Implement trade mutation pipeline and active management actions
This commit is contained in:
@@ -3,10 +3,12 @@ from datetime import datetime
|
||||
from decimal import Decimal as D
|
||||
|
||||
from core.exchanges.convert import (
|
||||
annotate_trade_tp_sl_percent,
|
||||
convert_trades,
|
||||
sl_percent_to_price,
|
||||
tp_percent_to_price,
|
||||
)
|
||||
from core.lib.notify import sendmsg
|
||||
from core.trading import assetfilter, checks, market, risk
|
||||
from core.trading.crossfilter import crossfilter
|
||||
from core.trading.market import get_base_quote, get_trade_size_in_base
|
||||
@@ -15,6 +17,10 @@ from core.util import logs
|
||||
log = logs.get_logger("ams")
|
||||
|
||||
|
||||
class TradeClosed(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ActiveManagement(object):
|
||||
def __init__(self, strategy):
|
||||
self.strategy = strategy
|
||||
@@ -74,7 +80,6 @@ class ActiveManagement(object):
|
||||
self.close_trade(trade_id)
|
||||
|
||||
def bulk_notify(self, action, action_cast_list):
|
||||
print("CALL", action, action_cast_list)
|
||||
msg = ""
|
||||
for action_cast in action_cast_list:
|
||||
msg += f"ACTION: '{action}' on trade ID '{action_cast['id']}'\n"
|
||||
@@ -85,7 +90,7 @@ class ActiveManagement(object):
|
||||
msg += f"EXTRA: {extra}\n"
|
||||
msg += "=========\n"
|
||||
|
||||
print("NOTIFY", msg)
|
||||
sendmsg(self.strategy.user, msg, title=f"AMS: {action}")
|
||||
|
||||
def adjust_position_size(self, trade_id, new_size):
|
||||
pass
|
||||
@@ -120,7 +125,6 @@ class ActiveManagement(object):
|
||||
self.bulk_adjust(action_cast_list)
|
||||
|
||||
def handle_violation(self, check_type, action, trade_id, **kwargs):
|
||||
print("VIOLATION", check_type, action, trade_id, kwargs)
|
||||
if action == "none":
|
||||
return
|
||||
self.add_action(action, check_type, trade_id, **kwargs)
|
||||
@@ -148,6 +152,8 @@ class ActiveManagement(object):
|
||||
self.handle_violation(
|
||||
"trading_time", self.policy.when_trading_time_violated, trade["id"]
|
||||
)
|
||||
if self.policy.when_trading_time_violated == "close":
|
||||
raise TradeClosed
|
||||
|
||||
def check_trends(self, trade):
|
||||
direction = trade["direction"]
|
||||
@@ -157,6 +163,8 @@ class ActiveManagement(object):
|
||||
self.handle_violation(
|
||||
"trends", self.policy.when_trends_violated, trade["id"]
|
||||
)
|
||||
if self.policy.when_trends_violated == "close":
|
||||
raise TradeClosed
|
||||
|
||||
def check_position_size(self, trade):
|
||||
"""
|
||||
@@ -189,8 +197,13 @@ class ActiveManagement(object):
|
||||
"position_size",
|
||||
self.policy.when_position_size_violated,
|
||||
trade["id"],
|
||||
{"size": expected_trade_size},
|
||||
size=expected_trade_size,
|
||||
)
|
||||
if self.policy.when_position_size_violated == "close":
|
||||
raise TradeClosed
|
||||
elif self.policy.when_position_size_violated == "adjust":
|
||||
trade["amount"] = expected_trade_size
|
||||
trade["units"] = expected_trade_size
|
||||
|
||||
def check_protection(self, trade):
|
||||
deviation = D(0.05) # 5%
|
||||
@@ -252,8 +265,14 @@ class ActiveManagement(object):
|
||||
"protection",
|
||||
self.policy.when_protection_violated,
|
||||
trade["id"],
|
||||
violations
|
||||
**violations
|
||||
)
|
||||
if self.policy.when_protection_violated == "close":
|
||||
raise TradeClosed
|
||||
elif self.policy.when_protection_violated == "adjust":
|
||||
trade.update(violations)
|
||||
annotate_trade_tp_sl_percent(trade)
|
||||
market.convert_trades_to_usd(self.strategy.account, [trade])
|
||||
|
||||
def check_asset_groups(self, trade):
|
||||
if self.strategy.asset_group is not None:
|
||||
@@ -267,6 +286,8 @@ class ActiveManagement(object):
|
||||
self.handle_violation(
|
||||
"asset_group", self.policy.when_asset_groups_violated, trade["id"]
|
||||
)
|
||||
if self.policy.when_asset_groups_violated == "close":
|
||||
raise TradeClosed
|
||||
|
||||
def get_sorted_trades_copy(self, trades, reverse=True):
|
||||
trades_copy = deepcopy(trades)
|
||||
@@ -280,18 +301,22 @@ class ActiveManagement(object):
|
||||
def check_crossfilter(self, trades):
|
||||
close_trades = []
|
||||
|
||||
trades_copy = self.get_sorted_trades_copy(trades)
|
||||
# trades_copy = self.get_sorted_trades_copy(trades)
|
||||
|
||||
iterations = 0
|
||||
finished = []
|
||||
# Recursively run crossfilter on the newest-first list until we have no more
|
||||
# conflicts
|
||||
while not len(finished) == len(trades):
|
||||
length_before = len(trades)
|
||||
while not len(finished) == length_before:
|
||||
iterations += 1
|
||||
if iterations > 10000:
|
||||
raise Exception("Too many iterations")
|
||||
# For each trade
|
||||
for trade in trades_copy:
|
||||
# We need reverse because we are removing items from the list
|
||||
# This works in our favour because the list is sorted the wrong
|
||||
# way around in run_checks()
|
||||
for trade in reversed(trades):
|
||||
# Abort if we've already checked this trade
|
||||
if trade in close_trades:
|
||||
continue
|
||||
@@ -299,7 +324,7 @@ class ActiveManagement(object):
|
||||
# Also remove if we have already checked this
|
||||
others = [
|
||||
t
|
||||
for t in trades_copy
|
||||
for t in trades
|
||||
if t["id"] != trade["id"] and t not in close_trades
|
||||
]
|
||||
symbol = trade["symbol"]
|
||||
@@ -320,6 +345,10 @@ class ActiveManagement(object):
|
||||
# And don't check it again
|
||||
finished.append(trade)
|
||||
close_trades.append(trade)
|
||||
|
||||
# Remove it from the trades list
|
||||
if self.policy.when_crossfilter_violated == "close":
|
||||
trades.remove(trade)
|
||||
if not close_trades:
|
||||
return
|
||||
|
||||
@@ -334,15 +363,17 @@ class ActiveManagement(object):
|
||||
return
|
||||
max_open_pass = risk.check_max_open_trades(self.strategy.risk_model, trades)
|
||||
if not max_open_pass:
|
||||
trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
|
||||
# trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
|
||||
# fmt: off
|
||||
trades_over_limit = trades_copy[self.strategy.risk_model.max_open_trades:]
|
||||
trades_over_limit = trades[self.strategy.risk_model.max_open_trades:]
|
||||
for trade in trades_over_limit:
|
||||
self.handle_violation(
|
||||
"max_open_trades",
|
||||
self.policy.when_max_open_trades_violated,
|
||||
trade["id"],
|
||||
)
|
||||
if self.policy.when_max_open_trades_violated == "close":
|
||||
trades.remove(trade)
|
||||
|
||||
def check_max_open_trades_per_symbol(self, trades):
|
||||
if self.strategy.risk_model is None:
|
||||
@@ -352,10 +383,10 @@ class ActiveManagement(object):
|
||||
)
|
||||
max_open_pass = list(max_open_pass)
|
||||
if max_open_pass:
|
||||
trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
|
||||
# trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
|
||||
trades_over_limit = []
|
||||
for symbol in max_open_pass:
|
||||
symbol_trades = [x for x in trades_copy if x["symbol"] == symbol]
|
||||
symbol_trades = [x for x in trades if x["symbol"] == symbol]
|
||||
# fmt: off
|
||||
exceeding_limit = symbol_trades[
|
||||
self.strategy.risk_model.max_open_trades_per_symbol:
|
||||
@@ -369,8 +400,10 @@ class ActiveManagement(object):
|
||||
self.policy.when_max_open_trades_violated,
|
||||
trade["id"],
|
||||
)
|
||||
if self.policy.when_max_open_trades_violated == "close":
|
||||
trades.remove(trade)
|
||||
|
||||
def check_max_loss(self):
|
||||
def check_max_loss(self, trades):
|
||||
if self.strategy.risk_model is None:
|
||||
return
|
||||
check_passed = risk.check_max_loss(
|
||||
@@ -382,6 +415,9 @@ class ActiveManagement(object):
|
||||
self.handle_violation(
|
||||
"max_loss", self.policy.when_max_loss_violated, None # Close all trades
|
||||
)
|
||||
if self.policy.when_max_loss_violated == "close":
|
||||
for trade in trades:
|
||||
trades.remove(trade)
|
||||
|
||||
def check_max_risk(self, trades):
|
||||
if self.strategy.risk_model is None:
|
||||
@@ -389,7 +425,7 @@ class ActiveManagement(object):
|
||||
close_trades = []
|
||||
|
||||
trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
|
||||
market.convert_trades_to_usd(self.strategy.account, trades_copy)
|
||||
# market.convert_trades_to_usd(self.strategy.account, trades_copy)
|
||||
|
||||
iterations = 0
|
||||
finished = False
|
||||
@@ -414,21 +450,36 @@ class ActiveManagement(object):
|
||||
self.handle_violation(
|
||||
"max_risk", self.policy.when_max_risk_violated, trade["id"]
|
||||
)
|
||||
if self.policy.when_max_risk_violated == "close":
|
||||
trades.remove(trade)
|
||||
|
||||
def run_checks(self):
|
||||
converted_trades = convert_trades(self.get_trades())
|
||||
for trade in converted_trades:
|
||||
self.check_trading_time(trade)
|
||||
self.check_trends(trade)
|
||||
self.check_position_size(trade)
|
||||
self.check_protection(trade)
|
||||
self.check_asset_groups(trade)
|
||||
trades_copy = self.get_sorted_trades_copy(converted_trades, reverse=False)
|
||||
market.convert_trades_to_usd(self.strategy.account, trades_copy)
|
||||
for trade in reversed(trades_copy):
|
||||
try:
|
||||
self.check_trading_time(trade)
|
||||
self.check_trends(trade)
|
||||
self.check_position_size(trade)
|
||||
self.check_protection(trade)
|
||||
self.check_asset_groups(trade)
|
||||
except TradeClosed:
|
||||
# Trade was closed, don't check it again
|
||||
trades_copy.remove(trade)
|
||||
continue
|
||||
|
||||
self.check_crossfilter(converted_trades)
|
||||
self.check_max_open_trades(converted_trades)
|
||||
self.check_max_open_trades_per_symbol(converted_trades)
|
||||
self.check_max_loss()
|
||||
self.check_max_risk(converted_trades)
|
||||
self.check_crossfilter(trades_copy)
|
||||
self.check_max_open_trades(trades_copy)
|
||||
self.check_max_open_trades_per_symbol(trades_copy)
|
||||
self.check_max_loss(trades_copy)
|
||||
self.check_max_risk(trades_copy)
|
||||
|
||||
def execute_actions(self):
|
||||
if not self.actions:
|
||||
return
|
||||
self.reduce_actions()
|
||||
self.run_actions()
|
||||
|
||||
# Trading Time
|
||||
# Max loss
|
||||
|
||||
Reference in New Issue
Block a user