fisk/core/tests/trading/test_active_management.py

554 lines
19 KiB
Python
Raw Normal View History

from decimal import Decimal as D
from unittest.mock import patch
2023-02-17 17:05:52 +00:00
2023-02-17 07:20:15 +00:00
from django.test import TestCase
2023-02-17 17:05:52 +00:00
from core.lib.schemas.oanda_s import parse_time
from core.models import (
Account,
ActiveManagementPolicy,
AssetGroup,
AssetRule,
Hook,
Signal,
User,
)
2023-02-17 07:20:15 +00:00
from core.tests.helpers import StrategyMixin, SymbolPriceMock
from core.trading.active_management import ActiveManagement
2023-02-17 17:05:52 +00:00
2023-02-17 07:20:15 +00:00
class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
def setUp(self):
self.user = User.objects.create_user(
username="testuser", email="test@example.com", password="test"
)
self.account = Account.objects.create(
user=self.user,
name="Test Account",
exchange="fake",
currency="USD",
initial_balance=100000,
2023-02-17 07:20:15 +00:00
)
self.account.supported_symbols = ["EUR_USD", "EUR_XXX", "USD_EUR", "XXX_EUR"]
self.account.save()
super().setUp()
self.active_management_policy = ActiveManagementPolicy.objects.create(
user=self.user,
name="Test Policy",
when_trading_time_violated="close",
when_trends_violated="close",
when_position_size_violated="close",
when_protection_violated="close",
when_asset_groups_violated="close",
when_max_open_trades_violated="close",
when_max_open_trades_per_symbol_violated="close",
when_max_loss_violated="close",
when_max_risk_violated="close",
when_crossfilter_violated="close",
)
self.strategy.active_management_policy = self.active_management_policy
self.strategy.save()
self.ams = ActiveManagement(self.strategy)
self.trades = [
{
"id": "20083",
"symbol": "EUR_USD",
"price": "1.06331",
2023-02-17 17:05:52 +00:00
"openTime": "2023-02-13T11:38:06.302917985Z", # Monday at 11:38
2023-02-17 07:20:15 +00:00
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
"currentUnits": "10",
"realizedPL": "0.0000",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008",
"marginUsed": "0.2966",
"takeProfitOrder": {"price": "1.07934"},
"stopLossOrder": {"price": "1.05276"},
"trailingStopLossOrder": None,
2023-02-17 07:20:15 +00:00
"trailingStopValue": None,
"side": "long",
},
{
"id": "20084",
2023-02-17 07:20:15 +00:00
"symbol": "EUR_USD",
"price": "1.06331",
"openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:39
2023-02-17 07:20:15 +00:00
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
"currentUnits": "10",
"realizedPL": "0.0000",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008",
"marginUsed": "0.2966",
"takeProfitOrder": {"price": "1.07934"},
"stopLossOrder": {"price": "1.05276"},
"trailingStopLossOrder": None,
2023-02-17 07:20:15 +00:00
"trailingStopValue": None,
"side": "long",
2023-02-17 17:05:52 +00:00
},
2023-02-17 07:20:15 +00:00
]
# Run parse_time on all items in trades
for trade in self.trades:
trade["openTime"] = parse_time(trade)
self.balance = 100000
self.balance_usd = 120000
2023-02-17 07:20:15 +00:00
self.ams.get_trades = self.fake_get_trades
self.ams.get_balance = self.fake_get_balance
# self.ams.trades = self.trades
2023-02-17 17:05:52 +00:00
def get_ids(self, trades):
return [trade["id"] for trade in trades]
def add_trade(self, id, symbol, side, open_time):
trade = {
"id": id,
"symbol": symbol,
"price": "1.06331",
"openTime": open_time,
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
"currentUnits": "10",
"realizedPL": "0.0000",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008",
"marginUsed": "0.2966",
"takeProfitOrder": {"price": "1.07934"},
"stopLossOrder": {"price": "1.05276"},
"trailingStopLossOrder": None,
"trailingStopValue": None,
"side": side,
}
trade["openTime"] = parse_time(trade)
self.trades.append(trade)
def amend_tp_sl_flip_side(self):
"""
Amend the take profit and stop loss orders to be the opposite side.
This lets the protection tests pass, so we can only test one violation
per test.
"""
for trade in self.trades:
if trade["side"] == "short":
trade["stopLossOrder"]["price"] = "1.07386"
trade["takeProfitOrder"]["price"] = "1.04728"
2023-02-17 07:20:15 +00:00
def fake_get_trades(self):
self.ams.trades = self.trades
return self.trades
def fake_get_balance(self, return_usd=None):
if return_usd:
return self.balance_usd
return self.balance
2023-02-17 07:20:15 +00:00
def fake_get_currencies(self, symbols):
pass
def test_get_trades(self):
trades = self.ams.get_trades()
self.assertEqual(trades, self.trades)
def test_get_balance(self):
balance = self.ams.get_balance()
self.assertEqual(balance, self.balance)
2023-02-17 07:20:15 +00:00
2023-02-17 17:05:52 +00:00
def check_violation(
self, violation, calls, expected_action, expected_trades, expected_args=None
):
2023-02-17 07:20:15 +00:00
"""
Check that the violation was called with the expected action and trades.
Matches the first argument of the call to the violation name.
:param: violation: type of the violation to check against
:param: calls: list of calls to the violation
:param: expected_action: expected action to be called, close, notify, etc.
:param: expected_trades: list of expected trades to be passed to the violation
2023-02-17 17:05:52 +00:00
:param: expected_args: optional, expected args to be passed to the violation
2023-02-17 07:20:15 +00:00
"""
self.assertEqual(len(calls), len(expected_trades))
2023-02-17 07:20:15 +00:00
calls = list(calls)
violation_calls = []
for call in calls:
if call[0][0] == violation:
violation_calls.append(call)
self.assertEqual(len(violation_calls), len(expected_trades))
if all(expected_trades):
# expected_trades = convert_trades(expected_trades)
expected_trades = self.get_ids(expected_trades)
2023-02-17 07:20:15 +00:00
for call in violation_calls:
# Ensure the correct action has been called, like close
self.assertEqual(call[0][1], expected_action)
# Ensure the correct trade has been passed to the violation
self.assertIn(call[0][2], expected_trades)
2023-02-17 17:05:52 +00:00
if expected_args:
self.assertEqual(call[0][3], expected_args)
2023-02-17 07:20:15 +00:00
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_run_checks(self, handle_violation):
self.ams.run_checks()
self.assertEqual(handle_violation.call_count, 0)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trading_time_violated(self, handle_violation):
2023-02-17 17:05:52 +00:00
self.trades[0]["openTime"] = "2023-02-17T11:38:06.302917Z" # Friday
2023-02-17 07:20:15 +00:00
self.ams.run_checks()
2023-02-17 17:05:52 +00:00
self.check_violation(
"trading_time", handle_violation.call_args_list, "close", [self.trades[0]]
)
2023-02-17 07:20:15 +00:00
def create_hook_signal(self):
hook = Hook.objects.create(
user=self.user,
name="Test Hook",
)
signal = Signal.objects.create(
user=self.user,
name="Test Signal",
hook=hook,
type="trend",
)
return signal
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated(self, handle_violation):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.amend_tp_sl_flip_side()
2023-02-17 07:20:15 +00:00
self.strategy.save()
2023-02-17 07:20:15 +00:00
self.ams.run_checks()
2023-02-17 17:05:52 +00:00
self.check_violation(
"trends", handle_violation.call_args_list, "close", self.trades
)
2023-02-17 07:20:15 +00:00
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated_none(self, handle_violation):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "buy"}
self.strategy.save()
self.ams.run_checks()
self.check_violation("trends", handle_violation.call_args_list, "close", [])
# Mock crossfilter here since we want to allow this conflict in order to test that
# trends only close trades that are in the wrong direction
@patch(
"core.trading.active_management.ActiveManagement.check_crossfilter",
return_value=None,
)
2023-02-17 07:20:15 +00:00
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated_partial(self, handle_violation, check_crossfilter):
2023-02-17 07:20:15 +00:00
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.strategy.save()
2023-02-17 17:05:52 +00:00
2023-02-17 07:20:15 +00:00
# Change the side of the first trade to match the trends
self.trades[0]["side"] = "short"
self.amend_tp_sl_flip_side()
2023-02-17 07:20:15 +00:00
self.ams.run_checks()
2023-02-17 17:05:52 +00:00
self.check_violation(
"trends", handle_violation.call_args_list, "close", [self.trades[1]]
)
2023-02-17 07:20:15 +00:00
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_position_size_violated(self, handle_violation):
self.trades[0]["currentUnits"] = "100000"
self.ams.run_checks()
2023-02-17 17:05:52 +00:00
self.check_violation(
"position_size",
handle_violation.call_args_list,
"close",
[self.trades[0]],
{"size": 500},
2023-02-17 17:05:52 +00:00
)
2023-02-17 07:20:15 +00:00
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_protection_violated_absent(self, handle_violation):
self.trades[0]["takeProfitOrder"] = None
self.trades[0]["stopLossOrder"] = None
self.ams.run_checks()
2023-02-17 07:20:15 +00:00
expected_args = {
"take_profit_price": D("1.07934"),
"stop_loss_price": D("1.05276"),
}
self.check_violation(
"protection",
handle_violation.call_args_list,
"close",
[self.trades[0]],
expected_args,
)
2023-02-17 07:20:15 +00:00
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_protection_violated_absent_not_required(self, handle_violation):
self.strategy.order_settings.take_profit_percent = 0
self.strategy.order_settings.stop_loss_percent = 0
self.strategy.order_settings.save()
self.trades[0]["takeProfitOrder"] = None
self.trades[0]["stopLossOrder"] = None
self.ams.run_checks()
self.assertEqual(handle_violation.call_count, 0)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_asset_groups_violated(self, handle_violation):
asset_group = AssetGroup.objects.create(
user=self.user,
name="Test Asset Group",
)
AssetRule.objects.create(
user=self.user,
asset="USD",
group=asset_group,
status=2, # Bullish
)
self.strategy.asset_group = asset_group
self.strategy.save()
self.ams.run_checks()
self.check_violation(
"asset_group",
handle_violation.call_args_list,
"close",
self.trades, # All trades should be closed, since all are USD quote
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_asset_groups_violated_invert(self, handle_violation):
self.trades[0]["side"] = "short"
self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
asset_group = AssetGroup.objects.create(
user=self.user,
name="Test Asset Group",
)
AssetRule.objects.create(
user=self.user,
asset="USD",
group=asset_group,
status=3, # Bullish
)
self.strategy.asset_group = asset_group
self.strategy.save()
self.ams.run_checks()
self.check_violation(
"asset_group",
handle_violation.call_args_list,
"close",
self.trades, # All trades should be closed, since all are USD quote
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_side(self, handle_violation):
self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.check_violation(
"crossfilter",
handle_violation.call_args_list,
"close",
[self.trades[1]], # Only close newer trade
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_side_multiple(self, handle_violation):
self.add_trade(
"20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z"
) # 2:
self.add_trade("20086", "EUR_USD", "short", "2023-02-13T13:39:07.302917985Z")
self.add_trade("20087", "EUR_USD", "short", "2023-02-13T14:39:06.302917985Z")
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.check_violation(
"crossfilter",
handle_violation.call_args_list,
"close",
self.trades[2:], # Only close newer trades
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_symbol(self, handle_violation):
# Change symbol to conflict with long on EUR_USD
self.trades[1]["symbol"] = "USD_EUR"
self.ams.run_checks()
self.check_violation(
"crossfilter",
handle_violation.call_args_list,
"close",
[self.trades[1]], # Only close newer trade
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_symbol_multiple(self, handle_violation):
self.add_trade(
"20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z"
) # 2:
self.add_trade("20086", "USD_EUR", "long", "2023-02-13T13:39:06.302917985Z")
self.add_trade("20087", "USD_EUR", "long", "2023-02-13T14:39:06.302917985Z")
self.ams.run_checks()
self.check_violation(
"crossfilter",
handle_violation.call_args_list,
"close",
self.trades[2:], # Only close newer trades
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_open_trades_violated(self, handle_violation):
for x in range(9):
self.add_trade(
str(x),
f"EUR_USD{x}", # Vary symbol to prevent max open trades per symbol
"long",
f"2023-02-13T12:39:1{x}.302917985Z",
)
self.ams.run_checks()
self.check_violation(
"max_open_trades",
handle_violation.call_args_list,
"close",
self.trades[10:], # Only close newer trades
)
2023-02-17 07:20:15 +00:00
2023-02-17 22:23:12 +00:00
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_open_trades_per_symbol_violated(self, handle_violation):
for x in range(4):
2023-02-17 22:23:12 +00:00
self.add_trade(
str(x),
"EUR_USD",
"long",
f"2023-02-13T12:39:1{x}.302917985Z",
)
self.ams.run_checks()
self.check_violation(
"max_open_trades_per_symbol",
handle_violation.call_args_list,
"close",
self.trades[5:], # Only close newer trades
2023-02-17 22:23:12 +00:00
)
2023-02-17 07:20:15 +00:00
# Mock position size as we have no way of checking the balance at the start of the
# trade.
# TODO: Fix this when we have a way of checking the balance at the start of the
# trade.
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_loss_violated(self, handle_violation, check_position_size):
self.balance = D("1")
self.balance_usd = D("0.69")
self.trades = []
self.ams.run_checks()
2023-02-17 07:20:15 +00:00
self.check_violation(
"max_loss",
handle_violation.call_args_list,
"close",
[None],
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_protection",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_risk_violated(
self, handle_violation, check_protection, check_position_size
):
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.ams.run_checks()
self.check_violation(
"max_risk",
handle_violation.call_args_list,
"close",
[self.trades[2]],
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_protection",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_risk_violated_multiple(
self, handle_violation, check_protection, check_position_size
):
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.add_trade(
"20086",
"EUR_USD",
"long",
"2023-02-13T15:45:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.trades[3]["stopLossOrder"]["price"] = "0.001"
self.trades[3]["currentUnits"] = "13000"
self.ams.run_checks()
self.check_violation(
"max_risk",
handle_violation.call_args_list,
"close",
[self.trades[2], self.trades[3]],
)
def test_max_risk_not_violated_after_adjusting_protection(self):
"""
Ensure the max risk check is not violated after adjusting the protection.
"""
pass
def test_max_risk_not_violated_after_adjusting_position_size(self):
"""
Ensure the max risk check is not violated after adjusting the position size.
"""
pass