from decimal import Decimal as D from unittest.mock import patch from django.test import TestCase from core.lib.schemas.oanda_s import parse_time from core.models import ( Account, ActiveManagementPolicy, AssetGroup, AssetRule, Hook, Signal, User, ) from core.tests.helpers import StrategyMixin, SymbolPriceMock from core.trading.active_management import ActiveManagement class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): def setUp(self): self.user = User.objects.create_user( username="testuser", email="test@example.com", password="test" ) self.account = Account.objects.create( user=self.user, name="Test Account", exchange="fake", currency="USD", initial_balance=100000, ) self.account.supported_symbols = ["EUR_USD", "EUR_XXX", "USD_EUR", "XXX_EUR"] self.account.save() super().setUp() self.active_management_policy = ActiveManagementPolicy.objects.create( user=self.user, name="Test Policy", when_trading_time_violated="close", when_trends_violated="close", when_position_size_violated="close", when_protection_violated="close", when_asset_groups_violated="close", when_max_open_trades_violated="close", when_max_open_trades_per_symbol_violated="close", when_max_loss_violated="close", when_max_risk_violated="close", when_crossfilter_violated="close", ) self.strategy.active_management_policy = self.active_management_policy self.strategy.save() self.ams = ActiveManagement(self.strategy) self.trades = [ { "id": "20083", "symbol": "EUR_USD", "price": "1.06331", "openTime": "2023-02-13T11:38:06.302917985Z", # Monday at 11:38 "initialUnits": "10", "initialMarginRequired": "0.2966", "state": "OPEN", "currentUnits": "10", "realizedPL": "0.0000", "financing": "0.0000", "dividendAdjustment": "0.0000", "unrealizedPL": "-0.0008", "marginUsed": "0.2966", "takeProfitOrder": {"price": "1.07934"}, "stopLossOrder": {"price": "1.05276"}, "trailingStopLossOrder": None, "trailingStopValue": None, "side": "long", }, { "id": "20084", "symbol": "EUR_USD", "price": "1.06331", "openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:39 "initialUnits": "10", "initialMarginRequired": "0.2966", "state": "OPEN", "currentUnits": "10", "realizedPL": "0.0000", "financing": "0.0000", "dividendAdjustment": "0.0000", "unrealizedPL": "-0.0008", "marginUsed": "0.2966", "takeProfitOrder": {"price": "1.07934"}, "stopLossOrder": {"price": "1.05276"}, "trailingStopLossOrder": None, "trailingStopValue": None, "side": "long", }, ] # Run parse_time on all items in trades for trade in self.trades: trade["openTime"] = parse_time(trade) self.balance = 100000 self.balance_usd = 120000 self.ams.get_trades = self.fake_get_trades self.ams.get_balance = self.fake_get_balance # self.ams.trades = self.trades def get_ids(self, trades): return [trade["id"] for trade in trades] def add_trade(self, id, symbol, side, open_time): trade = { "id": id, "symbol": symbol, "price": "1.06331", "openTime": open_time, "initialUnits": "10", "initialMarginRequired": "0.2966", "state": "OPEN", "currentUnits": "10", "realizedPL": "0.0000", "financing": "0.0000", "dividendAdjustment": "0.0000", "unrealizedPL": "-0.0008", "marginUsed": "0.2966", "takeProfitOrder": {"price": "1.07934"}, "stopLossOrder": {"price": "1.05276"}, "trailingStopLossOrder": None, "trailingStopValue": None, "side": side, } trade["openTime"] = parse_time(trade) self.trades.append(trade) def amend_tp_sl_flip_side(self): """ Amend the take profit and stop loss orders to be the opposite side. This lets the protection tests pass, so we can only test one violation per test. """ for trade in self.trades: if trade["side"] == "short": trade["stopLossOrder"]["price"] = "1.07386" trade["takeProfitOrder"]["price"] = "1.04728" def fake_get_trades(self): self.ams.trades = self.trades return self.trades def fake_get_balance(self, return_usd=None): if return_usd: return self.balance_usd return self.balance def fake_get_currencies(self, symbols): pass def test_get_trades(self): trades = self.ams.get_trades() self.assertEqual(trades, self.trades) def test_get_balance(self): balance = self.ams.get_balance() self.assertEqual(balance, self.balance) def check_violation( self, violation, calls, expected_action, expected_trades, expected_args=None ): """ Check that the violation was called with the expected action and trades. Matches the first argument of the call to the violation name. :param: violation: type of the violation to check against :param: calls: list of calls to the violation :param: expected_action: expected action to be called, close, notify, etc. :param: expected_trades: list of expected trades to be passed to the violation :param: expected_args: optional, expected args to be passed to the violation """ self.assertEqual(len(calls), len(expected_trades)) calls = list(calls) violation_calls = [] for call in calls: if call[0][0] == violation: violation_calls.append(call) self.assertEqual(len(violation_calls), len(expected_trades)) if all(expected_trades): # expected_trades = convert_trades(expected_trades) expected_trades = self.get_ids(expected_trades) for call in violation_calls: # Ensure the correct action has been called, like close self.assertEqual(call[0][1], expected_action) # Ensure the correct trade has been passed to the violation self.assertIn(call[0][2], expected_trades) if expected_args: self.assertEqual(call[0][3], expected_args) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_run_checks(self, handle_violation): self.ams.run_checks() self.assertEqual(handle_violation.call_count, 0) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_trading_time_violated(self, handle_violation): self.trades[0]["openTime"] = "2023-02-17T11:38:06.302917Z" # Friday self.ams.run_checks() self.check_violation( "trading_time", handle_violation.call_args_list, "close", [self.trades[0]] ) def create_hook_signal(self): hook = Hook.objects.create( user=self.user, name="Test Hook", ) signal = Signal.objects.create( user=self.user, name="Test Signal", hook=hook, type="trend", ) return signal @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_trends_violated(self, handle_violation): signal = self.create_hook_signal() self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "sell"} self.amend_tp_sl_flip_side() self.strategy.save() self.ams.run_checks() self.check_violation( "trends", handle_violation.call_args_list, "close", self.trades ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_trends_violated_none(self, handle_violation): signal = self.create_hook_signal() self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "buy"} self.strategy.save() self.ams.run_checks() self.check_violation("trends", handle_violation.call_args_list, "close", []) # Mock crossfilter here since we want to allow this conflict in order to test that # trends only close trades that are in the wrong direction @patch( "core.trading.active_management.ActiveManagement.check_crossfilter", return_value=None, ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_trends_violated_partial(self, handle_violation, check_crossfilter): signal = self.create_hook_signal() self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "sell"} self.strategy.save() # Change the side of the first trade to match the trends self.trades[0]["side"] = "short" self.amend_tp_sl_flip_side() self.ams.run_checks() self.check_violation( "trends", handle_violation.call_args_list, "close", [self.trades[1]] ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_position_size_violated(self, handle_violation): self.trades[0]["currentUnits"] = "100000" self.ams.run_checks() self.check_violation( "position_size", handle_violation.call_args_list, "close", [self.trades[0]], {"size": 500}, ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_protection_violated_absent(self, handle_violation): self.trades[0]["takeProfitOrder"] = None self.trades[0]["stopLossOrder"] = None self.ams.run_checks() expected_args = { "take_profit_price": D("1.07934"), "stop_loss_price": D("1.05276"), } self.check_violation( "protection", handle_violation.call_args_list, "close", [self.trades[0]], expected_args, ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_protection_violated_absent_not_required(self, handle_violation): self.strategy.order_settings.take_profit_percent = 0 self.strategy.order_settings.stop_loss_percent = 0 self.strategy.order_settings.save() self.trades[0]["takeProfitOrder"] = None self.trades[0]["stopLossOrder"] = None self.ams.run_checks() self.assertEqual(handle_violation.call_count, 0) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_asset_groups_violated(self, handle_violation): asset_group = AssetGroup.objects.create( user=self.user, name="Test Asset Group", ) AssetRule.objects.create( user=self.user, asset="USD", group=asset_group, status=2, # Bullish ) self.strategy.asset_group = asset_group self.strategy.save() self.ams.run_checks() self.check_violation( "asset_group", handle_violation.call_args_list, "close", self.trades, # All trades should be closed, since all are USD quote ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_asset_groups_violated_invert(self, handle_violation): self.trades[0]["side"] = "short" self.trades[1]["side"] = "short" self.amend_tp_sl_flip_side() asset_group = AssetGroup.objects.create( user=self.user, name="Test Asset Group", ) AssetRule.objects.create( user=self.user, asset="USD", group=asset_group, status=3, # Bullish ) self.strategy.asset_group = asset_group self.strategy.save() self.ams.run_checks() self.check_violation( "asset_group", handle_violation.call_args_list, "close", self.trades, # All trades should be closed, since all are USD quote ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_crossfilter_violated_side(self, handle_violation): self.trades[1]["side"] = "short" self.amend_tp_sl_flip_side() self.ams.run_checks() self.check_violation( "crossfilter", handle_violation.call_args_list, "close", [self.trades[1]], # Only close newer trade ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_crossfilter_violated_side_multiple(self, handle_violation): self.add_trade( "20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z" ) # 2: self.add_trade("20086", "EUR_USD", "short", "2023-02-13T13:39:07.302917985Z") self.add_trade("20087", "EUR_USD", "short", "2023-02-13T14:39:06.302917985Z") self.amend_tp_sl_flip_side() self.ams.run_checks() self.check_violation( "crossfilter", handle_violation.call_args_list, "close", self.trades[2:], # Only close newer trades ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_crossfilter_violated_symbol(self, handle_violation): # Change symbol to conflict with long on EUR_USD self.trades[1]["symbol"] = "USD_EUR" self.ams.run_checks() self.check_violation( "crossfilter", handle_violation.call_args_list, "close", [self.trades[1]], # Only close newer trade ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_crossfilter_violated_symbol_multiple(self, handle_violation): self.add_trade( "20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z" ) # 2: self.add_trade("20086", "USD_EUR", "long", "2023-02-13T13:39:06.302917985Z") self.add_trade("20087", "USD_EUR", "long", "2023-02-13T14:39:06.302917985Z") self.ams.run_checks() self.check_violation( "crossfilter", handle_violation.call_args_list, "close", self.trades[2:], # Only close newer trades ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_max_open_trades_violated(self, handle_violation): for x in range(9): self.add_trade( str(x), f"EUR_USD{x}", # Vary symbol to prevent max open trades per symbol "long", f"2023-02-13T12:39:1{x}.302917985Z", ) self.ams.run_checks() self.check_violation( "max_open_trades", handle_violation.call_args_list, "close", self.trades[10:], # Only close newer trades ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_max_open_trades_per_symbol_violated(self, handle_violation): for x in range(4): self.add_trade( str(x), "EUR_USD", "long", f"2023-02-13T12:39:1{x}.302917985Z", ) self.ams.run_checks() self.check_violation( "max_open_trades_per_symbol", handle_violation.call_args_list, "close", self.trades[5:], # Only close newer trades ) # Mock position size as we have no way of checking the balance at the start of the # trade. # TODO: Fix this when we have a way of checking the balance at the start of the # trade. @patch( "core.trading.active_management.ActiveManagement.check_position_size", return_value=None, ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_max_loss_violated(self, handle_violation, check_position_size): self.balance = D("1") self.balance_usd = D("0.69") self.trades = [] self.ams.run_checks() self.check_violation( "max_loss", handle_violation.call_args_list, "close", [None], ) @patch( "core.trading.active_management.ActiveManagement.check_position_size", return_value=None, ) @patch( "core.trading.active_management.ActiveManagement.check_protection", return_value=None, ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_max_risk_violated( self, handle_violation, check_protection, check_position_size ): self.add_trade( "20085", "EUR_USD", "long", "2023-02-13T15:39:19.302917985Z", ) self.trades[2]["stopLossOrder"]["price"] = "0.001" self.trades[2]["currentUnits"] = "13000" self.ams.run_checks() self.check_violation( "max_risk", handle_violation.call_args_list, "close", [self.trades[2]], ) @patch( "core.trading.active_management.ActiveManagement.check_position_size", return_value=None, ) @patch( "core.trading.active_management.ActiveManagement.check_protection", return_value=None, ) @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_max_risk_violated_multiple( self, handle_violation, check_protection, check_position_size ): self.add_trade( "20085", "EUR_USD", "long", "2023-02-13T15:39:19.302917985Z", ) self.add_trade( "20086", "EUR_USD", "long", "2023-02-13T15:45:19.302917985Z", ) self.trades[2]["stopLossOrder"]["price"] = "0.001" self.trades[2]["currentUnits"] = "13000" self.trades[3]["stopLossOrder"]["price"] = "0.001" self.trades[3]["currentUnits"] = "13000" self.ams.run_checks() self.check_violation( "max_risk", handle_violation.call_args_list, "close", [self.trades[2], self.trades[3]], ) def test_max_risk_not_violated_after_adjusting_protection(self): """ Ensure the max risk check is not violated after adjusting the protection. """ pass def test_max_risk_not_violated_after_adjusting_position_size(self): """ Ensure the max risk check is not violated after adjusting the position size. """ pass