From ae104f446a78a133840a7a484394afc40c83c144 Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Sat, 18 Feb 2023 17:55:39 +0000 Subject: [PATCH] Start implementing active management actions --- core/tests/trading/test_active_management.py | 160 +++++++++++++++++++ core/trading/active_management.py | 84 +++++++++- 2 files changed, 242 insertions(+), 2 deletions(-) diff --git a/core/tests/trading/test_active_management.py b/core/tests/trading/test_active_management.py index ce779f2..4b01c90 100644 --- a/core/tests/trading/test_active_management.py +++ b/core/tests/trading/test_active_management.py @@ -540,6 +540,157 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): [self.trades[2], self.trades[3]], ) + @patch( + "core.trading.active_management.ActiveManagement.add_action", + ) + def test_handle_violation(self, add_action): + self.ams.handle_violation("max_loss", "close", "trade_id") + self.ams.handle_violation("position_size", "adjust", "trade_id2", size=1000) + + self.assertEqual(add_action.call_count, 2) + add_action.assert_any_call("close", "max_loss", "trade_id") + add_action.assert_any_call("adjust", "position_size", "trade_id2", size=1000) + + def test_add_action(self): + protection_args = { + "take_profit_price": D("1.07934"), + "stop_loss_price": D("1.05276"), + } + self.ams.add_action("close", "trading_time", "fake_trade_id") + self.ams.add_action("close", "protection", "fake_trade_id2", **protection_args) + + self.assertEqual( + self.ams.actions, + { + "close": [ + {"id": "fake_trade_id", "check": "trading_time", "extra": {}}, + { + "id": "fake_trade_id2", + "check": "protection", + "extra": protection_args, + }, + ], + }, + ) + + def test_reduce_actions(self): + pass + + @patch("core.trading.active_management.ActiveManagement.bulk_close_trades") + @patch("core.trading.active_management.ActiveManagement.bulk_notify") + def test_run_actions(self, bulk_notify, bulk_close_trades): + protection_args = { + "take_profit_price": D("1.07934"), + "stop_loss_price": D("1.05276"), + } + self.ams.add_action("close", "trading_time", "fake_trade_id") + self.ams.add_action("close", "protection", "fake_trade_id2", **protection_args) + self.ams.run_actions() + + expected_action_cast = [ + {"id": "fake_trade_id", "check": "trading_time", "extra": {}}, + { + "id": "fake_trade_id2", + "check": "protection", + "extra": { + **protection_args, + }, + }, + ] + bulk_notify.assert_called_once_with( + "close", + expected_action_cast, + ) + + bulk_close_trades.assert_called_once_with( + ["fake_trade_id", "fake_trade_id2"], + ) + + @patch("core.trading.active_management.ActiveManagement.bulk_close_trades") + @patch("core.trading.active_management.ActiveManagement.bulk_notify") + def test_run_actions_notify_only(self, bulk_notify, bulk_close_trades): + protection_args = { + "take_profit_price": D("1.07934"), + "stop_loss_price": D("1.05276"), + } + self.ams.add_action("notify", "trading_time", "fake_trade_id") + self.ams.add_action("notify", "protection", "fake_trade_id2", **protection_args) + self.ams.run_actions() + + expected_action_cast = [ + {"id": "fake_trade_id", "check": "trading_time", "extra": {}}, + { + "id": "fake_trade_id2", + "check": "protection", + "extra": { + **protection_args, + }, + }, + ] + bulk_notify.assert_called_once_with( + "notify", + expected_action_cast, + ) + + bulk_close_trades.assert_not_called() + + @patch("core.trading.active_management.ActiveManagement.bulk_close_trades") + @patch("core.trading.active_management.ActiveManagement.bulk_adjust") + @patch("core.trading.active_management.ActiveManagement.bulk_notify") + def test_run_actions_notify_adjust_only( + self, bulk_notify, bulk_adjust, bulk_close_trades + ): + protection_args = { + "take_profit_price": D("1.07934"), + "stop_loss_price": D("1.05276"), + } + self.ams.add_action("adjust", "position_size", "fake_trade_id", size=1000) + self.ams.add_action("adjust", "protection", "fake_trade_id2", **protection_args) + self.ams.run_actions() + + expected_action_cast = [ + {"id": "fake_trade_id", "check": "position_size", "extra": {"size": 1000}}, + { + "id": "fake_trade_id2", + "check": "protection", + "extra": { + **protection_args, + }, + }, + ] + bulk_notify.assert_called_once_with( + "adjust", + expected_action_cast, + ) + + bulk_adjust.assert_called_once_with(expected_action_cast) + bulk_close_trades.assert_not_called() + + @patch("core.trading.active_management.ActiveManagement.adjust_position_size") + @patch("core.trading.active_management.ActiveManagement.adjust_protection") + def test_bulk_adjust(self, adjust_protection, adjust_position_size): + expected_protection = {"take_profit_price": 1.10, "stop_loss_price": 1.05} + cast_list = [ + {"id": "id1", "check": "position_size", "extra": {"size": 1000}}, + {"id": "id2", "check": "protection", "extra": expected_protection}, + ] + + self.ams.bulk_adjust(cast_list) + + adjust_position_size.assert_called_once_with("id1", 1000) + adjust_protection.assert_called_once_with("id2", expected_protection) + + @patch("core.trading.active_management.ActiveManagement.close_trade") + def test_bulk_close_trades(self, close_trade): + self.ams.bulk_close_trades(["id1", "id2"]) + + self.assertEqual(close_trade.call_count, 2) + close_trade.assert_any_call("id1") + close_trade.assert_any_call("id2") + + def test_bulk_notify(self): + pass + def test_max_risk_not_violated_after_adjusting_protection(self): """ Ensure the max risk check is not violated after adjusting the protection. @@ -551,3 +702,12 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): Ensure the max risk check is not violated after adjusting the position size. """ pass + + def test_position_size_reduced(self): + pass + + def test_protection_added(self): + pass + + def test_protection_amended(self): + pass diff --git a/core/trading/active_management.py b/core/trading/active_management.py index 1391dc4..d25146d 100644 --- a/core/trading/active_management.py +++ b/core/trading/active_management.py @@ -10,6 +10,9 @@ from core.exchanges.convert import ( from core.trading import assetfilter, checks, market, risk from core.trading.crossfilter import crossfilter from core.trading.market import get_base_quote, get_trade_size_in_base +from core.util import logs + +log = logs.get_logger("ams") class ActiveManagement(object): @@ -18,9 +21,30 @@ class ActiveManagement(object): self.policy = strategy.active_management_policy self.trades = [] + self.actions = {} self.balance = None self.balance_usd = None + def add_action(self, action, check_type, trade_id, **kwargs): + if action not in self.actions: + self.actions[action] = [] + self.actions[action].append( + {"id": trade_id, "check": check_type, "extra": kwargs} + ) + + def reduce_actions(self): + """ + If a trade is in the close actions, remove it from adjust. + """ + if "close" in self.actions: + for close_action in self.actions["close"]: + if "adjust" in self.actions: + self.actions["adjust"] = [ + action + for action in self.actions["adjust"] + if action["id"] != close_action["id"] + ] + def get_trades(self): if not self.trades: self.trades = self.strategy.account.client.get_all_open_trades() @@ -42,8 +66,64 @@ class ActiveManagement(object): else: return self.balance - def handle_violation(self, check_type, action, trade, **kwargs): - print("VIOLATION", check_type, action, trade, kwargs) + def close_trade(self, trade_id): + self.strategy.account.client.close_trade(trade_id) + + def bulk_close_trades(self, trade_ids): + for trade_id in trade_ids: + self.close_trade(trade_id) + + def bulk_notify(self, action, action_cast_list): + print("CALL", action, action_cast_list) + msg = "" + for action_cast in action_cast_list: + msg += f"ACTION: '{action}' on trade ID '{action_cast['id']}'\n" + msg += f"VIOLATION: '{action_cast['check']}'\n" + if action_cast["extra"]: + extra = action_cast["extra"] + extra = ", ".join([f"{k}: {v}" for k, v in extra.items()]) + msg += f"EXTRA: {extra}\n" + msg += "=========\n" + + print("NOTIFY", msg) + + def adjust_position_size(self, trade_id, new_size): + pass + + def adjust_protection(self, trade_id, new_protection): + pass + + def bulk_adjust(self, action_cast_list): + for item in action_cast_list: + trade_id = item["id"] + check = item["check"] + if "extra" in item: + extra = item["extra"] + else: + log.error(f"Adjust action missing extra data: {item}") + continue + if check == "position_size": + new_size = extra["size"] + self.adjust_position_size(trade_id, new_size) + elif check == "protection": + self.adjust_protection(trade_id, extra) + + def run_actions(self): + for action, action_cast_list in self.actions.items(): + if action == "none": + continue + self.bulk_notify(action, action_cast_list) + if action == "close": + trade_ids = [action_cast["id"] for action_cast in action_cast_list] + self.bulk_close_trades(trade_ids) + elif action == "adjust": + self.bulk_adjust(action_cast_list) + + def handle_violation(self, check_type, action, trade_id, **kwargs): + print("VIOLATION", check_type, action, trade_id, kwargs) + if action == "none": + return + self.add_action(action, check_type, trade_id, **kwargs) # TODO: close/notify for: # - trading time # - trends