Finish implementing active management hooks

master
Mark Veidemanis 1 year ago
parent 3e35214e82
commit 466b17400f
Signed by: m
GPG Key ID: 5ACFCEED46C0904F

@ -98,5 +98,4 @@ def to_currency(direction, account, amount, from_currency, to_currency):
# Convert the amount to the destination currency
converted = D(amount) * price
return converted

@ -419,7 +419,7 @@ class ActiveManagementPolicyForm(RestrictedFormMixin, ModelForm):
"when_asset_groups_violated": "The action to take when a trade violating the asset group rules is discovered.",
"when_max_open_trades_violated": "The action to take when a trade puts the account above the maximum open trades.",
"when_max_open_trades_per_symbol_violated": "The action to take when a trade puts the account above the maximum open trades per symbol.",
"when_max_loss_violated": "The action to take when a trade puts the account above the maximum loss.",
"when_max_loss_violated": "The action to take when the account exceeds its maximum loss. NOTE: The close action will close all trades.",
"when_max_risk_violated": "The action to take when a trade exposes the account to more than the maximum risk.",
"when_crossfilter_violated": "The action to take when a trade is deemed to conflict with another -- e.g. a buy and sell on the same asset.",
}

@ -1,17 +1,9 @@
from datetime import time
from decimal import Decimal as D
from os import getenv
from unittest.mock import Mock, patch
from core.models import (
Account,
Hook,
OrderSettings,
RiskModel,
Signal,
Strategy,
TradingTime,
User,
)
from core.models import Account, OrderSettings, RiskModel, Strategy, TradingTime, User
# Create patch mixin to mock out the Elastic client
@ -48,7 +40,7 @@ class SymbolPriceMock:
cls.patcher = patch("core.exchanges.common.get_symbol_price")
patcher = cls.patcher.start()
patcher.return_value = 1
patcher.return_value = D(1)
@classmethod
def tearDownClass(cls):
@ -141,7 +133,7 @@ class StrategyMixin:
max_loss_percent=50,
max_risk_percent=10,
max_open_trades=10,
max_open_trades_per_symbol=2,
max_open_trades_per_symbol=5,
)
self.strategy = Strategy.objects.create(

@ -1,5 +1,5 @@
from decimal import Decimal as D
from unittest.mock import Mock, patch
from unittest.mock import patch
from django.test import TestCase
@ -28,6 +28,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
name="Test Account",
exchange="fake",
currency="USD",
initial_balance=100000,
)
self.account.supported_symbols = ["EUR_USD", "EUR_XXX", "USD_EUR", "XXX_EUR"]
self.account.save()
@ -75,7 +76,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
"id": "20084",
"symbol": "EUR_USD",
"price": "1.06331",
"openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:38
"openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:39
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
@ -95,6 +96,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
# Run parse_time on all items in trades
for trade in self.trades:
trade["openTime"] = parse_time(trade)
self.balance = 100000
self.balance_usd = 120000
self.ams.get_trades = self.fake_get_trades
self.ams.get_balance = self.fake_get_balance
# self.ams.trades = self.trades
@ -123,12 +127,25 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
trade["openTime"] = parse_time(trade)
self.trades.append(trade)
def amend_tp_sl_flip_side(self):
"""
Amend the take profit and stop loss orders to be the opposite side.
This lets the protection tests pass, so we can only test one violation
per test.
"""
for trade in self.trades:
if trade["side"] == "short":
trade["stopLossOrder"]["price"] = "1.07386"
trade["takeProfitOrder"]["price"] = "1.04728"
def fake_get_trades(self):
self.ams.trades = self.trades
return self.trades
def fake_get_balance(self):
return 10000
def fake_get_balance(self, return_usd=None):
if return_usd:
return self.balance_usd
return self.balance
def fake_get_currencies(self, symbols):
pass
@ -139,7 +156,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
def test_get_balance(self):
balance = self.ams.get_balance()
self.assertEqual(balance, 10000)
self.assertEqual(balance, self.balance)
def check_violation(
self, violation, calls, expected_action, expected_trades, expected_args=None
@ -153,6 +170,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
:param: expected_trades: list of expected trades to be passed to the violation
:param: expected_args: optional, expected args to be passed to the violation
"""
self.assertEqual(len(calls), len(expected_trades))
calls = list(calls)
violation_calls = []
for call in calls:
@ -160,11 +178,18 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
violation_calls.append(call)
self.assertEqual(len(violation_calls), len(expected_trades))
expected_trades = convert_trades(expected_trades)
if all(expected_trades):
expected_trades = convert_trades(expected_trades)
for call in violation_calls:
# Ensure the correct action has been called, like close
self.assertEqual(call[0][1], expected_action)
# Ensure the correct trade has been passed to the violation
trade = call[0][2]
if trade:
for field in list(trade.keys()):
if "_usd" in field:
if field in trade.keys():
del trade[field]
self.assertIn(call[0][2], expected_trades)
if expected_args:
self.assertEqual(call[0][3], expected_args)
@ -172,7 +197,6 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_run_checks(self, handle_violation):
self.ams.run_checks()
print("handle_violation.call_count", handle_violation.call_args_list)
self.assertEqual(handle_violation.call_count, 0)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
@ -201,7 +225,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.amend_tp_sl_flip_side()
self.strategy.save()
self.ams.run_checks()
self.check_violation(
"trends", handle_violation.call_args_list, "close", self.trades
@ -216,8 +242,14 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
self.ams.run_checks()
self.check_violation("trends", handle_violation.call_args_list, "close", [])
# Mock crossfilter here since we want to allow this conflict in order to test that
# trends only close trades that are in the wrong direction
@patch(
"core.trading.active_management.ActiveManagement.check_crossfilter",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated_partial(self, handle_violation):
def test_trends_violated_partial(self, handle_violation, check_crossfilter):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
@ -225,6 +257,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
# Change the side of the first trade to match the trends
self.trades[0]["side"] = "short"
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.check_violation(
@ -241,7 +274,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
handle_violation.call_args_list,
"close",
[self.trades[0]],
{"size": 50},
{"size": 500},
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
@ -270,7 +303,6 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
self.trades[0]["takeProfitOrder"] = None
self.trades[0]["stopLossOrder"] = None
self.ams.run_checks()
print("CALLS", handle_violation.call_args_list)
self.assertEqual(handle_violation.call_count, 0)
@ -301,6 +333,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
def test_asset_groups_violated_invert(self, handle_violation):
self.trades[0]["side"] = "short"
self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
asset_group = AssetGroup.objects.create(
user=self.user,
name="Test Asset Group",
@ -325,6 +358,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_side(self, handle_violation):
self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.check_violation(
@ -336,16 +370,19 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_side_multiple(self, handle_violation):
self.add_trade("20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z")
self.add_trade("20086", "EUR_USD", "short", "2023-02-14T12:39:06.302917985Z")
self.add_trade("20087", "EUR_USD", "short", "2023-02-10T12:39:06.302917985Z")
self.add_trade(
"20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z"
) # 2:
self.add_trade("20086", "EUR_USD", "short", "2023-02-13T13:39:07.302917985Z")
self.add_trade("20087", "EUR_USD", "short", "2023-02-13T14:39:06.302917985Z")
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.check_violation(
"crossfilter",
handle_violation.call_args_list,
"close",
self.trades[0:4], # Only close newer trades
self.trades[2:], # Only close newer trades
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
@ -363,16 +400,18 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_symbol_multiple(self, handle_violation):
self.add_trade("20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z")
self.add_trade("20086", "USD_EUR", "long", "2023-02-14T12:39:06.302917985Z")
self.add_trade("20087", "USD_EUR", "long", "2023-02-10T12:39:06.302917985Z")
self.add_trade(
"20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z"
) # 2:
self.add_trade("20086", "USD_EUR", "long", "2023-02-13T13:39:06.302917985Z")
self.add_trade("20087", "USD_EUR", "long", "2023-02-13T14:39:06.302917985Z")
self.ams.run_checks()
self.check_violation(
"crossfilter",
handle_violation.call_args_list,
"close",
self.trades[0:4], # Only close newer trades
self.trades[2:], # Only close newer trades
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
@ -380,7 +419,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
for x in range(9):
self.add_trade(
str(x),
"EUR_USD",
f"EUR_USD{x}", # Vary symbol to prevent max open trades per symbol
"long",
f"2023-02-13T12:39:1{x}.302917985Z",
)
@ -395,7 +434,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_open_trades_per_symbol_violated(self, handle_violation):
for x in range(2):
for x in range(4):
self.add_trade(
str(x),
"EUR_USD",
@ -408,11 +447,102 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
"max_open_trades_per_symbol",
handle_violation.call_args_list,
"close",
self.trades[2:], # Only close newer trades
self.trades[5:], # Only close newer trades
)
def test_max_loss_violated(self):
pass
# Mock position size as we have no way of checking the balance at the start of the
# trade.
# TODO: Fix this when we have a way of checking the balance at the start of the
# trade.
# Max risk is also mocked as this puts us over the limit, due to the low account
# size.
@patch(
"core.trading.active_management.ActiveManagement.check_max_risk",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_loss_violated(
self, handle_violation, check_position_size, check_max_risk
):
self.balance = D("1")
self.balance_usd = D("0.69")
self.ams.run_checks()
def test_max_risk_violated(self):
pass
self.check_violation(
"max_loss",
handle_violation.call_args_list,
"close",
[None],
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_protection",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_risk_violated(
self, handle_violation, check_protection, check_position_size
):
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.ams.run_checks()
self.check_violation(
"max_risk",
handle_violation.call_args_list,
"close",
[self.trades[2]],
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_protection",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_risk_violated_multiple(
self, handle_violation, check_protection, check_position_size
):
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.add_trade(
"20086",
"EUR_USD",
"long",
"2023-02-13T15:45:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.trades[3]["stopLossOrder"]["price"] = "0.001"
self.trades[3]["currentUnits"] = "13000"
self.ams.run_checks()
self.check_violation(
"max_risk",
handle_violation.call_args_list,
"close",
[self.trades[2], self.trades[3]],
)

@ -1,18 +1,7 @@
from datetime import time
import freezegun
from django.test import TestCase
from core.models import (
Account,
Hook,
OrderSettings,
RiskModel,
Signal,
Strategy,
TradingTime,
User,
)
from core.models import Account, Hook, Signal, User
from core.tests.helpers import StrategyMixin
from core.trading import checks

@ -1,9 +1,8 @@
from django.test import TestCase
import core.trading.market # noqa # to avoid messy circular import
from core.exchanges import convert
from core.models import RiskModel, User
import core.trading.market # to avoid messy circular import
from core.trading import risk

@ -1,14 +1,13 @@
from copy import deepcopy
from datetime import datetime
from decimal import Decimal as D
import core.trading.market # to avoid messy circular import
from core.exchanges.convert import (
convert_trades,
side_to_direction,
sl_percent_to_price,
tp_percent_to_price,
)
from core.trading import assetfilter, checks, risk
from core.trading import assetfilter, checks, market, risk
from core.trading.crossfilter import crossfilter
from core.trading.market import get_base_quote, get_trade_size_in_base
@ -20,17 +19,28 @@ class ActiveManagement(object):
self.trades = []
self.balance = None
self.balance_usd = None
def get_trades(self):
if not self.trades:
self.trades = self.strategy.account.client.get_all_open_trades()
return self.trades
def get_balance(self):
if self.balance is None:
self.balance = self.strategy.account.client.get_balance()
def get_balance(self, return_usd=False):
if return_usd:
if self.balance_usd is None:
self.balance_usd = self.strategy.account.client.get_balance(
return_usd=True
)
else:
return self.balance_usd
else:
return self.balance
if self.balance is None:
self.balance = self.strategy.account.client.get_balance(
return_usd=False
)
else:
return self.balance
def handle_violation(self, check_type, action, trade, **kwargs):
print("VIOLATION", check_type, action, trade, kwargs)
@ -54,8 +64,10 @@ class ActiveManagement(object):
def check_position_size(self, trade):
"""
Check the position size is within the allowed deviation.
WARNING: This uses the current balance, not the balance at the time of the trade.
WARNING: This uses the current symbol prices, not those at the time of the trade.
WARNING: This uses the current balance, not the balance at the time of the
trade.
WARNING: This uses the current symbol prices, not those at the time of the
trade.
This should normally be run every 5 seconds, so this is fine.
"""
# TODO: add the trade value to the balance
@ -86,10 +98,12 @@ class ActiveManagement(object):
def check_protection(self, trade):
deviation = D(0.05) # 5%
# fmt: off
matches = {
"stop_loss_percent": self.strategy.order_settings.stop_loss_percent,
"take_profit_percent": self.strategy.order_settings.take_profit_percent,
"trailing_stop_percent": self.strategy.order_settings.trailing_stop_loss_percent,
"trailing_stop_percent":
self.strategy.order_settings.trailing_stop_loss_percent,
}
violations = {}
@ -155,7 +169,7 @@ class ActiveManagement(object):
)
def get_sorted_trades_copy(self, trades, reverse=True):
trades_copy = trades.copy()
trades_copy = deepcopy(trades)
# sort by open time, newest first
trades_copy.sort(
key=lambda x: datetime.strptime(x["open_time"], "%Y-%m-%dT%H:%M:%S.%fZ"),
@ -170,7 +184,8 @@ class ActiveManagement(object):
iterations = 0
finished = []
# Recursively run crossfilter on the newest-first list until we have no more conflicts
# Recursively run crossfilter on the newest-first list until we have no more
# conflicts
while not len(finished) == len(trades):
iterations += 1
if iterations > 10000:
@ -207,17 +222,7 @@ class ActiveManagement(object):
close_trades.append(trade)
if not close_trades:
return
# For each conflicting symbol, identify the oldest trades
# removed_trades = []
# for symbol in conflict:
# newest_trade = max(conflict, key=lambda x: datetime.strptime(x["open_time"], "%Y-%m-%dT%H:%M:%S.%fZ"))
# removed_trades.append(newest_trade)
# print("KEEP TRADES", keep_trade_ids)
# close_trades = []
# for x in keep_trade_ids:
# for position in conflict[x]:
# if position["id"] not in keep_trade_ids[x]:
# close_trades.append(position)
if close_trades:
for trade in close_trades:
self.handle_violation(
@ -225,57 +230,90 @@ class ActiveManagement(object):
)
def check_max_open_trades(self, trades):
if self.strategy.risk_model is not None:
max_open_pass = risk.check_max_open_trades(self.strategy.risk_model, trades)
if not max_open_pass:
trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
print("TRADES COPY", [x["id"] for x in trades_copy])
print("MAX", self.strategy.risk_model.max_open_trades)
trades_over_limit = trades_copy[
self.strategy.risk_model.max_open_trades :
]
for trade in trades_over_limit:
self.handle_violation(
"max_open_trades",
self.policy.when_max_open_trades_violated,
trade,
)
print("TRADES OVER LIMNIT", trades_over_limit)
if self.strategy.risk_model is None:
return
max_open_pass = risk.check_max_open_trades(self.strategy.risk_model, trades)
if not max_open_pass:
trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
# fmt: off
trades_over_limit = trades_copy[self.strategy.risk_model.max_open_trades:]
for trade in trades_over_limit:
self.handle_violation(
"max_open_trades",
self.policy.when_max_open_trades_violated,
trade,
)
def check_max_open_trades_per_symbol(self, trades):
if self.strategy.risk_model is not None:
max_open_pass = risk.check_max_open_trades_per_symbol(
self.strategy.risk_model, trades, return_symbols=True
)
print("max_open_pass", max_open_pass)
max_open_pass = list(max_open_pass)
print("MAX OPEN PASS", max_open_pass)
if max_open_pass:
trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
trades_over_limit = []
for symbol in max_open_pass:
print("SYMBOL", symbol)
print("TRADES", trades)
symbol_trades = [x for x in trades_copy if x["symbol"] == symbol]
exceeding_limit = symbol_trades[
self.strategy.risk_model.max_open_trades_per_symbol :
]
for x in exceeding_limit:
trades_over_limit.append(x)
for trade in trades_over_limit:
self.handle_violation(
"max_open_trades_per_symbol",
self.policy.when_max_open_trades_violated,
trade,
)
print("TRADES OVER LIMNIT", trades_over_limit)
if self.strategy.risk_model is None:
return
max_open_pass = risk.check_max_open_trades_per_symbol(
self.strategy.risk_model, trades, return_symbols=True
)
max_open_pass = list(max_open_pass)
if max_open_pass:
trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
trades_over_limit = []
for symbol in max_open_pass:
symbol_trades = [x for x in trades_copy if x["symbol"] == symbol]
# fmt: off
exceeding_limit = symbol_trades[
self.strategy.risk_model.max_open_trades_per_symbol:
]
for x in exceeding_limit:
trades_over_limit.append(x)
for trade in trades_over_limit:
self.handle_violation(
"max_open_trades_per_symbol",
self.policy.when_max_open_trades_violated,
trade,
)
def check_max_loss(self):
check_passed = risk.check_max_loss(self.strategy.risk_model, self.strategy.account.initial_balance, self.get_balance())
if self.strategy.risk_model is None:
return
check_passed = risk.check_max_loss(
self.strategy.risk_model,
self.strategy.account.initial_balance,
self.get_balance(),
)
if not check_passed:
self.handle_violation(
"max_loss", self.policy.when_max_loss_violated, None # Close all trades
)
def check_max_risk(self, trades):
pass
if self.strategy.risk_model is None:
return
close_trades = []
trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
market.convert_trades_to_usd(self.strategy.account, trades_copy)
iterations = 0
finished = False
while not finished:
iterations += 1
if iterations > 10000:
raise Exception("Too many iterations")
check_passed = risk.check_max_risk(
self.strategy.risk_model,
self.get_balance(return_usd=True),
trades_copy,
)
if check_passed:
finished = True
else:
# Add the newest trade to close_trades and remove it from trades_copy
close_trades.append(trades_copy[-1])
trades_copy = trades_copy[:-1]
if close_trades:
for trade in close_trades:
self.handle_violation(
"max_risk", self.policy.when_max_risk_violated, trade
)
def run_checks(self):
converted_trades = convert_trades(self.get_trades())

@ -20,7 +20,7 @@ def convert_trades_to_usd(account, trades):
:return: List of trades, with amount_usd added
"""
for trade in trades:
amount = trade["amount"]
amount = D(trade["amount"])
symbol = trade["symbol"]
side = trade["side"]
direction = side_to_direction(side)

@ -26,6 +26,7 @@ def check_max_risk(risk_model, account_balance_usd, account_trades):
# Calculate the max risk of the account in USD
max_risk_usd = account_balance_usd * (max_risk_percent / D(100))
total_risk = 0
for trade in account_trades:
max_tmp = []
# Need to calculate the max risk in base account currency
@ -36,7 +37,8 @@ def check_max_risk(risk_model, account_balance_usd, account_trades):
if "trailing_stop_loss_usd" in trade:
max_tmp.append(trade["trailing_stop_loss_usd"])
if max_tmp:
total_risk += max(max_tmp)
max_risk = max(max_tmp)
total_risk += max_risk
allowed = total_risk < max_risk_usd
return allowed
@ -59,7 +61,6 @@ def check_max_open_trades_per_symbol(risk_model, account_trades, return_symbols=
if symbol not in symbol_map:
symbol_map[symbol] = 0
symbol_map[symbol] += 1
print("Symbol map: ", symbol_map)
violating_symbols = []
for symbol, count in symbol_map.items():
if count >= risk_model.max_open_trades_per_symbol:

Loading…
Cancel
Save