Finish AMS tests
This commit is contained in:
parent
ed63085e10
commit
9c537187f0
|
@ -7,6 +7,8 @@ from django.test import TestCase
|
|||
from core.exchanges.convert import convert_trades
|
||||
from core.models import (
|
||||
ActiveManagementPolicy,
|
||||
AssetGroup,
|
||||
AssetRule,
|
||||
Hook,
|
||||
RiskModel,
|
||||
Signal,
|
||||
|
@ -128,6 +130,10 @@ class ActiveManagementMixinTestCase(StrategyMixin):
|
|||
# Don't need to close the trade, it's already closed.
|
||||
# Otherwise the test would fail.
|
||||
|
||||
self.strategy.trend_signals.set([])
|
||||
self.strategy.trends = {}
|
||||
self.strategy.save()
|
||||
|
||||
def test_ams_position_size_violated(self):
|
||||
self.active_management_policy.when_position_size_violated = "close"
|
||||
self.active_management_policy.save()
|
||||
|
@ -177,12 +183,8 @@ class ActiveManagementMixinTestCase(StrategyMixin):
|
|||
def test_ams_protection_violated(self):
|
||||
self.active_management_policy.when_protection_violated = "close"
|
||||
self.active_management_policy.save()
|
||||
# Don't violate position size check
|
||||
trade_size = market.get_trade_size_in_base(
|
||||
"buy", self.account, self.strategy, self.account.client.get_balance(), "EUR"
|
||||
)
|
||||
trade_size = round(trade_size, 0)
|
||||
complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5)
|
||||
|
||||
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 5, 5)
|
||||
self.open_trade(complex_trade)
|
||||
|
||||
self.ams.run_checks()
|
||||
|
@ -206,12 +208,7 @@ class ActiveManagementMixinTestCase(StrategyMixin):
|
|||
self.assertEqual(len(trades), 0)
|
||||
|
||||
def test_ams_protection_violated_adjust(self):
|
||||
# Don't violate position size check
|
||||
trade_size = market.get_trade_size_in_base(
|
||||
"buy", self.account, self.strategy, self.account.client.get_balance(), "EUR"
|
||||
)
|
||||
trade_size = round(trade_size, 0)
|
||||
complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5)
|
||||
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 5, 5)
|
||||
self.open_trade(complex_trade)
|
||||
|
||||
self.ams.run_checks()
|
||||
|
@ -231,22 +228,246 @@ class ActiveManagementMixinTestCase(StrategyMixin):
|
|||
self.close_trade(complex_trade)
|
||||
|
||||
def test_ams_asset_groups_violated(self):
|
||||
pass
|
||||
asset_group = AssetGroup.objects.create(
|
||||
user=self.user,
|
||||
name="Test Asset Group",
|
||||
)
|
||||
AssetRule.objects.create(
|
||||
user=self.user,
|
||||
asset="USD",
|
||||
group=asset_group,
|
||||
status=2, # Bullish
|
||||
)
|
||||
self.strategy.asset_group = asset_group
|
||||
self.strategy.save()
|
||||
|
||||
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||
self.open_trade(complex_trade)
|
||||
|
||||
self.ams.run_checks()
|
||||
|
||||
expected = {
|
||||
"close": [
|
||||
{"id": complex_trade.order_id, "check": "asset_group", "extra": {}}
|
||||
]
|
||||
}
|
||||
self.assertEqual(self.ams.actions, expected)
|
||||
|
||||
self.ams.execute_actions()
|
||||
|
||||
trades = self.account.client.get_all_open_trades()
|
||||
self.assertEqual(len(trades), 0)
|
||||
|
||||
def test_ams_crossfilter_violated(self):
|
||||
pass
|
||||
complex_trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||
self.open_trade(complex_trade1)
|
||||
|
||||
def test_ams_max_open_trades_violated(self):
|
||||
pass
|
||||
complex_trade2 = self.create_complex_trade("buy", 10, "USD_JPY", 1.5, 1.0)
|
||||
self.open_trade(complex_trade2)
|
||||
|
||||
def test_ams_max_open_trades_per_symbol_violated(self):
|
||||
pass
|
||||
trades = self.account.client.get_all_open_trades()
|
||||
self.assertEqual(len(trades), 2)
|
||||
|
||||
self.ams.run_checks()
|
||||
|
||||
expected = {
|
||||
"close": [
|
||||
{
|
||||
"id": complex_trade2.order_id, # Only the second one
|
||||
"check": "crossfilter",
|
||||
"extra": {},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
self.assertEqual(self.ams.actions, expected)
|
||||
|
||||
self.ams.execute_actions()
|
||||
|
||||
trades = self.account.client.get_all_open_trades()
|
||||
self.assertEqual(len(trades), 1)
|
||||
|
||||
self.close_trade(complex_trade1)
|
||||
|
||||
@patch(
|
||||
"core.trading.active_management.ActiveManagement.check_trends",
|
||||
return_value=None,
|
||||
)
|
||||
@patch(
|
||||
"core.trading.active_management.ActiveManagement.check_crossfilter",
|
||||
return_value=None,
|
||||
)
|
||||
def test_ams_max_open_trades_violated(self, check_crossfilter, check_trends):
|
||||
self.strategy.risk_model.max_open_trades = 2
|
||||
self.strategy.risk_model.save()
|
||||
trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||
self.open_trade(trade1)
|
||||
|
||||
trade2 = self.create_complex_trade("buy", 10, "USD_JPY", 1.5, 1.0)
|
||||
self.open_trade(trade2)
|
||||
|
||||
trade3 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0)
|
||||
self.open_trade(trade3)
|
||||
|
||||
trades = self.account.client.get_all_open_trades()
|
||||
self.assertEqual(len(trades), 3)
|
||||
|
||||
self.ams.run_checks()
|
||||
|
||||
expected = {
|
||||
"close": [
|
||||
{
|
||||
"id": trade3.order_id,
|
||||
"check": "max_open_trades",
|
||||
"extra": {},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
self.assertEqual(self.ams.actions, expected)
|
||||
|
||||
self.ams.execute_actions()
|
||||
|
||||
trades = self.account.client.get_all_open_trades()
|
||||
self.assertEqual(len(trades), 2)
|
||||
|
||||
trade_ids = [trade["id"] for trade in trades]
|
||||
self.assertIn(trade1.order_id, trade_ids)
|
||||
self.assertIn(trade2.order_id, trade_ids)
|
||||
self.assertNotIn(trade3.order_id, trade_ids)
|
||||
|
||||
for x in [trade1, trade2]:
|
||||
self.close_trade(x)
|
||||
|
||||
@patch(
|
||||
"core.trading.active_management.ActiveManagement.check_trends",
|
||||
return_value=None,
|
||||
)
|
||||
def test_ams_max_open_trades_per_symbol_violated(self, check_trends):
|
||||
self.strategy.risk_model.max_open_trades_per_symbol = 2
|
||||
self.strategy.risk_model.save()
|
||||
|
||||
trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||
self.open_trade(trade1)
|
||||
|
||||
trade2 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||
self.open_trade(trade2)
|
||||
|
||||
trade3 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||
self.open_trade(trade3)
|
||||
|
||||
trade4 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0)
|
||||
self.open_trade(trade4)
|
||||
|
||||
trade5 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0)
|
||||
self.open_trade(trade5)
|
||||
|
||||
trade6 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0)
|
||||
self.open_trade(trade6)
|
||||
|
||||
trades = self.account.client.get_all_open_trades()
|
||||
self.assertEqual(len(trades), 6)
|
||||
|
||||
self.ams.run_checks()
|
||||
|
||||
expected = {
|
||||
"close": [
|
||||
{
|
||||
"id": trade3.order_id,
|
||||
"check": "max_open_trades_per_symbol",
|
||||
"extra": {},
|
||||
},
|
||||
{
|
||||
"id": trade6.order_id,
|
||||
"check": "max_open_trades_per_symbol",
|
||||
"extra": {},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
self.assertEqual(self.ams.actions, expected)
|
||||
|
||||
self.ams.execute_actions()
|
||||
|
||||
trades = self.account.client.get_all_open_trades()
|
||||
self.assertEqual(len(trades), 4)
|
||||
|
||||
trade_ids = [trade["id"] for trade in trades]
|
||||
self.assertIn(trade1.order_id, trade_ids)
|
||||
self.assertIn(trade2.order_id, trade_ids)
|
||||
self.assertNotIn(trade3.order_id, trade_ids)
|
||||
|
||||
self.assertIn(trade4.order_id, trade_ids)
|
||||
self.assertIn(trade5.order_id, trade_ids)
|
||||
self.assertNotIn(trade6.order_id, trade_ids)
|
||||
|
||||
def test_ams_max_loss_violated(self):
|
||||
pass
|
||||
trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||
self.open_trade(trade1)
|
||||
|
||||
def test_ams_max_risk_violated(self):
|
||||
pass
|
||||
self.account.initial_balance = self.account.client.get_balance() * 2
|
||||
self.account.save()
|
||||
|
||||
self.ams.run_checks()
|
||||
|
||||
expected = {
|
||||
"close": [
|
||||
{
|
||||
"id": None,
|
||||
"check": "max_loss",
|
||||
"extra": {},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
self.assertEqual(self.ams.actions, expected)
|
||||
|
||||
self.ams.execute_actions()
|
||||
|
||||
trades = self.account.client.get_all_open_trades()
|
||||
self.assertEqual(len(trades), 0)
|
||||
|
||||
@patch(
|
||||
"core.trading.active_management.ActiveManagement.check_position_size",
|
||||
return_value=None,
|
||||
)
|
||||
def test_ams_max_risk_violated(self, check_position_size):
|
||||
self.strategy.risk_model.max_risk_percent = 0.001
|
||||
self.strategy.risk_model.save()
|
||||
|
||||
trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||
self.open_trade(trade1)
|
||||
|
||||
trade2 = self.create_complex_trade("buy", 1000, "EUR_USD", 1.5, 1.0)
|
||||
self.open_trade(trade2)
|
||||
|
||||
trades = self.account.client.get_all_open_trades()
|
||||
self.assertEqual(len(trades), 2)
|
||||
|
||||
self.ams.run_checks()
|
||||
|
||||
expected = {
|
||||
"close": [
|
||||
{
|
||||
"id": trade2.order_id,
|
||||
"check": "max_risk",
|
||||
"extra": {},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
self.assertEqual(self.ams.actions, expected)
|
||||
|
||||
self.ams.execute_actions()
|
||||
|
||||
trades = self.account.client.get_all_open_trades()
|
||||
self.assertEqual(len(trades), 1)
|
||||
|
||||
trade_ids = [trade["id"] for trade in trades]
|
||||
self.assertIn(trade1.order_id, trade_ids)
|
||||
self.assertNotIn(trade2.order_id, trade_ids)
|
||||
|
||||
self.close_trade(trade1)
|
||||
|
||||
|
||||
class LiveTradingTestCase(
|
||||
|
@ -290,6 +511,9 @@ class LiveTradingTestCase(
|
|||
# Check the opened trade
|
||||
self.assertEqual(posted["type"], "MARKET_ORDER")
|
||||
self.assertEqual(posted["symbol"], trade.symbol)
|
||||
if trade.direction == "sell":
|
||||
self.assertEqual(posted["units"], str(0 - trade.amount))
|
||||
else:
|
||||
self.assertEqual(posted["units"], str(trade.amount))
|
||||
self.assertEqual(posted["timeInForce"], "FOK")
|
||||
|
||||
|
|
|
@ -79,7 +79,11 @@ class ActiveManagement(object):
|
|||
|
||||
def bulk_close_trades(self, trade_ids):
|
||||
for trade_id in trade_ids:
|
||||
if trade_id is not None:
|
||||
self.close_trade(trade_id)
|
||||
else:
|
||||
self.strategy.account.client.close_all_positions()
|
||||
return
|
||||
|
||||
def bulk_notify(self, action, action_cast_list):
|
||||
msg = ""
|
||||
|
@ -461,6 +465,7 @@ class ActiveManagement(object):
|
|||
converted_trades = convert_trades(self.get_trades())
|
||||
trades_copy = self.get_sorted_trades_copy(converted_trades, reverse=False)
|
||||
market.convert_trades_to_usd(self.strategy.account, trades_copy)
|
||||
|
||||
for trade in reversed(trades_copy):
|
||||
try:
|
||||
self.check_trading_time(trade)
|
||||
|
|
|
@ -46,6 +46,7 @@ def within_callback_price_deviation(strategy, price, current_price):
|
|||
|
||||
def within_trends(strategy, symbol, direction):
|
||||
if strategy.trend_signals.exists():
|
||||
if len(strategy.trend_signals.all()) > 0:
|
||||
if strategy.trends is None:
|
||||
log.debug("Refusing to trade with no trend signals received")
|
||||
sendmsg(
|
||||
|
@ -74,6 +75,5 @@ def within_trends(strategy, symbol, direction):
|
|||
else:
|
||||
log.debug(f"Trend check passed for {symbol} - {direction}")
|
||||
return True
|
||||
else:
|
||||
log.debug("No trend signals configured")
|
||||
return True
|
||||
|
|
Loading…
Reference in New Issue