from datetime import time from decimal import Decimal as D from os import getenv from unittest.mock import Mock, patch from core.models import ( Account, OrderSettings, RiskModel, Strategy, Trade, TradingTime, User, ) from core.trading import market # Create patch mixin to mock out the Elastic client class ElasticMock: """ Mixin to mock out the Elastic client. Patches the initialise_elasticsearch function to return a mock client. This client returns a mock response for the index function. """ @classmethod def setUpClass(cls): super(ElasticMock, cls).setUpClass() cls.patcher = patch("core.lib.elastic.initialise_elasticsearch") # Pretend the object has been created patcher = cls.patcher.start() fake_client = Mock() fake_client.index = Mock() fake_client.index.return_value = {"result": "created"} patcher.return_value = fake_client @classmethod def tearDownClass(cls): super(ElasticMock, cls).tearDownClass() cls.patcher.stop() class SymbolPriceMock: @classmethod def setUpClass(cls): super(SymbolPriceMock, cls).setUpClass() cls.patcher = patch("core.exchanges.common.get_symbol_price") patcher = cls.patcher.start() patcher.return_value = D(1) @classmethod def tearDownClass(cls): super(SymbolPriceMock, cls).tearDownClass() cls.patcher.stop() class LiveBase: @classmethod def tearDownClass(cls): cls.account.client.close_all_positions() super(LiveBase, cls).tearDownClass() @classmethod def setUpClass(cls): super(LiveBase, cls).setUpClass() cls.live = getenv("LIVE", "0") == "1" account_name = getenv("TEST_ACCOUNT_NAME", "Test Account") exchange = getenv("TEST_ACCOUNT_EXCHANGE", None) api_key = getenv("TEST_ACCOUNT_API_KEY", None) api_secret = getenv("TEST_ACCOUNT_API_SECRET", None) if not cls.live: cls.skipTest(LiveBase, "Not running live tests") # Make sure all the variables are set cls.fail = False if not exchange: cls.fail = True reason = "No exchange specified" if not api_key: cls.fail = True reason = "No API key specified" if not api_secret: cls.fail = True reason = "No API secret specified" if cls.fail: print("Live tests aborted.") print( f""" Please check that the following environment variables are set: TEST_ACCOUNT_NAME="Test Account" TEST_ACCOUNT_EXCHANGE="oanda" TEST_ACCOUNT_API_KEY="xxx-xxx-xxxxxxxx-xxx" TEST_ACCOUNT_API_SECRET="xxx-xxx" If you have done this, please see the following line for more information: {reason} """ ) cls.user = User.objects.create_user( username="testuser", email="test@example.com", ) cls.account = Account.objects.create( user=cls.user, name=account_name, exchange=exchange, api_key=api_key, api_secret=api_secret, initial_balance=100000, ) def setUp(self): if self.fail: self.skipTest("Live tests aborted") def open_trade(self, trade=None): if trade: posted = trade.post() else: trade = self.trade posted = self.trade.post() # Check the opened trade self.assertEqual(posted["type"], "MARKET_ORDER") self.assertEqual(posted["symbol"], trade.symbol) if trade.direction == "sell": self.assertEqual(posted["units"], str(0 - trade.amount)) else: self.assertEqual(posted["units"], str(trade.amount)) self.assertEqual(posted["timeInForce"], "FOK") return posted def close_trade(self, trade=None): if trade: trade.refresh_from_db() closed = self.account.client.close_trade(trade.order_id) else: trade = self.trade # refresh the trade to get the trade id self.trade.refresh_from_db() closed = self.account.client.close_trade(self.trade.order_id) # Check the feedback from closing the trade self.assertEqual(closed["type"], "MARKET_ORDER") self.assertEqual(closed["symbol"], trade.symbol) self.assertEqual(closed["units"], str(0 - int(trade.amount))) self.assertEqual(closed["timeInForce"], "FOK") self.assertEqual(closed["reason"], "TRADE_CLOSE") return closed def create_complex_trade(self, direction, amount, symbol, tp_percent, sl_percent): eur_usd_price = market.get_price(self.account, direction, symbol) trade_tp = market.get_tp(direction, tp_percent, eur_usd_price) trade_sl = market.get_sl(direction, sl_percent, eur_usd_price) # trade_tsl = market.get_sl("buy", 1, eur_usd_price, return_var=True) # # TP 1% profit # trade_tp = eur_usd_price * D(1.01) # # SL 2% loss # trade_sl = eur_usd_price * D(0.98) # # TSL 1% loss # trade_tsl = eur_usd_price * D(0.99) trade_precision, display_precision = market.get_precision(self.account, symbol) # Round everything to the display precision trade_tp = round(trade_tp, display_precision) trade_sl = round(trade_sl, display_precision) # trade_tsl = round(trade_tsl, display_precision) complex_trade = Trade.objects.create( user=self.user, account=self.account, symbol=symbol, time_in_force="FOK", type="market", amount=amount, direction=direction, take_profit=trade_tp, stop_loss=trade_sl, # trailing_stop_loss=trade_tsl, ) return complex_trade class StrategyMixin: def setUp(self): super().setUp() self.time_8 = time(8, 0, 0) self.time_16 = time(16, 0, 0) self.order_settings = OrderSettings.objects.create( user=self.user, name="Default" ) self.trading_time_now = TradingTime.objects.create( user=self.user, name="Test Trading Time", start_day=1, # Monday start_time=self.time_8, end_day=1, # Monday end_time=self.time_16, ) self.risk_model = RiskModel.objects.create( user=self.user, name="Test Risk Model", max_loss_percent=50, max_risk_percent=10, max_open_trades=10, max_open_trades_per_symbol=5, ) self.strategy = Strategy.objects.create( user=self.user, name="Test Strategy", account=self.account, order_settings=self.order_settings, risk_model=self.risk_model, active_management_enabled=True, ) self.strategy.trading_times.set([self.trading_time_now]) self.strategy.save()