From 911ccde37b3781ed143cd88fb257e6090abba804 Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Sat, 18 Feb 2023 21:23:59 +0000 Subject: [PATCH] Implement trade mutation pipeline and active management actions --- core/tests/trading/test_active_management.py | 543 ++++++++++++++----- core/trading/active_management.py | 105 +++- 2 files changed, 490 insertions(+), 158 deletions(-) diff --git a/core/tests/trading/test_active_management.py b/core/tests/trading/test_active_management.py index 4b01c90..fa9b983 100644 --- a/core/tests/trading/test_active_management.py +++ b/core/tests/trading/test_active_management.py @@ -189,19 +189,20 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): # Ensure the correct trade has been passed to the violation self.assertIn(call[0][2], expected_trades) if expected_args: - self.assertEqual(call[0][3], expected_args) + _, kwargs = call + self.assertEqual(kwargs, expected_args) + # self.assertEqual(call[0][3], expected_args) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_run_checks(self, handle_violation): + def test_run_checks(self): self.ams.run_checks() - self.assertEqual(handle_violation.call_count, 0) + self.assertEqual(len(self.ams.actions), 0) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_trading_time_violated(self, handle_violation): + def test_trading_time_violated(self): self.trades[0]["openTime"] = "2023-02-17T11:38:06.302917Z" # Friday self.ams.run_checks() - self.check_violation( - "trading_time", handle_violation.call_args_list, "close", [self.trades[0]] + self.assertEqual( + self.ams.actions, + {"close": [{"id": "20083", "check": "trading_time", "extra": {}}]}, ) def create_hook_signal(self): @@ -217,8 +218,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): ) return signal - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_trends_violated(self, handle_violation): + def test_trends_violated(self): signal = self.create_hook_signal() self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "sell"} @@ -226,18 +226,24 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.strategy.save() self.ams.run_checks() - self.check_violation( - "trends", handle_violation.call_args_list, "close", self.trades + self.assertEqual( + self.ams.actions, + { + "close": [ + {"id": "20084", "check": "trends", "extra": {}}, + {"id": "20083", "check": "trends", "extra": {}}, + ] + }, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_trends_violated_none(self, handle_violation): + def test_trends_violated_none(self): signal = self.create_hook_signal() self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "buy"} self.strategy.save() self.ams.run_checks() - self.check_violation("trends", handle_violation.call_args_list, "close", []) + + self.assertEqual(self.ams.actions, {}) # Mock crossfilter here since we want to allow this conflict in order to test that # trends only close trades that are in the wrong direction @@ -245,8 +251,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): "core.trading.active_management.ActiveManagement.check_crossfilter", return_value=None, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_trends_violated_partial(self, handle_violation, check_crossfilter): + def test_trends_violated_partial(self, check_crossfilter): signal = self.create_hook_signal() self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "sell"} @@ -257,25 +262,53 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.amend_tp_sl_flip_side() self.ams.run_checks() - self.check_violation( - "trends", handle_violation.call_args_list, "close", [self.trades[1]] + self.assertEqual( + self.ams.actions, + {"close": [{"id": "20084", "check": "trends", "extra": {}}]}, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_position_size_violated(self, handle_violation): + def test_position_size_violated(self): self.trades[0]["currentUnits"] = "100000" self.ams.run_checks() + self.assertEqual( + self.ams.actions, + { + "close": [ + { + "id": "20083", + "check": "position_size", + "extra": {"size": D("500.000")}, + } + ] + }, + ) - self.check_violation( - "position_size", - handle_violation.call_args_list, - "close", - [self.trades[0]], - {"size": 500}, + def test_protection_violated(self): + self.trades[0]["takeProfitOrder"] = {"price": "0.0001"} + self.trades[0]["stopLossOrder"] = {"price": "0.0001"} + self.ams.run_checks() + + expected_args = { + "take_profit_price": D("1.07934"), + "stop_loss_price": D("1.05276"), + } + + self.assertEqual( + self.ams.actions, + { + "close": [ + { + "id": "20083", + "check": "protection", + "extra": { + **expected_args, + }, + } + ] + }, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_protection_violated_absent(self, handle_violation): + def test_protection_violated_absent(self): self.trades[0]["takeProfitOrder"] = None self.trades[0]["stopLossOrder"] = None self.ams.run_checks() @@ -284,16 +317,22 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): "take_profit_price": D("1.07934"), "stop_loss_price": D("1.05276"), } - self.check_violation( - "protection", - handle_violation.call_args_list, - "close", - [self.trades[0]], - expected_args, + self.assertEqual( + self.ams.actions, + { + "close": [ + { + "id": "20083", + "check": "protection", + "extra": { + **expected_args, + }, + } + ] + }, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_protection_violated_absent_not_required(self, handle_violation): + def test_protection_violated_absent_not_required(self): self.strategy.order_settings.take_profit_percent = 0 self.strategy.order_settings.stop_loss_percent = 0 self.strategy.order_settings.save() @@ -301,10 +340,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.trades[0]["stopLossOrder"] = None self.ams.run_checks() - self.assertEqual(handle_violation.call_count, 0) + self.assertEqual(self.ams.actions, {}) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_asset_groups_violated(self, handle_violation): + def test_asset_groups_violated(self): asset_group = AssetGroup.objects.create( user=self.user, name="Test Asset Group", @@ -319,15 +357,17 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.strategy.save() self.ams.run_checks() - self.check_violation( - "asset_group", - handle_violation.call_args_list, - "close", - self.trades, # All trades should be closed, since all are USD quote + self.assertEqual( + self.ams.actions, + { + "close": [ + {"id": "20084", "check": "asset_group", "extra": {}}, + {"id": "20083", "check": "asset_group", "extra": {}}, + ] + }, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_asset_groups_violated_invert(self, handle_violation): + def test_asset_groups_violated_invert(self): self.trades[0]["side"] = "short" self.trades[1]["side"] = "short" self.amend_tp_sl_flip_side() @@ -345,28 +385,27 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.strategy.save() self.ams.run_checks() - self.check_violation( - "asset_group", - handle_violation.call_args_list, - "close", - self.trades, # All trades should be closed, since all are USD quote + self.assertEqual( + self.ams.actions, + { + "close": [ + {"id": "20084", "check": "asset_group", "extra": {}}, + {"id": "20083", "check": "asset_group", "extra": {}}, + ] + }, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_crossfilter_violated_side(self, handle_violation): + def test_crossfilter_violated_side(self): self.trades[1]["side"] = "short" self.amend_tp_sl_flip_side() self.ams.run_checks() - self.check_violation( - "crossfilter", - handle_violation.call_args_list, - "close", - [self.trades[1]], # Only close newer trade + self.assertEqual( + self.ams.actions, + {"close": [{"id": "20084", "check": "crossfilter", "extra": {}}]}, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_crossfilter_violated_side_multiple(self, handle_violation): + def test_crossfilter_violated_side_multiple(self): self.add_trade( "20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z" ) # 2: @@ -375,28 +414,28 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.amend_tp_sl_flip_side() self.ams.run_checks() - self.check_violation( - "crossfilter", - handle_violation.call_args_list, - "close", - self.trades[2:], # Only close newer trades + self.assertEqual( + self.ams.actions, + { + "close": [ + {"id": "20087", "check": "crossfilter", "extra": {}}, + {"id": "20086", "check": "crossfilter", "extra": {}}, + {"id": "20085", "check": "crossfilter", "extra": {}}, + ] + }, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_crossfilter_violated_symbol(self, handle_violation): + def test_crossfilter_violated_symbol(self): # Change symbol to conflict with long on EUR_USD self.trades[1]["symbol"] = "USD_EUR" self.ams.run_checks() - self.check_violation( - "crossfilter", - handle_violation.call_args_list, - "close", - [self.trades[1]], # Only close newer trade + self.assertEqual( + self.ams.actions, + {"close": [{"id": "20084", "check": "crossfilter", "extra": {}}]}, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_crossfilter_violated_symbol_multiple(self, handle_violation): + def test_crossfilter_violated_symbol_multiple(self): self.add_trade( "20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z" ) # 2: @@ -404,15 +443,18 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.add_trade("20087", "USD_EUR", "long", "2023-02-13T14:39:06.302917985Z") self.ams.run_checks() - self.check_violation( - "crossfilter", - handle_violation.call_args_list, - "close", - self.trades[2:], # Only close newer trades + self.assertEqual( + self.ams.actions, + { + "close": [ + {"id": "20087", "check": "crossfilter", "extra": {}}, + {"id": "20086", "check": "crossfilter", "extra": {}}, + {"id": "20085", "check": "crossfilter", "extra": {}}, + ] + }, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_max_open_trades_violated(self, handle_violation): + def test_max_open_trades_violated(self): for x in range(9): self.add_trade( str(x), @@ -422,15 +464,12 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): ) self.ams.run_checks() - self.check_violation( - "max_open_trades", - handle_violation.call_args_list, - "close", - self.trades[10:], # Only close newer trades + self.assertEqual( + self.ams.actions, + {"close": [{"id": "8", "check": "max_open_trades", "extra": {}}]}, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_max_open_trades_per_symbol_violated(self, handle_violation): + def test_max_open_trades_per_symbol_violated(self): for x in range(4): self.add_trade( str(x), @@ -440,11 +479,14 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): ) self.ams.run_checks() - self.check_violation( - "max_open_trades_per_symbol", - handle_violation.call_args_list, - "close", - self.trades[5:], # Only close newer trades + + self.assertEqual( + self.ams.actions, + { + "close": [ + {"id": "3", "check": "max_open_trades_per_symbol", "extra": {}} + ] + }, ) # Mock position size as we have no way of checking the balance at the start of the @@ -456,8 +498,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): "core.trading.active_management.ActiveManagement.check_position_size", return_value=None, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_max_loss_violated(self, handle_violation, check_position_size): + def test_max_loss_violated(self, check_position_size): self.balance = D("1") self.balance_usd = D("0.69") @@ -465,11 +506,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.ams.run_checks() - self.check_violation( - "max_loss", - handle_violation.call_args_list, - "close", - [None], + self.assertEqual( + self.ams.actions, + {"close": [{"id": None, "check": "max_loss", "extra": {}}]}, ) @patch( @@ -480,10 +519,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): "core.trading.active_management.ActiveManagement.check_protection", return_value=None, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_max_risk_violated( - self, handle_violation, check_protection, check_position_size - ): + def test_max_risk_violated(self, check_protection, check_position_size): self.add_trade( "20085", "EUR_USD", @@ -495,11 +531,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.ams.run_checks() - self.check_violation( - "max_risk", - handle_violation.call_args_list, - "close", - [self.trades[2]], + self.assertEqual( + self.ams.actions, + {"close": [{"id": "20085", "check": "max_risk", "extra": {}}]}, ) @patch( @@ -510,10 +544,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): "core.trading.active_management.ActiveManagement.check_protection", return_value=None, ) - @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_max_risk_violated_multiple( - self, handle_violation, check_protection, check_position_size - ): + def test_max_risk_violated_multiple(self, check_protection, check_position_size): self.add_trade( "20085", "EUR_USD", @@ -533,11 +564,14 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.ams.run_checks() - self.check_violation( - "max_risk", - handle_violation.call_args_list, - "close", - [self.trades[2], self.trades[3]], + self.assertEqual( + self.ams.actions, + { + "close": [ + {"id": "20086", "check": "max_risk", "extra": {}}, + {"id": "20085", "check": "max_risk", "extra": {}}, + ] + }, ) @patch( @@ -574,7 +608,17 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): ) def test_reduce_actions(self): - pass + """ + Test that closing actions precede adjusting actions. + """ + self.ams.add_action("close", "trading_time", "fake_trade_id") + self.ams.add_action("adjust", "position_size", "fake_trade_id", size=1000) + self.assertEqual(len(self.ams.actions["close"]), 1) + self.assertEqual(len(self.ams.actions["adjust"]), 1) + + self.ams.reduce_actions() + self.assertEqual(len(self.ams.actions["close"]), 1) + self.assertEqual(len(self.ams.actions["adjust"]), 0) @patch("core.trading.active_management.ActiveManagement.bulk_close_trades") @patch("core.trading.active_management.ActiveManagement.bulk_notify") @@ -688,26 +732,263 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): close_trade.assert_any_call("id1") close_trade.assert_any_call("id2") - def test_bulk_notify(self): - pass + @patch("core.trading.active_management.sendmsg") + def test_bulk_notify_plain(self, sendmsg): + self.ams.bulk_notify("close", [{"id": "id1", "check": "check1", "extra": {}}]) + + sendmsg.assert_called_once_with( + self.user, + "ACTION: 'close' on trade ID 'id1'\nVIOLATION: 'check1'\n=========\n", + title="AMS: close", + ) + + @patch("core.trading.active_management.sendmsg") + def test_bulk_notify_extra(self, sendmsg): + self.ams.bulk_notify( + "close", [{"id": "id1", "check": "check1", "extra": {"field1": "value1"}}] + ) + + sendmsg.assert_called_once_with( + self.user, + ( + "ACTION: 'close' on trade ID 'id1'\nVIOLATION: 'check1'\nEXTRA:" + " field1: value1\n=========\n" + ), + title="AMS: close", + ) + + @patch("core.trading.active_management.ActiveManagement.check_protection") + def test_position_size_reduced(self, check_protection): + self.active_management_policy.when_position_size_violated = "adjust" + self.active_management_policy.save() + + self.trades[0]["currentUnits"] = "100000" + self.ams.run_checks() + + call_args = check_protection.call_args[0][0] + + self.assertEqual(call_args["amount"], 500) + self.assertEqual(call_args["units"], 500) + + @patch("core.trading.active_management.ActiveManagement.check_asset_groups") + def test_protection_added(self, check_asset_groups): + self.active_management_policy.when_protection_violated = "adjust" + self.active_management_policy.save() + self.trades[0]["takeProfitOrder"] = None + self.trades[0]["stopLossOrder"] = None + self.ams.run_checks() + + call_args = check_asset_groups.call_args[0][0] + self.assertEqual(call_args["take_profit_price"], D("1.07934")) + self.assertEqual(call_args["stop_loss_price"], D("1.05276")) + + @patch("core.trading.active_management.ActiveManagement.check_asset_groups") + def test_protection_amended(self, check_asset_groups): + self.active_management_policy.when_protection_violated = "adjust" + self.active_management_policy.save() + self.trades[0]["takeProfitOrder"] = {"price": "0.0001"} + self.trades[0]["stopLossOrder"] = {"price": "0.0001"} + self.ams.run_checks() + + call_args = check_asset_groups.call_args[0][0] + self.assertEqual(call_args["take_profit_price"], D("1.07934")) + self.assertEqual(call_args["stop_loss_price"], D("1.05276")) - def test_max_risk_not_violated_after_adjusting_protection(self): + @patch("core.trading.active_management.ActiveManagement.close_trade") + def test_max_risk_not_violated_after_adjusting_protection(self, close_trade): """ Ensure the max risk check is not violated after adjusting the protection. """ - pass + self.active_management_policy.when_protection_violated = "adjust" + self.active_management_policy.save() - def test_max_risk_not_violated_after_adjusting_position_size(self): + self.add_trade( + "20085", + "EUR_USD", + "long", + "2023-02-13T15:39:19.302917985Z", + ) + self.trades[2]["stopLossOrder"]["price"] = "0.001" + self.trades[2]["currentUnits"] = "13000" + + self.ams.run_checks() + + self.assertEqual(close_trade.call_count, 0) + + @patch("core.trading.active_management.ActiveManagement.close_trade") + def test_max_risk_not_violated_after_adjusting_position_size(self, close_trade): """ Ensure the max risk check is not violated after adjusting the position size. """ - pass + self.active_management_policy.when_position_size_violated = "adjust" + self.active_management_policy.save() - def test_position_size_reduced(self): - pass + self.trades[0]["currentUnits"] = "100000" + self.ams.run_checks() - def test_protection_added(self): - pass + self.assertEqual(close_trade.call_count, 0) + + @patch("core.trading.active_management.ActiveManagement.check_crossfilter") + @patch("core.trading.active_management.ActiveManagement.check_trends") + def test_trading_time_mutation(self, check_trends, check_crossfilter): + self.trades[0]["openTime"] = "2023-02-17T11:38:06.302917Z" # Friday + self.ams.run_checks() - def test_protection_amended(self): + self.assertEqual(check_trends.call_count, 1) + call_args = check_trends.call_args[0][0] + self.assertEqual( + call_args["id"], "20084" + ) # Only run trends on the second trade + + crossfilter_call_args = check_crossfilter.call_args[0][0] + self.assertEqual(len(crossfilter_call_args), 1) + self.assertEqual( + crossfilter_call_args[0]["id"], "20084" + ) # Same for crossfilter + + @patch("core.trading.active_management.ActiveManagement.check_crossfilter") + @patch("core.trading.active_management.ActiveManagement.check_position_size") + def test_check_trends_mutation(self, check_position_size, check_crossfilter): + signal = self.create_hook_signal() + self.strategy.trend_signals.set([signal]) + self.strategy.trends = {"EUR_USD": "sell"} + self.amend_tp_sl_flip_side() + self.strategy.save() + + self.ams.run_checks() + + self.assertEqual(check_position_size.call_count, 0) + + crossfilter_call_args = check_crossfilter.call_args[0][0] + self.assertEqual(len(crossfilter_call_args), 0) + + @patch("core.trading.active_management.ActiveManagement.check_crossfilter") + @patch("core.trading.active_management.ActiveManagement.check_protection") + def test_check_position_size_mutation(self, check_protection, check_crossfilter): + self.trades[0]["currentUnits"] = "100000" + self.ams.run_checks() + + self.assertEqual(check_protection.call_count, 1) + + call_args = check_protection.call_args[0][0] + self.assertEqual( + call_args["id"], "20084" + ) # Only run protection on the second trade + + crossfilter_call_args = check_crossfilter.call_args[0][0] + self.assertEqual(len(crossfilter_call_args), 1) + self.assertEqual( + crossfilter_call_args[0]["id"], "20084" + ) # Same for crossfilter + + @patch("core.trading.active_management.ActiveManagement.check_crossfilter") + @patch("core.trading.active_management.ActiveManagement.check_protection") + def test_check_protection_mutation(self, check_protection, check_crossfilter): + self.trades[0]["currentUnits"] = "100000" + self.ams.run_checks() + + self.assertEqual(check_protection.call_count, 1) + + call_args = check_protection.call_args[0][0] + self.assertEqual( + call_args["id"], "20084" + ) # Only run protection on the second trade + + crossfilter_call_args = check_crossfilter.call_args[0][0] + self.assertEqual(len(crossfilter_call_args), 1) + self.assertEqual( + crossfilter_call_args[0]["id"], "20084" + ) # Same for crossfilter + + # This may look similar but check_crossfilter is called with the whole trade list. + # Check that the trade that is removed from the list is not checked. + @patch("core.trading.active_management.ActiveManagement.check_crossfilter") + def test_check_asset_groups_mutation(self, check_crossfilter): + asset_group = AssetGroup.objects.create( + user=self.user, + name="Test Asset Group", + ) + AssetRule.objects.create( + user=self.user, + asset="USD", + group=asset_group, + status=2, # Bullish + ) + self.strategy.asset_group = asset_group + self.strategy.save() + self.ams.run_checks() + + check_crossfilter.assert_called_once_with([]) + + @patch("core.trading.active_management.ActiveManagement.check_max_open_trades") + def test_check_crossfilter_mutation(self, check_max_open_trades): + self.trades[1]["side"] = "short" + self.amend_tp_sl_flip_side() + self.ams.run_checks() + + self.assertEqual(check_max_open_trades.call_count, 1) + call_args = check_max_open_trades.call_args[0][0] + + self.assertEqual(len(call_args), 1) + self.assertEqual(call_args[0]["id"], "20083") + + @patch( # When the string is just too damn long + ( + "core.trading.active_management.ActiveManagement." + "check_max_open_trades_per_symbol" + ) + ) + def test_check_max_open_trades_mutation(self, check_max_open_trades_per_symbol): + for x in range(9): + self.add_trade( + str(x), + f"EUR_USD{x}", # Vary symbol to prevent max open trades per symbol + "long", + f"2023-02-13T12:39:1{x}.302917985Z", + ) + + self.ams.run_checks() + + self.assertEqual(check_max_open_trades_per_symbol.call_count, 1) + call_args = check_max_open_trades_per_symbol.call_args[0][0] + called_with_ids = [x["id"] for x in call_args] + self.assertListEqual( + called_with_ids, ["20083", "20084", "0", "1", "2", "3", "4", "5", "6", "7"] + ) + + @patch("core.trading.active_management.ActiveManagement.check_max_loss") + def test_check_max_open_trades_per_symbol_mutation(self, check_max_loss): + for x in range(4): + self.add_trade( + str(x), + "EUR_USD", + "long", + f"2023-02-13T12:39:1{x}.302917985Z", + ) + + self.ams.run_checks() + + self.assertEqual(check_max_loss.call_count, 1) + call_args = check_max_loss.call_args[0][0] + called_with_ids = [x["id"] for x in call_args] + self.assertListEqual(called_with_ids, ["20083", "20084", "0", "1", "2"]) + + @patch("core.trading.active_management.ActiveManagement.check_max_risk") + def test_check_max_loss_mutation(self, check_max_risk): + self.balance = D("1") + self.balance_usd = D("0.69") + + self.trades = [] + + self.ams.run_checks() + + self.assertEqual(check_max_risk.call_count, 1) + call_args = check_max_risk.call_args[0][0] + called_with_ids = [x["id"] for x in call_args] + self.assertListEqual(called_with_ids, []) + + def test_check_max_risk_mutation(self): + """ + This cannot be tested as there are no hooks after it. + """ pass diff --git a/core/trading/active_management.py b/core/trading/active_management.py index d25146d..236b8db 100644 --- a/core/trading/active_management.py +++ b/core/trading/active_management.py @@ -3,10 +3,12 @@ from datetime import datetime from decimal import Decimal as D from core.exchanges.convert import ( + annotate_trade_tp_sl_percent, convert_trades, sl_percent_to_price, tp_percent_to_price, ) +from core.lib.notify import sendmsg from core.trading import assetfilter, checks, market, risk from core.trading.crossfilter import crossfilter from core.trading.market import get_base_quote, get_trade_size_in_base @@ -15,6 +17,10 @@ from core.util import logs log = logs.get_logger("ams") +class TradeClosed(Exception): + pass + + class ActiveManagement(object): def __init__(self, strategy): self.strategy = strategy @@ -74,7 +80,6 @@ class ActiveManagement(object): self.close_trade(trade_id) def bulk_notify(self, action, action_cast_list): - print("CALL", action, action_cast_list) msg = "" for action_cast in action_cast_list: msg += f"ACTION: '{action}' on trade ID '{action_cast['id']}'\n" @@ -85,7 +90,7 @@ class ActiveManagement(object): msg += f"EXTRA: {extra}\n" msg += "=========\n" - print("NOTIFY", msg) + sendmsg(self.strategy.user, msg, title=f"AMS: {action}") def adjust_position_size(self, trade_id, new_size): pass @@ -120,7 +125,6 @@ class ActiveManagement(object): self.bulk_adjust(action_cast_list) def handle_violation(self, check_type, action, trade_id, **kwargs): - print("VIOLATION", check_type, action, trade_id, kwargs) if action == "none": return self.add_action(action, check_type, trade_id, **kwargs) @@ -148,6 +152,8 @@ class ActiveManagement(object): self.handle_violation( "trading_time", self.policy.when_trading_time_violated, trade["id"] ) + if self.policy.when_trading_time_violated == "close": + raise TradeClosed def check_trends(self, trade): direction = trade["direction"] @@ -157,6 +163,8 @@ class ActiveManagement(object): self.handle_violation( "trends", self.policy.when_trends_violated, trade["id"] ) + if self.policy.when_trends_violated == "close": + raise TradeClosed def check_position_size(self, trade): """ @@ -189,8 +197,13 @@ class ActiveManagement(object): "position_size", self.policy.when_position_size_violated, trade["id"], - {"size": expected_trade_size}, + size=expected_trade_size, ) + if self.policy.when_position_size_violated == "close": + raise TradeClosed + elif self.policy.when_position_size_violated == "adjust": + trade["amount"] = expected_trade_size + trade["units"] = expected_trade_size def check_protection(self, trade): deviation = D(0.05) # 5% @@ -252,8 +265,14 @@ class ActiveManagement(object): "protection", self.policy.when_protection_violated, trade["id"], - violations + **violations ) + if self.policy.when_protection_violated == "close": + raise TradeClosed + elif self.policy.when_protection_violated == "adjust": + trade.update(violations) + annotate_trade_tp_sl_percent(trade) + market.convert_trades_to_usd(self.strategy.account, [trade]) def check_asset_groups(self, trade): if self.strategy.asset_group is not None: @@ -267,6 +286,8 @@ class ActiveManagement(object): self.handle_violation( "asset_group", self.policy.when_asset_groups_violated, trade["id"] ) + if self.policy.when_asset_groups_violated == "close": + raise TradeClosed def get_sorted_trades_copy(self, trades, reverse=True): trades_copy = deepcopy(trades) @@ -280,18 +301,22 @@ class ActiveManagement(object): def check_crossfilter(self, trades): close_trades = [] - trades_copy = self.get_sorted_trades_copy(trades) + # trades_copy = self.get_sorted_trades_copy(trades) iterations = 0 finished = [] # Recursively run crossfilter on the newest-first list until we have no more # conflicts - while not len(finished) == len(trades): + length_before = len(trades) + while not len(finished) == length_before: iterations += 1 if iterations > 10000: raise Exception("Too many iterations") # For each trade - for trade in trades_copy: + # We need reverse because we are removing items from the list + # This works in our favour because the list is sorted the wrong + # way around in run_checks() + for trade in reversed(trades): # Abort if we've already checked this trade if trade in close_trades: continue @@ -299,7 +324,7 @@ class ActiveManagement(object): # Also remove if we have already checked this others = [ t - for t in trades_copy + for t in trades if t["id"] != trade["id"] and t not in close_trades ] symbol = trade["symbol"] @@ -320,6 +345,10 @@ class ActiveManagement(object): # And don't check it again finished.append(trade) close_trades.append(trade) + + # Remove it from the trades list + if self.policy.when_crossfilter_violated == "close": + trades.remove(trade) if not close_trades: return @@ -334,15 +363,17 @@ class ActiveManagement(object): return max_open_pass = risk.check_max_open_trades(self.strategy.risk_model, trades) if not max_open_pass: - trades_copy = self.get_sorted_trades_copy(trades, reverse=False) + # trades_copy = self.get_sorted_trades_copy(trades, reverse=False) # fmt: off - trades_over_limit = trades_copy[self.strategy.risk_model.max_open_trades:] + trades_over_limit = trades[self.strategy.risk_model.max_open_trades:] for trade in trades_over_limit: self.handle_violation( "max_open_trades", self.policy.when_max_open_trades_violated, trade["id"], ) + if self.policy.when_max_open_trades_violated == "close": + trades.remove(trade) def check_max_open_trades_per_symbol(self, trades): if self.strategy.risk_model is None: @@ -352,10 +383,10 @@ class ActiveManagement(object): ) max_open_pass = list(max_open_pass) if max_open_pass: - trades_copy = self.get_sorted_trades_copy(trades, reverse=False) + # trades_copy = self.get_sorted_trades_copy(trades, reverse=False) trades_over_limit = [] for symbol in max_open_pass: - symbol_trades = [x for x in trades_copy if x["symbol"] == symbol] + symbol_trades = [x for x in trades if x["symbol"] == symbol] # fmt: off exceeding_limit = symbol_trades[ self.strategy.risk_model.max_open_trades_per_symbol: @@ -369,8 +400,10 @@ class ActiveManagement(object): self.policy.when_max_open_trades_violated, trade["id"], ) + if self.policy.when_max_open_trades_violated == "close": + trades.remove(trade) - def check_max_loss(self): + def check_max_loss(self, trades): if self.strategy.risk_model is None: return check_passed = risk.check_max_loss( @@ -382,6 +415,9 @@ class ActiveManagement(object): self.handle_violation( "max_loss", self.policy.when_max_loss_violated, None # Close all trades ) + if self.policy.when_max_loss_violated == "close": + for trade in trades: + trades.remove(trade) def check_max_risk(self, trades): if self.strategy.risk_model is None: @@ -389,7 +425,7 @@ class ActiveManagement(object): close_trades = [] trades_copy = self.get_sorted_trades_copy(trades, reverse=False) - market.convert_trades_to_usd(self.strategy.account, trades_copy) + # market.convert_trades_to_usd(self.strategy.account, trades_copy) iterations = 0 finished = False @@ -414,21 +450,36 @@ class ActiveManagement(object): self.handle_violation( "max_risk", self.policy.when_max_risk_violated, trade["id"] ) + if self.policy.when_max_risk_violated == "close": + trades.remove(trade) def run_checks(self): converted_trades = convert_trades(self.get_trades()) - for trade in converted_trades: - self.check_trading_time(trade) - self.check_trends(trade) - self.check_position_size(trade) - self.check_protection(trade) - self.check_asset_groups(trade) - - self.check_crossfilter(converted_trades) - self.check_max_open_trades(converted_trades) - self.check_max_open_trades_per_symbol(converted_trades) - self.check_max_loss() - self.check_max_risk(converted_trades) + trades_copy = self.get_sorted_trades_copy(converted_trades, reverse=False) + market.convert_trades_to_usd(self.strategy.account, trades_copy) + for trade in reversed(trades_copy): + try: + self.check_trading_time(trade) + self.check_trends(trade) + self.check_position_size(trade) + self.check_protection(trade) + self.check_asset_groups(trade) + except TradeClosed: + # Trade was closed, don't check it again + trades_copy.remove(trade) + continue + + self.check_crossfilter(trades_copy) + self.check_max_open_trades(trades_copy) + self.check_max_open_trades_per_symbol(trades_copy) + self.check_max_loss(trades_copy) + self.check_max_risk(trades_copy) + + def execute_actions(self): + if not self.actions: + return + self.reduce_actions() + self.run_actions() # Trading Time # Max loss