Write protection check tests

This commit is contained in:
Mark Veidemanis 2023-02-17 17:05:52 +00:00
parent 1dbb3fcf79
commit 67117f0978
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
7 changed files with 172 additions and 41 deletions

View File

@ -103,6 +103,30 @@ def tp_price_to_percent(tp_price, side, current_price, current_units, unrealised
return round(change_percent, 5) return round(change_percent, 5)
def tp_percent_to_price(tp_percent, side, current_price, current_units, unrealised_pl):
"""
Determine the price of the TP percent from the initial price.
"""
pl_per_unit = D(unrealised_pl) / D(current_units)
if side == "long":
initial_price = D(current_price) - pl_per_unit
else:
initial_price = D(current_price) + pl_per_unit
# Get the percent change of the TP price from the initial price.
change_percent = D(tp_percent) / 100
# Get the price of the TP percent from the initial price.
change_price = initial_price * change_percent
if side == "long":
tp_price = initial_price - change_price
else:
tp_price = initial_price + change_price
return round(tp_price, 5)
def sl_price_to_percent(sl_price, side, current_price, current_units, unrealised_pl): def sl_price_to_percent(sl_price, side, current_price, current_units, unrealised_pl):
""" """
Determine the percent change of the SL price from the initial price. Determine the percent change of the SL price from the initial price.
@ -146,6 +170,30 @@ def sl_price_to_percent(sl_price, side, current_price, current_units, unrealised
return round(change_percent, 5) return round(change_percent, 5)
def sl_percent_to_price(sl_percent, side, current_price, current_units, unrealised_pl):
"""
Determine the price of the SL percent from the initial price.
"""
pl_per_unit = D(unrealised_pl) / D(current_units)
if side == "long":
initial_price = D(current_price) - pl_per_unit
else:
initial_price = D(current_price) + pl_per_unit
# Get the percent change of the SL price from the initial price.
change_percent = D(sl_percent) / 100
# Get the price of the SL percent from the initial price.
change_price = initial_price * change_percent
if side == "long":
sl_price = initial_price - change_price
else:
sl_price = initial_price + change_price
return round(sl_price, 5)
def annotate_trade_tp_sl_percent(trade): def annotate_trade_tp_sl_percent(trade):
""" """
Annotate the trade with the TP and SL percent. Annotate the trade with the TP and SL percent.
@ -228,6 +276,8 @@ def open_trade_to_unified_format(trade):
"current_price": current_price, "current_price": current_price,
"pl": unrealised_pl, "pl": unrealised_pl,
} }
if "openTime" in trade:
cast["open_time"] = trade["openTime"]
# Add some extra fields, sometimes we have already looked up the # Add some extra fields, sometimes we have already looked up the
# prices and don't need to call convert_trades_to_usd # prices and don't need to call convert_trades_to_usd
# This is mostly for tests, but it can be useful in other places. # This is mostly for tests, but it can be useful in other places.

View File

@ -1,8 +1,8 @@
# Generated by Django 4.1.7 on 2023-02-17 11:50 # Generated by Django 4.1.7 on 2023-02-17 11:50
import django.db.models.deletion
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration): class Migration(migrations.Migration):

View File

@ -1,7 +1,7 @@
# Generated by Django 4.1.7 on 2023-02-17 13:16 # Generated by Django 4.1.7 on 2023-02-17 13:16
from django.db import migrations, models
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):

View File

@ -1,6 +1,13 @@
from decimal import Decimal as D
from django.test import TestCase from django.test import TestCase
from core.exchanges.convert import sl_price_to_percent, tp_price_to_percent from core.exchanges.convert import (
sl_percent_to_price,
sl_price_to_percent,
tp_percent_to_price,
tp_price_to_percent,
)
class CommonTestCase(TestCase): class CommonTestCase(TestCase):
@ -247,25 +254,39 @@ class CommonTestCase(TestCase):
Test that the SL price to percent conversion works for long trades Test that the SL price to percent conversion works for long trades
when the price has changed, with multiple units, and the SL is at a profit. when the price has changed, with multiple units, and the SL is at a profit.
""" """
sl_price = 1.2 # +20% sl_price = D(1.2) # +20%
current_price = 1.1 # +10% current_price = 1.1 # +10%
current_units = 10 current_units = 10
unrealised_pl = 1 # +10% unrealised_pl = 1 # +10%
expected_percent = -20
percent = sl_price_to_percent( percent = sl_price_to_percent(
sl_price, "long", current_price, current_units, unrealised_pl sl_price, "long", current_price, current_units, unrealised_pl
) )
self.assertEqual(percent, -20) self.assertEqual(percent, expected_percent)
self.assertEqual(
tp_percent_to_price(
expected_percent, "long", current_price, current_units, unrealised_pl
),
sl_price,
)
def test_sl_price_to_percent_change_short_multi_profit(self): def test_sl_price_to_percent_change_short_multi_profit(self):
""" """
Test that the SL price to percent conversion works for short trades Test that the SL price to percent conversion works for short trades
when the price has changed, with multiple units, and the SL is at a profit. when the price has changed, with multiple units, and the SL is at a profit.
""" """
sl_price = 0.8 # -20% sl_price = D(0.8) # -20%
current_price = 0.9 # +10% current_price = 0.9 # +10%
current_units = 10 current_units = 10
unrealised_pl = 1 # +10% unrealised_pl = 1 # +10%
expected_percent = -20
percent = sl_price_to_percent( percent = sl_price_to_percent(
sl_price, "short", current_price, current_units, unrealised_pl sl_price, "short", current_price, current_units, unrealised_pl
) )
self.assertEqual(percent, -20) self.assertEqual(percent, expected_percent)
self.assertEqual(
tp_percent_to_price(
expected_percent, "short", current_price, current_units, unrealised_pl
),
sl_price,
)

View File

@ -1,10 +1,13 @@
from unittest.mock import Mock, patch
from django.test import TestCase from django.test import TestCase
from core.exchanges.convert import convert_trades
from core.lib.schemas.oanda_s import parse_time
from core.models import Account, ActiveManagementPolicy, Hook, Signal, User
from core.tests.helpers import StrategyMixin, SymbolPriceMock from core.tests.helpers import StrategyMixin, SymbolPriceMock
from core.trading.active_management import ActiveManagement from core.trading.active_management import ActiveManagement
from core.models import User, Account, ActiveManagementPolicy, Hook, Signal
from unittest.mock import Mock, patch
from core.lib.schemas.oanda_s import parse_time
class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
def setUp(self): def setUp(self):
@ -43,7 +46,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
"id": "20083", "id": "20083",
"symbol": "EUR_USD", "symbol": "EUR_USD",
"price": "1.06331", "price": "1.06331",
"openTime": "2023-02-13T11:38:06.302917985Z", # Monday at 11:38 "openTime": "2023-02-13T11:38:06.302917985Z", # Monday at 11:38
"initialUnits": "10", "initialUnits": "10",
"initialMarginRequired": "0.2966", "initialMarginRequired": "0.2966",
"state": "OPEN", "state": "OPEN",
@ -53,9 +56,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
"dividendAdjustment": "0.0000", "dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008", "unrealizedPL": "-0.0008",
"marginUsed": "0.2966", "marginUsed": "0.2966",
"takeProfitOrder": None, "takeProfitOrder": {"price": "1.06331"},
"stopLossOrder": None, "stopLossOrder": {"price": "1.06331"},
"trailingStopLossOrder": None, "trailingStopLossOrder": {"price": "1.06331"},
"trailingStopValue": None, "trailingStopValue": None,
"side": "long", "side": "long",
}, },
@ -63,7 +66,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
"id": "20083", "id": "20083",
"symbol": "EUR_USD", "symbol": "EUR_USD",
"price": "1.06331", "price": "1.06331",
"openTime": "2023-02-13T11:38:06.302917985Z", # Monday at 11:38 "openTime": "2023-02-13T11:38:06.302917985Z", # Monday at 11:38
"initialUnits": "10", "initialUnits": "10",
"initialMarginRequired": "0.2966", "initialMarginRequired": "0.2966",
"state": "OPEN", "state": "OPEN",
@ -73,12 +76,12 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
"dividendAdjustment": "0.0000", "dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008", "unrealizedPL": "-0.0008",
"marginUsed": "0.2966", "marginUsed": "0.2966",
"takeProfitOrder": None, "takeProfitOrder": {"price": "1.06331"},
"stopLossOrder": None, "stopLossOrder": {"price": "1.06331"},
"trailingStopLossOrder": None, "trailingStopLossOrder": {"price": "1.06331"},
"trailingStopValue": None, "trailingStopValue": None,
"side": "long", "side": "long",
} },
] ]
# Run parse_time on all items in trades # Run parse_time on all items in trades
for trade in self.trades: for trade in self.trades:
@ -86,7 +89,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
self.ams.get_trades = self.fake_get_trades self.ams.get_trades = self.fake_get_trades
self.ams.get_balance = self.fake_get_balance self.ams.get_balance = self.fake_get_balance
# self.ams.trades = self.trades # self.ams.trades = self.trades
def fake_get_trades(self): def fake_get_trades(self):
self.ams.trades = self.trades self.ams.trades = self.trades
return self.trades return self.trades
@ -105,7 +108,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
balance = self.ams.get_balance() balance = self.ams.get_balance()
self.assertEqual(balance, 10000) self.assertEqual(balance, 10000)
def check_violation(self, violation, calls, expected_action, expected_trades): def check_violation(
self, violation, calls, expected_action, expected_trades, expected_args=None
):
""" """
Check that the violation was called with the expected action and trades. Check that the violation was called with the expected action and trades.
Matches the first argument of the call to the violation name. Matches the first argument of the call to the violation name.
@ -113,6 +118,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
:param: calls: list of calls to the violation :param: calls: list of calls to the violation
:param: expected_action: expected action to be called, close, notify, etc. :param: expected_action: expected action to be called, close, notify, etc.
:param: expected_trades: list of expected trades to be passed to the violation :param: expected_trades: list of expected trades to be passed to the violation
:param: expected_args: optional, expected args to be passed to the violation
""" """
calls = list(calls) calls = list(calls)
violation_calls = [] violation_calls = []
@ -121,22 +127,28 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
violation_calls.append(call) violation_calls.append(call)
self.assertEqual(len(violation_calls), len(expected_trades)) self.assertEqual(len(violation_calls), len(expected_trades))
expected_trades = convert_trades(expected_trades)
for call in violation_calls: for call in violation_calls:
# Ensure the correct action has been called, like close # Ensure the correct action has been called, like close
self.assertEqual(call[0][1], expected_action) self.assertEqual(call[0][1], expected_action)
# Ensure the correct trade has been passed to the violation # Ensure the correct trade has been passed to the violation
self.assertIn(call[0][2], expected_trades) self.assertIn(call[0][2], expected_trades)
if expected_args:
self.assertEqual(call[0][3], expected_args)
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_run_checks(self, handle_violation): def test_run_checks(self, handle_violation):
self.ams.run_checks() self.ams.run_checks()
print("handle_violation.call_count", handle_violation.call_args_list)
self.assertEqual(handle_violation.call_count, 0) self.assertEqual(handle_violation.call_count, 0)
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trading_time_violated(self, handle_violation): def test_trading_time_violated(self, handle_violation):
self.trades[0]["openTime"] = "2023-02-17T11:38:06.302917Z" # Friday self.trades[0]["openTime"] = "2023-02-17T11:38:06.302917Z" # Friday
self.ams.run_checks() self.ams.run_checks()
self.check_violation("trading_time", handle_violation.call_args_list, "close", [self.trades[0]]) self.check_violation(
"trading_time", handle_violation.call_args_list, "close", [self.trades[0]]
)
def create_hook_signal(self): def create_hook_signal(self):
hook = Hook.objects.create( hook = Hook.objects.create(
@ -158,7 +170,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
self.strategy.trends = {"EUR_USD": "sell"} self.strategy.trends = {"EUR_USD": "sell"}
self.strategy.save() self.strategy.save()
self.ams.run_checks() self.ams.run_checks()
self.check_violation("trends", handle_violation.call_args_list, "close", self.trades) self.check_violation(
"trends", handle_violation.call_args_list, "close", self.trades
)
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated_none(self, handle_violation): def test_trends_violated_none(self, handle_violation):
@ -175,19 +189,27 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
self.strategy.trend_signals.set([signal]) self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"} self.strategy.trends = {"EUR_USD": "sell"}
self.strategy.save() self.strategy.save()
# Change the side of the first trade to match the trends # Change the side of the first trade to match the trends
self.trades[0]["side"] = "short" self.trades[0]["side"] = "short"
self.ams.run_checks() self.ams.run_checks()
self.check_violation("trends", handle_violation.call_args_list, "close", [self.trades[1]]) self.check_violation(
"trends", handle_violation.call_args_list, "close", [self.trades[1]]
)
@patch("core.trading.active_management.ActiveManagement.handle_violation") @patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_position_size_violated(self, handle_violation): def test_position_size_violated(self, handle_violation):
self.trades[0]["currentUnits"] = "100000" self.trades[0]["currentUnits"] = "100000"
self.ams.run_checks() self.ams.run_checks()
self.check_violation("position_size", handle_violation.call_args_list, "close", [self.trades[0]]) self.check_violation(
"position_size",
handle_violation.call_args_list,
"close",
[self.trades[0]],
{"size": 50},
)
def test_protection_violated(self): def test_protection_violated(self):
pass pass

View File

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from decimal import Decimal as D from decimal import Decimal as D
from core.exchanges.convert import side_to_direction from core.exchanges.convert import convert_trades, side_to_direction
from core.trading import checks from core.trading import checks
from core.trading.market import get_base_quote, get_trade_size_in_base from core.trading.market import get_base_quote, get_trade_size_in_base
@ -25,11 +25,11 @@ class ActiveManagement(object):
else: else:
return self.balance return self.balance
def handle_violation(self, check_type, action, trade): def handle_violation(self, check_type, action, trade, **kwargs):
print("VIOLATION", check_type, action, trade) print("VIOLATION", check_type, action, trade, kwargs)
def check_trading_time(self, trade): def check_trading_time(self, trade):
open_ts = trade["openTime"] open_ts = trade["open_time"]
open_ts_as_date = datetime.strptime(open_ts, "%Y-%m-%dT%H:%M:%S.%fZ") open_ts_as_date = datetime.strptime(open_ts, "%Y-%m-%dT%H:%M:%S.%fZ")
trading_time_pass = checks.within_trading_times(self.strategy, open_ts_as_date) trading_time_pass = checks.within_trading_times(self.strategy, open_ts_as_date)
if not trading_time_pass: if not trading_time_pass:
@ -38,40 +38,80 @@ class ActiveManagement(object):
) )
def check_trends(self, trade): def check_trends(self, trade):
direction = side_to_direction(trade["side"]) direction = trade["direction"]
symbol = trade["symbol"] symbol = trade["symbol"]
trends_pass = checks.within_trends(self.strategy, symbol, direction) trends_pass = checks.within_trends(self.strategy, symbol, direction)
if not trends_pass: if not trends_pass:
print("VIOLATION", "trends", self.policy.when_trends_violated, trade)
self.handle_violation("trends", self.policy.when_trends_violated, trade) self.handle_violation("trends", self.policy.when_trends_violated, trade)
def check_position_size(self, trade): def check_position_size(self, trade):
"""
Check the position size is within the allowed deviation.
WARNING: This uses the current balance, not the balance at the time of the trade.
WARNING: This uses the current symbol prices, not those at the time of the trade.
This should normally be run every 5 seconds, so this is fine.
"""
# TODO: add the trade value to the balance
# Need to determine which prices to use
balance = self.get_balance() balance = self.get_balance()
print("BALANCE", balance) direction = trade["direction"]
direction = side_to_direction(trade["side"])
symbol = trade["symbol"] symbol = trade["symbol"]
# TODO:
base, quote = get_base_quote(self.strategy.account.exchange, symbol) base, quote = get_base_quote(self.strategy.account.exchange, symbol)
expected_trade_size = get_trade_size_in_base( expected_trade_size = get_trade_size_in_base(
direction, self.strategy.account, self.strategy, balance, base direction, self.strategy.account, self.strategy, balance, base
) )
print("TRADE SIZE", expected_trade_size)
deviation = D(0.05) # 5% deviation = D(0.05) # 5%
actual_trade_size = D(trade["currentUnits"]) actual_trade_size = D(trade["amount"])
# Ensure the trade size not above the expected trade size by more than 5% # Ensure the trade size not above the expected trade size by more than 5%
max_trade_size = expected_trade_size + (deviation * expected_trade_size) max_trade_size = expected_trade_size + (deviation * expected_trade_size)
within_max_trade_size = actual_trade_size <= max_trade_size within_max_trade_size = actual_trade_size <= max_trade_size
if not within_max_trade_size: if not within_max_trade_size:
self.handle_violation( self.handle_violation(
"position_size", self.policy.when_position_size_violated, trade "position_size",
self.policy.when_position_size_violated,
trade,
{"size": expected_trade_size},
)
def check_protection(self, trade):
print("CHECK PROTECTION", trade)
deviation = D(0.05) # 5%
matches = {
"stop_loss_percent": self.strategy.order_settings.stop_loss_percent,
"take_profit_percent": self.strategy.order_settings.take_profit_percent,
"trailing_stop_percent": self.strategy.order_settings.trailing_stop_loss_percent,
}
violations = {}
for key, expected in matches.items():
if key in trade:
actual = D(trade[key])
if expected is None:
continue
expected = D(expected)
min_val = expected - (deviation * expected)
max_val = expected + (deviation * expected)
within_deviation = min_val <= actual <= max_val
if not within_deviation:
violations[key] = expected
if violations:
self.handle_violation(
"protection", self.policy.when_protection_violated, trade, violations
) )
def run_checks(self): def run_checks(self):
for trade in self.get_trades(): converted_trades = convert_trades(self.get_trades())
for trade in converted_trades:
self.check_trading_time(trade) self.check_trading_time(trade)
self.check_trends(trade) self.check_trends(trade)
self.check_position_size(trade) self.check_position_size(trade)
self.check_protection(trade)
# Trading Time # Trading Time
# Max loss # Max loss

View File

@ -53,9 +53,7 @@ def within_max_loss(strategy):
def within_trends(strategy, symbol, direction): def within_trends(strategy, symbol, direction):
print("WITHIN TRENDS", symbol, direction)
if strategy.trend_signals.exists(): if strategy.trend_signals.exists():
print("TREND SIGNALS EXIST")
if strategy.trends is None: if strategy.trends is None:
log.debug("Refusing to trade with no trend signals received") log.debug("Refusing to trade with no trend signals received")
sendmsg( sendmsg(