Simplify active management by only specifying trade IDs for violations
This commit is contained in:
parent
466b17400f
commit
15a8bec105
|
@ -3,7 +3,6 @@ from unittest.mock import patch
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from core.exchanges.convert import convert_trades
|
|
||||||
from core.lib.schemas.oanda_s import parse_time
|
from core.lib.schemas.oanda_s import parse_time
|
||||||
from core.models import (
|
from core.models import (
|
||||||
Account,
|
Account,
|
||||||
|
@ -103,6 +102,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
|
||||||
self.ams.get_balance = self.fake_get_balance
|
self.ams.get_balance = self.fake_get_balance
|
||||||
# self.ams.trades = self.trades
|
# self.ams.trades = self.trades
|
||||||
|
|
||||||
|
def get_ids(self, trades):
|
||||||
|
return [trade["id"] for trade in trades]
|
||||||
|
|
||||||
def add_trade(self, id, symbol, side, open_time):
|
def add_trade(self, id, symbol, side, open_time):
|
||||||
trade = {
|
trade = {
|
||||||
"id": id,
|
"id": id,
|
||||||
|
@ -179,17 +181,12 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
|
||||||
|
|
||||||
self.assertEqual(len(violation_calls), len(expected_trades))
|
self.assertEqual(len(violation_calls), len(expected_trades))
|
||||||
if all(expected_trades):
|
if all(expected_trades):
|
||||||
expected_trades = convert_trades(expected_trades)
|
# expected_trades = convert_trades(expected_trades)
|
||||||
|
expected_trades = self.get_ids(expected_trades)
|
||||||
for call in violation_calls:
|
for call in violation_calls:
|
||||||
# Ensure the correct action has been called, like close
|
# Ensure the correct action has been called, like close
|
||||||
self.assertEqual(call[0][1], expected_action)
|
self.assertEqual(call[0][1], expected_action)
|
||||||
# Ensure the correct trade has been passed to the violation
|
# Ensure the correct trade has been passed to the violation
|
||||||
trade = call[0][2]
|
|
||||||
if trade:
|
|
||||||
for field in list(trade.keys()):
|
|
||||||
if "_usd" in field:
|
|
||||||
if field in trade.keys():
|
|
||||||
del trade[field]
|
|
||||||
self.assertIn(call[0][2], expected_trades)
|
self.assertIn(call[0][2], expected_trades)
|
||||||
if expected_args:
|
if expected_args:
|
||||||
self.assertEqual(call[0][3], expected_args)
|
self.assertEqual(call[0][3], expected_args)
|
||||||
|
@ -454,22 +451,18 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
|
||||||
# trade.
|
# trade.
|
||||||
# TODO: Fix this when we have a way of checking the balance at the start of the
|
# TODO: Fix this when we have a way of checking the balance at the start of the
|
||||||
# trade.
|
# trade.
|
||||||
# Max risk is also mocked as this puts us over the limit, due to the low account
|
|
||||||
# size.
|
|
||||||
@patch(
|
|
||||||
"core.trading.active_management.ActiveManagement.check_max_risk",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
@patch(
|
@patch(
|
||||||
"core.trading.active_management.ActiveManagement.check_position_size",
|
"core.trading.active_management.ActiveManagement.check_position_size",
|
||||||
return_value=None,
|
return_value=None,
|
||||||
)
|
)
|
||||||
@patch("core.trading.active_management.ActiveManagement.handle_violation")
|
@patch("core.trading.active_management.ActiveManagement.handle_violation")
|
||||||
def test_max_loss_violated(
|
def test_max_loss_violated(self, handle_violation, check_position_size):
|
||||||
self, handle_violation, check_position_size, check_max_risk
|
|
||||||
):
|
|
||||||
self.balance = D("1")
|
self.balance = D("1")
|
||||||
self.balance_usd = D("0.69")
|
self.balance_usd = D("0.69")
|
||||||
|
|
||||||
|
self.trades = []
|
||||||
|
|
||||||
self.ams.run_checks()
|
self.ams.run_checks()
|
||||||
|
|
||||||
self.check_violation(
|
self.check_violation(
|
||||||
|
@ -546,3 +539,15 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
|
||||||
"close",
|
"close",
|
||||||
[self.trades[2], self.trades[3]],
|
[self.trades[2], self.trades[3]],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_max_risk_not_violated_after_adjusting_protection(self):
|
||||||
|
"""
|
||||||
|
Ensure the max risk check is not violated after adjusting the protection.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_max_risk_not_violated_after_adjusting_position_size(self):
|
||||||
|
"""
|
||||||
|
Ensure the max risk check is not violated after adjusting the position size.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
|
@ -44,6 +44,21 @@ class ActiveManagement(object):
|
||||||
|
|
||||||
def handle_violation(self, check_type, action, trade, **kwargs):
|
def handle_violation(self, check_type, action, trade, **kwargs):
|
||||||
print("VIOLATION", check_type, action, trade, kwargs)
|
print("VIOLATION", check_type, action, trade, kwargs)
|
||||||
|
# TODO: close/notify for:
|
||||||
|
# - trading time
|
||||||
|
# - trends
|
||||||
|
# - position size
|
||||||
|
# - protection
|
||||||
|
# - asset groups
|
||||||
|
# - crossfilter
|
||||||
|
# - max open trades
|
||||||
|
# - max open trades per symbol
|
||||||
|
# - max loss
|
||||||
|
# - max risk
|
||||||
|
|
||||||
|
# TODO: adjust for:
|
||||||
|
# - position size
|
||||||
|
# - protection
|
||||||
|
|
||||||
def check_trading_time(self, trade):
|
def check_trading_time(self, trade):
|
||||||
open_ts = trade["open_time"]
|
open_ts = trade["open_time"]
|
||||||
|
@ -51,7 +66,7 @@ class ActiveManagement(object):
|
||||||
trading_time_pass = checks.within_trading_times(self.strategy, open_ts_as_date)
|
trading_time_pass = checks.within_trading_times(self.strategy, open_ts_as_date)
|
||||||
if not trading_time_pass:
|
if not trading_time_pass:
|
||||||
self.handle_violation(
|
self.handle_violation(
|
||||||
"trading_time", self.policy.when_trading_time_violated, trade
|
"trading_time", self.policy.when_trading_time_violated, trade["id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_trends(self, trade):
|
def check_trends(self, trade):
|
||||||
|
@ -59,7 +74,9 @@ class ActiveManagement(object):
|
||||||
symbol = trade["symbol"]
|
symbol = trade["symbol"]
|
||||||
trends_pass = checks.within_trends(self.strategy, symbol, direction)
|
trends_pass = checks.within_trends(self.strategy, symbol, direction)
|
||||||
if not trends_pass:
|
if not trends_pass:
|
||||||
self.handle_violation("trends", self.policy.when_trends_violated, trade)
|
self.handle_violation(
|
||||||
|
"trends", self.policy.when_trends_violated, trade["id"]
|
||||||
|
)
|
||||||
|
|
||||||
def check_position_size(self, trade):
|
def check_position_size(self, trade):
|
||||||
"""
|
"""
|
||||||
|
@ -91,7 +108,7 @@ class ActiveManagement(object):
|
||||||
self.handle_violation(
|
self.handle_violation(
|
||||||
"position_size",
|
"position_size",
|
||||||
self.policy.when_position_size_violated,
|
self.policy.when_position_size_violated,
|
||||||
trade,
|
trade["id"],
|
||||||
{"size": expected_trade_size},
|
{"size": expected_trade_size},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -152,7 +169,10 @@ class ActiveManagement(object):
|
||||||
|
|
||||||
if violations:
|
if violations:
|
||||||
self.handle_violation(
|
self.handle_violation(
|
||||||
"protection", self.policy.when_protection_violated, trade, violations
|
"protection",
|
||||||
|
self.policy.when_protection_violated,
|
||||||
|
trade["id"],
|
||||||
|
violations
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_asset_groups(self, trade):
|
def check_asset_groups(self, trade):
|
||||||
|
@ -165,7 +185,7 @@ class ActiveManagement(object):
|
||||||
)
|
)
|
||||||
if not allowed:
|
if not allowed:
|
||||||
self.handle_violation(
|
self.handle_violation(
|
||||||
"asset_group", self.policy.when_asset_groups_violated, trade
|
"asset_group", self.policy.when_asset_groups_violated, trade["id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_sorted_trades_copy(self, trades, reverse=True):
|
def get_sorted_trades_copy(self, trades, reverse=True):
|
||||||
|
@ -226,7 +246,7 @@ class ActiveManagement(object):
|
||||||
if close_trades:
|
if close_trades:
|
||||||
for trade in close_trades:
|
for trade in close_trades:
|
||||||
self.handle_violation(
|
self.handle_violation(
|
||||||
"crossfilter", self.policy.when_crossfilter_violated, trade
|
"crossfilter", self.policy.when_crossfilter_violated, trade["id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_max_open_trades(self, trades):
|
def check_max_open_trades(self, trades):
|
||||||
|
@ -241,7 +261,7 @@ class ActiveManagement(object):
|
||||||
self.handle_violation(
|
self.handle_violation(
|
||||||
"max_open_trades",
|
"max_open_trades",
|
||||||
self.policy.when_max_open_trades_violated,
|
self.policy.when_max_open_trades_violated,
|
||||||
trade,
|
trade["id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_max_open_trades_per_symbol(self, trades):
|
def check_max_open_trades_per_symbol(self, trades):
|
||||||
|
@ -267,7 +287,7 @@ class ActiveManagement(object):
|
||||||
self.handle_violation(
|
self.handle_violation(
|
||||||
"max_open_trades_per_symbol",
|
"max_open_trades_per_symbol",
|
||||||
self.policy.when_max_open_trades_violated,
|
self.policy.when_max_open_trades_violated,
|
||||||
trade,
|
trade["id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_max_loss(self):
|
def check_max_loss(self):
|
||||||
|
@ -312,7 +332,7 @@ class ActiveManagement(object):
|
||||||
if close_trades:
|
if close_trades:
|
||||||
for trade in close_trades:
|
for trade in close_trades:
|
||||||
self.handle_violation(
|
self.handle_violation(
|
||||||
"max_risk", self.policy.when_max_risk_violated, trade
|
"max_risk", self.policy.when_max_risk_violated, trade["id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_checks(self):
|
def run_checks(self):
|
||||||
|
|
Loading…
Reference in New Issue