diff --git a/core/tests/trading/test_live.py b/core/tests/trading/test_live.py index b036572..26ba225 100644 --- a/core/tests/trading/test_live.py +++ b/core/tests/trading/test_live.py @@ -7,6 +7,8 @@ from django.test import TestCase from core.exchanges.convert import convert_trades from core.models import ( ActiveManagementPolicy, + AssetGroup, + AssetRule, Hook, RiskModel, Signal, @@ -128,6 +130,10 @@ class ActiveManagementMixinTestCase(StrategyMixin): # Don't need to close the trade, it's already closed. # Otherwise the test would fail. + self.strategy.trend_signals.set([]) + self.strategy.trends = {} + self.strategy.save() + def test_ams_position_size_violated(self): self.active_management_policy.when_position_size_violated = "close" self.active_management_policy.save() @@ -177,12 +183,8 @@ class ActiveManagementMixinTestCase(StrategyMixin): def test_ams_protection_violated(self): self.active_management_policy.when_protection_violated = "close" self.active_management_policy.save() - # Don't violate position size check - trade_size = market.get_trade_size_in_base( - "buy", self.account, self.strategy, self.account.client.get_balance(), "EUR" - ) - trade_size = round(trade_size, 0) - complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5) + + complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 5, 5) self.open_trade(complex_trade) self.ams.run_checks() @@ -206,12 +208,7 @@ class ActiveManagementMixinTestCase(StrategyMixin): self.assertEqual(len(trades), 0) def test_ams_protection_violated_adjust(self): - # Don't violate position size check - trade_size = market.get_trade_size_in_base( - "buy", self.account, self.strategy, self.account.client.get_balance(), "EUR" - ) - trade_size = round(trade_size, 0) - complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5) + complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 5, 5) self.open_trade(complex_trade) self.ams.run_checks() @@ -231,22 +228,246 @@ class ActiveManagementMixinTestCase(StrategyMixin): self.close_trade(complex_trade) def test_ams_asset_groups_violated(self): - pass + asset_group = AssetGroup.objects.create( + user=self.user, + name="Test Asset Group", + ) + AssetRule.objects.create( + user=self.user, + asset="USD", + group=asset_group, + status=2, # Bullish + ) + self.strategy.asset_group = asset_group + self.strategy.save() + + complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) + self.open_trade(complex_trade) + + self.ams.run_checks() + + expected = { + "close": [ + {"id": complex_trade.order_id, "check": "asset_group", "extra": {}} + ] + } + self.assertEqual(self.ams.actions, expected) + + self.ams.execute_actions() + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 0) def test_ams_crossfilter_violated(self): - pass + complex_trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) + self.open_trade(complex_trade1) + + complex_trade2 = self.create_complex_trade("buy", 10, "USD_JPY", 1.5, 1.0) + self.open_trade(complex_trade2) + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 2) + + self.ams.run_checks() + + expected = { + "close": [ + { + "id": complex_trade2.order_id, # Only the second one + "check": "crossfilter", + "extra": {}, + } + ] + } + + self.assertEqual(self.ams.actions, expected) + + self.ams.execute_actions() + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 1) + + self.close_trade(complex_trade1) + + @patch( + "core.trading.active_management.ActiveManagement.check_trends", + return_value=None, + ) + @patch( + "core.trading.active_management.ActiveManagement.check_crossfilter", + return_value=None, + ) + def test_ams_max_open_trades_violated(self, check_crossfilter, check_trends): + self.strategy.risk_model.max_open_trades = 2 + self.strategy.risk_model.save() + trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) + self.open_trade(trade1) + + trade2 = self.create_complex_trade("buy", 10, "USD_JPY", 1.5, 1.0) + self.open_trade(trade2) - def test_ams_max_open_trades_violated(self): - pass + trade3 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0) + self.open_trade(trade3) + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 3) + + self.ams.run_checks() - def test_ams_max_open_trades_per_symbol_violated(self): - pass + expected = { + "close": [ + { + "id": trade3.order_id, + "check": "max_open_trades", + "extra": {}, + } + ] + } + + self.assertEqual(self.ams.actions, expected) + + self.ams.execute_actions() + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 2) + + trade_ids = [trade["id"] for trade in trades] + self.assertIn(trade1.order_id, trade_ids) + self.assertIn(trade2.order_id, trade_ids) + self.assertNotIn(trade3.order_id, trade_ids) + + for x in [trade1, trade2]: + self.close_trade(x) + + @patch( + "core.trading.active_management.ActiveManagement.check_trends", + return_value=None, + ) + def test_ams_max_open_trades_per_symbol_violated(self, check_trends): + self.strategy.risk_model.max_open_trades_per_symbol = 2 + self.strategy.risk_model.save() + + trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) + self.open_trade(trade1) + + trade2 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) + self.open_trade(trade2) + + trade3 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) + self.open_trade(trade3) + + trade4 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0) + self.open_trade(trade4) + + trade5 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0) + self.open_trade(trade5) + + trade6 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0) + self.open_trade(trade6) + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 6) + + self.ams.run_checks() + + expected = { + "close": [ + { + "id": trade3.order_id, + "check": "max_open_trades_per_symbol", + "extra": {}, + }, + { + "id": trade6.order_id, + "check": "max_open_trades_per_symbol", + "extra": {}, + }, + ] + } + + self.assertEqual(self.ams.actions, expected) + + self.ams.execute_actions() + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 4) + + trade_ids = [trade["id"] for trade in trades] + self.assertIn(trade1.order_id, trade_ids) + self.assertIn(trade2.order_id, trade_ids) + self.assertNotIn(trade3.order_id, trade_ids) + + self.assertIn(trade4.order_id, trade_ids) + self.assertIn(trade5.order_id, trade_ids) + self.assertNotIn(trade6.order_id, trade_ids) def test_ams_max_loss_violated(self): - pass + trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) + self.open_trade(trade1) - def test_ams_max_risk_violated(self): - pass + self.account.initial_balance = self.account.client.get_balance() * 2 + self.account.save() + + self.ams.run_checks() + + expected = { + "close": [ + { + "id": None, + "check": "max_loss", + "extra": {}, + } + ] + } + + self.assertEqual(self.ams.actions, expected) + + self.ams.execute_actions() + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 0) + + @patch( + "core.trading.active_management.ActiveManagement.check_position_size", + return_value=None, + ) + def test_ams_max_risk_violated(self, check_position_size): + self.strategy.risk_model.max_risk_percent = 0.001 + self.strategy.risk_model.save() + + trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) + self.open_trade(trade1) + + trade2 = self.create_complex_trade("buy", 1000, "EUR_USD", 1.5, 1.0) + self.open_trade(trade2) + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 2) + + self.ams.run_checks() + + expected = { + "close": [ + { + "id": trade2.order_id, + "check": "max_risk", + "extra": {}, + } + ] + } + + self.assertEqual(self.ams.actions, expected) + + self.ams.execute_actions() + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 1) + + trade_ids = [trade["id"] for trade in trades] + self.assertIn(trade1.order_id, trade_ids) + self.assertNotIn(trade2.order_id, trade_ids) + + self.close_trade(trade1) class LiveTradingTestCase( @@ -290,7 +511,10 @@ class LiveTradingTestCase( # Check the opened trade self.assertEqual(posted["type"], "MARKET_ORDER") self.assertEqual(posted["symbol"], trade.symbol) - self.assertEqual(posted["units"], str(trade.amount)) + if trade.direction == "sell": + self.assertEqual(posted["units"], str(0 - trade.amount)) + else: + self.assertEqual(posted["units"], str(trade.amount)) self.assertEqual(posted["timeInForce"], "FOK") return posted diff --git a/core/trading/active_management.py b/core/trading/active_management.py index 2916c40..1d7de11 100644 --- a/core/trading/active_management.py +++ b/core/trading/active_management.py @@ -79,7 +79,11 @@ class ActiveManagement(object): def bulk_close_trades(self, trade_ids): for trade_id in trade_ids: - self.close_trade(trade_id) + if trade_id is not None: + self.close_trade(trade_id) + else: + self.strategy.account.client.close_all_positions() + return def bulk_notify(self, action, action_cast_list): msg = "" @@ -461,6 +465,7 @@ class ActiveManagement(object): converted_trades = convert_trades(self.get_trades()) trades_copy = self.get_sorted_trades_copy(converted_trades, reverse=False) market.convert_trades_to_usd(self.strategy.account, trades_copy) + for trade in reversed(trades_copy): try: self.check_trading_time(trade) diff --git a/core/trading/checks.py b/core/trading/checks.py index 2b0fc80..41fecb8 100644 --- a/core/trading/checks.py +++ b/core/trading/checks.py @@ -46,34 +46,34 @@ def within_callback_price_deviation(strategy, price, current_price): def within_trends(strategy, symbol, direction): if strategy.trend_signals.exists(): - if strategy.trends is None: - log.debug("Refusing to trade with no trend signals received") - sendmsg( - strategy.user, - f"Refusing to trade {symbol} with no trend signals received", - title="Trend not ready", - ) - return None - if symbol not in strategy.trends: - log.debug("Refusing to trade asset without established trend") - sendmsg( - strategy.user, - f"Refusing to trade {symbol} without established trend", - title="Trend not ready", - ) - return None - else: - if strategy.trends[symbol] != direction: - log.debug("Refusing to trade against the trend") + if len(strategy.trend_signals.all()) > 0: + if strategy.trends is None: + log.debug("Refusing to trade with no trend signals received") sendmsg( strategy.user, - f"Refusing to trade {symbol} against the trend", - title="Trend rejection", + f"Refusing to trade {symbol} with no trend signals received", + title="Trend not ready", ) - return False + return None + if symbol not in strategy.trends: + log.debug("Refusing to trade asset without established trend") + sendmsg( + strategy.user, + f"Refusing to trade {symbol} without established trend", + title="Trend not ready", + ) + return None else: - log.debug(f"Trend check passed for {symbol} - {direction}") - return True - else: - log.debug("No trend signals configured") - return True + if strategy.trends[symbol] != direction: + log.debug("Refusing to trade against the trend") + sendmsg( + strategy.user, + f"Refusing to trade {symbol} against the trend", + title="Trend rejection", + ) + return False + else: + log.debug(f"Trend check passed for {symbol} - {direction}") + return True + log.debug("No trend signals configured") + return True