Finish AMS tests

This commit is contained in:
Mark Veidemanis 2023-02-22 07:20:58 +00:00
parent ed63085e10
commit 9c537187f0
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
3 changed files with 279 additions and 50 deletions

View File

@ -7,6 +7,8 @@ from django.test import TestCase
from core.exchanges.convert import convert_trades from core.exchanges.convert import convert_trades
from core.models import ( from core.models import (
ActiveManagementPolicy, ActiveManagementPolicy,
AssetGroup,
AssetRule,
Hook, Hook,
RiskModel, RiskModel,
Signal, Signal,
@ -128,6 +130,10 @@ class ActiveManagementMixinTestCase(StrategyMixin):
# Don't need to close the trade, it's already closed. # Don't need to close the trade, it's already closed.
# Otherwise the test would fail. # Otherwise the test would fail.
self.strategy.trend_signals.set([])
self.strategy.trends = {}
self.strategy.save()
def test_ams_position_size_violated(self): def test_ams_position_size_violated(self):
self.active_management_policy.when_position_size_violated = "close" self.active_management_policy.when_position_size_violated = "close"
self.active_management_policy.save() self.active_management_policy.save()
@ -177,12 +183,8 @@ class ActiveManagementMixinTestCase(StrategyMixin):
def test_ams_protection_violated(self): def test_ams_protection_violated(self):
self.active_management_policy.when_protection_violated = "close" self.active_management_policy.when_protection_violated = "close"
self.active_management_policy.save() self.active_management_policy.save()
# Don't violate position size check
trade_size = market.get_trade_size_in_base( complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 5, 5)
"buy", self.account, self.strategy, self.account.client.get_balance(), "EUR"
)
trade_size = round(trade_size, 0)
complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5)
self.open_trade(complex_trade) self.open_trade(complex_trade)
self.ams.run_checks() self.ams.run_checks()
@ -206,12 +208,7 @@ class ActiveManagementMixinTestCase(StrategyMixin):
self.assertEqual(len(trades), 0) self.assertEqual(len(trades), 0)
def test_ams_protection_violated_adjust(self): def test_ams_protection_violated_adjust(self):
# Don't violate position size check complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 5, 5)
trade_size = market.get_trade_size_in_base(
"buy", self.account, self.strategy, self.account.client.get_balance(), "EUR"
)
trade_size = round(trade_size, 0)
complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5)
self.open_trade(complex_trade) self.open_trade(complex_trade)
self.ams.run_checks() self.ams.run_checks()
@ -231,22 +228,246 @@ class ActiveManagementMixinTestCase(StrategyMixin):
self.close_trade(complex_trade) self.close_trade(complex_trade)
def test_ams_asset_groups_violated(self): def test_ams_asset_groups_violated(self):
pass asset_group = AssetGroup.objects.create(
user=self.user,
name="Test Asset Group",
)
AssetRule.objects.create(
user=self.user,
asset="USD",
group=asset_group,
status=2, # Bullish
)
self.strategy.asset_group = asset_group
self.strategy.save()
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(complex_trade)
self.ams.run_checks()
expected = {
"close": [
{"id": complex_trade.order_id, "check": "asset_group", "extra": {}}
]
}
self.assertEqual(self.ams.actions, expected)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 0)
def test_ams_crossfilter_violated(self): def test_ams_crossfilter_violated(self):
pass complex_trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(complex_trade1)
def test_ams_max_open_trades_violated(self): complex_trade2 = self.create_complex_trade("buy", 10, "USD_JPY", 1.5, 1.0)
pass self.open_trade(complex_trade2)
def test_ams_max_open_trades_per_symbol_violated(self): trades = self.account.client.get_all_open_trades()
pass self.assertEqual(len(trades), 2)
self.ams.run_checks()
expected = {
"close": [
{
"id": complex_trade2.order_id, # Only the second one
"check": "crossfilter",
"extra": {},
}
]
}
self.assertEqual(self.ams.actions, expected)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 1)
self.close_trade(complex_trade1)
@patch(
"core.trading.active_management.ActiveManagement.check_trends",
return_value=None,
)
@patch(
"core.trading.active_management.ActiveManagement.check_crossfilter",
return_value=None,
)
def test_ams_max_open_trades_violated(self, check_crossfilter, check_trends):
self.strategy.risk_model.max_open_trades = 2
self.strategy.risk_model.save()
trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(trade1)
trade2 = self.create_complex_trade("buy", 10, "USD_JPY", 1.5, 1.0)
self.open_trade(trade2)
trade3 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0)
self.open_trade(trade3)
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 3)
self.ams.run_checks()
expected = {
"close": [
{
"id": trade3.order_id,
"check": "max_open_trades",
"extra": {},
}
]
}
self.assertEqual(self.ams.actions, expected)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 2)
trade_ids = [trade["id"] for trade in trades]
self.assertIn(trade1.order_id, trade_ids)
self.assertIn(trade2.order_id, trade_ids)
self.assertNotIn(trade3.order_id, trade_ids)
for x in [trade1, trade2]:
self.close_trade(x)
@patch(
"core.trading.active_management.ActiveManagement.check_trends",
return_value=None,
)
def test_ams_max_open_trades_per_symbol_violated(self, check_trends):
self.strategy.risk_model.max_open_trades_per_symbol = 2
self.strategy.risk_model.save()
trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(trade1)
trade2 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(trade2)
trade3 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(trade3)
trade4 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0)
self.open_trade(trade4)
trade5 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0)
self.open_trade(trade5)
trade6 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0)
self.open_trade(trade6)
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 6)
self.ams.run_checks()
expected = {
"close": [
{
"id": trade3.order_id,
"check": "max_open_trades_per_symbol",
"extra": {},
},
{
"id": trade6.order_id,
"check": "max_open_trades_per_symbol",
"extra": {},
},
]
}
self.assertEqual(self.ams.actions, expected)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 4)
trade_ids = [trade["id"] for trade in trades]
self.assertIn(trade1.order_id, trade_ids)
self.assertIn(trade2.order_id, trade_ids)
self.assertNotIn(trade3.order_id, trade_ids)
self.assertIn(trade4.order_id, trade_ids)
self.assertIn(trade5.order_id, trade_ids)
self.assertNotIn(trade6.order_id, trade_ids)
def test_ams_max_loss_violated(self): def test_ams_max_loss_violated(self):
pass trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(trade1)
def test_ams_max_risk_violated(self): self.account.initial_balance = self.account.client.get_balance() * 2
pass self.account.save()
self.ams.run_checks()
expected = {
"close": [
{
"id": None,
"check": "max_loss",
"extra": {},
}
]
}
self.assertEqual(self.ams.actions, expected)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 0)
@patch(
"core.trading.active_management.ActiveManagement.check_position_size",
return_value=None,
)
def test_ams_max_risk_violated(self, check_position_size):
self.strategy.risk_model.max_risk_percent = 0.001
self.strategy.risk_model.save()
trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(trade1)
trade2 = self.create_complex_trade("buy", 1000, "EUR_USD", 1.5, 1.0)
self.open_trade(trade2)
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 2)
self.ams.run_checks()
expected = {
"close": [
{
"id": trade2.order_id,
"check": "max_risk",
"extra": {},
}
]
}
self.assertEqual(self.ams.actions, expected)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 1)
trade_ids = [trade["id"] for trade in trades]
self.assertIn(trade1.order_id, trade_ids)
self.assertNotIn(trade2.order_id, trade_ids)
self.close_trade(trade1)
class LiveTradingTestCase( class LiveTradingTestCase(
@ -290,6 +511,9 @@ class LiveTradingTestCase(
# Check the opened trade # Check the opened trade
self.assertEqual(posted["type"], "MARKET_ORDER") self.assertEqual(posted["type"], "MARKET_ORDER")
self.assertEqual(posted["symbol"], trade.symbol) self.assertEqual(posted["symbol"], trade.symbol)
if trade.direction == "sell":
self.assertEqual(posted["units"], str(0 - trade.amount))
else:
self.assertEqual(posted["units"], str(trade.amount)) self.assertEqual(posted["units"], str(trade.amount))
self.assertEqual(posted["timeInForce"], "FOK") self.assertEqual(posted["timeInForce"], "FOK")

View File

@ -79,7 +79,11 @@ class ActiveManagement(object):
def bulk_close_trades(self, trade_ids): def bulk_close_trades(self, trade_ids):
for trade_id in trade_ids: for trade_id in trade_ids:
if trade_id is not None:
self.close_trade(trade_id) self.close_trade(trade_id)
else:
self.strategy.account.client.close_all_positions()
return
def bulk_notify(self, action, action_cast_list): def bulk_notify(self, action, action_cast_list):
msg = "" msg = ""
@ -461,6 +465,7 @@ class ActiveManagement(object):
converted_trades = convert_trades(self.get_trades()) converted_trades = convert_trades(self.get_trades())
trades_copy = self.get_sorted_trades_copy(converted_trades, reverse=False) trades_copy = self.get_sorted_trades_copy(converted_trades, reverse=False)
market.convert_trades_to_usd(self.strategy.account, trades_copy) market.convert_trades_to_usd(self.strategy.account, trades_copy)
for trade in reversed(trades_copy): for trade in reversed(trades_copy):
try: try:
self.check_trading_time(trade) self.check_trading_time(trade)

View File

@ -46,6 +46,7 @@ def within_callback_price_deviation(strategy, price, current_price):
def within_trends(strategy, symbol, direction): def within_trends(strategy, symbol, direction):
if strategy.trend_signals.exists(): if strategy.trend_signals.exists():
if len(strategy.trend_signals.all()) > 0:
if strategy.trends is None: if strategy.trends is None:
log.debug("Refusing to trade with no trend signals received") log.debug("Refusing to trade with no trend signals received")
sendmsg( sendmsg(
@ -74,6 +75,5 @@ def within_trends(strategy, symbol, direction):
else: else:
log.debug(f"Trend check passed for {symbol} - {direction}") log.debug(f"Trend check passed for {symbol} - {direction}")
return True return True
else:
log.debug("No trend signals configured") log.debug("No trend signals configured")
return True return True