You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

487 lines
17 KiB
Python

from datetime import time
from decimal import Decimal as D
from unittest.mock import patch
from django.test import TestCase
from core.exchanges.convert import convert_trades
from core.models import (
ActiveManagementPolicy,
Hook,
RiskModel,
Signal,
Trade,
TradingTime,
)
from core.tests.helpers import ElasticMock, LiveBase, StrategyMixin
from core.trading import market, risk
from core.trading.active_management import ActiveManagement
class ActiveManagementMixinTestCase(StrategyMixin):
def setUp(self):
super().setUp()
self.trading_time_all = TradingTime.objects.create(
user=self.user,
name="All",
start_day=1,
start_time=time(0, 0, 0),
end_day=7,
end_time=time(23, 59, 59),
)
self.active_management_policy = ActiveManagementPolicy.objects.create(
user=self.user,
name="Test Policy",
when_trading_time_violated="close",
when_trends_violated="close",
when_position_size_violated="adjust",
when_protection_violated="adjust",
when_asset_groups_violated="close",
when_max_open_trades_violated="close",
when_max_open_trades_per_symbol_violated="close",
when_max_loss_violated="close",
when_max_risk_violated="close",
when_crossfilter_violated="close",
)
self.strategy.active_management_policy = self.active_management_policy
self.strategy.trading_times.set([self.trading_time_all])
self.strategy.save()
self.ams = ActiveManagement(self.strategy)
def test_ams_success(self):
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(complex_trade)
self.ams.run_checks()
self.assertEqual(self.ams.actions, {})
self.close_trade(complex_trade)
def test_ams_trading_time_violated(self):
# Forex market is closed on Saturday and Sunday.
# All of these tests will fail then anyway, so it's the only time
# we can use for this.
trading_time_weekend = TradingTime.objects.create(
user=self.user,
name="Weekend",
start_day=6,
start_time=time(0, 0, 0),
end_day=7,
end_time=time(23, 59, 59),
)
self.strategy.trading_times.set([trading_time_weekend])
self.strategy.save()
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(complex_trade)
self.ams.run_checks()
expected_actions = {
"close": [
{"id": complex_trade.order_id, "check": "trading_time", "extra": {}}
]
}
self.assertEqual(self.ams.actions, expected_actions)
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 1)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 0)
# Don't need to close the trade, it's already closed.
# Otherwise the test would fail.
def test_ams_trends_violated(self):
hook = Hook.objects.create(
user=self.user,
name="Test Hook",
)
signal = Signal.objects.create(
user=self.user,
name="Test Signal",
hook=hook,
type="trend",
)
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.strategy.save()
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(complex_trade)
self.ams.run_checks()
expected_actions = {
"close": [{"id": complex_trade.order_id, "check": "trends", "extra": {}}]
}
self.assertEqual(self.ams.actions, expected_actions)
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 1)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 0)
# Don't need to close the trade, it's already closed.
# Otherwise the test would fail.
def test_ams_position_size_violated(self):
self.active_management_policy.when_position_size_violated = "close"
self.active_management_policy.save()
complex_trade = self.create_complex_trade("buy", 600, "EUR_USD", 1.5, 1.0)
self.open_trade(complex_trade)
self.ams.run_checks()
expected_actions = {
"close": [
{"id": complex_trade.order_id, "check": "position_size", "extra": {}}
]
}
self.assertEqual(len(self.ams.actions["close"]), 1)
del self.ams.actions["close"][0]["extra"]["size"] # We don't know the size
self.assertEqual(self.ams.actions, expected_actions)
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 1)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 0)
def test_ams_position_size_violated_adjust(self):
complex_trade = self.create_complex_trade("buy", 600, "EUR_USD", 1.5, 1.0)
self.open_trade(complex_trade)
self.ams.run_checks()
# Convert to int to make a whole number
expected = round(self.ams.actions["adjust"][0]["extra"]["size"], 0)
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 1)
self.assertEqual(trades[0]["currentUnits"], "600")
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 1)
self.assertEqual(trades[0]["currentUnits"], str(expected))
complex_trade.amount = expected
complex_trade.save()
self.close_trade(complex_trade)
def test_ams_protection_violated(self):
self.active_management_policy.when_protection_violated = "close"
self.active_management_policy.save()
# Don't violate position size check
trade_size = market.get_trade_size_in_base(
"buy", self.account, self.strategy, self.account.client.get_balance(), "EUR"
)
trade_size = round(trade_size, 0)
print("TRADE SIZE", trade_size)
print("TYPE", type(trade_size))
complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5)
self.open_trade(complex_trade)
self.ams.run_checks()
self.assertEqual(len(self.ams.actions["close"]), 1)
expected = {
"close": [
{
"id": complex_trade.id,
"check": "protection",
}
]
}
del self.ams.checks["close"][0]["extra"]
self.assertEqual(self.ams.checks, expected)
def test_ams_protection_violated_adjust(self):
# Don't violate position size check
trade_size = market.get_trade_size_in_base(
"buy", self.account, self.strategy, self.account.client.get_balance(), "EUR"
)
trade_size = round(trade_size, 0)
print("TRADE SIZE", trade_size)
print("TYPE", type(trade_size))
complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5)
self.open_trade(complex_trade)
self.ams.run_checks()
expected = {
"adjust": [
{
"id": "21381",
"check": "protection",
"extra": {
"stop_loss_price": D("1.05812"),
"take_profit_price": D("1.08484"),
},
}
]
}
print("CHECKS", self.ams.actions)
def test_ams_asset_groups_violated(self):
pass
def test_ams_crossfilter_violated(self):
pass
def test_ams_max_open_trades_violated(self):
pass
def test_ams_max_open_trades_per_symbol_violated(self):
pass
def test_ams_max_loss_violated(self):
pass
def test_ams_max_risk_violated(self):
pass
class LiveTradingTestCase(
ElasticMock, ActiveManagementMixinTestCase, LiveBase, TestCase
):
def setUp(self):
super(LiveTradingTestCase, self).setUp()
self.trade = Trade.objects.create(
user=self.user,
account=self.account,
symbol="EUR_USD",
time_in_force="FOK",
type="market",
amount=10,
direction="buy",
)
self.commission = 0.025
self.risk_model = RiskModel.objects.create(
user=self.user,
name="Test Risk Model",
max_loss_percent=4,
max_risk_percent=2,
max_open_trades=3,
max_open_trades_per_symbol=2,
)
def test_account_functional(self):
"""
Test that the account is functional.
"""
balance = self.account.client.get_balance()
# We need some money to place trades
self.assertTrue(balance > 1000)
def open_trade(self, trade=None):
if trade:
posted = trade.post()
else:
trade = self.trade
posted = self.trade.post()
# Check the opened trade
self.assertEqual(posted["type"], "MARKET_ORDER")
self.assertEqual(posted["symbol"], trade.symbol)
self.assertEqual(posted["units"], str(trade.amount))
self.assertEqual(posted["timeInForce"], "FOK")
return posted
def close_trade(self, trade=None):
if trade:
trade.refresh_from_db()
closed = self.account.client.close_trade(trade.order_id)
else:
trade = self.trade
# refresh the trade to get the trade id
self.trade.refresh_from_db()
closed = self.account.client.close_trade(self.trade.order_id)
# Check the feedback from closing the trade
self.assertEqual(closed["type"], "MARKET_ORDER")
self.assertEqual(closed["symbol"], trade.symbol)
self.assertEqual(closed["units"], str(0 - int(trade.amount)))
self.assertEqual(closed["timeInForce"], "FOK")
self.assertEqual(closed["reason"], "TRADE_CLOSE")
return closed
def test_place_close_trade(self):
"""
Test placing a trade.
"""
self.open_trade()
self.close_trade()
def test_get_all_open_trades(self):
"""
Test getting all open trades.
"""
self.open_trade()
trades = self.account.client.get_all_open_trades()
self.trade.refresh_from_db()
found = False
for trade in trades:
if trade["id"] == self.trade.order_id:
self.assertEqual(trade["symbol"], "EUR_USD")
self.assertEqual(trade["currentUnits"], "10")
self.assertEqual(trade["initialUnits"], "10")
self.assertEqual(trade["state"], "OPEN")
found = True
break
self.close_trade()
if not found:
self.fail("Could not find the trade in the list of open trades")
def create_complex_trade(self, direction, amount, symbol, tp_percent, sl_percent):
eur_usd_price = market.get_price(self.account, direction, symbol)
trade_tp = market.get_tp(direction, tp_percent, eur_usd_price)
trade_sl = market.get_sl(direction, sl_percent, eur_usd_price)
# trade_tsl = market.get_sl("buy", 1, eur_usd_price, return_var=True)
# # TP 1% profit
# trade_tp = eur_usd_price * D(1.01)
# # SL 2% loss
# trade_sl = eur_usd_price * D(0.98)
# # TSL 1% loss
# trade_tsl = eur_usd_price * D(0.99)
trade_precision, display_precision = market.get_precision(self.account, symbol)
# Round everything to the display precision
trade_tp = round(trade_tp, display_precision)
trade_sl = round(trade_sl, display_precision)
# trade_tsl = round(trade_tsl, display_precision)
complex_trade = Trade.objects.create(
user=self.user,
account=self.account,
symbol=symbol,
time_in_force="FOK",
type="market",
amount=amount,
direction=direction,
take_profit=trade_tp,
stop_loss=trade_sl,
# trailing_stop_loss=trade_tsl,
)
return complex_trade
@patch("core.exchanges.oanda.OANDAExchange.get_balance", return_value=100000)
def test_check_risk_max_risk_pass(self, mock_balance):
# SL of 19% on a 10000 trade on a 100000 account is 1.8 loss
# Should be comfortably under 2% risk
trade = self.create_complex_trade("buy", 10000, "EUR_USD", 1, 18)
allowed = risk.check_risk(self.risk_model, self.account, trade)
self.assertTrue(allowed["allowed"])
@patch("core.exchanges.oanda.OANDAExchange.get_balance", return_value=100000)
def test_check_risk_max_risk_fail(self, mock_balance):
# SL of 21% on a 10000 trade on a 100000 account is 2.2 loss
# Should be over 2% risk
trade = self.create_complex_trade("buy", 10000, "EUR_USD", 1, 22)
allowed = risk.check_risk(self.risk_model, self.account, trade)
self.assertFalse(allowed["allowed"])
self.assertEqual(allowed["reason"], "Maximum risk exceeded.")
@patch("core.exchanges.oanda.OANDAExchange.get_balance", return_value=94000)
# We have lost 6% of our account
def test_check_risk_max_loss_fail(self, mock_balance):
# Doesn't matter, shouldn't get as far as the trade
trade = self.create_complex_trade("buy", 1, "EUR_USD", 1, 1)
allowed = risk.check_risk(self.risk_model, self.account, trade)
self.assertFalse(allowed["allowed"])
self.assertEqual(allowed["reason"], "Maximum loss exceeded.")
@patch("core.exchanges.oanda.OANDAExchange.get_balance", return_value=100000)
def test_check_risk_max_open_trades_fail(self, mock_balance):
# The maximum open trades is 3. Let's open 2 trades
# This would not be allowed by the risk model but we're doing it
# manually
# Don't be confused by the next test. The max open trades check
# fails before the symbol one is run, but yes, they would both
# fail.
trade1 = self.create_complex_trade("buy", 1, "EUR_USD", 1, 1)
self.open_trade(trade1)
trade2 = self.create_complex_trade("buy", 1, "EUR_USD", 1, 1)
self.open_trade(trade2)
trade3 = self.create_complex_trade("buy", 1, "EUR_USD", 1, 1)
allowed = risk.check_risk(self.risk_model, self.account, trade3)
self.assertFalse(allowed["allowed"])
self.assertEqual(allowed["reason"], "Maximum open trades exceeded.")
self.close_trade(trade1)
self.close_trade(trade2)
@patch("core.exchanges.oanda.OANDAExchange.get_balance", return_value=100000)
def test_check_risk_max_open_trades_per_symbol_fail(self, mock_balance):
trade1 = self.create_complex_trade("buy", 1, "EUR_USD", 1, 1)
self.open_trade(trade1)
trade2 = self.create_complex_trade("buy", 1, "EUR_USD", 1, 1)
allowed = risk.check_risk(self.risk_model, self.account, trade2)
self.assertFalse(allowed["allowed"])
self.assertEqual(allowed["reason"], "Maximum open trades per symbol exceeded.")
self.close_trade(trade1)
def test_convert_trades(self):
"""
Test converting open trades response to Trade-like format.
"""
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1, 2)
self.open_trade(complex_trade)
# Get and annotate the trades
trades = self.account.client.get_all_open_trades()
trades_converted = convert_trades(trades)
# Check the converted trades
self.assertEqual(len(trades_converted), 1)
expected_tp_percent = D(1 - self.commission)
expected_sl_percent = D(2 - self.commission)
actual_tp_percent = trades_converted[0]["take_profit_percent"]
actual_sl_percent = trades_converted[0]["stop_loss_percent"]
tp_percent_difference = abs(expected_tp_percent - actual_tp_percent)
sl_percent_difference = abs(expected_sl_percent - actual_sl_percent)
max_difference = D(0.08) # depends on market conditions
self.assertLess(tp_percent_difference, max_difference)
self.assertLess(sl_percent_difference, max_difference)
# Convert the trades to USD
trades_usd = market.convert_trades_to_usd(self.account, trades_converted)
# Convert the trade to USD ourselves
trade_in_usd = D(trades_usd[0]["amount"]) * D(trades_usd[0]["current_price"])
# It will never be perfect, but let's check it's at least close
trade_usd_conversion_difference = (
trades_usd[0]["trade_amount_usd"] - trade_in_usd
)
self.assertLess(trade_usd_conversion_difference, D(0.01))
# Check the converted TP and SL values
trade_usd_tp_difference = trades_usd[0]["take_profit_usd"] - D(0.1)
trade_usd_sl_difference = trades_usd[0]["stop_loss_usd"] - D(0.2)
self.assertLess(trade_usd_tp_difference, D(0.01))
self.assertLess(trade_usd_sl_difference, D(0.02))
self.close_trade(complex_trade)