From 3e35214e826b606f8a8d56594721aca243a934d8 Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Fri, 17 Feb 2023 22:23:12 +0000 Subject: [PATCH] Fix open trades checks --- core/tests/helpers.py | 2 +- core/tests/trading/test_active_management.py | 19 +++++++++++++++++-- core/tests/trading/test_risk.py | 2 ++ core/trading/active_management.py | 5 +++-- core/trading/risk.py | 18 ++++++++++-------- 5 files changed, 33 insertions(+), 13 deletions(-) diff --git a/core/tests/helpers.py b/core/tests/helpers.py index 019ab88..198fb58 100644 --- a/core/tests/helpers.py +++ b/core/tests/helpers.py @@ -141,7 +141,7 @@ class StrategyMixin: max_loss_percent=50, max_risk_percent=10, max_open_trades=10, - max_open_trades_per_symbol=5, + max_open_trades_per_symbol=2, ) self.strategy = Strategy.objects.create( diff --git a/core/tests/trading/test_active_management.py b/core/tests/trading/test_active_management.py index 32508ab..727651c 100644 --- a/core/tests/trading/test_active_management.py +++ b/core/tests/trading/test_active_management.py @@ -393,8 +393,23 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.trades[10:], # Only close newer trades ) - def test_max_open_trades_per_symbol_violated(self): - pass + @patch("core.trading.active_management.ActiveManagement.handle_violation") + def test_max_open_trades_per_symbol_violated(self, handle_violation): + for x in range(2): + self.add_trade( + str(x), + "EUR_USD", + "long", + f"2023-02-13T12:39:1{x}.302917985Z", + ) + + self.ams.run_checks() + self.check_violation( + "max_open_trades_per_symbol", + handle_violation.call_args_list, + "close", + self.trades[2:], # Only close newer trades + ) def test_max_loss_violated(self): pass diff --git a/core/tests/trading/test_risk.py b/core/tests/trading/test_risk.py index 86715c8..d45d2af 100644 --- a/core/tests/trading/test_risk.py +++ b/core/tests/trading/test_risk.py @@ -2,6 +2,8 @@ from django.test import TestCase from core.exchanges import convert from core.models import RiskModel, User +import core.trading.market # to avoid messy circular import + from core.trading import risk diff --git a/core/trading/active_management.py b/core/trading/active_management.py index 9dafa99..c78e7b9 100644 --- a/core/trading/active_management.py +++ b/core/trading/active_management.py @@ -245,8 +245,9 @@ class ActiveManagement(object): def check_max_open_trades_per_symbol(self, trades): if self.strategy.risk_model is not None: max_open_pass = risk.check_max_open_trades_per_symbol( - self.strategy.risk_model, trades + self.strategy.risk_model, trades, return_symbols=True ) + print("max_open_pass", max_open_pass) max_open_pass = list(max_open_pass) print("MAX OPEN PASS", max_open_pass) if max_open_pass: @@ -271,7 +272,7 @@ class ActiveManagement(object): print("TRADES OVER LIMNIT", trades_over_limit) def check_max_loss(self): - pass + check_passed = risk.check_max_loss(self.strategy.risk_model, self.strategy.account.initial_balance, self.get_balance()) def check_max_risk(self, trades): pass diff --git a/core/trading/risk.py b/core/trading/risk.py index 204bf47..1762957 100644 --- a/core/trading/risk.py +++ b/core/trading/risk.py @@ -49,7 +49,7 @@ def check_max_open_trades(risk_model, account_trades): return len(account_trades) < risk_model.max_open_trades -def check_max_open_trades_per_symbol(risk_model, account_trades, yield_symbol=False): +def check_max_open_trades_per_symbol(risk_model, account_trades, return_symbols=False): """ Check we cannot open more trades per symbol than permissible. """ @@ -59,15 +59,17 @@ def check_max_open_trades_per_symbol(risk_model, account_trades, yield_symbol=Fa if symbol not in symbol_map: symbol_map[symbol] = 0 symbol_map[symbol] += 1 - + print("Symbol map: ", symbol_map) + violating_symbols = [] for symbol, count in symbol_map.items(): if count >= risk_model.max_open_trades_per_symbol: - if yield_symbol: - yield symbol - else: - return False - if not yield_symbol: - return True + violating_symbols.append(symbol) + + if return_symbols: + return violating_symbols + if violating_symbols: + return False + return True def check_risk(risk_model, account, proposed_trade):