from decimal import Decimal as D from unittest.mock import patch from django.test import TestCase from core.lib.schemas.oanda_s import parse_time from core.models import ( Account, ActiveManagementPolicy, AssetGroup, AssetRule, Hook, Signal, User, ) from core.tests.helpers import StrategyMixin, SymbolPriceMock from core.trading.active_management import ActiveManagement class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): def setUp(self): self.user = User.objects.create_user( username="testuser", email="test@example.com", password="test" ) self.account = Account.objects.create( user=self.user, name="Test Account", exchange="fake", currency="USD", initial_balance=100000, ) self.account.supported_symbols = ["EUR_USD", "EUR_XXX", "USD_EUR", "XXX_EUR"] self.account.save() super().setUp() self.active_management_policy = ActiveManagementPolicy.objects.create( user=self.user, name="Test Policy", when_trading_time_violated="close", when_trends_violated="close", when_position_size_violated="close", when_protection_violated="close", when_asset_groups_violated="close", when_max_open_trades_violated="close", when_max_open_trades_per_symbol_violated="close", when_max_loss_violated="close", when_max_risk_violated="close", when_crossfilter_violated="close", ) self.strategy.active_management_policy = self.active_management_policy self.strategy.save() self.ams = ActiveManagement(self.strategy) self.trades = [ { "id": "20083", "symbol": "EUR_USD", "price": "1.06331", "openTime": "2023-02-13T11:38:06.302917985Z", # Monday at 11:38 "initialUnits": "10", "initialMarginRequired": "0.2966", "state": "OPEN", "currentUnits": "10", "realizedPL": "0.0000", "financing": "0.0000", "dividendAdjustment": "0.0000", "unrealizedPL": "-0.0008", "marginUsed": "0.2966", "takeProfitOrder": {"price": "1.07934"}, "stopLossOrder": {"price": "1.05276"}, "trailingStopLossOrder": None, "trailingStopValue": None, "side": "long", }, { "id": "20084", "symbol": "EUR_USD", "price": "1.06331", "openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:39 "initialUnits": "10", "initialMarginRequired": "0.2966", "state": "OPEN", "currentUnits": "10", "realizedPL": "0.0000", "financing": "0.0000", "dividendAdjustment": "0.0000", "unrealizedPL": "-0.0008", "marginUsed": "0.2966", "takeProfitOrder": {"price": "1.07934"}, "stopLossOrder": {"price": "1.05276"}, "trailingStopLossOrder": None, "trailingStopValue": None, "side": "long", }, ] # Run parse_time on all items in trades for trade in self.trades: trade["openTime"] = parse_time(trade) self.balance = 100000 self.balance_usd = 120000 self.ams.get_trades = self.fake_get_trades self.ams.get_balance = self.fake_get_balance # self.ams.trades = self.trades def get_ids(self, trades): return [trade["id"] for trade in trades] def add_trade(self, id, symbol, side, open_time): trade = { "id": id, "symbol": symbol, "price": "1.06331", "openTime": open_time, "initialUnits": "10", "initialMarginRequired": "0.2966", "state": "OPEN", "currentUnits": "10", "realizedPL": "0.0000", "financing": "0.0000", "dividendAdjustment": "0.0000", "unrealizedPL": "-0.0008", "marginUsed": "0.2966", "takeProfitOrder": {"price": "1.07934"}, "stopLossOrder": {"price": "1.05276"}, "trailingStopLossOrder": None, "trailingStopValue": None, "side": side, } trade["openTime"] = parse_time(trade) self.trades.append(trade) def amend_tp_sl_flip_side(self): """ Amend the take profit and stop loss orders to be the opposite side. This lets the protection tests pass, so we can only test one violation per test. """ for trade in self.trades: if trade["side"] == "short": trade["stopLossOrder"]["price"] = "1.07386" trade["takeProfitOrder"]["price"] = "1.04728" def fake_get_trades(self): self.ams.trades = self.trades return self.trades def fake_get_balance(self, return_usd=None): if return_usd: return self.balance_usd return self.balance def fake_get_currencies(self, symbols): pass def test_get_trades(self): trades = self.ams.get_trades() self.assertEqual(trades, self.trades) def test_get_balance(self): balance = self.ams.get_balance() self.assertEqual(balance, self.balance) def check_violation( self, violation, calls, expected_action, expected_trades, expected_args=None ): """ Check that the violation was called with the expected action and trades. Matches the first argument of the call to the violation name. :param: violation: type of the violation to check against :param: calls: list of calls to the violation :param: expected_action: expected action to be called, close, notify, etc. :param: expected_trades: list of expected trades to be passed to the violation :param: expected_args: optional, expected args to be passed to the violation """ self.assertEqual(len(calls), len(expected_trades)) calls = list(calls) violation_calls = [] for call in calls: if call[0][0] == violation: violation_calls.append(call) self.assertEqual(len(violation_calls), len(expected_trades)) if all(expected_trades): # expected_trades = convert_trades(expected_trades) expected_trades = self.get_ids(expected_trades) for call in violation_calls: # Ensure the correct action has been called, like close self.assertEqual(call[0][1], expected_action) # Ensure the correct trade has been passed to the violation self.assertIn(call[0][2], expected_trades) if expected_args: _, kwargs = call self.assertEqual(kwargs, expected_args) # self.assertEqual(call[0][3], expected_args) def test_run_checks(self): self.ams.run_checks() self.assertEqual(len(self.ams.actions), 0) def test_trading_time_violated(self): self.trades[0]["openTime"] = "2023-02-17T11:38:06.302917Z" # Friday self.ams.run_checks() self.assertEqual( self.ams.actions, {"close": [{"id": "20083", "check": "trading_time", "extra": {}}]}, ) def create_hook_signal(self): hook = Hook.objects.create( user=self.user, name="Test Hook", ) signal = Signal.objects.create( user=self.user, name="Test Signal", hook=hook, type="trend", ) return signal def test_trends_violated(self): signal = self.create_hook_signal() self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "sell"} self.amend_tp_sl_flip_side() self.strategy.save() self.ams.run_checks() self.assertEqual( self.ams.actions, { "close": [ {"id": "20084", "check": "trends", "extra": {}}, {"id": "20083", "check": "trends", "extra": {}}, ] }, ) def test_trends_violated_none(self): signal = self.create_hook_signal() self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "buy"} self.strategy.save() self.ams.run_checks() self.assertEqual(self.ams.actions, {}) # Mock crossfilter here since we want to allow this conflict in order to test that # trends only close trades that are in the wrong direction @patch( "core.trading.active_management.ActiveManagement.check_crossfilter", return_value=None, ) def test_trends_violated_partial(self, check_crossfilter): signal = self.create_hook_signal() self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "sell"} self.strategy.save() # Change the side of the first trade to match the trends self.trades[0]["side"] = "short" self.amend_tp_sl_flip_side() self.ams.run_checks() self.assertEqual( self.ams.actions, {"close": [{"id": "20084", "check": "trends", "extra": {}}]}, ) def test_position_size_violated(self): self.trades[0]["currentUnits"] = "100000" self.ams.run_checks() self.assertEqual( self.ams.actions, { "close": [ { "id": "20083", "check": "position_size", "extra": {"size": D("500.000")}, } ] }, ) def test_position_size_violated_increase_only(self): pass def test_position_size_violated_decrease_only(self): pass def test_position_size_violated_increase_decrease(self): pass def test_protection_violated(self): self.trades[0]["takeProfitOrder"] = {"price": "0.0001"} self.trades[0]["stopLossOrder"] = {"price": "0.0001"} self.ams.run_checks() expected_args = { "take_profit_price": D("1.07934"), "stop_loss_price": D("1.05276"), } self.assertEqual( self.ams.actions, { "close": [ { "id": "20083", "check": "protection", "extra": { **expected_args, }, } ] }, ) def test_protection_violated_absent(self): self.trades[0]["takeProfitOrder"] = None self.trades[0]["stopLossOrder"] = None self.ams.run_checks() expected_args = { "take_profit_price": D("1.07934"), "stop_loss_price": D("1.05276"), } self.assertEqual( self.ams.actions, { "close": [ { "id": "20083", "check": "protection", "extra": { **expected_args, }, } ] }, ) def test_protection_violated_absent_not_required(self): self.strategy.order_settings.take_profit_percent = 0 self.strategy.order_settings.stop_loss_percent = 0 self.strategy.order_settings.save() self.trades[0]["takeProfitOrder"] = None self.trades[0]["stopLossOrder"] = None self.ams.run_checks() self.assertEqual(self.ams.actions, {}) def test_asset_groups_violated(self): asset_group = AssetGroup.objects.create( user=self.user, name="Test Asset Group", ) AssetRule.objects.create( user=self.user, asset="USD", group=asset_group, status=2, # Bullish ) self.strategy.asset_group = asset_group self.strategy.save() self.ams.run_checks() self.assertEqual( self.ams.actions, { "close": [ {"id": "20084", "check": "asset_group", "extra": {}}, {"id": "20083", "check": "asset_group", "extra": {}}, ] }, ) def test_asset_groups_violated_invert(self): self.trades[0]["side"] = "short" self.trades[1]["side"] = "short" self.amend_tp_sl_flip_side() asset_group = AssetGroup.objects.create( user=self.user, name="Test Asset Group", ) AssetRule.objects.create( user=self.user, asset="USD", group=asset_group, status=3, # Bullish ) self.strategy.asset_group = asset_group self.strategy.save() self.ams.run_checks() self.assertEqual( self.ams.actions, { "close": [ {"id": "20084", "check": "asset_group", "extra": {}}, {"id": "20083", "check": "asset_group", "extra": {}}, ] }, ) def test_crossfilter_violated_side(self): self.trades[1]["side"] = "short" self.amend_tp_sl_flip_side() self.ams.run_checks() self.assertEqual( self.ams.actions, {"close": [{"id": "20084", "check": "crossfilter", "extra": {}}]}, ) def test_crossfilter_violated_side_multiple(self): self.add_trade( "20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z" ) # 2: self.add_trade("20086", "EUR_USD", "short", "2023-02-13T13:39:07.302917985Z") self.add_trade("20087", "EUR_USD", "short", "2023-02-13T14:39:06.302917985Z") self.amend_tp_sl_flip_side() self.ams.run_checks() self.assertEqual( self.ams.actions, { "close": [ {"id": "20087", "check": "crossfilter", "extra": {}}, {"id": "20086", "check": "crossfilter", "extra": {}}, {"id": "20085", "check": "crossfilter", "extra": {}}, ] }, ) def test_crossfilter_violated_symbol(self): # Change symbol to conflict with long on EUR_USD self.trades[1]["symbol"] = "USD_EUR" self.ams.run_checks() self.assertEqual( self.ams.actions, {"close": [{"id": "20084", "check": "crossfilter", "extra": {}}]}, ) def test_crossfilter_violated_symbol_multiple(self): self.add_trade( "20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z" ) # 2: self.add_trade("20086", "USD_EUR", "long", "2023-02-13T13:39:06.302917985Z") self.add_trade("20087", "USD_EUR", "long", "2023-02-13T14:39:06.302917985Z") self.ams.run_checks() self.assertEqual( self.ams.actions, { "close": [ {"id": "20087", "check": "crossfilter", "extra": {}}, {"id": "20086", "check": "crossfilter", "extra": {}}, {"id": "20085", "check": "crossfilter", "extra": {}}, ] }, ) def test_max_open_trades_violated(self): for x in range(9): self.add_trade( str(x), f"EUR_USD{x}", # Vary symbol to prevent max open trades per symbol "long", f"2023-02-13T12:39:1{x}.302917985Z", ) self.ams.run_checks() self.assertEqual( self.ams.actions, {"close": [{"id": "8", "check": "max_open_trades", "extra": {}}]}, ) def test_max_open_trades_per_symbol_violated(self): for x in range(4): self.add_trade( str(x), "EUR_USD", "long", f"2023-02-13T12:39:1{x}.302917985Z", ) self.ams.run_checks() self.assertEqual( self.ams.actions, { "close": [ {"id": "3", "check": "max_open_trades_per_symbol", "extra": {}} ] }, ) # Mock position size as we have no way of checking the balance at the start of the # trade. # TODO: Fix this when we have a way of checking the balance at the start of the # trade. @patch( "core.trading.active_management.ActiveManagement.check_position_size", return_value=None, ) def test_max_loss_violated(self, check_position_size): self.balance = D("1") self.balance_usd = D("0.69") self.trades = [] self.ams.run_checks() self.assertEqual( self.ams.actions, {"close": [{"id": None, "check": "max_loss", "extra": {}}]}, ) @patch( "core.trading.active_management.ActiveManagement.check_position_size", return_value=None, ) @patch( "core.trading.active_management.ActiveManagement.check_protection", return_value=None, ) def test_max_risk_violated(self, check_protection, check_position_size): self.add_trade( "20085", "EUR_USD", "long", "2023-02-13T15:39:19.302917985Z", ) self.trades[2]["stopLossOrder"]["price"] = "0.001" self.trades[2]["currentUnits"] = "13000" self.ams.run_checks() self.assertEqual( self.ams.actions, {"close": [{"id": "20085", "check": "max_risk", "extra": {}}]}, ) @patch( "core.trading.active_management.ActiveManagement.check_position_size", return_value=None, ) @patch( "core.trading.active_management.ActiveManagement.check_protection", return_value=None, ) def test_max_risk_violated_multiple(self, check_protection, check_position_size): self.add_trade( "20085", "EUR_USD", "long", "2023-02-13T15:39:19.302917985Z", ) self.add_trade( "20086", "EUR_USD", "long", "2023-02-13T15:45:19.302917985Z", ) self.trades[2]["stopLossOrder"]["price"] = "0.001" self.trades[2]["currentUnits"] = "13000" self.trades[3]["stopLossOrder"]["price"] = "0.001" self.trades[3]["currentUnits"] = "13000" self.ams.run_checks() self.assertEqual( self.ams.actions, { "close": [ {"id": "20086", "check": "max_risk", "extra": {}}, {"id": "20085", "check": "max_risk", "extra": {}}, ] }, ) @patch( "core.trading.active_management.ActiveManagement.add_action", ) def test_handle_violation(self, add_action): self.ams.handle_violation("max_loss", "close", "trade_id") self.ams.handle_violation("position_size", "adjust", "trade_id2", size=1000) self.assertEqual(add_action.call_count, 2) add_action.assert_any_call("close", "max_loss", "trade_id") add_action.assert_any_call("adjust", "position_size", "trade_id2", size=1000) def test_add_action(self): protection_args = { "take_profit_price": D("1.07934"), "stop_loss_price": D("1.05276"), } self.ams.add_action("close", "trading_time", "fake_trade_id") self.ams.add_action("close", "protection", "fake_trade_id2", **protection_args) self.assertEqual( self.ams.actions, { "close": [ {"id": "fake_trade_id", "check": "trading_time", "extra": {}}, { "id": "fake_trade_id2", "check": "protection", "extra": protection_args, }, ], }, ) def test_reduce_actions(self): """ Test that closing actions precede adjusting actions. """ self.ams.add_action("close", "trading_time", "fake_trade_id") self.ams.add_action("adjust", "position_size", "fake_trade_id", size=1000) self.assertEqual(len(self.ams.actions["close"]), 1) self.assertEqual(len(self.ams.actions["adjust"]), 1) self.ams.reduce_actions() self.assertEqual(len(self.ams.actions["close"]), 1) self.assertEqual(len(self.ams.actions["adjust"]), 0) @patch("core.trading.active_management.ActiveManagement.bulk_close_trades") @patch("core.trading.active_management.ActiveManagement.bulk_notify") def test_run_actions(self, bulk_notify, bulk_close_trades): protection_args = { "take_profit_price": D("1.07934"), "stop_loss_price": D("1.05276"), } self.ams.add_action("close", "trading_time", "fake_trade_id") self.ams.add_action("close", "protection", "fake_trade_id2", **protection_args) self.ams.run_actions() expected_action_cast = [ {"id": "fake_trade_id", "check": "trading_time", "extra": {}}, { "id": "fake_trade_id2", "check": "protection", "extra": { **protection_args, }, }, ] bulk_notify.assert_called_once_with( "close", expected_action_cast, ) bulk_close_trades.assert_called_once_with( ["fake_trade_id", "fake_trade_id2"], ) @patch("core.trading.active_management.ActiveManagement.bulk_close_trades") @patch("core.trading.active_management.ActiveManagement.bulk_notify") def test_run_actions_notify_only(self, bulk_notify, bulk_close_trades): protection_args = { "take_profit_price": D("1.07934"), "stop_loss_price": D("1.05276"), } self.ams.add_action("notify", "trading_time", "fake_trade_id") self.ams.add_action("notify", "protection", "fake_trade_id2", **protection_args) self.ams.run_actions() expected_action_cast = [ {"id": "fake_trade_id", "check": "trading_time", "extra": {}}, { "id": "fake_trade_id2", "check": "protection", "extra": { **protection_args, }, }, ] bulk_notify.assert_called_once_with( "notify", expected_action_cast, ) bulk_close_trades.assert_not_called() @patch("core.trading.active_management.ActiveManagement.bulk_close_trades") @patch("core.trading.active_management.ActiveManagement.bulk_adjust") @patch("core.trading.active_management.ActiveManagement.bulk_notify") def test_run_actions_notify_adjust_only( self, bulk_notify, bulk_adjust, bulk_close_trades ): protection_args = { "take_profit_price": D("1.07934"), "stop_loss_price": D("1.05276"), } self.ams.add_action("adjust", "position_size", "fake_trade_id", size=1000) self.ams.add_action("adjust", "protection", "fake_trade_id2", **protection_args) self.ams.run_actions() expected_action_cast = [ {"id": "fake_trade_id", "check": "position_size", "extra": {"size": 1000}}, { "id": "fake_trade_id2", "check": "protection", "extra": { **protection_args, }, }, ] bulk_notify.assert_called_once_with( "adjust", expected_action_cast, ) bulk_adjust.assert_called_once_with(expected_action_cast) bulk_close_trades.assert_not_called() @patch("core.trading.active_management.ActiveManagement.adjust_position_size") @patch("core.trading.active_management.ActiveManagement.adjust_protection") def test_bulk_adjust(self, adjust_protection, adjust_position_size): expected_protection = {"take_profit_price": 1.10, "stop_loss_price": 1.05} cast_list = [ {"id": "id1", "check": "position_size", "extra": {"size": 1000}}, {"id": "id2", "check": "protection", "extra": expected_protection}, ] self.ams.bulk_adjust(cast_list) adjust_position_size.assert_called_once_with("id1", 1000) adjust_protection.assert_called_once_with("id2", expected_protection) @patch("core.trading.active_management.ActiveManagement.close_trade") def test_bulk_close_trades(self, close_trade): self.ams.bulk_close_trades(["id1", "id2"]) self.assertEqual(close_trade.call_count, 2) close_trade.assert_any_call("id1") close_trade.assert_any_call("id2") @patch("core.trading.active_management.sendmsg") def test_bulk_notify_plain(self, sendmsg): self.ams.bulk_notify("close", [{"id": "id1", "check": "check1", "extra": {}}]) sendmsg.assert_called_once_with( self.user, "ACTION: 'close' on trade ID 'id1'\nVIOLATION: 'check1'\n=========\n", title="AMS: close", ) @patch("core.trading.active_management.sendmsg") def test_bulk_notify_extra(self, sendmsg): self.ams.bulk_notify( "close", [{"id": "id1", "check": "check1", "extra": {"field1": "value1"}}] ) sendmsg.assert_called_once_with( self.user, ( "ACTION: 'close' on trade ID 'id1'\nVIOLATION: 'check1'\nEXTRA:" " field1: value1\n=========\n" ), title="AMS: close", ) @patch("core.trading.active_management.ActiveManagement.check_protection") def test_position_size_reduced(self, check_protection): self.active_management_policy.when_position_size_violated = "adjust" self.active_management_policy.save() self.trades[0]["currentUnits"] = "100000" self.ams.run_checks() call_args = check_protection.call_args[0][0] self.assertEqual(call_args["amount"], 500) self.assertEqual(call_args["units"], 500) @patch("core.trading.active_management.ActiveManagement.check_asset_groups") def test_protection_added(self, check_asset_groups): self.active_management_policy.when_protection_violated = "adjust" self.active_management_policy.save() self.trades[0]["takeProfitOrder"] = None self.trades[0]["stopLossOrder"] = None self.ams.run_checks() call_args = check_asset_groups.call_args[0][0] self.assertEqual(call_args["take_profit_price"], D("1.07934")) self.assertEqual(call_args["stop_loss_price"], D("1.05276")) @patch("core.trading.active_management.ActiveManagement.check_asset_groups") def test_protection_amended(self, check_asset_groups): self.active_management_policy.when_protection_violated = "adjust" self.active_management_policy.save() self.trades[0]["takeProfitOrder"] = {"price": "0.0001"} self.trades[0]["stopLossOrder"] = {"price": "0.0001"} self.ams.run_checks() call_args = check_asset_groups.call_args[0][0] self.assertEqual(call_args["take_profit_price"], D("1.07934")) self.assertEqual(call_args["stop_loss_price"], D("1.05276")) @patch("core.trading.active_management.ActiveManagement.close_trade") def test_max_risk_not_violated_after_adjusting_protection(self, close_trade): """ Ensure the max risk check is not violated after adjusting the protection. """ self.active_management_policy.when_protection_violated = "adjust" self.active_management_policy.save() self.add_trade( "20085", "EUR_USD", "long", "2023-02-13T15:39:19.302917985Z", ) self.trades[2]["stopLossOrder"]["price"] = "0.001" self.trades[2]["currentUnits"] = "13000" self.ams.run_checks() self.assertEqual(close_trade.call_count, 0) @patch("core.trading.active_management.ActiveManagement.close_trade") def test_max_risk_not_violated_after_adjusting_position_size(self, close_trade): """ Ensure the max risk check is not violated after adjusting the position size. """ self.active_management_policy.when_position_size_violated = "adjust" self.active_management_policy.save() self.trades[0]["currentUnits"] = "100000" self.ams.run_checks() self.assertEqual(close_trade.call_count, 0) @patch("core.trading.active_management.ActiveManagement.check_crossfilter") @patch("core.trading.active_management.ActiveManagement.check_trends") def test_trading_time_mutation(self, check_trends, check_crossfilter): self.trades[0]["openTime"] = "2023-02-17T11:38:06.302917Z" # Friday self.ams.run_checks() self.assertEqual(check_trends.call_count, 1) call_args = check_trends.call_args[0][0] self.assertEqual( call_args["id"], "20084" ) # Only run trends on the second trade crossfilter_call_args = check_crossfilter.call_args[0][0] self.assertEqual(len(crossfilter_call_args), 1) self.assertEqual( crossfilter_call_args[0]["id"], "20084" ) # Same for crossfilter @patch("core.trading.active_management.ActiveManagement.check_crossfilter") @patch("core.trading.active_management.ActiveManagement.check_position_size") def test_check_trends_mutation(self, check_position_size, check_crossfilter): signal = self.create_hook_signal() self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "sell"} self.amend_tp_sl_flip_side() self.strategy.save() self.ams.run_checks() self.assertEqual(check_position_size.call_count, 0) crossfilter_call_args = check_crossfilter.call_args[0][0] self.assertEqual(len(crossfilter_call_args), 0) @patch("core.trading.active_management.ActiveManagement.check_crossfilter") @patch("core.trading.active_management.ActiveManagement.check_protection") def test_check_position_size_mutation(self, check_protection, check_crossfilter): self.trades[0]["currentUnits"] = "100000" self.ams.run_checks() self.assertEqual(check_protection.call_count, 1) call_args = check_protection.call_args[0][0] self.assertEqual( call_args["id"], "20084" ) # Only run protection on the second trade crossfilter_call_args = check_crossfilter.call_args[0][0] self.assertEqual(len(crossfilter_call_args), 1) self.assertEqual( crossfilter_call_args[0]["id"], "20084" ) # Same for crossfilter @patch("core.trading.active_management.ActiveManagement.check_crossfilter") @patch("core.trading.active_management.ActiveManagement.check_protection") def test_check_protection_mutation(self, check_protection, check_crossfilter): self.trades[0]["currentUnits"] = "100000" self.ams.run_checks() self.assertEqual(check_protection.call_count, 1) call_args = check_protection.call_args[0][0] self.assertEqual( call_args["id"], "20084" ) # Only run protection on the second trade crossfilter_call_args = check_crossfilter.call_args[0][0] self.assertEqual(len(crossfilter_call_args), 1) self.assertEqual( crossfilter_call_args[0]["id"], "20084" ) # Same for crossfilter # This may look similar but check_crossfilter is called with the whole trade list. # Check that the trade that is removed from the list is not checked. @patch("core.trading.active_management.ActiveManagement.check_crossfilter") def test_check_asset_groups_mutation(self, check_crossfilter): asset_group = AssetGroup.objects.create( user=self.user, name="Test Asset Group", ) AssetRule.objects.create( user=self.user, asset="USD", group=asset_group, status=2, # Bullish ) self.strategy.asset_group = asset_group self.strategy.save() self.ams.run_checks() check_crossfilter.assert_called_once_with([]) @patch("core.trading.active_management.ActiveManagement.check_max_open_trades") def test_check_crossfilter_mutation(self, check_max_open_trades): self.trades[1]["side"] = "short" self.amend_tp_sl_flip_side() self.ams.run_checks() self.assertEqual(check_max_open_trades.call_count, 1) call_args = check_max_open_trades.call_args[0][0] self.assertEqual(len(call_args), 1) self.assertEqual(call_args[0]["id"], "20083") @patch( # When the string is just too damn long ( "core.trading.active_management.ActiveManagement." "check_max_open_trades_per_symbol" ) ) def test_check_max_open_trades_mutation(self, check_max_open_trades_per_symbol): for x in range(9): self.add_trade( str(x), f"EUR_USD{x}", # Vary symbol to prevent max open trades per symbol "long", f"2023-02-13T12:39:1{x}.302917985Z", ) self.ams.run_checks() self.assertEqual(check_max_open_trades_per_symbol.call_count, 1) call_args = check_max_open_trades_per_symbol.call_args[0][0] called_with_ids = [x["id"] for x in call_args] self.assertListEqual( called_with_ids, ["20083", "20084", "0", "1", "2", "3", "4", "5", "6", "7"] ) @patch("core.trading.active_management.ActiveManagement.check_max_loss") def test_check_max_open_trades_per_symbol_mutation(self, check_max_loss): for x in range(4): self.add_trade( str(x), "EUR_USD", "long", f"2023-02-13T12:39:1{x}.302917985Z", ) self.ams.run_checks() self.assertEqual(check_max_loss.call_count, 1) call_args = check_max_loss.call_args[0][0] called_with_ids = [x["id"] for x in call_args] self.assertListEqual(called_with_ids, ["20083", "20084", "0", "1", "2"]) @patch("core.trading.active_management.ActiveManagement.check_max_risk") def test_check_max_loss_mutation(self, check_max_risk): self.balance = D("1") self.balance_usd = D("0.69") self.trades = [] self.ams.run_checks() self.assertEqual(check_max_risk.call_count, 1) call_args = check_max_risk.call_args[0][0] called_with_ids = [x["id"] for x in call_args] self.assertListEqual(called_with_ids, []) def test_check_max_risk_mutation(self): """ This cannot be tested as there are no hooks after it. """ pass