Start implementing active management actions

This commit is contained in:
Mark Veidemanis 2023-02-18 17:55:39 +00:00
parent 15a8bec105
commit ae104f446a
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
2 changed files with 242 additions and 2 deletions

View File

@ -540,6 +540,157 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
[self.trades[2], self.trades[3]], [self.trades[2], self.trades[3]],
) )
@patch(
"core.trading.active_management.ActiveManagement.add_action",
)
def test_handle_violation(self, add_action):
self.ams.handle_violation("max_loss", "close", "trade_id")
self.ams.handle_violation("position_size", "adjust", "trade_id2", size=1000)
self.assertEqual(add_action.call_count, 2)
add_action.assert_any_call("close", "max_loss", "trade_id")
add_action.assert_any_call("adjust", "position_size", "trade_id2", size=1000)
def test_add_action(self):
protection_args = {
"take_profit_price": D("1.07934"),
"stop_loss_price": D("1.05276"),
}
self.ams.add_action("close", "trading_time", "fake_trade_id")
self.ams.add_action("close", "protection", "fake_trade_id2", **protection_args)
self.assertEqual(
self.ams.actions,
{
"close": [
{"id": "fake_trade_id", "check": "trading_time", "extra": {}},
{
"id": "fake_trade_id2",
"check": "protection",
"extra": protection_args,
},
],
},
)
def test_reduce_actions(self):
pass
@patch("core.trading.active_management.ActiveManagement.bulk_close_trades")
@patch("core.trading.active_management.ActiveManagement.bulk_notify")
def test_run_actions(self, bulk_notify, bulk_close_trades):
protection_args = {
"take_profit_price": D("1.07934"),
"stop_loss_price": D("1.05276"),
}
self.ams.add_action("close", "trading_time", "fake_trade_id")
self.ams.add_action("close", "protection", "fake_trade_id2", **protection_args)
self.ams.run_actions()
expected_action_cast = [
{"id": "fake_trade_id", "check": "trading_time", "extra": {}},
{
"id": "fake_trade_id2",
"check": "protection",
"extra": {
**protection_args,
},
},
]
bulk_notify.assert_called_once_with(
"close",
expected_action_cast,
)
bulk_close_trades.assert_called_once_with(
["fake_trade_id", "fake_trade_id2"],
)
@patch("core.trading.active_management.ActiveManagement.bulk_close_trades")
@patch("core.trading.active_management.ActiveManagement.bulk_notify")
def test_run_actions_notify_only(self, bulk_notify, bulk_close_trades):
protection_args = {
"take_profit_price": D("1.07934"),
"stop_loss_price": D("1.05276"),
}
self.ams.add_action("notify", "trading_time", "fake_trade_id")
self.ams.add_action("notify", "protection", "fake_trade_id2", **protection_args)
self.ams.run_actions()
expected_action_cast = [
{"id": "fake_trade_id", "check": "trading_time", "extra": {}},
{
"id": "fake_trade_id2",
"check": "protection",
"extra": {
**protection_args,
},
},
]
bulk_notify.assert_called_once_with(
"notify",
expected_action_cast,
)
bulk_close_trades.assert_not_called()
@patch("core.trading.active_management.ActiveManagement.bulk_close_trades")
@patch("core.trading.active_management.ActiveManagement.bulk_adjust")
@patch("core.trading.active_management.ActiveManagement.bulk_notify")
def test_run_actions_notify_adjust_only(
self, bulk_notify, bulk_adjust, bulk_close_trades
):
protection_args = {
"take_profit_price": D("1.07934"),
"stop_loss_price": D("1.05276"),
}
self.ams.add_action("adjust", "position_size", "fake_trade_id", size=1000)
self.ams.add_action("adjust", "protection", "fake_trade_id2", **protection_args)
self.ams.run_actions()
expected_action_cast = [
{"id": "fake_trade_id", "check": "position_size", "extra": {"size": 1000}},
{
"id": "fake_trade_id2",
"check": "protection",
"extra": {
**protection_args,
},
},
]
bulk_notify.assert_called_once_with(
"adjust",
expected_action_cast,
)
bulk_adjust.assert_called_once_with(expected_action_cast)
bulk_close_trades.assert_not_called()
@patch("core.trading.active_management.ActiveManagement.adjust_position_size")
@patch("core.trading.active_management.ActiveManagement.adjust_protection")
def test_bulk_adjust(self, adjust_protection, adjust_position_size):
expected_protection = {"take_profit_price": 1.10, "stop_loss_price": 1.05}
cast_list = [
{"id": "id1", "check": "position_size", "extra": {"size": 1000}},
{"id": "id2", "check": "protection", "extra": expected_protection},
]
self.ams.bulk_adjust(cast_list)
adjust_position_size.assert_called_once_with("id1", 1000)
adjust_protection.assert_called_once_with("id2", expected_protection)
@patch("core.trading.active_management.ActiveManagement.close_trade")
def test_bulk_close_trades(self, close_trade):
self.ams.bulk_close_trades(["id1", "id2"])
self.assertEqual(close_trade.call_count, 2)
close_trade.assert_any_call("id1")
close_trade.assert_any_call("id2")
def test_bulk_notify(self):
pass
def test_max_risk_not_violated_after_adjusting_protection(self): def test_max_risk_not_violated_after_adjusting_protection(self):
""" """
Ensure the max risk check is not violated after adjusting the protection. Ensure the max risk check is not violated after adjusting the protection.
@ -551,3 +702,12 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
Ensure the max risk check is not violated after adjusting the position size. Ensure the max risk check is not violated after adjusting the position size.
""" """
pass pass
def test_position_size_reduced(self):
pass
def test_protection_added(self):
pass
def test_protection_amended(self):
pass

View File

@ -10,6 +10,9 @@ from core.exchanges.convert import (
from core.trading import assetfilter, checks, market, risk from core.trading import assetfilter, checks, market, risk
from core.trading.crossfilter import crossfilter from core.trading.crossfilter import crossfilter
from core.trading.market import get_base_quote, get_trade_size_in_base from core.trading.market import get_base_quote, get_trade_size_in_base
from core.util import logs
log = logs.get_logger("ams")
class ActiveManagement(object): class ActiveManagement(object):
@ -18,9 +21,30 @@ class ActiveManagement(object):
self.policy = strategy.active_management_policy self.policy = strategy.active_management_policy
self.trades = [] self.trades = []
self.actions = {}
self.balance = None self.balance = None
self.balance_usd = None self.balance_usd = None
def add_action(self, action, check_type, trade_id, **kwargs):
if action not in self.actions:
self.actions[action] = []
self.actions[action].append(
{"id": trade_id, "check": check_type, "extra": kwargs}
)
def reduce_actions(self):
"""
If a trade is in the close actions, remove it from adjust.
"""
if "close" in self.actions:
for close_action in self.actions["close"]:
if "adjust" in self.actions:
self.actions["adjust"] = [
action
for action in self.actions["adjust"]
if action["id"] != close_action["id"]
]
def get_trades(self): def get_trades(self):
if not self.trades: if not self.trades:
self.trades = self.strategy.account.client.get_all_open_trades() self.trades = self.strategy.account.client.get_all_open_trades()
@ -42,8 +66,64 @@ class ActiveManagement(object):
else: else:
return self.balance return self.balance
def handle_violation(self, check_type, action, trade, **kwargs): def close_trade(self, trade_id):
print("VIOLATION", check_type, action, trade, kwargs) self.strategy.account.client.close_trade(trade_id)
def bulk_close_trades(self, trade_ids):
for trade_id in trade_ids:
self.close_trade(trade_id)
def bulk_notify(self, action, action_cast_list):
print("CALL", action, action_cast_list)
msg = ""
for action_cast in action_cast_list:
msg += f"ACTION: '{action}' on trade ID '{action_cast['id']}'\n"
msg += f"VIOLATION: '{action_cast['check']}'\n"
if action_cast["extra"]:
extra = action_cast["extra"]
extra = ", ".join([f"{k}: {v}" for k, v in extra.items()])
msg += f"EXTRA: {extra}\n"
msg += "=========\n"
print("NOTIFY", msg)
def adjust_position_size(self, trade_id, new_size):
pass
def adjust_protection(self, trade_id, new_protection):
pass
def bulk_adjust(self, action_cast_list):
for item in action_cast_list:
trade_id = item["id"]
check = item["check"]
if "extra" in item:
extra = item["extra"]
else:
log.error(f"Adjust action missing extra data: {item}")
continue
if check == "position_size":
new_size = extra["size"]
self.adjust_position_size(trade_id, new_size)
elif check == "protection":
self.adjust_protection(trade_id, extra)
def run_actions(self):
for action, action_cast_list in self.actions.items():
if action == "none":
continue
self.bulk_notify(action, action_cast_list)
if action == "close":
trade_ids = [action_cast["id"] for action_cast in action_cast_list]
self.bulk_close_trades(trade_ids)
elif action == "adjust":
self.bulk_adjust(action_cast_list)
def handle_violation(self, check_type, action, trade_id, **kwargs):
print("VIOLATION", check_type, action, trade_id, kwargs)
if action == "none":
return
self.add_action(action, check_type, trade_id, **kwargs)
# TODO: close/notify for: # TODO: close/notify for:
# - trading time # - trading time
# - trends # - trends