diff --git a/core/tests/trading/test_active_management.py b/core/tests/trading/test_active_management.py index cc92179..ce779f2 100644 --- a/core/tests/trading/test_active_management.py +++ b/core/tests/trading/test_active_management.py @@ -3,7 +3,6 @@ from unittest.mock import patch from django.test import TestCase -from core.exchanges.convert import convert_trades from core.lib.schemas.oanda_s import parse_time from core.models import ( Account, @@ -103,6 +102,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.ams.get_balance = self.fake_get_balance # self.ams.trades = self.trades + def get_ids(self, trades): + return [trade["id"] for trade in trades] + def add_trade(self, id, symbol, side, open_time): trade = { "id": id, @@ -179,17 +181,12 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.assertEqual(len(violation_calls), len(expected_trades)) if all(expected_trades): - expected_trades = convert_trades(expected_trades) + # expected_trades = convert_trades(expected_trades) + expected_trades = self.get_ids(expected_trades) for call in violation_calls: # Ensure the correct action has been called, like close self.assertEqual(call[0][1], expected_action) # Ensure the correct trade has been passed to the violation - trade = call[0][2] - if trade: - for field in list(trade.keys()): - if "_usd" in field: - if field in trade.keys(): - del trade[field] self.assertIn(call[0][2], expected_trades) if expected_args: self.assertEqual(call[0][3], expected_args) @@ -454,22 +451,18 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): # trade. # TODO: Fix this when we have a way of checking the balance at the start of the # trade. - # Max risk is also mocked as this puts us over the limit, due to the low account - # size. - @patch( - "core.trading.active_management.ActiveManagement.check_max_risk", - return_value=None, - ) + @patch( "core.trading.active_management.ActiveManagement.check_position_size", return_value=None, ) @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_max_loss_violated( - self, handle_violation, check_position_size, check_max_risk - ): + def test_max_loss_violated(self, handle_violation, check_position_size): self.balance = D("1") self.balance_usd = D("0.69") + + self.trades = [] + self.ams.run_checks() self.check_violation( @@ -546,3 +539,15 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): "close", [self.trades[2], self.trades[3]], ) + + def test_max_risk_not_violated_after_adjusting_protection(self): + """ + Ensure the max risk check is not violated after adjusting the protection. + """ + pass + + def test_max_risk_not_violated_after_adjusting_position_size(self): + """ + Ensure the max risk check is not violated after adjusting the position size. + """ + pass diff --git a/core/trading/active_management.py b/core/trading/active_management.py index da2c5a2..1391dc4 100644 --- a/core/trading/active_management.py +++ b/core/trading/active_management.py @@ -44,6 +44,21 @@ class ActiveManagement(object): def handle_violation(self, check_type, action, trade, **kwargs): print("VIOLATION", check_type, action, trade, kwargs) + # TODO: close/notify for: + # - trading time + # - trends + # - position size + # - protection + # - asset groups + # - crossfilter + # - max open trades + # - max open trades per symbol + # - max loss + # - max risk + + # TODO: adjust for: + # - position size + # - protection def check_trading_time(self, trade): open_ts = trade["open_time"] @@ -51,7 +66,7 @@ class ActiveManagement(object): trading_time_pass = checks.within_trading_times(self.strategy, open_ts_as_date) if not trading_time_pass: self.handle_violation( - "trading_time", self.policy.when_trading_time_violated, trade + "trading_time", self.policy.when_trading_time_violated, trade["id"] ) def check_trends(self, trade): @@ -59,7 +74,9 @@ class ActiveManagement(object): symbol = trade["symbol"] trends_pass = checks.within_trends(self.strategy, symbol, direction) if not trends_pass: - self.handle_violation("trends", self.policy.when_trends_violated, trade) + self.handle_violation( + "trends", self.policy.when_trends_violated, trade["id"] + ) def check_position_size(self, trade): """ @@ -91,7 +108,7 @@ class ActiveManagement(object): self.handle_violation( "position_size", self.policy.when_position_size_violated, - trade, + trade["id"], {"size": expected_trade_size}, ) @@ -152,7 +169,10 @@ class ActiveManagement(object): if violations: self.handle_violation( - "protection", self.policy.when_protection_violated, trade, violations + "protection", + self.policy.when_protection_violated, + trade["id"], + violations ) def check_asset_groups(self, trade): @@ -165,7 +185,7 @@ class ActiveManagement(object): ) if not allowed: self.handle_violation( - "asset_group", self.policy.when_asset_groups_violated, trade + "asset_group", self.policy.when_asset_groups_violated, trade["id"] ) def get_sorted_trades_copy(self, trades, reverse=True): @@ -226,7 +246,7 @@ class ActiveManagement(object): if close_trades: for trade in close_trades: self.handle_violation( - "crossfilter", self.policy.when_crossfilter_violated, trade + "crossfilter", self.policy.when_crossfilter_violated, trade["id"] ) def check_max_open_trades(self, trades): @@ -241,7 +261,7 @@ class ActiveManagement(object): self.handle_violation( "max_open_trades", self.policy.when_max_open_trades_violated, - trade, + trade["id"], ) def check_max_open_trades_per_symbol(self, trades): @@ -267,7 +287,7 @@ class ActiveManagement(object): self.handle_violation( "max_open_trades_per_symbol", self.policy.when_max_open_trades_violated, - trade, + trade["id"], ) def check_max_loss(self): @@ -312,7 +332,7 @@ class ActiveManagement(object): if close_trades: for trade in close_trades: self.handle_violation( - "max_risk", self.policy.when_max_risk_violated, trade + "max_risk", self.policy.when_max_risk_violated, trade["id"] ) def run_checks(self):