fisk/core/tests/trading/test_active_management.py

1004 lines
35 KiB
Python
Raw Normal View History

from decimal import Decimal as D
from unittest.mock import patch
2023-02-17 17:05:52 +00:00
2023-02-17 07:20:15 +00:00
from django.test import TestCase
2023-02-17 17:05:52 +00:00
from core.lib.schemas.oanda_s import parse_time
from core.models import (
Account,
ActiveManagementPolicy,
AssetGroup,
AssetRule,
Hook,
Signal,
User,
)
2023-02-17 07:20:15 +00:00
from core.tests.helpers import StrategyMixin, SymbolPriceMock
from core.trading.active_management import ActiveManagement
2023-02-17 17:05:52 +00:00
2023-02-17 07:20:15 +00:00
class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
def setUp(self):
self.user = User.objects.create_user(
username="testuser", email="test@example.com", password="test"
)
self.account = Account.objects.create(
user=self.user,
name="Test Account",
exchange="fake",
currency="USD",
initial_balance=100000,
2023-02-17 07:20:15 +00:00
)
self.account.supported_symbols = ["EUR_USD", "EUR_XXX", "USD_EUR", "XXX_EUR"]
self.account.save()
super().setUp()
self.active_management_policy = ActiveManagementPolicy.objects.create(
user=self.user,
name="Test Policy",
when_trading_time_violated="close",
when_trends_violated="close",
when_position_size_violated="close",
when_protection_violated="close",
when_asset_groups_violated="close",
when_max_open_trades_violated="close",
when_max_open_trades_per_symbol_violated="close",
when_max_loss_violated="close",
when_max_risk_violated="close",
when_crossfilter_violated="close",
)
self.strategy.active_management_policy = self.active_management_policy
self.strategy.save()
self.ams = ActiveManagement(self.strategy)
self.trades = [
{
"id": "20083",
"symbol": "EUR_USD",
"price": "1.06331",
2023-02-17 17:05:52 +00:00
"openTime": "2023-02-13T11:38:06.302917985Z", # Monday at 11:38
2023-02-17 07:20:15 +00:00
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
"currentUnits": "10",
"realizedPL": "0.0000",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008",
"marginUsed": "0.2966",
"takeProfitOrder": {"price": "1.07934"},
"stopLossOrder": {"price": "1.05276"},
"trailingStopLossOrder": None,
2023-02-17 07:20:15 +00:00
"trailingStopValue": None,
"side": "long",
},
{
"id": "20084",
2023-02-17 07:20:15 +00:00
"symbol": "EUR_USD",
"price": "1.06331",
"openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:39
2023-02-17 07:20:15 +00:00
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
"currentUnits": "10",
"realizedPL": "0.0000",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008",
"marginUsed": "0.2966",
"takeProfitOrder": {"price": "1.07934"},
"stopLossOrder": {"price": "1.05276"},
"trailingStopLossOrder": None,
2023-02-17 07:20:15 +00:00
"trailingStopValue": None,
"side": "long",
2023-02-17 17:05:52 +00:00
},
2023-02-17 07:20:15 +00:00
]
# Run parse_time on all items in trades
for trade in self.trades:
trade["openTime"] = parse_time(trade)
self.balance = 100000
self.balance_usd = 120000
2023-02-17 07:20:15 +00:00
self.ams.get_trades = self.fake_get_trades
self.ams.get_balance = self.fake_get_balance
# self.ams.trades = self.trades
2023-02-17 17:05:52 +00:00
def get_ids(self, trades):
return [trade["id"] for trade in trades]
def add_trade(self, id, symbol, side, open_time):
trade = {
"id": id,
"symbol": symbol,
"price": "1.06331",
"openTime": open_time,
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
"currentUnits": "10",
"realizedPL": "0.0000",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008",
"marginUsed": "0.2966",
"takeProfitOrder": {"price": "1.07934"},
"stopLossOrder": {"price": "1.05276"},
"trailingStopLossOrder": None,
"trailingStopValue": None,
"side": side,
}
trade["openTime"] = parse_time(trade)
self.trades.append(trade)
def amend_tp_sl_flip_side(self):
"""
Amend the take profit and stop loss orders to be the opposite side.
This lets the protection tests pass, so we can only test one violation
per test.
"""
for trade in self.trades:
if trade["side"] == "short":
trade["stopLossOrder"]["price"] = "1.07386"
trade["takeProfitOrder"]["price"] = "1.04728"
2023-02-17 07:20:15 +00:00
def fake_get_trades(self):
self.ams.trades = self.trades
return self.trades
def fake_get_balance(self, return_usd=None):
if return_usd:
return self.balance_usd
return self.balance
2023-02-17 07:20:15 +00:00
def fake_get_currencies(self, symbols):
pass
def test_get_trades(self):
trades = self.ams.get_trades()
self.assertEqual(trades, self.trades)
def test_get_balance(self):
balance = self.ams.get_balance()
self.assertEqual(balance, self.balance)
2023-02-17 07:20:15 +00:00
2023-02-17 17:05:52 +00:00
def check_violation(
self, violation, calls, expected_action, expected_trades, expected_args=None
):
2023-02-17 07:20:15 +00:00
"""
Check that the violation was called with the expected action and trades.
Matches the first argument of the call to the violation name.
:param: violation: type of the violation to check against
:param: calls: list of calls to the violation
:param: expected_action: expected action to be called, close, notify, etc.
:param: expected_trades: list of expected trades to be passed to the violation
2023-02-17 17:05:52 +00:00
:param: expected_args: optional, expected args to be passed to the violation
2023-02-17 07:20:15 +00:00
"""
self.assertEqual(len(calls), len(expected_trades))
2023-02-17 07:20:15 +00:00
calls = list(calls)
violation_calls = []
for call in calls:
if call[0][0] == violation:
violation_calls.append(call)
self.assertEqual(len(violation_calls), len(expected_trades))
if all(expected_trades):
# expected_trades = convert_trades(expected_trades)
expected_trades = self.get_ids(expected_trades)
2023-02-17 07:20:15 +00:00
for call in violation_calls:
# Ensure the correct action has been called, like close
self.assertEqual(call[0][1], expected_action)
# Ensure the correct trade has been passed to the violation
self.assertIn(call[0][2], expected_trades)
2023-02-17 17:05:52 +00:00
if expected_args:
_, kwargs = call
self.assertEqual(kwargs, expected_args)
# self.assertEqual(call[0][3], expected_args)
2023-02-17 07:20:15 +00:00
def test_run_checks(self):
2023-02-17 07:20:15 +00:00
self.ams.run_checks()
self.assertEqual(len(self.ams.actions), 0)
2023-02-17 07:20:15 +00:00
def test_trading_time_violated(self):
2023-02-17 17:05:52 +00:00
self.trades[0]["openTime"] = "2023-02-17T11:38:06.302917Z" # Friday
2023-02-17 07:20:15 +00:00
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{"close": [{"id": "20083", "check": "trading_time", "extra": {}}]},
2023-02-17 17:05:52 +00:00
)
2023-02-17 07:20:15 +00:00
def create_hook_signal(self):
hook = Hook.objects.create(
user=self.user,
name="Test Hook",
)
signal = Signal.objects.create(
user=self.user,
name="Test Signal",
hook=hook,
type="trend",
)
return signal
def test_trends_violated(self):
2023-02-17 07:20:15 +00:00
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.amend_tp_sl_flip_side()
2023-02-17 07:20:15 +00:00
self.strategy.save()
2023-02-17 07:20:15 +00:00
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{
"close": [
{"id": "20084", "check": "trends", "extra": {}},
{"id": "20083", "check": "trends", "extra": {}},
]
},
2023-02-17 17:05:52 +00:00
)
2023-02-17 07:20:15 +00:00
def test_trends_violated_none(self):
2023-02-17 07:20:15 +00:00
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "buy"}
self.strategy.save()
self.ams.run_checks()
self.assertEqual(self.ams.actions, {})
2023-02-17 07:20:15 +00:00
# Mock crossfilter here since we want to allow this conflict in order to test that
# trends only close trades that are in the wrong direction
@patch(
"core.trading.active_management.ActiveManagement.check_crossfilter",
return_value=None,
)
def test_trends_violated_partial(self, check_crossfilter):
2023-02-17 07:20:15 +00:00
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.strategy.save()
2023-02-17 17:05:52 +00:00
2023-02-17 07:20:15 +00:00
# Change the side of the first trade to match the trends
self.trades[0]["side"] = "short"
self.amend_tp_sl_flip_side()
2023-02-17 07:20:15 +00:00
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{"close": [{"id": "20084", "check": "trends", "extra": {}}]},
2023-02-17 17:05:52 +00:00
)
2023-02-17 07:20:15 +00:00
def test_position_size_violated(self):
2023-02-17 07:20:15 +00:00
self.trades[0]["currentUnits"] = "100000"
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{
"close": [
{
"id": "20083",
"check": "position_size",
"extra": {"size": D("500.000")},
}
]
},
)
2023-02-17 07:20:15 +00:00
2023-02-28 07:20:12 +00:00
def test_position_size_violated_increase_only(self):
pass
def test_position_size_violated_decrease_only(self):
pass
def test_position_size_violated_increase_decrease(self):
pass
def test_protection_violated(self):
self.trades[0]["takeProfitOrder"] = {"price": "0.0001"}
self.trades[0]["stopLossOrder"] = {"price": "0.0001"}
self.ams.run_checks()
expected_args = {
"take_profit_price": D("1.07934"),
"stop_loss_price": D("1.05276"),
}
self.assertEqual(
self.ams.actions,
{
"close": [
{
"id": "20083",
"check": "protection",
"extra": {
**expected_args,
},
}
]
},
2023-02-17 17:05:52 +00:00
)
2023-02-17 07:20:15 +00:00
def test_protection_violated_absent(self):
self.trades[0]["takeProfitOrder"] = None
self.trades[0]["stopLossOrder"] = None
self.ams.run_checks()
2023-02-17 07:20:15 +00:00
expected_args = {
"take_profit_price": D("1.07934"),
"stop_loss_price": D("1.05276"),
}
self.assertEqual(
self.ams.actions,
{
"close": [
{
"id": "20083",
"check": "protection",
"extra": {
**expected_args,
},
}
]
},
)
2023-02-17 07:20:15 +00:00
def test_protection_violated_absent_not_required(self):
self.strategy.order_settings.take_profit_percent = 0
self.strategy.order_settings.stop_loss_percent = 0
self.strategy.order_settings.save()
self.trades[0]["takeProfitOrder"] = None
self.trades[0]["stopLossOrder"] = None
self.ams.run_checks()
self.assertEqual(self.ams.actions, {})
def test_asset_groups_violated(self):
asset_group = AssetGroup.objects.create(
user=self.user,
name="Test Asset Group",
)
AssetRule.objects.create(
user=self.user,
asset="USD",
group=asset_group,
status=2, # Bullish
)
self.strategy.asset_group = asset_group
self.strategy.save()
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{
"close": [
{"id": "20084", "check": "asset_group", "extra": {}},
{"id": "20083", "check": "asset_group", "extra": {}},
]
},
)
def test_asset_groups_violated_invert(self):
self.trades[0]["side"] = "short"
self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
asset_group = AssetGroup.objects.create(
user=self.user,
name="Test Asset Group",
)
AssetRule.objects.create(
user=self.user,
asset="USD",
group=asset_group,
status=3, # Bullish
)
self.strategy.asset_group = asset_group
self.strategy.save()
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{
"close": [
{"id": "20084", "check": "asset_group", "extra": {}},
{"id": "20083", "check": "asset_group", "extra": {}},
]
},
)
def test_crossfilter_violated_side(self):
self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{"close": [{"id": "20084", "check": "crossfilter", "extra": {}}]},
)
def test_crossfilter_violated_side_multiple(self):
self.add_trade(
"20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z"
) # 2:
self.add_trade("20086", "EUR_USD", "short", "2023-02-13T13:39:07.302917985Z")
self.add_trade("20087", "EUR_USD", "short", "2023-02-13T14:39:06.302917985Z")
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{
"close": [
{"id": "20087", "check": "crossfilter", "extra": {}},
{"id": "20086", "check": "crossfilter", "extra": {}},
{"id": "20085", "check": "crossfilter", "extra": {}},
]
},
)
def test_crossfilter_violated_symbol(self):
# Change symbol to conflict with long on EUR_USD
self.trades[1]["symbol"] = "USD_EUR"
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{"close": [{"id": "20084", "check": "crossfilter", "extra": {}}]},
)
def test_crossfilter_violated_symbol_multiple(self):
self.add_trade(
"20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z"
) # 2:
self.add_trade("20086", "USD_EUR", "long", "2023-02-13T13:39:06.302917985Z")
self.add_trade("20087", "USD_EUR", "long", "2023-02-13T14:39:06.302917985Z")
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{
"close": [
{"id": "20087", "check": "crossfilter", "extra": {}},
{"id": "20086", "check": "crossfilter", "extra": {}},
{"id": "20085", "check": "crossfilter", "extra": {}},
]
},
)
def test_max_open_trades_violated(self):
for x in range(9):
self.add_trade(
str(x),
f"EUR_USD{x}", # Vary symbol to prevent max open trades per symbol
"long",
f"2023-02-13T12:39:1{x}.302917985Z",
)
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{"close": [{"id": "8", "check": "max_open_trades", "extra": {}}]},
)
2023-02-17 07:20:15 +00:00
def test_max_open_trades_per_symbol_violated(self):
for x in range(4):
2023-02-17 22:23:12 +00:00
self.add_trade(
str(x),
"EUR_USD",
"long",
f"2023-02-13T12:39:1{x}.302917985Z",
)
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{
"close": [
{"id": "3", "check": "max_open_trades_per_symbol", "extra": {}}
]
},
2023-02-17 22:23:12 +00:00
)
2023-02-17 07:20:15 +00:00
# Mock position size as we have no way of checking the balance at the start of the
# trade.
# TODO: Fix this when we have a way of checking the balance at the start of the
# trade.
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
def test_max_loss_violated(self, check_position_size):
self.balance = D("1")
self.balance_usd = D("0.69")
self.trades = []
self.ams.run_checks()
2023-02-17 07:20:15 +00:00
self.assertEqual(
self.ams.actions,
{"close": [{"id": None, "check": "max_loss", "extra": {}}]},
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_protection",
return_value=None,
)
def test_max_risk_violated(self, check_protection, check_position_size):
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{"close": [{"id": "20085", "check": "max_risk", "extra": {}}]},
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_protection",
return_value=None,
)
def test_max_risk_violated_multiple(self, check_protection, check_position_size):
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.add_trade(
"20086",
"EUR_USD",
"long",
"2023-02-13T15:45:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.trades[3]["stopLossOrder"]["price"] = "0.001"
self.trades[3]["currentUnits"] = "13000"
self.ams.run_checks()
self.assertEqual(
self.ams.actions,
{
"close": [
{"id": "20086", "check": "max_risk", "extra": {}},
{"id": "20085", "check": "max_risk", "extra": {}},
]
},
)
@patch(
"core.trading.active_management.ActiveManagement.add_action",
)
def test_handle_violation(self, add_action):
self.ams.handle_violation("max_loss", "close", "trade_id")
self.ams.handle_violation("position_size", "adjust", "trade_id2", size=1000)
self.assertEqual(add_action.call_count, 2)
add_action.assert_any_call("close", "max_loss", "trade_id")
add_action.assert_any_call("adjust", "position_size", "trade_id2", size=1000)
def test_add_action(self):
protection_args = {
"take_profit_price": D("1.07934"),
"stop_loss_price": D("1.05276"),
}
self.ams.add_action("close", "trading_time", "fake_trade_id")
self.ams.add_action("close", "protection", "fake_trade_id2", **protection_args)
self.assertEqual(
self.ams.actions,
{
"close": [
{"id": "fake_trade_id", "check": "trading_time", "extra": {}},
{
"id": "fake_trade_id2",
"check": "protection",
"extra": protection_args,
},
],
},
)
def test_reduce_actions(self):
"""
Test that closing actions precede adjusting actions.
"""
self.ams.add_action("close", "trading_time", "fake_trade_id")
self.ams.add_action("adjust", "position_size", "fake_trade_id", size=1000)
self.assertEqual(len(self.ams.actions["close"]), 1)
self.assertEqual(len(self.ams.actions["adjust"]), 1)
self.ams.reduce_actions()
self.assertEqual(len(self.ams.actions["close"]), 1)
self.assertEqual(len(self.ams.actions["adjust"]), 0)
@patch("core.trading.active_management.ActiveManagement.bulk_close_trades")
@patch("core.trading.active_management.ActiveManagement.bulk_notify")
def test_run_actions(self, bulk_notify, bulk_close_trades):
protection_args = {
"take_profit_price": D("1.07934"),
"stop_loss_price": D("1.05276"),
}
self.ams.add_action("close", "trading_time", "fake_trade_id")
self.ams.add_action("close", "protection", "fake_trade_id2", **protection_args)
self.ams.run_actions()
expected_action_cast = [
{"id": "fake_trade_id", "check": "trading_time", "extra": {}},
{
"id": "fake_trade_id2",
"check": "protection",
"extra": {
**protection_args,
},
},
]
bulk_notify.assert_called_once_with(
"close",
expected_action_cast,
)
bulk_close_trades.assert_called_once_with(
["fake_trade_id", "fake_trade_id2"],
)
@patch("core.trading.active_management.ActiveManagement.bulk_close_trades")
@patch("core.trading.active_management.ActiveManagement.bulk_notify")
def test_run_actions_notify_only(self, bulk_notify, bulk_close_trades):
protection_args = {
"take_profit_price": D("1.07934"),
"stop_loss_price": D("1.05276"),
}
self.ams.add_action("notify", "trading_time", "fake_trade_id")
self.ams.add_action("notify", "protection", "fake_trade_id2", **protection_args)
self.ams.run_actions()
expected_action_cast = [
{"id": "fake_trade_id", "check": "trading_time", "extra": {}},
{
"id": "fake_trade_id2",
"check": "protection",
"extra": {
**protection_args,
},
},
]
bulk_notify.assert_called_once_with(
"notify",
expected_action_cast,
)
bulk_close_trades.assert_not_called()
@patch("core.trading.active_management.ActiveManagement.bulk_close_trades")
@patch("core.trading.active_management.ActiveManagement.bulk_adjust")
@patch("core.trading.active_management.ActiveManagement.bulk_notify")
def test_run_actions_notify_adjust_only(
self, bulk_notify, bulk_adjust, bulk_close_trades
):
protection_args = {
"take_profit_price": D("1.07934"),
"stop_loss_price": D("1.05276"),
}
self.ams.add_action("adjust", "position_size", "fake_trade_id", size=1000)
self.ams.add_action("adjust", "protection", "fake_trade_id2", **protection_args)
self.ams.run_actions()
expected_action_cast = [
{"id": "fake_trade_id", "check": "position_size", "extra": {"size": 1000}},
{
"id": "fake_trade_id2",
"check": "protection",
"extra": {
**protection_args,
},
},
]
bulk_notify.assert_called_once_with(
"adjust",
expected_action_cast,
)
bulk_adjust.assert_called_once_with(expected_action_cast)
bulk_close_trades.assert_not_called()
@patch("core.trading.active_management.ActiveManagement.adjust_position_size")
@patch("core.trading.active_management.ActiveManagement.adjust_protection")
def test_bulk_adjust(self, adjust_protection, adjust_position_size):
expected_protection = {"take_profit_price": 1.10, "stop_loss_price": 1.05}
cast_list = [
{"id": "id1", "check": "position_size", "extra": {"size": 1000}},
{"id": "id2", "check": "protection", "extra": expected_protection},
]
self.ams.bulk_adjust(cast_list)
adjust_position_size.assert_called_once_with("id1", 1000)
adjust_protection.assert_called_once_with("id2", expected_protection)
@patch("core.trading.active_management.ActiveManagement.close_trade")
def test_bulk_close_trades(self, close_trade):
self.ams.bulk_close_trades(["id1", "id2"])
self.assertEqual(close_trade.call_count, 2)
close_trade.assert_any_call("id1")
close_trade.assert_any_call("id2")
@patch("core.trading.active_management.sendmsg")
def test_bulk_notify_plain(self, sendmsg):
self.ams.bulk_notify("close", [{"id": "id1", "check": "check1", "extra": {}}])
sendmsg.assert_called_once_with(
self.user,
"ACTION: 'close' on trade ID 'id1'\nVIOLATION: 'check1'\n=========\n",
title="AMS: close",
)
@patch("core.trading.active_management.sendmsg")
def test_bulk_notify_extra(self, sendmsg):
self.ams.bulk_notify(
"close", [{"id": "id1", "check": "check1", "extra": {"field1": "value1"}}]
)
sendmsg.assert_called_once_with(
self.user,
(
"ACTION: 'close' on trade ID 'id1'\nVIOLATION: 'check1'\nEXTRA:"
" field1: value1\n=========\n"
),
title="AMS: close",
)
@patch("core.trading.active_management.ActiveManagement.check_protection")
def test_position_size_reduced(self, check_protection):
self.active_management_policy.when_position_size_violated = "adjust"
self.active_management_policy.save()
self.trades[0]["currentUnits"] = "100000"
self.ams.run_checks()
call_args = check_protection.call_args[0][0]
self.assertEqual(call_args["amount"], 500)
self.assertEqual(call_args["units"], 500)
@patch("core.trading.active_management.ActiveManagement.check_asset_groups")
def test_protection_added(self, check_asset_groups):
self.active_management_policy.when_protection_violated = "adjust"
self.active_management_policy.save()
self.trades[0]["takeProfitOrder"] = None
self.trades[0]["stopLossOrder"] = None
self.ams.run_checks()
call_args = check_asset_groups.call_args[0][0]
self.assertEqual(call_args["take_profit_price"], D("1.07934"))
self.assertEqual(call_args["stop_loss_price"], D("1.05276"))
@patch("core.trading.active_management.ActiveManagement.check_asset_groups")
def test_protection_amended(self, check_asset_groups):
self.active_management_policy.when_protection_violated = "adjust"
self.active_management_policy.save()
self.trades[0]["takeProfitOrder"] = {"price": "0.0001"}
self.trades[0]["stopLossOrder"] = {"price": "0.0001"}
self.ams.run_checks()
call_args = check_asset_groups.call_args[0][0]
self.assertEqual(call_args["take_profit_price"], D("1.07934"))
self.assertEqual(call_args["stop_loss_price"], D("1.05276"))
@patch("core.trading.active_management.ActiveManagement.close_trade")
def test_max_risk_not_violated_after_adjusting_protection(self, close_trade):
"""
Ensure the max risk check is not violated after adjusting the protection.
"""
self.active_management_policy.when_protection_violated = "adjust"
self.active_management_policy.save()
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.ams.run_checks()
self.assertEqual(close_trade.call_count, 0)
@patch("core.trading.active_management.ActiveManagement.close_trade")
def test_max_risk_not_violated_after_adjusting_position_size(self, close_trade):
"""
Ensure the max risk check is not violated after adjusting the position size.
"""
self.active_management_policy.when_position_size_violated = "adjust"
self.active_management_policy.save()
self.trades[0]["currentUnits"] = "100000"
self.ams.run_checks()
self.assertEqual(close_trade.call_count, 0)
@patch("core.trading.active_management.ActiveManagement.check_crossfilter")
@patch("core.trading.active_management.ActiveManagement.check_trends")
def test_trading_time_mutation(self, check_trends, check_crossfilter):
self.trades[0]["openTime"] = "2023-02-17T11:38:06.302917Z" # Friday
self.ams.run_checks()
self.assertEqual(check_trends.call_count, 1)
call_args = check_trends.call_args[0][0]
self.assertEqual(
call_args["id"], "20084"
) # Only run trends on the second trade
crossfilter_call_args = check_crossfilter.call_args[0][0]
self.assertEqual(len(crossfilter_call_args), 1)
self.assertEqual(
crossfilter_call_args[0]["id"], "20084"
) # Same for crossfilter
@patch("core.trading.active_management.ActiveManagement.check_crossfilter")
@patch("core.trading.active_management.ActiveManagement.check_position_size")
def test_check_trends_mutation(self, check_position_size, check_crossfilter):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.amend_tp_sl_flip_side()
self.strategy.save()
self.ams.run_checks()
self.assertEqual(check_position_size.call_count, 0)
crossfilter_call_args = check_crossfilter.call_args[0][0]
self.assertEqual(len(crossfilter_call_args), 0)
@patch("core.trading.active_management.ActiveManagement.check_crossfilter")
@patch("core.trading.active_management.ActiveManagement.check_protection")
def test_check_position_size_mutation(self, check_protection, check_crossfilter):
self.trades[0]["currentUnits"] = "100000"
self.ams.run_checks()
self.assertEqual(check_protection.call_count, 1)
call_args = check_protection.call_args[0][0]
self.assertEqual(
call_args["id"], "20084"
) # Only run protection on the second trade
crossfilter_call_args = check_crossfilter.call_args[0][0]
self.assertEqual(len(crossfilter_call_args), 1)
self.assertEqual(
crossfilter_call_args[0]["id"], "20084"
) # Same for crossfilter
@patch("core.trading.active_management.ActiveManagement.check_crossfilter")
@patch("core.trading.active_management.ActiveManagement.check_protection")
def test_check_protection_mutation(self, check_protection, check_crossfilter):
self.trades[0]["currentUnits"] = "100000"
self.ams.run_checks()
self.assertEqual(check_protection.call_count, 1)
call_args = check_protection.call_args[0][0]
self.assertEqual(
call_args["id"], "20084"
) # Only run protection on the second trade
crossfilter_call_args = check_crossfilter.call_args[0][0]
self.assertEqual(len(crossfilter_call_args), 1)
self.assertEqual(
crossfilter_call_args[0]["id"], "20084"
) # Same for crossfilter
# This may look similar but check_crossfilter is called with the whole trade list.
# Check that the trade that is removed from the list is not checked.
@patch("core.trading.active_management.ActiveManagement.check_crossfilter")
def test_check_asset_groups_mutation(self, check_crossfilter):
asset_group = AssetGroup.objects.create(
user=self.user,
name="Test Asset Group",
)
AssetRule.objects.create(
user=self.user,
asset="USD",
group=asset_group,
status=2, # Bullish
)
self.strategy.asset_group = asset_group
self.strategy.save()
self.ams.run_checks()
check_crossfilter.assert_called_once_with([])
@patch("core.trading.active_management.ActiveManagement.check_max_open_trades")
def test_check_crossfilter_mutation(self, check_max_open_trades):
self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.assertEqual(check_max_open_trades.call_count, 1)
call_args = check_max_open_trades.call_args[0][0]
self.assertEqual(len(call_args), 1)
self.assertEqual(call_args[0]["id"], "20083")
@patch( # When the string is just too damn long
(
"core.trading.active_management.ActiveManagement."
"check_max_open_trades_per_symbol"
)
)
def test_check_max_open_trades_mutation(self, check_max_open_trades_per_symbol):
for x in range(9):
self.add_trade(
str(x),
f"EUR_USD{x}", # Vary symbol to prevent max open trades per symbol
"long",
f"2023-02-13T12:39:1{x}.302917985Z",
)
self.ams.run_checks()
self.assertEqual(check_max_open_trades_per_symbol.call_count, 1)
call_args = check_max_open_trades_per_symbol.call_args[0][0]
called_with_ids = [x["id"] for x in call_args]
self.assertListEqual(
called_with_ids, ["20083", "20084", "0", "1", "2", "3", "4", "5", "6", "7"]
)
@patch("core.trading.active_management.ActiveManagement.check_max_loss")
def test_check_max_open_trades_per_symbol_mutation(self, check_max_loss):
for x in range(4):
self.add_trade(
str(x),
"EUR_USD",
"long",
f"2023-02-13T12:39:1{x}.302917985Z",
)
self.ams.run_checks()
self.assertEqual(check_max_loss.call_count, 1)
call_args = check_max_loss.call_args[0][0]
called_with_ids = [x["id"] for x in call_args]
self.assertListEqual(called_with_ids, ["20083", "20084", "0", "1", "2"])
@patch("core.trading.active_management.ActiveManagement.check_max_risk")
def test_check_max_loss_mutation(self, check_max_risk):
self.balance = D("1")
self.balance_usd = D("0.69")
self.trades = []
self.ams.run_checks()
self.assertEqual(check_max_risk.call_count, 1)
call_args = check_max_risk.call_args[0][0]
called_with_ids = [x["id"] for x in call_args]
self.assertListEqual(called_with_ids, [])
def test_check_max_risk_mutation(self):
"""
This cannot be tested as there are no hooks after it.
"""
pass