from datetime import time from decimal import Decimal as D from unittest.mock import patch from django.test import TestCase from core.exchanges.convert import convert_trades from core.models import ( ActiveManagementPolicy, AssetGroup, AssetRule, Hook, RiskModel, Signal, Trade, TradingTime, ) from core.tests.helpers import ElasticMock, LiveBase, StrategyMixin from core.trading import market, risk from core.trading.active_management import ActiveManagement class ActiveManagementLiveTestCase(ElasticMock, StrategyMixin, LiveBase, TestCase): def setUp(self): super(ActiveManagementLiveTestCase, self).setUp() self.trade = Trade.objects.create( user=self.user, account=self.account, symbol="EUR_USD", time_in_force="FOK", type="market", amount=10, direction="buy", ) self.commission = 0.025 self.risk_model = RiskModel.objects.create( user=self.user, name="Test Risk Model", max_loss_percent=4, max_risk_percent=2, max_open_trades=3, max_open_trades_per_symbol=2, ) self.trading_time_all = TradingTime.objects.create( user=self.user, name="All", start_day=1, start_time=time(0, 0, 0), end_day=7, end_time=time(23, 59, 59), ) self.active_management_policy = ActiveManagementPolicy.objects.create( user=self.user, name="Test Policy", when_trading_time_violated="close", when_trends_violated="close", when_position_size_violated="adjust", when_protection_violated="adjust", when_asset_groups_violated="close", when_max_open_trades_violated="close", when_max_open_trades_per_symbol_violated="close", when_max_loss_violated="close", when_max_risk_violated="close", when_crossfilter_violated="close", ) self.strategy.active_management_policy = self.active_management_policy self.strategy.trading_times.set([self.trading_time_all]) self.strategy.save() self.ams = ActiveManagement(self.strategy) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) def test_ams_success(self): complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) self.open_trade(complex_trade) self.ams.run_checks() self.assertEqual(self.ams.actions, {}) self.close_trade(complex_trade) def test_ams_trading_time_violated(self): # Forex market is closed on Saturday and Sunday. # All of these tests will fail then anyway, so it's the only time # we can use for this. trading_time_weekend = TradingTime.objects.create( user=self.user, name="Weekend", start_day=6, start_time=time(0, 0, 0), end_day=7, end_time=time(23, 59, 59), ) self.strategy.trading_times.set([trading_time_weekend]) self.strategy.save() complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) self.open_trade(complex_trade) self.ams.run_checks() expected_actions = { "close": [ {"id": complex_trade.order_id, "check": "trading_time", "extra": {}} ] } self.assertEqual(self.ams.actions, expected_actions) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 1) self.ams.execute_actions() trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) # Don't need to close the trade, it's already closed. # Otherwise the test would fail. def test_ams_trends_violated(self): hook = Hook.objects.create( user=self.user, name="Test Hook", ) signal = Signal.objects.create( user=self.user, name="Test Signal", hook=hook, type="trend", ) self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "sell"} self.strategy.save() complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) self.open_trade(complex_trade) self.ams.run_checks() expected_actions = { "close": [{"id": complex_trade.order_id, "check": "trends", "extra": {}}] } self.assertEqual(self.ams.actions, expected_actions) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 1) self.ams.execute_actions() trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) # Don't need to close the trade, it's already closed. # Otherwise the test would fail. self.strategy.trend_signals.set([]) self.strategy.trends = {} self.strategy.save() def test_ams_position_size_violated(self): self.active_management_policy.when_position_size_violated = "close" self.active_management_policy.save() complex_trade = self.create_complex_trade("buy", 600, "EUR_USD", 1.5, 1.0) self.open_trade(complex_trade) self.ams.run_checks() expected_actions = { "close": [ {"id": complex_trade.order_id, "check": "position_size", "extra": {}} ] } self.assertEqual(len(self.ams.actions["close"]), 1) del self.ams.actions["close"][0]["extra"]["size"] # We don't know the size self.assertEqual(self.ams.actions, expected_actions) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 1) self.ams.execute_actions() trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) def test_ams_position_size_violated_adjust(self): complex_trade = self.create_complex_trade("buy", 600, "EUR_USD", 1.5, 1.0) self.open_trade(complex_trade) self.ams.run_checks() # Convert to int to make a whole number expected = round(self.ams.actions["adjust"][0]["extra"]["size"], 0) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 1) self.assertEqual(trades[0]["currentUnits"], "600") self.ams.execute_actions() trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 1) self.assertEqual(trades[0]["currentUnits"], str(expected)) complex_trade.amount = expected complex_trade.save() self.close_trade(complex_trade) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) def test_ams_protection_violated(self): self.active_management_policy.when_protection_violated = "close" self.active_management_policy.save() complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 5, 5) self.open_trade(complex_trade) self.ams.run_checks() self.assertEqual(len(self.ams.actions["close"]), 1) expected = { "close": [ { "id": complex_trade.order_id, "check": "protection", } ] } del self.ams.actions["close"][0]["extra"] self.assertEqual(self.ams.actions, expected) self.ams.execute_actions() trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) def test_ams_protection_violated_adjust(self): complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 5, 5) self.open_trade(complex_trade) self.ams.run_checks() self.assertEqual(len(self.ams.actions["adjust"]), 1) expected_tp = self.ams.actions["adjust"][0]["extra"]["take_profit_price"] expected_sl = self.ams.actions["adjust"][0]["extra"]["stop_loss_price"] self.assertEqual(len(self.ams.actions["adjust"]), 1) self.ams.execute_actions() trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 1) self.assertEqual(D(trades[0]["takeProfitOrder"]["price"]), expected_tp) self.assertEqual(D(trades[0]["stopLossOrder"]["price"]), expected_sl) self.close_trade(complex_trade) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) def test_ams_asset_groups_violated(self): asset_group = AssetGroup.objects.create( user=self.user, name="Test Asset Group", ) AssetRule.objects.create( user=self.user, asset="USD", group=asset_group, status=2, # Bullish ) self.strategy.asset_group = asset_group self.strategy.save() complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) self.open_trade(complex_trade) self.ams.run_checks() expected = { "close": [ {"id": complex_trade.order_id, "check": "asset_group", "extra": {}} ] } self.assertEqual(self.ams.actions, expected) self.ams.execute_actions() trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) def test_ams_crossfilter_violated(self): complex_trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) self.open_trade(complex_trade1) complex_trade2 = self.create_complex_trade("buy", 10, "USD_JPY", 1.5, 1.0) self.open_trade(complex_trade2) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 2) self.ams.run_checks() expected = { "close": [ { "id": complex_trade2.order_id, # Only the second one "check": "crossfilter", "extra": {}, } ] } self.assertEqual(self.ams.actions, expected) self.ams.execute_actions() trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 1) self.close_trade(complex_trade1) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) @patch( "core.trading.active_management.ActiveManagement.check_trends", return_value=None, ) @patch( "core.trading.active_management.ActiveManagement.check_crossfilter", return_value=None, ) def test_ams_max_open_trades_violated(self, check_crossfilter, check_trends): self.strategy.risk_model.max_open_trades = 2 self.strategy.risk_model.save() trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) self.open_trade(trade1) trade2 = self.create_complex_trade("buy", 10, "USD_JPY", 1.5, 1.0) self.open_trade(trade2) trade3 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0) self.open_trade(trade3) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 3) self.ams.run_checks() expected = { "close": [ { "id": trade3.order_id, "check": "max_open_trades", "extra": {}, } ] } self.assertEqual(self.ams.actions, expected) self.ams.execute_actions() trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 2) trade_ids = [trade["id"] for trade in trades] self.assertIn(trade1.order_id, trade_ids) self.assertIn(trade2.order_id, trade_ids) self.assertNotIn(trade3.order_id, trade_ids) for x in [trade1, trade2]: self.close_trade(x) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) @patch( "core.trading.active_management.ActiveManagement.check_trends", return_value=None, ) def test_ams_max_open_trades_per_symbol_violated(self, check_trends): self.strategy.risk_model.max_open_trades_per_symbol = 2 self.strategy.risk_model.save() trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) self.open_trade(trade1) trade2 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) self.open_trade(trade2) trade3 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) self.open_trade(trade3) trade4 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0) self.open_trade(trade4) trade5 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0) self.open_trade(trade5) trade6 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0) self.open_trade(trade6) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 6) self.ams.run_checks() expected = { "close": [ { "id": trade3.order_id, "check": "max_open_trades_per_symbol", "extra": {}, }, { "id": trade6.order_id, "check": "max_open_trades_per_symbol", "extra": {}, }, ] } print("ACTIONS", self.ams.actions) print("EXP", expected) self.assertEqual(self.ams.actions, expected) self.ams.execute_actions() trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 4) trade_ids = [trade["id"] for trade in trades] self.assertIn(trade1.order_id, trade_ids) self.assertIn(trade2.order_id, trade_ids) self.assertNotIn(trade3.order_id, trade_ids) self.assertIn(trade4.order_id, trade_ids) self.assertIn(trade5.order_id, trade_ids) self.assertNotIn(trade6.order_id, trade_ids) for x in [trade1, trade2, trade4, trade5]: self.close_trade(x) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) self.strategy.risk_model.max_open_trades_per_symbol = 5 self.strategy.risk_model.save() def test_ams_max_loss_violated(self): trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) self.open_trade(trade1) self.account.initial_balance = self.account.client.get_balance() * 2 self.account.save() self.ams.run_checks() expected = { "close": [ { "id": None, "check": "max_loss", "extra": {}, } ] } self.assertEqual(self.ams.actions, expected) self.ams.execute_actions() trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) self.account.initial_balance = 100000 self.account.save() @patch( "core.trading.active_management.ActiveManagement.check_position_size", return_value=None, ) def test_ams_max_risk_violated(self, check_position_size): self.strategy.risk_model.max_risk_percent = 0.001 self.strategy.risk_model.save() trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) self.open_trade(trade1) trade2 = self.create_complex_trade("buy", 1000, "EUR_USD", 1.5, 1.0) self.open_trade(trade2) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 2) self.ams.run_checks() expected = { "close": [ { "id": trade2.order_id, "check": "max_risk", "extra": {}, } ] } self.assertEqual(self.ams.actions, expected) self.ams.execute_actions() trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 1) trade_ids = [trade["id"] for trade in trades] self.assertIn(trade1.order_id, trade_ids) self.assertNotIn(trade2.order_id, trade_ids) self.close_trade(trade1) trades = self.account.client.get_all_open_trades() self.assertEqual(len(trades), 0) class LiveTradingTestCase(ElasticMock, LiveBase, TestCase): def setUp(self): super(LiveTradingTestCase, self).setUp() self.trade = Trade.objects.create( user=self.user, account=self.account, symbol="EUR_USD", time_in_force="FOK", type="market", amount=10, direction="buy", ) self.commission = 0.025 self.risk_model = RiskModel.objects.create( user=self.user, name="Test Risk Model", max_loss_percent=4, max_risk_percent=2, max_open_trades=3, max_open_trades_per_symbol=2, ) def test_account_functional(self): """ Test that the account is functional. """ balance = self.account.client.get_balance() # We need some money to place trades self.assertTrue(balance > 1000) def test_place_close_trade(self): """ Test placing a trade. """ self.open_trade() self.close_trade() def test_get_all_open_trades(self): """ Test getting all open trades. """ self.open_trade() trades = self.account.client.get_all_open_trades() self.trade.refresh_from_db() found = False for trade in trades: if trade["id"] == self.trade.order_id: self.assertEqual(trade["symbol"], "EUR_USD") self.assertEqual(trade["currentUnits"], "10") self.assertEqual(trade["initialUnits"], "10") self.assertEqual(trade["state"], "OPEN") found = True break self.close_trade() if not found: self.fail("Could not find the trade in the list of open trades") @patch("core.exchanges.oanda.OANDAExchange.get_balance", return_value=100000) def test_check_risk_max_risk_pass(self, mock_balance): # SL of 19% on a 10000 trade on a 100000 account is 1.8 loss # Should be comfortably under 2% risk trade = self.create_complex_trade("buy", 10000, "EUR_USD", 1, 18) allowed = risk.check_risk(self.risk_model, self.account, trade) self.assertTrue(allowed["allowed"]) @patch("core.exchanges.oanda.OANDAExchange.get_balance", return_value=100000) def test_check_risk_max_risk_fail(self, mock_balance): # SL of 21% on a 10000 trade on a 100000 account is 2.2 loss # Should be over 2% risk trade = self.create_complex_trade("buy", 10000, "EUR_USD", 1, 22) allowed = risk.check_risk(self.risk_model, self.account, trade) self.assertFalse(allowed["allowed"]) self.assertEqual(allowed["reason"], "Maximum risk exceeded.") @patch("core.exchanges.oanda.OANDAExchange.get_balance", return_value=94000) # We have lost 6% of our account def test_check_risk_max_loss_fail(self, mock_balance): # Doesn't matter, shouldn't get as far as the trade trade = self.create_complex_trade("buy", 1, "EUR_USD", 1, 1) allowed = risk.check_risk(self.risk_model, self.account, trade) self.assertFalse(allowed["allowed"]) self.assertEqual(allowed["reason"], "Maximum loss exceeded.") @patch("core.exchanges.oanda.OANDAExchange.get_balance", return_value=100000) def test_check_risk_max_open_trades_fail(self, mock_balance): # The maximum open trades is 3. Let's open 2 trades # This would not be allowed by the risk model but we're doing it # manually # Don't be confused by the next test. The max open trades check # fails before the symbol one is run, but yes, they would both # fail. trade1 = self.create_complex_trade("buy", 1, "EUR_USD", 1, 1) self.open_trade(trade1) trade2 = self.create_complex_trade("buy", 1, "EUR_USD", 1, 1) self.open_trade(trade2) trade3 = self.create_complex_trade("buy", 1, "EUR_USD", 1, 1) allowed = risk.check_risk(self.risk_model, self.account, trade3) self.assertFalse(allowed["allowed"]) self.assertEqual(allowed["reason"], "Maximum open trades exceeded.") self.close_trade(trade1) self.close_trade(trade2) @patch("core.exchanges.oanda.OANDAExchange.get_balance", return_value=100000) def test_check_risk_max_open_trades_per_symbol_fail(self, mock_balance): trade1 = self.create_complex_trade("buy", 1, "EUR_USD", 1, 1) self.open_trade(trade1) trade2 = self.create_complex_trade("buy", 1, "EUR_USD", 1, 1) allowed = risk.check_risk(self.risk_model, self.account, trade2) self.assertFalse(allowed["allowed"]) self.assertEqual(allowed["reason"], "Maximum open trades per symbol exceeded.") self.close_trade(trade1) def test_convert_trades(self): """ Test converting open trades response to Trade-like format. """ complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1, 2) self.open_trade(complex_trade) # Get and annotate the trades trades = self.account.client.get_all_open_trades() trades_converted = convert_trades(trades) # Check the converted trades self.assertEqual(len(trades_converted), 1) expected_tp_percent = D(1 - self.commission) expected_sl_percent = D(2 - self.commission) actual_tp_percent = trades_converted[0]["take_profit_percent"] actual_sl_percent = trades_converted[0]["stop_loss_percent"] tp_percent_difference = abs(expected_tp_percent - actual_tp_percent) sl_percent_difference = abs(expected_sl_percent - actual_sl_percent) max_difference = D(0.08) # depends on market conditions self.assertLess(tp_percent_difference, max_difference) self.assertLess(sl_percent_difference, max_difference) # Convert the trades to USD trades_usd = market.convert_trades_to_usd(self.account, trades_converted) # Convert the trade to USD ourselves trade_in_usd = D(trades_usd[0]["amount"]) * D(trades_usd[0]["current_price"]) # It will never be perfect, but let's check it's at least close trade_usd_conversion_difference = ( trades_usd[0]["trade_amount_usd"] - trade_in_usd ) self.assertLess(trade_usd_conversion_difference, D(0.01)) # Check the converted TP and SL values trade_usd_tp_difference = trades_usd[0]["take_profit_usd"] - D(0.1) trade_usd_sl_difference = trades_usd[0]["stop_loss_usd"] - D(0.2) self.assertLess(trade_usd_tp_difference, D(0.01)) self.assertLess(trade_usd_sl_difference, D(0.02)) self.close_trade(complex_trade)