You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fisk/core/tests/trading/test_active_management.py

549 lines
19 KiB
Python

from decimal import Decimal as D
from unittest.mock import patch
from django.test import TestCase
from core.exchanges.convert import convert_trades
from core.lib.schemas.oanda_s import parse_time
from core.models import (
Account,
ActiveManagementPolicy,
AssetGroup,
AssetRule,
Hook,
Signal,
User,
)
from core.tests.helpers import StrategyMixin, SymbolPriceMock
from core.trading.active_management import ActiveManagement
class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
def setUp(self):
self.user = User.objects.create_user(
username="testuser", email="test@example.com", password="test"
)
self.account = Account.objects.create(
user=self.user,
name="Test Account",
exchange="fake",
currency="USD",
initial_balance=100000,
)
self.account.supported_symbols = ["EUR_USD", "EUR_XXX", "USD_EUR", "XXX_EUR"]
self.account.save()
super().setUp()
self.active_management_policy = ActiveManagementPolicy.objects.create(
user=self.user,
name="Test Policy",
when_trading_time_violated="close",
when_trends_violated="close",
when_position_size_violated="close",
when_protection_violated="close",
when_asset_groups_violated="close",
when_max_open_trades_violated="close",
when_max_open_trades_per_symbol_violated="close",
when_max_loss_violated="close",
when_max_risk_violated="close",
when_crossfilter_violated="close",
)
self.strategy.active_management_policy = self.active_management_policy
self.strategy.save()
self.ams = ActiveManagement(self.strategy)
self.trades = [
{
"id": "20083",
"symbol": "EUR_USD",
"price": "1.06331",
"openTime": "2023-02-13T11:38:06.302917985Z", # Monday at 11:38
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
"currentUnits": "10",
"realizedPL": "0.0000",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008",
"marginUsed": "0.2966",
"takeProfitOrder": {"price": "1.07934"},
"stopLossOrder": {"price": "1.05276"},
"trailingStopLossOrder": None,
"trailingStopValue": None,
"side": "long",
},
{
"id": "20084",
"symbol": "EUR_USD",
"price": "1.06331",
"openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:39
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
"currentUnits": "10",
"realizedPL": "0.0000",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008",
"marginUsed": "0.2966",
"takeProfitOrder": {"price": "1.07934"},
"stopLossOrder": {"price": "1.05276"},
"trailingStopLossOrder": None,
"trailingStopValue": None,
"side": "long",
},
]
# Run parse_time on all items in trades
for trade in self.trades:
trade["openTime"] = parse_time(trade)
self.balance = 100000
self.balance_usd = 120000
self.ams.get_trades = self.fake_get_trades
self.ams.get_balance = self.fake_get_balance
# self.ams.trades = self.trades
def add_trade(self, id, symbol, side, open_time):
trade = {
"id": id,
"symbol": symbol,
"price": "1.06331",
"openTime": open_time,
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
"currentUnits": "10",
"realizedPL": "0.0000",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008",
"marginUsed": "0.2966",
"takeProfitOrder": {"price": "1.07934"},
"stopLossOrder": {"price": "1.05276"},
"trailingStopLossOrder": None,
"trailingStopValue": None,
"side": side,
}
trade["openTime"] = parse_time(trade)
self.trades.append(trade)
def amend_tp_sl_flip_side(self):
"""
Amend the take profit and stop loss orders to be the opposite side.
This lets the protection tests pass, so we can only test one violation
per test.
"""
for trade in self.trades:
if trade["side"] == "short":
trade["stopLossOrder"]["price"] = "1.07386"
trade["takeProfitOrder"]["price"] = "1.04728"
def fake_get_trades(self):
self.ams.trades = self.trades
return self.trades
def fake_get_balance(self, return_usd=None):
if return_usd:
return self.balance_usd
return self.balance
def fake_get_currencies(self, symbols):
pass
def test_get_trades(self):
trades = self.ams.get_trades()
self.assertEqual(trades, self.trades)
def test_get_balance(self):
balance = self.ams.get_balance()
self.assertEqual(balance, self.balance)
def check_violation(
self, violation, calls, expected_action, expected_trades, expected_args=None
):
"""
Check that the violation was called with the expected action and trades.
Matches the first argument of the call to the violation name.
:param: violation: type of the violation to check against
:param: calls: list of calls to the violation
:param: expected_action: expected action to be called, close, notify, etc.
:param: expected_trades: list of expected trades to be passed to the violation
:param: expected_args: optional, expected args to be passed to the violation
"""
self.assertEqual(len(calls), len(expected_trades))
calls = list(calls)
violation_calls = []
for call in calls:
if call[0][0] == violation:
violation_calls.append(call)
self.assertEqual(len(violation_calls), len(expected_trades))
if all(expected_trades):
expected_trades = convert_trades(expected_trades)
for call in violation_calls:
# Ensure the correct action has been called, like close
self.assertEqual(call[0][1], expected_action)
# Ensure the correct trade has been passed to the violation
trade = call[0][2]
if trade:
for field in list(trade.keys()):
if "_usd" in field:
if field in trade.keys():
del trade[field]
self.assertIn(call[0][2], expected_trades)
if expected_args:
self.assertEqual(call[0][3], expected_args)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_run_checks(self, handle_violation):
self.ams.run_checks()
self.assertEqual(handle_violation.call_count, 0)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trading_time_violated(self, handle_violation):
self.trades[0]["openTime"] = "2023-02-17T11:38:06.302917Z" # Friday
self.ams.run_checks()
self.check_violation(
"trading_time", handle_violation.call_args_list, "close", [self.trades[0]]
)
def create_hook_signal(self):
hook = Hook.objects.create(
user=self.user,
name="Test Hook",
)
signal = Signal.objects.create(
user=self.user,
name="Test Signal",
hook=hook,
type="trend",
)
return signal
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated(self, handle_violation):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.amend_tp_sl_flip_side()
self.strategy.save()
self.ams.run_checks()
self.check_violation(
"trends", handle_violation.call_args_list, "close", self.trades
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated_none(self, handle_violation):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "buy"}
self.strategy.save()
self.ams.run_checks()
self.check_violation("trends", handle_violation.call_args_list, "close", [])
# Mock crossfilter here since we want to allow this conflict in order to test that
# trends only close trades that are in the wrong direction
@patch(
"core.trading.active_management.ActiveManagement.check_crossfilter",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated_partial(self, handle_violation, check_crossfilter):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.strategy.save()
# Change the side of the first trade to match the trends
self.trades[0]["side"] = "short"
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.check_violation(
"trends", handle_violation.call_args_list, "close", [self.trades[1]]
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_position_size_violated(self, handle_violation):
self.trades[0]["currentUnits"] = "100000"
self.ams.run_checks()
self.check_violation(
"position_size",
handle_violation.call_args_list,
"close",
[self.trades[0]],
{"size": 500},
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_protection_violated_absent(self, handle_violation):
self.trades[0]["takeProfitOrder"] = None
self.trades[0]["stopLossOrder"] = None
self.ams.run_checks()
expected_args = {
"take_profit_price": D("1.07934"),
"stop_loss_price": D("1.05276"),
}
self.check_violation(
"protection",
handle_violation.call_args_list,
"close",
[self.trades[0]],
expected_args,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_protection_violated_absent_not_required(self, handle_violation):
self.strategy.order_settings.take_profit_percent = 0
self.strategy.order_settings.stop_loss_percent = 0
self.strategy.order_settings.save()
self.trades[0]["takeProfitOrder"] = None
self.trades[0]["stopLossOrder"] = None
self.ams.run_checks()
self.assertEqual(handle_violation.call_count, 0)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_asset_groups_violated(self, handle_violation):
asset_group = AssetGroup.objects.create(
user=self.user,
name="Test Asset Group",
)
AssetRule.objects.create(
user=self.user,
asset="USD",
group=asset_group,
status=2, # Bullish
)
self.strategy.asset_group = asset_group
self.strategy.save()
self.ams.run_checks()
self.check_violation(
"asset_group",
handle_violation.call_args_list,
"close",
self.trades, # All trades should be closed, since all are USD quote
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_asset_groups_violated_invert(self, handle_violation):
self.trades[0]["side"] = "short"
self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
asset_group = AssetGroup.objects.create(
user=self.user,
name="Test Asset Group",
)
AssetRule.objects.create(
user=self.user,
asset="USD",
group=asset_group,
status=3, # Bullish
)
self.strategy.asset_group = asset_group
self.strategy.save()
self.ams.run_checks()
self.check_violation(
"asset_group",
handle_violation.call_args_list,
"close",
self.trades, # All trades should be closed, since all are USD quote
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_side(self, handle_violation):
self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.check_violation(
"crossfilter",
handle_violation.call_args_list,
"close",
[self.trades[1]], # Only close newer trade
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_side_multiple(self, handle_violation):
self.add_trade(
"20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z"
) # 2:
self.add_trade("20086", "EUR_USD", "short", "2023-02-13T13:39:07.302917985Z")
self.add_trade("20087", "EUR_USD", "short", "2023-02-13T14:39:06.302917985Z")
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.check_violation(
"crossfilter",
handle_violation.call_args_list,
"close",
self.trades[2:], # Only close newer trades
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_symbol(self, handle_violation):
# Change symbol to conflict with long on EUR_USD
self.trades[1]["symbol"] = "USD_EUR"
self.ams.run_checks()
self.check_violation(
"crossfilter",
handle_violation.call_args_list,
"close",
[self.trades[1]], # Only close newer trade
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_symbol_multiple(self, handle_violation):
self.add_trade(
"20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z"
) # 2:
self.add_trade("20086", "USD_EUR", "long", "2023-02-13T13:39:06.302917985Z")
self.add_trade("20087", "USD_EUR", "long", "2023-02-13T14:39:06.302917985Z")
self.ams.run_checks()
self.check_violation(
"crossfilter",
handle_violation.call_args_list,
"close",
self.trades[2:], # Only close newer trades
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_open_trades_violated(self, handle_violation):
for x in range(9):
self.add_trade(
str(x),
f"EUR_USD{x}", # Vary symbol to prevent max open trades per symbol
"long",
f"2023-02-13T12:39:1{x}.302917985Z",
)
self.ams.run_checks()
self.check_violation(
"max_open_trades",
handle_violation.call_args_list,
"close",
self.trades[10:], # Only close newer trades
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_open_trades_per_symbol_violated(self, handle_violation):
for x in range(4):
self.add_trade(
str(x),
"EUR_USD",
"long",
f"2023-02-13T12:39:1{x}.302917985Z",
)
self.ams.run_checks()
self.check_violation(
"max_open_trades_per_symbol",
handle_violation.call_args_list,
"close",
self.trades[5:], # Only close newer trades
)
# Mock position size as we have no way of checking the balance at the start of the
# trade.
# TODO: Fix this when we have a way of checking the balance at the start of the
# trade.
# Max risk is also mocked as this puts us over the limit, due to the low account
# size.
@patch(
"core.trading.active_management.ActiveManagement.check_max_risk",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_loss_violated(
self, handle_violation, check_position_size, check_max_risk
):
self.balance = D("1")
self.balance_usd = D("0.69")
self.ams.run_checks()
self.check_violation(
"max_loss",
handle_violation.call_args_list,
"close",
[None],
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_protection",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_risk_violated(
self, handle_violation, check_protection, check_position_size
):
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.ams.run_checks()
self.check_violation(
"max_risk",
handle_violation.call_args_list,
"close",
[self.trades[2]],
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_protection",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_risk_violated_multiple(
self, handle_violation, check_protection, check_position_size
):
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.add_trade(
"20086",
"EUR_USD",
"long",
"2023-02-13T15:45:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.trades[3]["stopLossOrder"]["price"] = "0.001"
self.trades[3]["currentUnits"] = "13000"
self.ams.run_checks()
self.check_violation(
"max_risk",
handle_violation.call_args_list,
"close",
[self.trades[2], self.trades[3]],
)