diff --git a/core/exchanges/__init__.py b/core/exchanges/__init__.py index cee8b2b..0f5eeff 100644 --- a/core/exchanges/__init__.py +++ b/core/exchanges/__init__.py @@ -205,7 +205,7 @@ class BaseExchange(ABC): pass @abstractmethod - def close_trade(self, trade_id): + def close_trade(self, trade_id, units=None): pass @abstractmethod diff --git a/core/exchanges/alpaca.py b/core/exchanges/alpaca.py index 03861d9..55f92c2 100644 --- a/core/exchanges/alpaca.py +++ b/core/exchanges/alpaca.py @@ -121,7 +121,7 @@ class AlpacaExchange(BaseExchange): trade.save() return order - def close_trade(self, trade_id): # TODO + def close_trade(self, trade_id, units=None): # TODO """ Close a trade """ diff --git a/core/exchanges/oanda.py b/core/exchanges/oanda.py index bbaab76..a27b6d8 100644 --- a/core/exchanges/oanda.py +++ b/core/exchanges/oanda.py @@ -2,6 +2,9 @@ from oandapyV20 import API from oandapyV20.endpoints import accounts, orders, positions, pricing, trades from core.exchanges import BaseExchange, common +from core.util import logs + +log = logs.get_logger("oanda") class OANDAExchange(BaseExchange): @@ -98,12 +101,44 @@ class OANDAExchange(BaseExchange): trade.save() return response - def close_trade(self, trade_id): + def get_trade_precision(self, symbol): + instruments = self.account.instruments + if not instruments: + log.error("No instruments found") + return None + # Extract the information for the symbol + instrument = self.extract_instrument(instruments, symbol) + if not instrument: + log.error(f"Symbol not found: {symbol}") + return None + # Get the required precision + try: + trade_precision = instrument["tradeUnitsPrecision"] + return trade_precision + except KeyError: + log.error(f"Precision not found for {symbol} from {instrument}") + return None + + def close_trade(self, trade_id, units=None, symbol=None): """ Close a trade. """ - r = trades.TradeClose(accountID=self.account_id, tradeID=trade_id) - return self.call(r) + if not units: + r = trades.TradeClose(accountID=self.account_id, tradeID=trade_id) + return self.call(r) + else: + trade_precision = self.get_trade_precision(symbol) + if trade_precision is None: + log.error(f"Unable to get trade precision for {symbol}") + return None + units = round(units, trade_precision) + data = { + "units": str(units), + } + r = trades.TradeClose( + accountID=self.account_id, tradeID=trade_id, data=data + ) + return self.call(r) def get_trade(self, trade_id): # OANDA is off by one... diff --git a/core/tests/trading/test_live.py b/core/tests/trading/test_live.py index 8b48010..a6a8278 100644 --- a/core/tests/trading/test_live.py +++ b/core/tests/trading/test_live.py @@ -1,15 +1,204 @@ +from datetime import time from decimal import Decimal as D from unittest.mock import patch from django.test import TestCase from core.exchanges.convert import convert_trades -from core.models import RiskModel, Trade -from core.tests.helpers import ElasticMock, LiveBase +from core.models import ( + ActiveManagementPolicy, + Hook, + RiskModel, + Signal, + Trade, + TradingTime, +) +from core.tests.helpers import ElasticMock, LiveBase, StrategyMixin from core.trading import market, risk +from core.trading.active_management import ActiveManagement -class LiveTradingTestCase(ElasticMock, LiveBase, TestCase): +class ActiveManagementMixinTestCase(StrategyMixin): + def setUp(self): + super().setUp() + self.trading_time_all = TradingTime.objects.create( + user=self.user, + name="All", + start_day=1, + start_time=time(0, 0, 0), + end_day=7, + end_time=time(23, 59, 59), + ) + self.active_management_policy = ActiveManagementPolicy.objects.create( + user=self.user, + name="Test Policy", + when_trading_time_violated="close", + when_trends_violated="close", + when_position_size_violated="adjust", + when_protection_violated="adjust", + when_asset_groups_violated="close", + when_max_open_trades_violated="close", + when_max_open_trades_per_symbol_violated="close", + when_max_loss_violated="close", + when_max_risk_violated="close", + when_crossfilter_violated="close", + ) + + self.strategy.active_management_policy = self.active_management_policy + self.strategy.trading_times.set([self.trading_time_all]) + self.strategy.save() + self.ams = ActiveManagement(self.strategy) + + def test_ams_success(self): + complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) + self.open_trade(complex_trade) + + self.ams.run_checks() + self.assertEqual(self.ams.actions, {}) + + self.close_trade(complex_trade) + + def test_ams_trading_time_violated(self): + # Forex market is closed on Saturday and Sunday. + # All of these tests will fail then anyway, so it's the only time + # we can use for this. + trading_time_weekend = TradingTime.objects.create( + user=self.user, + name="Weekend", + start_day=6, + start_time=time(0, 0, 0), + end_day=7, + end_time=time(23, 59, 59), + ) + self.strategy.trading_times.set([trading_time_weekend]) + self.strategy.save() + + complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) + self.open_trade(complex_trade) + + self.ams.run_checks() + expected_actions = { + "close": [ + {"id": complex_trade.order_id, "check": "trading_time", "extra": {}} + ] + } + self.assertEqual(self.ams.actions, expected_actions) + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 1) + + self.ams.execute_actions() + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 0) + + # Don't need to close the trade, it's already closed. + # Otherwise the test would fail. + + def test_ams_trends_violated(self): + hook = Hook.objects.create( + user=self.user, + name="Test Hook", + ) + signal = Signal.objects.create( + user=self.user, + name="Test Signal", + hook=hook, + type="trend", + ) + self.strategy.trend_signals.set([signal]) + self.strategy.trends = {"EUR_USD": "sell"} + self.strategy.save() + + complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0) + self.open_trade(complex_trade) + + self.ams.run_checks() + expected_actions = { + "close": [{"id": complex_trade.order_id, "check": "trends", "extra": {}}] + } + self.assertEqual(self.ams.actions, expected_actions) + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 1) + + self.ams.execute_actions() + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 0) + + # Don't need to close the trade, it's already closed. + # Otherwise the test would fail. + + def test_ams_position_size_violated(self): + self.active_management_policy.when_position_size_violated = "close" + self.active_management_policy.save() + + complex_trade = self.create_complex_trade("buy", 600, "EUR_USD", 1.5, 1.0) + self.open_trade(complex_trade) + + self.ams.run_checks() + expected_actions = { + "close": [ + {"id": complex_trade.order_id, "check": "position_size", "extra": {}} + ] + } + self.assertEqual(len(self.ams.actions["close"]), 1) + del self.ams.actions["close"][0]["extra"]["size"] # We don't know the size + self.assertEqual(self.ams.actions, expected_actions) + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 1) + + self.ams.execute_actions() + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 0) + + def test_ams_position_size_violated_adjust(self): + complex_trade = self.create_complex_trade("buy", 600, "EUR_USD", 1.5, 1.0) + self.open_trade(complex_trade) + + self.ams.run_checks() + # Convert to int to make a whole number + expected = round(self.ams.actions["adjust"][0]["extra"]["size"], 0) + + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 1) + self.assertEqual(trades[0]["currentUnits"], "600") + + self.ams.execute_actions() + trades = self.account.client.get_all_open_trades() + self.assertEqual(len(trades), 1) + self.assertEqual(trades[0]["currentUnits"], str(expected)) + + complex_trade.amount = expected + complex_trade.save() + + self.close_trade(complex_trade) + + def test_ams_protection_violated(self): + pass + + def test_ams_asset_groups_violated(self): + pass + + def test_ams_crossfilter_violated(self): + pass + + def test_ams_max_open_trades_violated(self): + pass + + def test_ams_max_open_trades_per_symbol_violated(self): + pass + + def test_ams_max_loss_violated(self): + pass + + def test_ams_max_risk_violated(self): + pass + + +class LiveTradingTestCase( + ElasticMock, ActiveManagementMixinTestCase, LiveBase, TestCase +): def setUp(self): super(LiveTradingTestCase, self).setUp() self.trade = Trade.objects.create( diff --git a/core/trading/active_management.py b/core/trading/active_management.py index 97c42e7..da9ca2d 100644 --- a/core/trading/active_management.py +++ b/core/trading/active_management.py @@ -95,7 +95,23 @@ class ActiveManagement(object): sendmsg(self.strategy.user, msg, title=f"AMS: {action}") def adjust_position_size(self, trade_id, new_size): - pass # TODO + # Get old size + old_size = None + for trade in self.trades: + if trade["id"] == trade_id: + old_size = D(trade["currentUnits"]) + symbol = trade["symbol"] + break + if old_size is None: + log.error(f"Could not find trade ID {trade_id} in active management") + return + + # Reduce only + assert old_size > new_size + difference = old_size - new_size + + # Close the difference + self.strategy.account.client.close_trade(trade_id, difference, symbol) def adjust_protection(self, trade_id, new_protection): pass # TODO