You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

227 lines
7.0 KiB
Python

from datetime import time
from decimal import Decimal as D
from os import getenv
from unittest.mock import Mock, patch
from core.models import (
Account,
OrderSettings,
RiskModel,
Strategy,
Trade,
TradingTime,
User,
)
from core.trading import market
# Create patch mixin to mock out the Elastic client
class ElasticMock:
"""
Mixin to mock out the Elastic client.
Patches the initialise_elasticsearch function to return a mock client.
This client returns a mock response for the index function.
"""
@classmethod
def setUpClass(cls):
super(ElasticMock, cls).setUpClass()
cls.patcher = patch("core.lib.elastic.initialise_elasticsearch")
# Pretend the object has been created
patcher = cls.patcher.start()
fake_client = Mock()
fake_client.index = Mock()
fake_client.index.return_value = {"result": "created"}
patcher.return_value = fake_client
@classmethod
def tearDownClass(cls):
super(ElasticMock, cls).tearDownClass()
cls.patcher.stop()
class SymbolPriceMock:
@classmethod
def setUpClass(cls):
super(SymbolPriceMock, cls).setUpClass()
cls.patcher = patch("core.exchanges.common.get_symbol_price")
patcher = cls.patcher.start()
patcher.return_value = D(1)
@classmethod
def tearDownClass(cls):
super(SymbolPriceMock, cls).tearDownClass()
cls.patcher.stop()
class LiveBase:
@classmethod
def tearDownClass(cls):
cls.account.client.close_all_positions()
super(LiveBase, cls).tearDownClass()
@classmethod
def setUpClass(cls):
super(LiveBase, cls).setUpClass()
cls.live = getenv("LIVE", "0") == "1"
account_name = getenv("TEST_ACCOUNT_NAME", "Test Account")
exchange = getenv("TEST_ACCOUNT_EXCHANGE", None)
api_key = getenv("TEST_ACCOUNT_API_KEY", None)
api_secret = getenv("TEST_ACCOUNT_API_SECRET", None)
if not cls.live:
cls.skipTest(LiveBase, "Not running live tests")
# Make sure all the variables are set
cls.fail = False
if not exchange:
cls.fail = True
reason = "No exchange specified"
if not api_key:
cls.fail = True
reason = "No API key specified"
if not api_secret:
cls.fail = True
reason = "No API secret specified"
if cls.fail:
print("Live tests aborted.")
print(
f"""
Please check that the following environment variables are set:
TEST_ACCOUNT_NAME="Test Account"
TEST_ACCOUNT_EXCHANGE="oanda"
TEST_ACCOUNT_API_KEY="xxx-xxx-xxxxxxxx-xxx"
TEST_ACCOUNT_API_SECRET="xxx-xxx"
If you have done this, please see the following line for more information:
{reason}
"""
)
cls.user = User.objects.create_user(
username="testuser",
email="test@example.com",
)
cls.account = Account.objects.create(
user=cls.user,
name=account_name,
exchange=exchange,
api_key=api_key,
api_secret=api_secret,
initial_balance=100000,
)
def setUp(self):
if self.fail:
self.skipTest("Live tests aborted")
def open_trade(self, trade=None):
if trade:
posted = trade.post()
else:
trade = self.trade
posted = self.trade.post()
# Check the opened trade
self.assertEqual(posted["type"], "MARKET_ORDER")
self.assertEqual(posted["symbol"], trade.symbol)
if trade.direction == "sell":
self.assertEqual(posted["units"], str(0 - trade.amount))
else:
self.assertEqual(posted["units"], str(trade.amount))
self.assertEqual(posted["timeInForce"], "FOK")
return posted
def close_trade(self, trade=None):
if trade:
trade.refresh_from_db()
closed = self.account.client.close_trade(trade.order_id)
else:
trade = self.trade
# refresh the trade to get the trade id
self.trade.refresh_from_db()
closed = self.account.client.close_trade(self.trade.order_id)
# Check the feedback from closing the trade
self.assertEqual(closed["type"], "MARKET_ORDER")
self.assertEqual(closed["symbol"], trade.symbol)
self.assertEqual(closed["units"], str(0 - int(trade.amount)))
self.assertEqual(closed["timeInForce"], "FOK")
self.assertEqual(closed["reason"], "TRADE_CLOSE")
return closed
def create_complex_trade(self, direction, amount, symbol, tp_percent, sl_percent):
eur_usd_price = market.get_price(self.account, direction, symbol)
trade_tp = market.get_tp(direction, tp_percent, eur_usd_price)
trade_sl = market.get_sl(direction, sl_percent, eur_usd_price)
# trade_tsl = market.get_sl("buy", 1, eur_usd_price, return_var=True)
# # TP 1% profit
# trade_tp = eur_usd_price * D(1.01)
# # SL 2% loss
# trade_sl = eur_usd_price * D(0.98)
# # TSL 1% loss
# trade_tsl = eur_usd_price * D(0.99)
trade_precision, display_precision = market.get_precision(self.account, symbol)
# Round everything to the display precision
trade_tp = round(trade_tp, display_precision)
trade_sl = round(trade_sl, display_precision)
# trade_tsl = round(trade_tsl, display_precision)
complex_trade = Trade.objects.create(
user=self.user,
account=self.account,
symbol=symbol,
time_in_force="FOK",
type="market",
amount=amount,
direction=direction,
take_profit=trade_tp,
stop_loss=trade_sl,
# trailing_stop_loss=trade_tsl,
)
return complex_trade
class StrategyMixin:
def setUp(self):
super().setUp()
self.time_8 = time(8, 0, 0)
self.time_16 = time(16, 0, 0)
self.order_settings = OrderSettings.objects.create(
user=self.user, name="Default"
)
self.trading_time_now = TradingTime.objects.create(
user=self.user,
name="Test Trading Time",
start_day=1, # Monday
start_time=self.time_8,
end_day=1, # Monday
end_time=self.time_16,
)
self.risk_model = RiskModel.objects.create(
user=self.user,
name="Test Risk Model",
max_loss_percent=50,
max_risk_percent=10,
max_open_trades=10,
max_open_trades_per_symbol=5,
)
self.strategy = Strategy.objects.create(
user=self.user,
name="Test Strategy",
account=self.account,
order_settings=self.order_settings,
risk_model=self.risk_model,
active_management_enabled=True,
)
self.strategy.trading_times.set([self.trading_time_now])
self.strategy.save()