Finish implementing active management hooks

This commit is contained in:
2023-02-18 11:54:30 +00:00
parent 3e35214e82
commit 466b17400f
9 changed files with 271 additions and 123 deletions

View File

@@ -1,17 +1,9 @@
from datetime import time
from decimal import Decimal as D
from os import getenv
from unittest.mock import Mock, patch
from core.models import (
Account,
Hook,
OrderSettings,
RiskModel,
Signal,
Strategy,
TradingTime,
User,
)
from core.models import Account, OrderSettings, RiskModel, Strategy, TradingTime, User
# Create patch mixin to mock out the Elastic client
@@ -48,7 +40,7 @@ class SymbolPriceMock:
cls.patcher = patch("core.exchanges.common.get_symbol_price")
patcher = cls.patcher.start()
patcher.return_value = 1
patcher.return_value = D(1)
@classmethod
def tearDownClass(cls):
@@ -141,7 +133,7 @@ class StrategyMixin:
max_loss_percent=50,
max_risk_percent=10,
max_open_trades=10,
max_open_trades_per_symbol=2,
max_open_trades_per_symbol=5,
)
self.strategy = Strategy.objects.create(

View File

@@ -1,5 +1,5 @@
from decimal import Decimal as D
from unittest.mock import Mock, patch
from unittest.mock import patch
from django.test import TestCase
@@ -28,6 +28,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
name="Test Account",
exchange="fake",
currency="USD",
initial_balance=100000,
)
self.account.supported_symbols = ["EUR_USD", "EUR_XXX", "USD_EUR", "XXX_EUR"]
self.account.save()
@@ -75,7 +76,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
"id": "20084",
"symbol": "EUR_USD",
"price": "1.06331",
"openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:38
"openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:39
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
@@ -95,6 +96,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
# Run parse_time on all items in trades
for trade in self.trades:
trade["openTime"] = parse_time(trade)
self.balance = 100000
self.balance_usd = 120000
self.ams.get_trades = self.fake_get_trades
self.ams.get_balance = self.fake_get_balance
# self.ams.trades = self.trades
@@ -123,12 +127,25 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
trade["openTime"] = parse_time(trade)
self.trades.append(trade)
def amend_tp_sl_flip_side(self):
"""
Amend the take profit and stop loss orders to be the opposite side.
This lets the protection tests pass, so we can only test one violation
per test.
"""
for trade in self.trades:
if trade["side"] == "short":
trade["stopLossOrder"]["price"] = "1.07386"
trade["takeProfitOrder"]["price"] = "1.04728"
def fake_get_trades(self):
self.ams.trades = self.trades
return self.trades
def fake_get_balance(self):
return 10000
def fake_get_balance(self, return_usd=None):
if return_usd:
return self.balance_usd
return self.balance
def fake_get_currencies(self, symbols):
pass
@@ -139,7 +156,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
def test_get_balance(self):
balance = self.ams.get_balance()
self.assertEqual(balance, 10000)
self.assertEqual(balance, self.balance)
def check_violation(
self, violation, calls, expected_action, expected_trades, expected_args=None
@@ -153,6 +170,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
:param: expected_trades: list of expected trades to be passed to the violation
:param: expected_args: optional, expected args to be passed to the violation
"""
self.assertEqual(len(calls), len(expected_trades))
calls = list(calls)
violation_calls = []
for call in calls:
@@ -160,11 +178,18 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
violation_calls.append(call)
self.assertEqual(len(violation_calls), len(expected_trades))
expected_trades = convert_trades(expected_trades)
if all(expected_trades):
expected_trades = convert_trades(expected_trades)
for call in violation_calls:
# Ensure the correct action has been called, like close
self.assertEqual(call[0][1], expected_action)
# Ensure the correct trade has been passed to the violation
trade = call[0][2]
if trade:
for field in list(trade.keys()):
if "_usd" in field:
if field in trade.keys():
del trade[field]
self.assertIn(call[0][2], expected_trades)
if expected_args:
self.assertEqual(call[0][3], expected_args)
@@ -172,7 +197,6 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_run_checks(self, handle_violation):
self.ams.run_checks()
print("handle_violation.call_count", handle_violation.call_args_list)
self.assertEqual(handle_violation.call_count, 0)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
@@ -201,7 +225,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.amend_tp_sl_flip_side()
self.strategy.save()
self.ams.run_checks()
self.check_violation(
"trends", handle_violation.call_args_list, "close", self.trades
@@ -216,8 +242,14 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
self.ams.run_checks()
self.check_violation("trends", handle_violation.call_args_list, "close", [])
# Mock crossfilter here since we want to allow this conflict in order to test that
# trends only close trades that are in the wrong direction
@patch(
"core.trading.active_management.ActiveManagement.check_crossfilter",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated_partial(self, handle_violation):
def test_trends_violated_partial(self, handle_violation, check_crossfilter):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
@@ -225,6 +257,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
# Change the side of the first trade to match the trends
self.trades[0]["side"] = "short"
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.check_violation(
@@ -241,7 +274,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
handle_violation.call_args_list,
"close",
[self.trades[0]],
{"size": 50},
{"size": 500},
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
@@ -270,7 +303,6 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
self.trades[0]["takeProfitOrder"] = None
self.trades[0]["stopLossOrder"] = None
self.ams.run_checks()
print("CALLS", handle_violation.call_args_list)
self.assertEqual(handle_violation.call_count, 0)
@@ -301,6 +333,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
def test_asset_groups_violated_invert(self, handle_violation):
self.trades[0]["side"] = "short"
self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
asset_group = AssetGroup.objects.create(
user=self.user,
name="Test Asset Group",
@@ -325,6 +358,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_side(self, handle_violation):
self.trades[1]["side"] = "short"
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.check_violation(
@@ -336,16 +370,19 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_side_multiple(self, handle_violation):
self.add_trade("20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z")
self.add_trade("20086", "EUR_USD", "short", "2023-02-14T12:39:06.302917985Z")
self.add_trade("20087", "EUR_USD", "short", "2023-02-10T12:39:06.302917985Z")
self.add_trade(
"20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z"
) # 2:
self.add_trade("20086", "EUR_USD", "short", "2023-02-13T13:39:07.302917985Z")
self.add_trade("20087", "EUR_USD", "short", "2023-02-13T14:39:06.302917985Z")
self.amend_tp_sl_flip_side()
self.ams.run_checks()
self.check_violation(
"crossfilter",
handle_violation.call_args_list,
"close",
self.trades[0:4], # Only close newer trades
self.trades[2:], # Only close newer trades
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
@@ -363,16 +400,18 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_crossfilter_violated_symbol_multiple(self, handle_violation):
self.add_trade("20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z")
self.add_trade("20086", "USD_EUR", "long", "2023-02-14T12:39:06.302917985Z")
self.add_trade("20087", "USD_EUR", "long", "2023-02-10T12:39:06.302917985Z")
self.add_trade(
"20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z"
) # 2:
self.add_trade("20086", "USD_EUR", "long", "2023-02-13T13:39:06.302917985Z")
self.add_trade("20087", "USD_EUR", "long", "2023-02-13T14:39:06.302917985Z")
self.ams.run_checks()
self.check_violation(
"crossfilter",
handle_violation.call_args_list,
"close",
self.trades[0:4], # Only close newer trades
self.trades[2:], # Only close newer trades
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
@@ -380,7 +419,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
for x in range(9):
self.add_trade(
str(x),
"EUR_USD",
f"EUR_USD{x}", # Vary symbol to prevent max open trades per symbol
"long",
f"2023-02-13T12:39:1{x}.302917985Z",
)
@@ -395,7 +434,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_open_trades_per_symbol_violated(self, handle_violation):
for x in range(2):
for x in range(4):
self.add_trade(
str(x),
"EUR_USD",
@@ -408,11 +447,102 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
"max_open_trades_per_symbol",
handle_violation.call_args_list,
"close",
self.trades[2:], # Only close newer trades
self.trades[5:], # Only close newer trades
)
def test_max_loss_violated(self):
pass
# Mock position size as we have no way of checking the balance at the start of the
# trade.
# TODO: Fix this when we have a way of checking the balance at the start of the
# trade.
# Max risk is also mocked as this puts us over the limit, due to the low account
# size.
@patch(
"core.trading.active_management.ActiveManagement.check_max_risk",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_loss_violated(
self, handle_violation, check_position_size, check_max_risk
):
self.balance = D("1")
self.balance_usd = D("0.69")
self.ams.run_checks()
def test_max_risk_violated(self):
pass
self.check_violation(
"max_loss",
handle_violation.call_args_list,
"close",
[None],
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_protection",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_risk_violated(
self, handle_violation, check_protection, check_position_size
):
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.ams.run_checks()
self.check_violation(
"max_risk",
handle_violation.call_args_list,
"close",
[self.trades[2]],
)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_protection",
return_value=None,
)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_risk_violated_multiple(
self, handle_violation, check_protection, check_position_size
):
self.add_trade(
"20085",
"EUR_USD",
"long",
"2023-02-13T15:39:19.302917985Z",
)
self.add_trade(
"20086",
"EUR_USD",
"long",
"2023-02-13T15:45:19.302917985Z",
)
self.trades[2]["stopLossOrder"]["price"] = "0.001"
self.trades[2]["currentUnits"] = "13000"
self.trades[3]["stopLossOrder"]["price"] = "0.001"
self.trades[3]["currentUnits"] = "13000"
self.ams.run_checks()
self.check_violation(
"max_risk",
handle_violation.call_args_list,
"close",
[self.trades[2], self.trades[3]],
)

View File

@@ -1,18 +1,7 @@
from datetime import time
import freezegun
from django.test import TestCase
from core.models import (
Account,
Hook,
OrderSettings,
RiskModel,
Signal,
Strategy,
TradingTime,
User,
)
from core.models import Account, Hook, Signal, User
from core.tests.helpers import StrategyMixin
from core.trading import checks

View File

@@ -1,9 +1,8 @@
from django.test import TestCase
import core.trading.market # noqa # to avoid messy circular import
from core.exchanges import convert
from core.models import RiskModel, User
import core.trading.market # to avoid messy circular import
from core.trading import risk