Finish implementing active management hooks

This commit is contained in:
Mark Veidemanis 2023-02-18 11:54:30 +00:00
parent 3e35214e82
commit 466b17400f
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
9 changed files with 271 additions and 123 deletions

View File

@ -98,5 +98,4 @@ def to_currency(direction, account, amount, from_currency, to_currency):
# Convert the amount to the destination currency # Convert the amount to the destination currency
converted = D(amount) * price converted = D(amount) * price
return converted return converted

View File

@ -419,7 +419,7 @@ class ActiveManagementPolicyForm(RestrictedFormMixin, ModelForm):
"when_asset_groups_violated": "The action to take when a trade violating the asset group rules is discovered.", "when_asset_groups_violated": "The action to take when a trade violating the asset group rules is discovered.",
"when_max_open_trades_violated": "The action to take when a trade puts the account above the maximum open trades.", "when_max_open_trades_violated": "The action to take when a trade puts the account above the maximum open trades.",
"when_max_open_trades_per_symbol_violated": "The action to take when a trade puts the account above the maximum open trades per symbol.", "when_max_open_trades_per_symbol_violated": "The action to take when a trade puts the account above the maximum open trades per symbol.",
"when_max_loss_violated": "The action to take when a trade puts the account above the maximum loss.", "when_max_loss_violated": "The action to take when the account exceeds its maximum loss. NOTE: The close action will close all trades.",
"when_max_risk_violated": "The action to take when a trade exposes the account to more than the maximum risk.", "when_max_risk_violated": "The action to take when a trade exposes the account to more than the maximum risk.",
"when_crossfilter_violated": "The action to take when a trade is deemed to conflict with another -- e.g. a buy and sell on the same asset.", "when_crossfilter_violated": "The action to take when a trade is deemed to conflict with another -- e.g. a buy and sell on the same asset.",
} }

View File

@ -1,17 +1,9 @@
from datetime import time from datetime import time
from decimal import Decimal as D
from os import getenv from os import getenv
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from core.models import ( from core.models import Account, OrderSettings, RiskModel, Strategy, TradingTime, User
Account,
Hook,
OrderSettings,
RiskModel,
Signal,
Strategy,
TradingTime,
User,
)
# Create patch mixin to mock out the Elastic client # Create patch mixin to mock out the Elastic client
@ -48,7 +40,7 @@ class SymbolPriceMock:
cls.patcher = patch("core.exchanges.common.get_symbol_price") cls.patcher = patch("core.exchanges.common.get_symbol_price")
patcher = cls.patcher.start() patcher = cls.patcher.start()
patcher.return_value = 1 patcher.return_value = D(1)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
@ -141,7 +133,7 @@ class StrategyMixin:
max_loss_percent=50, max_loss_percent=50,
max_risk_percent=10, max_risk_percent=10,
max_open_trades=10, max_open_trades=10,
max_open_trades_per_symbol=2, max_open_trades_per_symbol=5,
) )
self.strategy = Strategy.objects.create( self.strategy = Strategy.objects.create(

View File

@ -1,5 +1,5 @@
from decimal import Decimal as D from decimal import Decimal as D
from unittest.mock import Mock, patch from unittest.mock import patch
from django.test import TestCase from django.test import TestCase
@ -28,6 +28,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
name="Test Account", name="Test Account",
exchange="fake", exchange="fake",
currency="USD", currency="USD",
initial_balance=100000,
) )
self.account.supported_symbols = ["EUR_USD", "EUR_XXX", "USD_EUR", "XXX_EUR"] self.account.supported_symbols = ["EUR_USD", "EUR_XXX", "USD_EUR", "XXX_EUR"]
self.account.save() self.account.save()
@ -75,7 +76,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
"id": "20084", "id": "20084",
"symbol": "EUR_USD", "symbol": "EUR_USD",
"price": "1.06331", "price": "1.06331",
"openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:38 "openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:39
"initialUnits": "10", "initialUnits": "10",
"initialMarginRequired": "0.2966", "initialMarginRequired": "0.2966",
"state": "OPEN", "state": "OPEN",
@ -95,6 +96,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
# Run parse_time on all items in trades # Run parse_time on all items in trades
for trade in self.trades: for trade in self.trades:
trade["openTime"] = parse_time(trade) trade["openTime"] = parse_time(trade)
self.balance = 100000
self.balance_usd = 120000
self.ams.get_trades = self.fake_get_trades self.ams.get_trades = self.fake_get_trades
self.ams.get_balance = self.fake_get_balance self.ams.get_balance = self.fake_get_balance
# self.ams.trades = self.trades # self.ams.trades = self.trades
@ -123,12 +127,25 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
trade["openTime"] = parse_time(trade) trade["openTime"] = parse_time(trade)
self.trades.append(trade) self.trades.append(trade)
def amend_tp_sl_flip_side(self):
"""
Amend the take profit and stop loss orders to be the opposite side.
This lets the protection tests pass, so we can only test one violation
per test.
"""
for trade in self.trades:
if trade["side"] == "short":
trade["stopLossOrder"]["price"] = "1.07386"
trade["takeProfitOrder"]["price"] = "1.04728"
def fake_get_trades(self): def fake_get_trades(self):
self.ams.trades = self.trades self.ams.trades = self.trades
return self.trades return self.trades
def fake_get_balance(self): def fake_get_balance(self, return_usd=None):
return 10000 if return_usd:
return self.balance_usd
return self.balance
def fake_get_currencies(self, symbols): def fake_get_currencies(self, symbols):
pass pass
@ -139,7 +156,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
def test_get_balance(self): def test_get_balance(self):
balance = self.ams.get_balance() balance = self.ams.get_balance()
self.assertEqual(balance, 10000) self.assertEqual(balance, self.balance)
def check_violation( def check_violation(
self, violation, calls, expected_action, expected_trades, expected_args=None self, violation, calls, expected_action, expected_trades, expected_args=None
@ -153,6 +170,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
:param: expected_trades: list of expected trades to be passed to the violation :param: expected_trades: list of expected trades to be passed to the violation
:param: expected_args: optional, expected args to be passed to the violation :param: expected_args: optional, expected args to be passed to the violation
""" """
self.assertEqual(len(calls), len(expected_trades))
calls = list(calls) calls = list(calls)
violation_calls = [] violation_calls = []
for call in calls: for call in calls:
@ -160,11 +178,18 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
violation_calls.append(call) violation_calls.append(call)
self.assertEqual(len(violation_calls), len(expected_trades)) self.assertEqual(len(violation_calls), len(expected_trades))
expected_trades = convert_trades(expected_trades) if all(expected_trades):
expected_trades = convert_trades(expected_trades)
for call in violation_calls: for call in violation_calls:
# Ensure the correct action has been called, like close # Ensure the correct action has been called, like close
self.assertEqual(call[0][1], expected_action) self.assertEqual(call[0][1], expected_action)
# Ensure the correct trade has been passed to the violation # Ensure the correct trade has been passed to the violation
trade = call[0][2]
if trade:
for field in list(trade.keys()):
if "_usd" in field:
if field in trade.keys():
del trade[field]
self.assertIn(call[0][2], expected_trades) self.assertIn(call[0][2], expected_trades)
if expected_args: if expected_args:
self.assertEqual(call[0][3], expected_args) self.assertEqual(call[0][3], expected_args)
@ -172,7 +197,6 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_run_checks(self, handle_violation): def test_run_checks(self, handle_violation):
self.ams.run_checks() self.ams.run_checks()
print("handle_violation.call_count", handle_violation.call_args_list)
self.assertEqual(handle_violation.call_count, 0) self.assertEqual(handle_violation.call_count, 0)
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
@ -201,7 +225,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
signal = self.create_hook_signal() signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal]) self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"} self.strategy.trends = {"EUR_USD": "sell"}
self.amend_tp_sl_flip_side()
self.strategy.save() self.strategy.save()
self.ams.run_checks() self.ams.run_checks()
self.check_violation( self.check_violation(
"trends", handle_violation.call_args_list, "close", self.trades "trends", handle_violation.call_args_list, "close", self.trades
@ -216,8 +242,14 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
self.ams.run_checks() self.ams.run_checks()
self.check_violation("trends", handle_violation.call_args_list, "close", []) self.check_violation("trends", handle_violation.call_args_list, "close", [])
# Mock crossfilter here since we want to allow this conflict in order to test that
# trends only close trades that are in the wrong direction
@patch(
"core.trading.active_management.ActiveManagement.check_crossfilter",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated_partial(self, handle_violation): def test_trends_violated_partial(self, handle_violation, check_crossfilter):
signal = self.create_hook_signal() signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal]) self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"} self.strategy.trends = {"EUR_USD": "sell"}
@ -225,6 +257,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
# Change the side of the first trade to match the trends # Change the side of the first trade to match the trends
self.trades[0]["side"] = "short" self.trades[0]["side"] = "short"
self.amend_tp_sl_flip_side()
self.ams.run_checks() self.ams.run_checks()
self.check_violation( self.check_violation(
@ -241,7 +274,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
handle_violation.call_args_list, handle_violation.call_args_list,
"close", "close",
[self.trades[0]], [self.trades[0]],
{"size": 50}, {"size": 500},
) )
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
@ -270,7 +303,6 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
self.trades[0]["takeProfitOrder"] = None self.trades[0]["takeProfitOrder"] = None
self.trades[0]["stopLossOrder"] = None self.trades[0]["stopLossOrder"] = None
self.ams.run_checks() self.ams.run_checks()
print("CALLS", handle_violation.call_args_list)
self.assertEqual(handle_violation.call_count, 0) self.assertEqual(handle_violation.call_count, 0)
@ -301,6 +333,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
def test_asset_groups_violated_invert(self, handle_violation): def test_asset_groups_violated_invert(self, handle_violation):
self.trades[0]["side"] = "short" self.trades[0]["side"] = "short"
self.trades[1]["side"] = "short" self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
asset_group = AssetGroup.objects.create( asset_group = AssetGroup.objects.create(
user=self.user, user=self.user,
name="Test Asset Group", name="Test Asset Group",
@ -325,6 +358,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_side(self, handle_violation): def test_crossfilter_violated_side(self, handle_violation):
self.trades[1]["side"] = "short" self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
self.ams.run_checks() self.ams.run_checks()
self.check_violation( self.check_violation(
@ -336,16 +370,19 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_side_multiple(self, handle_violation): def test_crossfilter_violated_side_multiple(self, handle_violation):
self.add_trade("20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z") self.add_trade(
self.add_trade("20086", "EUR_USD", "short", "2023-02-14T12:39:06.302917985Z") "20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z"
self.add_trade("20087", "EUR_USD", "short", "2023-02-10T12:39:06.302917985Z") ) # 2:
self.add_trade("20086", "EUR_USD", "short", "2023-02-13T13:39:07.302917985Z")
self.add_trade("20087", "EUR_USD", "short", "2023-02-13T14:39:06.302917985Z")
self.amend_tp_sl_flip_side()
self.ams.run_checks() self.ams.run_checks()
self.check_violation( self.check_violation(
"crossfilter", "crossfilter",
handle_violation.call_args_list, handle_violation.call_args_list,
"close", "close",
self.trades[0:4], # Only close newer trades self.trades[2:], # Only close newer trades
) )
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
@ -363,16 +400,18 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_symbol_multiple(self, handle_violation): def test_crossfilter_violated_symbol_multiple(self, handle_violation):
self.add_trade("20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z") self.add_trade(
self.add_trade("20086", "USD_EUR", "long", "2023-02-14T12:39:06.302917985Z") "20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z"
self.add_trade("20087", "USD_EUR", "long", "2023-02-10T12:39:06.302917985Z") ) # 2:
self.add_trade("20086", "USD_EUR", "long", "2023-02-13T13:39:06.302917985Z")
self.add_trade("20087", "USD_EUR", "long", "2023-02-13T14:39:06.302917985Z")
self.ams.run_checks() self.ams.run_checks()
self.check_violation( self.check_violation(
"crossfilter", "crossfilter",
handle_violation.call_args_list, handle_violation.call_args_list,
"close", "close",
self.trades[0:4], # Only close newer trades self.trades[2:], # Only close newer trades
) )
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
@ -380,7 +419,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
for x in range(9): for x in range(9):
self.add_trade( self.add_trade(
str(x), str(x),
"EUR_USD", f"EUR_USD{x}", # Vary symbol to prevent max open trades per symbol
"long", "long",
f"2023-02-13T12:39:1{x}.302917985Z", f"2023-02-13T12:39:1{x}.302917985Z",
) )
@ -395,7 +434,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_open_trades_per_symbol_violated(self, handle_violation): def test_max_open_trades_per_symbol_violated(self, handle_violation):
for x in range(2): for x in range(4):
self.add_trade( self.add_trade(
str(x), str(x),
"EUR_USD", "EUR_USD",
@ -408,11 +447,102 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
"max_open_trades_per_symbol", "max_open_trades_per_symbol",
handle_violation.call_args_list, handle_violation.call_args_list,
"close", "close",
self.trades[2:], # Only close newer trades self.trades[5:], # Only close newer trades
) )
def test_max_loss_violated(self): # Mock position size as we have no way of checking the balance at the start of the
pass # trade.
# TODO: Fix this when we have a way of checking the balance at the start of the
# trade.
# Max risk is also mocked as this puts us over the limit, due to the low account
# size.
@patch(
"core.trading.active_management.ActiveManagement.check_max_risk",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_loss_violated(
self, handle_violation, check_position_size, check_max_risk
):
self.balance = D("1")
self.balance_usd = D("0.69")
self.ams.run_checks()
def test_max_risk_violated(self): self.check_violation(
pass "max_loss",
handle_violation.call_args_list,
"close",
[None],
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_protection",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_risk_violated(
self, handle_violation, check_protection, check_position_size
):
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.ams.run_checks()
self.check_violation(
"max_risk",
handle_violation.call_args_list,
"close",
[self.trades[2]],
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_protection",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_risk_violated_multiple(
self, handle_violation, check_protection, check_position_size
):
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.add_trade(
"20086",
"EUR_USD",
"long",
"2023-02-13T15:45:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.trades[3]["stopLossOrder"]["price"] = "0.001"
self.trades[3]["currentUnits"] = "13000"
self.ams.run_checks()
self.check_violation(
"max_risk",
handle_violation.call_args_list,
"close",
[self.trades[2], self.trades[3]],
)

View File

@ -1,18 +1,7 @@
from datetime import time
import freezegun import freezegun
from django.test import TestCase from django.test import TestCase
from core.models import ( from core.models import Account, Hook, Signal, User
Account,
Hook,
OrderSettings,
RiskModel,
Signal,
Strategy,
TradingTime,
User,
)
from core.tests.helpers import StrategyMixin from core.tests.helpers import StrategyMixin
from core.trading import checks from core.trading import checks

View File

@ -1,9 +1,8 @@
from django.test import TestCase from django.test import TestCase
import core.trading.market # noqa # to avoid messy circular import
from core.exchanges import convert from core.exchanges import convert
from core.models import RiskModel, User from core.models import RiskModel, User
import core.trading.market # to avoid messy circular import
from core.trading import risk from core.trading import risk

View File

@ -1,14 +1,13 @@
from copy import deepcopy
from datetime import datetime from datetime import datetime
from decimal import Decimal as D from decimal import Decimal as D
import core.trading.market # to avoid messy circular import
from core.exchanges.convert import ( from core.exchanges.convert import (
convert_trades, convert_trades,
side_to_direction,
sl_percent_to_price, sl_percent_to_price,
tp_percent_to_price, tp_percent_to_price,
) )
from core.trading import assetfilter, checks, risk from core.trading import assetfilter, checks, market, risk
from core.trading.crossfilter import crossfilter from core.trading.crossfilter import crossfilter
from core.trading.market import get_base_quote, get_trade_size_in_base from core.trading.market import get_base_quote, get_trade_size_in_base
@ -20,17 +19,28 @@ class ActiveManagement(object):
self.trades = [] self.trades = []
self.balance = None self.balance = None
self.balance_usd = None
def get_trades(self): def get_trades(self):
if not self.trades: if not self.trades:
self.trades = self.strategy.account.client.get_all_open_trades() self.trades = self.strategy.account.client.get_all_open_trades()
return self.trades return self.trades
def get_balance(self): def get_balance(self, return_usd=False):
if self.balance is None: if return_usd:
self.balance = self.strategy.account.client.get_balance() if self.balance_usd is None:
self.balance_usd = self.strategy.account.client.get_balance(
return_usd=True
)
else:
return self.balance_usd
else: else:
return self.balance if self.balance is None:
self.balance = self.strategy.account.client.get_balance(
return_usd=False
)
else:
return self.balance
def handle_violation(self, check_type, action, trade, **kwargs): def handle_violation(self, check_type, action, trade, **kwargs):
print("VIOLATION", check_type, action, trade, kwargs) print("VIOLATION", check_type, action, trade, kwargs)
@ -54,8 +64,10 @@ class ActiveManagement(object):
def check_position_size(self, trade): def check_position_size(self, trade):
""" """
Check the position size is within the allowed deviation. Check the position size is within the allowed deviation.
WARNING: This uses the current balance, not the balance at the time of the trade. WARNING: This uses the current balance, not the balance at the time of the
WARNING: This uses the current symbol prices, not those at the time of the trade. trade.
WARNING: This uses the current symbol prices, not those at the time of the
trade.
This should normally be run every 5 seconds, so this is fine. This should normally be run every 5 seconds, so this is fine.
""" """
# TODO: add the trade value to the balance # TODO: add the trade value to the balance
@ -86,10 +98,12 @@ class ActiveManagement(object):
def check_protection(self, trade): def check_protection(self, trade):
deviation = D(0.05) # 5% deviation = D(0.05) # 5%
# fmt: off
matches = { matches = {
"stop_loss_percent": self.strategy.order_settings.stop_loss_percent, "stop_loss_percent": self.strategy.order_settings.stop_loss_percent,
"take_profit_percent": self.strategy.order_settings.take_profit_percent, "take_profit_percent": self.strategy.order_settings.take_profit_percent,
"trailing_stop_percent": self.strategy.order_settings.trailing_stop_loss_percent, "trailing_stop_percent":
self.strategy.order_settings.trailing_stop_loss_percent,
} }
violations = {} violations = {}
@ -155,7 +169,7 @@ class ActiveManagement(object):
) )
def get_sorted_trades_copy(self, trades, reverse=True): def get_sorted_trades_copy(self, trades, reverse=True):
trades_copy = trades.copy() trades_copy = deepcopy(trades)
# sort by open time, newest first # sort by open time, newest first
trades_copy.sort( trades_copy.sort(
key=lambda x: datetime.strptime(x["open_time"], "%Y-%m-%dT%H:%M:%S.%fZ"), key=lambda x: datetime.strptime(x["open_time"], "%Y-%m-%dT%H:%M:%S.%fZ"),
@ -170,7 +184,8 @@ class ActiveManagement(object):
iterations = 0 iterations = 0
finished = [] finished = []
# Recursively run crossfilter on the newest-first list until we have no more conflicts # Recursively run crossfilter on the newest-first list until we have no more
# conflicts
while not len(finished) == len(trades): while not len(finished) == len(trades):
iterations += 1 iterations += 1
if iterations > 10000: if iterations > 10000:
@ -207,17 +222,7 @@ class ActiveManagement(object):
close_trades.append(trade) close_trades.append(trade)
if not close_trades: if not close_trades:
return return
# For each conflicting symbol, identify the oldest trades
# removed_trades = []
# for symbol in conflict:
# newest_trade = max(conflict, key=lambda x: datetime.strptime(x["open_time"], "%Y-%m-%dT%H:%M:%S.%fZ"))
# removed_trades.append(newest_trade)
# print("KEEP TRADES", keep_trade_ids)
# close_trades = []
# for x in keep_trade_ids:
# for position in conflict[x]:
# if position["id"] not in keep_trade_ids[x]:
# close_trades.append(position)
if close_trades: if close_trades:
for trade in close_trades: for trade in close_trades:
self.handle_violation( self.handle_violation(
@ -225,57 +230,90 @@ class ActiveManagement(object):
) )
def check_max_open_trades(self, trades): def check_max_open_trades(self, trades):
if self.strategy.risk_model is not None: if self.strategy.risk_model is None:
max_open_pass = risk.check_max_open_trades(self.strategy.risk_model, trades) return
if not max_open_pass: max_open_pass = risk.check_max_open_trades(self.strategy.risk_model, trades)
trades_copy = self.get_sorted_trades_copy(trades, reverse=False) if not max_open_pass:
print("TRADES COPY", [x["id"] for x in trades_copy]) trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
print("MAX", self.strategy.risk_model.max_open_trades) # fmt: off
trades_over_limit = trades_copy[ trades_over_limit = trades_copy[self.strategy.risk_model.max_open_trades:]
self.strategy.risk_model.max_open_trades : for trade in trades_over_limit:
] self.handle_violation(
for trade in trades_over_limit: "max_open_trades",
self.handle_violation( self.policy.when_max_open_trades_violated,
"max_open_trades", trade,
self.policy.when_max_open_trades_violated, )
trade,
)
print("TRADES OVER LIMNIT", trades_over_limit)
def check_max_open_trades_per_symbol(self, trades): def check_max_open_trades_per_symbol(self, trades):
if self.strategy.risk_model is not None: if self.strategy.risk_model is None:
max_open_pass = risk.check_max_open_trades_per_symbol( return
self.strategy.risk_model, trades, return_symbols=True max_open_pass = risk.check_max_open_trades_per_symbol(
) self.strategy.risk_model, trades, return_symbols=True
print("max_open_pass", max_open_pass) )
max_open_pass = list(max_open_pass) max_open_pass = list(max_open_pass)
print("MAX OPEN PASS", max_open_pass) if max_open_pass:
if max_open_pass: trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
trades_copy = self.get_sorted_trades_copy(trades, reverse=False) trades_over_limit = []
trades_over_limit = [] for symbol in max_open_pass:
for symbol in max_open_pass: symbol_trades = [x for x in trades_copy if x["symbol"] == symbol]
print("SYMBOL", symbol) # fmt: off
print("TRADES", trades) exceeding_limit = symbol_trades[
symbol_trades = [x for x in trades_copy if x["symbol"] == symbol] self.strategy.risk_model.max_open_trades_per_symbol:
exceeding_limit = symbol_trades[ ]
self.strategy.risk_model.max_open_trades_per_symbol : for x in exceeding_limit:
] trades_over_limit.append(x)
for x in exceeding_limit:
trades_over_limit.append(x)
for trade in trades_over_limit: for trade in trades_over_limit:
self.handle_violation( self.handle_violation(
"max_open_trades_per_symbol", "max_open_trades_per_symbol",
self.policy.when_max_open_trades_violated, self.policy.when_max_open_trades_violated,
trade, trade,
) )
print("TRADES OVER LIMNIT", trades_over_limit)
def check_max_loss(self): def check_max_loss(self):
check_passed = risk.check_max_loss(self.strategy.risk_model, self.strategy.account.initial_balance, self.get_balance()) if self.strategy.risk_model is None:
return
check_passed = risk.check_max_loss(
self.strategy.risk_model,
self.strategy.account.initial_balance,
self.get_balance(),
)
if not check_passed:
self.handle_violation(
"max_loss", self.policy.when_max_loss_violated, None # Close all trades
)
def check_max_risk(self, trades): def check_max_risk(self, trades):
pass if self.strategy.risk_model is None:
return
close_trades = []
trades_copy = self.get_sorted_trades_copy(trades, reverse=False)
market.convert_trades_to_usd(self.strategy.account, trades_copy)
iterations = 0
finished = False
while not finished:
iterations += 1
if iterations > 10000:
raise Exception("Too many iterations")
check_passed = risk.check_max_risk(
self.strategy.risk_model,
self.get_balance(return_usd=True),
trades_copy,
)
if check_passed:
finished = True
else:
# Add the newest trade to close_trades and remove it from trades_copy
close_trades.append(trades_copy[-1])
trades_copy = trades_copy[:-1]
if close_trades:
for trade in close_trades:
self.handle_violation(
"max_risk", self.policy.when_max_risk_violated, trade
)
def run_checks(self): def run_checks(self):
converted_trades = convert_trades(self.get_trades()) converted_trades = convert_trades(self.get_trades())

View File

@ -20,7 +20,7 @@ def convert_trades_to_usd(account, trades):
:return: List of trades, with amount_usd added :return: List of trades, with amount_usd added
""" """
for trade in trades: for trade in trades:
amount = trade["amount"] amount = D(trade["amount"])
symbol = trade["symbol"] symbol = trade["symbol"]
side = trade["side"] side = trade["side"]
direction = side_to_direction(side) direction = side_to_direction(side)

View File

@ -26,6 +26,7 @@ def check_max_risk(risk_model, account_balance_usd, account_trades):
# Calculate the max risk of the account in USD # Calculate the max risk of the account in USD
max_risk_usd = account_balance_usd * (max_risk_percent / D(100)) max_risk_usd = account_balance_usd * (max_risk_percent / D(100))
total_risk = 0 total_risk = 0
for trade in account_trades: for trade in account_trades:
max_tmp = [] max_tmp = []
# Need to calculate the max risk in base account currency # Need to calculate the max risk in base account currency
@ -36,7 +37,8 @@ def check_max_risk(risk_model, account_balance_usd, account_trades):
if "trailing_stop_loss_usd" in trade: if "trailing_stop_loss_usd" in trade:
max_tmp.append(trade["trailing_stop_loss_usd"]) max_tmp.append(trade["trailing_stop_loss_usd"])
if max_tmp: if max_tmp:
total_risk += max(max_tmp) max_risk = max(max_tmp)
total_risk += max_risk
allowed = total_risk < max_risk_usd allowed = total_risk < max_risk_usd
return allowed return allowed
@ -59,7 +61,6 @@ def check_max_open_trades_per_symbol(risk_model, account_trades, return_symbols=
if symbol not in symbol_map: if symbol not in symbol_map:
symbol_map[symbol] = 0 symbol_map[symbol] = 0
symbol_map[symbol] += 1 symbol_map[symbol] += 1
print("Symbol map: ", symbol_map)
violating_symbols = [] violating_symbols = []
for symbol, count in symbol_map.items(): for symbol, count in symbol_map.items():
if count >= risk_model.max_open_trades_per_symbol: if count >= risk_model.max_open_trades_per_symbol: