Simplify active management by only specifying trade IDs for violations

This commit is contained in:
Mark Veidemanis 2023-02-18 14:36:58 +00:00
parent 466b17400f
commit 15a8bec105
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
2 changed files with 51 additions and 26 deletions

View File

@ -3,7 +3,6 @@ from unittest.mock import patch
from django.test import TestCase from django.test import TestCase
from core.exchanges.convert import convert_trades
from core.lib.schemas.oanda_s import parse_time from core.lib.schemas.oanda_s import parse_time
from core.models import ( from core.models import (
Account, Account,
@ -103,6 +102,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
self.ams.get_balance = self.fake_get_balance self.ams.get_balance = self.fake_get_balance
# self.ams.trades = self.trades # self.ams.trades = self.trades
def get_ids(self, trades):
return [trade["id"] for trade in trades]
def add_trade(self, id, symbol, side, open_time): def add_trade(self, id, symbol, side, open_time):
trade = { trade = {
"id": id, "id": id,
@ -179,17 +181,12 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
self.assertEqual(len(violation_calls), len(expected_trades)) self.assertEqual(len(violation_calls), len(expected_trades))
if all(expected_trades): if all(expected_trades):
expected_trades = convert_trades(expected_trades) # expected_trades = convert_trades(expected_trades)
expected_trades = self.get_ids(expected_trades)
for call in violation_calls: for call in violation_calls:
# Ensure the correct action has been called, like close # Ensure the correct action has been called, like close
self.assertEqual(call[0][1], expected_action) self.assertEqual(call[0][1], expected_action)
# Ensure the correct trade has been passed to the violation # Ensure the correct trade has been passed to the violation
trade = call[0][2]
if trade:
for field in list(trade.keys()):
if "_usd" in field:
if field in trade.keys():
del trade[field]
self.assertIn(call[0][2], expected_trades) self.assertIn(call[0][2], expected_trades)
if expected_args: if expected_args:
self.assertEqual(call[0][3], expected_args) self.assertEqual(call[0][3], expected_args)
@ -454,22 +451,18 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
# trade. # trade.
# TODO: Fix this when we have a way of checking the balance at the start of the # TODO: Fix this when we have a way of checking the balance at the start of the
# trade. # trade.
# Max risk is also mocked as this puts us over the limit, due to the low account
# size.
@patch(
"core.trading.active_management.ActiveManagement.check_max_risk",
return_value=None,
)
@patch( @patch(
"core.trading.active_management.ActiveManagement.check_position_size", "core.trading.active_management.ActiveManagement.check_position_size",
return_value=None, return_value=None,
) )
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_max_loss_violated( def test_max_loss_violated(self, handle_violation, check_position_size):
self, handle_violation, check_position_size, check_max_risk
):
self.balance = D("1") self.balance = D("1")
self.balance_usd = D("0.69") self.balance_usd = D("0.69")
self.trades = []
self.ams.run_checks() self.ams.run_checks()
self.check_violation( self.check_violation(
@ -546,3 +539,15 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
"close", "close",
[self.trades[2], self.trades[3]], [self.trades[2], self.trades[3]],
) )
def test_max_risk_not_violated_after_adjusting_protection(self):
"""
Ensure the max risk check is not violated after adjusting the protection.
"""
pass
def test_max_risk_not_violated_after_adjusting_position_size(self):
"""
Ensure the max risk check is not violated after adjusting the position size.
"""
pass

View File

@ -44,6 +44,21 @@ class ActiveManagement(object):
def handle_violation(self, check_type, action, trade, **kwargs): def handle_violation(self, check_type, action, trade, **kwargs):
print("VIOLATION", check_type, action, trade, kwargs) print("VIOLATION", check_type, action, trade, kwargs)
# TODO: close/notify for:
# - trading time
# - trends
# - position size
# - protection
# - asset groups
# - crossfilter
# - max open trades
# - max open trades per symbol
# - max loss
# - max risk
# TODO: adjust for:
# - position size
# - protection
def check_trading_time(self, trade): def check_trading_time(self, trade):
open_ts = trade["open_time"] open_ts = trade["open_time"]
@ -51,7 +66,7 @@ class ActiveManagement(object):
trading_time_pass = checks.within_trading_times(self.strategy, open_ts_as_date) trading_time_pass = checks.within_trading_times(self.strategy, open_ts_as_date)
if not trading_time_pass: if not trading_time_pass:
self.handle_violation( self.handle_violation(
"trading_time", self.policy.when_trading_time_violated, trade "trading_time", self.policy.when_trading_time_violated, trade["id"]
) )
def check_trends(self, trade): def check_trends(self, trade):
@ -59,7 +74,9 @@ class ActiveManagement(object):
symbol = trade["symbol"] symbol = trade["symbol"]
trends_pass = checks.within_trends(self.strategy, symbol, direction) trends_pass = checks.within_trends(self.strategy, symbol, direction)
if not trends_pass: if not trends_pass:
self.handle_violation("trends", self.policy.when_trends_violated, trade) self.handle_violation(
"trends", self.policy.when_trends_violated, trade["id"]
)
def check_position_size(self, trade): def check_position_size(self, trade):
""" """
@ -91,7 +108,7 @@ class ActiveManagement(object):
self.handle_violation( self.handle_violation(
"position_size", "position_size",
self.policy.when_position_size_violated, self.policy.when_position_size_violated,
trade, trade["id"],
{"size": expected_trade_size}, {"size": expected_trade_size},
) )
@ -152,7 +169,10 @@ class ActiveManagement(object):
if violations: if violations:
self.handle_violation( self.handle_violation(
"protection", self.policy.when_protection_violated, trade, violations "protection",
self.policy.when_protection_violated,
trade["id"],
violations
) )
def check_asset_groups(self, trade): def check_asset_groups(self, trade):
@ -165,7 +185,7 @@ class ActiveManagement(object):
) )
if not allowed: if not allowed:
self.handle_violation( self.handle_violation(
"asset_group", self.policy.when_asset_groups_violated, trade "asset_group", self.policy.when_asset_groups_violated, trade["id"]
) )
def get_sorted_trades_copy(self, trades, reverse=True): def get_sorted_trades_copy(self, trades, reverse=True):
@ -226,7 +246,7 @@ class ActiveManagement(object):
if close_trades: if close_trades:
for trade in close_trades: for trade in close_trades:
self.handle_violation( self.handle_violation(
"crossfilter", self.policy.when_crossfilter_violated, trade "crossfilter", self.policy.when_crossfilter_violated, trade["id"]
) )
def check_max_open_trades(self, trades): def check_max_open_trades(self, trades):
@ -241,7 +261,7 @@ class ActiveManagement(object):
self.handle_violation( self.handle_violation(
"max_open_trades", "max_open_trades",
self.policy.when_max_open_trades_violated, self.policy.when_max_open_trades_violated,
trade, trade["id"],
) )
def check_max_open_trades_per_symbol(self, trades): def check_max_open_trades_per_symbol(self, trades):
@ -267,7 +287,7 @@ class ActiveManagement(object):
self.handle_violation( self.handle_violation(
"max_open_trades_per_symbol", "max_open_trades_per_symbol",
self.policy.when_max_open_trades_violated, self.policy.when_max_open_trades_violated,
trade, trade["id"],
) )
def check_max_loss(self): def check_max_loss(self):
@ -312,7 +332,7 @@ class ActiveManagement(object):
if close_trades: if close_trades:
for trade in close_trades: for trade in close_trades:
self.handle_violation( self.handle_violation(
"max_risk", self.policy.when_max_risk_violated, trade "max_risk", self.policy.when_max_risk_violated, trade["id"]
) )
def run_checks(self): def run_checks(self):