149 lines
4.3 KiB
Python
149 lines
4.3 KiB
Python
from datetime import time
|
|
from decimal import Decimal as D
|
|
from os import getenv
|
|
from unittest.mock import Mock, patch
|
|
|
|
from core.models import Account, OrderSettings, RiskModel, Strategy, TradingTime, User
|
|
|
|
|
|
# Create patch mixin to mock out the Elastic client
|
|
class ElasticMock:
|
|
"""
|
|
Mixin to mock out the Elastic client.
|
|
Patches the initialise_elasticsearch function to return a mock client.
|
|
This client returns a mock response for the index function.
|
|
"""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super(ElasticMock, cls).setUpClass()
|
|
cls.patcher = patch("core.lib.elastic.initialise_elasticsearch")
|
|
# Pretend the object has been created
|
|
patcher = cls.patcher.start()
|
|
|
|
fake_client = Mock()
|
|
fake_client.index = Mock()
|
|
fake_client.index.return_value = {"result": "created"}
|
|
|
|
patcher.return_value = fake_client
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super(ElasticMock, cls).tearDownClass()
|
|
cls.patcher.stop()
|
|
|
|
|
|
class SymbolPriceMock:
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super(SymbolPriceMock, cls).setUpClass()
|
|
cls.patcher = patch("core.exchanges.common.get_symbol_price")
|
|
patcher = cls.patcher.start()
|
|
|
|
patcher.return_value = D(1)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super(SymbolPriceMock, cls).tearDownClass()
|
|
cls.patcher.stop()
|
|
|
|
|
|
class LiveBase:
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.account.client.close_all_positions()
|
|
super(LiveBase, cls).tearDownClass()
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super(LiveBase, cls).setUpClass()
|
|
cls.live = getenv("LIVE", "0") == "1"
|
|
|
|
account_name = getenv("TEST_ACCOUNT_NAME", "Test Account")
|
|
exchange = getenv("TEST_ACCOUNT_EXCHANGE", None)
|
|
api_key = getenv("TEST_ACCOUNT_API_KEY", None)
|
|
api_secret = getenv("TEST_ACCOUNT_API_SECRET", None)
|
|
if not cls.live:
|
|
cls.skipTest(LiveBase, "Not running live tests")
|
|
|
|
# Make sure all the variables are set
|
|
cls.fail = False
|
|
if not exchange:
|
|
cls.fail = True
|
|
reason = "No exchange specified"
|
|
if not api_key:
|
|
cls.fail = True
|
|
reason = "No API key specified"
|
|
if not api_secret:
|
|
cls.fail = True
|
|
reason = "No API secret specified"
|
|
if cls.fail:
|
|
print("Live tests aborted.")
|
|
print(
|
|
f"""
|
|
Please check that the following environment variables are set:
|
|
TEST_ACCOUNT_NAME="Test Account"
|
|
TEST_ACCOUNT_EXCHANGE="oanda"
|
|
TEST_ACCOUNT_API_KEY="xxx-xxx-xxxxxxxx-xxx"
|
|
TEST_ACCOUNT_API_SECRET="xxx-xxx"
|
|
|
|
If you have done this, please see the following line for more information:
|
|
{reason}
|
|
"""
|
|
)
|
|
|
|
cls.user = User.objects.create_user(
|
|
username="testuser",
|
|
email="test@example.com",
|
|
)
|
|
|
|
cls.account = Account.objects.create(
|
|
user=cls.user,
|
|
name=account_name,
|
|
exchange=exchange,
|
|
api_key=api_key,
|
|
api_secret=api_secret,
|
|
initial_balance=100000,
|
|
)
|
|
|
|
def setUp(self):
|
|
if self.fail:
|
|
self.skipTest("Live tests aborted")
|
|
|
|
|
|
class StrategyMixin:
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.time_8 = time(8, 0, 0)
|
|
self.time_16 = time(16, 0, 0)
|
|
self.order_settings = OrderSettings.objects.create(
|
|
user=self.user, name="Default"
|
|
)
|
|
self.trading_time_now = TradingTime.objects.create(
|
|
user=self.user,
|
|
name="Test Trading Time",
|
|
start_day=1, # Monday
|
|
start_time=self.time_8,
|
|
end_day=1, # Monday
|
|
end_time=self.time_16,
|
|
)
|
|
self.risk_model = RiskModel.objects.create(
|
|
user=self.user,
|
|
name="Test Risk Model",
|
|
max_loss_percent=50,
|
|
max_risk_percent=10,
|
|
max_open_trades=10,
|
|
max_open_trades_per_symbol=5,
|
|
)
|
|
|
|
self.strategy = Strategy.objects.create(
|
|
user=self.user,
|
|
name="Test Strategy",
|
|
account=self.account,
|
|
order_settings=self.order_settings,
|
|
risk_model=self.risk_model,
|
|
active_management_enabled=True,
|
|
)
|
|
self.strategy.trading_times.set([self.trading_time_now])
|
|
self.strategy.save()
|