Finish AMS tests
This commit is contained in:
parent
ed63085e10
commit
9c537187f0
|
@ -7,6 +7,8 @@ from django.test import TestCase
|
||||||
from core.exchanges.convert import convert_trades
|
from core.exchanges.convert import convert_trades
|
||||||
from core.models import (
|
from core.models import (
|
||||||
ActiveManagementPolicy,
|
ActiveManagementPolicy,
|
||||||
|
AssetGroup,
|
||||||
|
AssetRule,
|
||||||
Hook,
|
Hook,
|
||||||
RiskModel,
|
RiskModel,
|
||||||
Signal,
|
Signal,
|
||||||
|
@ -128,6 +130,10 @@ class ActiveManagementMixinTestCase(StrategyMixin):
|
||||||
# Don't need to close the trade, it's already closed.
|
# Don't need to close the trade, it's already closed.
|
||||||
# Otherwise the test would fail.
|
# Otherwise the test would fail.
|
||||||
|
|
||||||
|
self.strategy.trend_signals.set([])
|
||||||
|
self.strategy.trends = {}
|
||||||
|
self.strategy.save()
|
||||||
|
|
||||||
def test_ams_position_size_violated(self):
|
def test_ams_position_size_violated(self):
|
||||||
self.active_management_policy.when_position_size_violated = "close"
|
self.active_management_policy.when_position_size_violated = "close"
|
||||||
self.active_management_policy.save()
|
self.active_management_policy.save()
|
||||||
|
@ -177,12 +183,8 @@ class ActiveManagementMixinTestCase(StrategyMixin):
|
||||||
def test_ams_protection_violated(self):
|
def test_ams_protection_violated(self):
|
||||||
self.active_management_policy.when_protection_violated = "close"
|
self.active_management_policy.when_protection_violated = "close"
|
||||||
self.active_management_policy.save()
|
self.active_management_policy.save()
|
||||||
# Don't violate position size check
|
|
||||||
trade_size = market.get_trade_size_in_base(
|
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 5, 5)
|
||||||
"buy", self.account, self.strategy, self.account.client.get_balance(), "EUR"
|
|
||||||
)
|
|
||||||
trade_size = round(trade_size, 0)
|
|
||||||
complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5)
|
|
||||||
self.open_trade(complex_trade)
|
self.open_trade(complex_trade)
|
||||||
|
|
||||||
self.ams.run_checks()
|
self.ams.run_checks()
|
||||||
|
@ -206,12 +208,7 @@ class ActiveManagementMixinTestCase(StrategyMixin):
|
||||||
self.assertEqual(len(trades), 0)
|
self.assertEqual(len(trades), 0)
|
||||||
|
|
||||||
def test_ams_protection_violated_adjust(self):
|
def test_ams_protection_violated_adjust(self):
|
||||||
# Don't violate position size check
|
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 5, 5)
|
||||||
trade_size = market.get_trade_size_in_base(
|
|
||||||
"buy", self.account, self.strategy, self.account.client.get_balance(), "EUR"
|
|
||||||
)
|
|
||||||
trade_size = round(trade_size, 0)
|
|
||||||
complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5)
|
|
||||||
self.open_trade(complex_trade)
|
self.open_trade(complex_trade)
|
||||||
|
|
||||||
self.ams.run_checks()
|
self.ams.run_checks()
|
||||||
|
@ -231,22 +228,246 @@ class ActiveManagementMixinTestCase(StrategyMixin):
|
||||||
self.close_trade(complex_trade)
|
self.close_trade(complex_trade)
|
||||||
|
|
||||||
def test_ams_asset_groups_violated(self):
|
def test_ams_asset_groups_violated(self):
|
||||||
pass
|
asset_group = AssetGroup.objects.create(
|
||||||
|
user=self.user,
|
||||||
|
name="Test Asset Group",
|
||||||
|
)
|
||||||
|
AssetRule.objects.create(
|
||||||
|
user=self.user,
|
||||||
|
asset="USD",
|
||||||
|
group=asset_group,
|
||||||
|
status=2, # Bullish
|
||||||
|
)
|
||||||
|
self.strategy.asset_group = asset_group
|
||||||
|
self.strategy.save()
|
||||||
|
|
||||||
|
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(complex_trade)
|
||||||
|
|
||||||
|
self.ams.run_checks()
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"close": [
|
||||||
|
{"id": complex_trade.order_id, "check": "asset_group", "extra": {}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
self.assertEqual(self.ams.actions, expected)
|
||||||
|
|
||||||
|
self.ams.execute_actions()
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 0)
|
||||||
|
|
||||||
def test_ams_crossfilter_violated(self):
|
def test_ams_crossfilter_violated(self):
|
||||||
pass
|
complex_trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(complex_trade1)
|
||||||
|
|
||||||
def test_ams_max_open_trades_violated(self):
|
complex_trade2 = self.create_complex_trade("buy", 10, "USD_JPY", 1.5, 1.0)
|
||||||
pass
|
self.open_trade(complex_trade2)
|
||||||
|
|
||||||
def test_ams_max_open_trades_per_symbol_violated(self):
|
trades = self.account.client.get_all_open_trades()
|
||||||
pass
|
self.assertEqual(len(trades), 2)
|
||||||
|
|
||||||
|
self.ams.run_checks()
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"close": [
|
||||||
|
{
|
||||||
|
"id": complex_trade2.order_id, # Only the second one
|
||||||
|
"check": "crossfilter",
|
||||||
|
"extra": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(self.ams.actions, expected)
|
||||||
|
|
||||||
|
self.ams.execute_actions()
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 1)
|
||||||
|
|
||||||
|
self.close_trade(complex_trade1)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.trading.active_management.ActiveManagement.check_trends",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
@patch(
|
||||||
|
"core.trading.active_management.ActiveManagement.check_crossfilter",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
def test_ams_max_open_trades_violated(self, check_crossfilter, check_trends):
|
||||||
|
self.strategy.risk_model.max_open_trades = 2
|
||||||
|
self.strategy.risk_model.save()
|
||||||
|
trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(trade1)
|
||||||
|
|
||||||
|
trade2 = self.create_complex_trade("buy", 10, "USD_JPY", 1.5, 1.0)
|
||||||
|
self.open_trade(trade2)
|
||||||
|
|
||||||
|
trade3 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0)
|
||||||
|
self.open_trade(trade3)
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 3)
|
||||||
|
|
||||||
|
self.ams.run_checks()
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"close": [
|
||||||
|
{
|
||||||
|
"id": trade3.order_id,
|
||||||
|
"check": "max_open_trades",
|
||||||
|
"extra": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(self.ams.actions, expected)
|
||||||
|
|
||||||
|
self.ams.execute_actions()
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 2)
|
||||||
|
|
||||||
|
trade_ids = [trade["id"] for trade in trades]
|
||||||
|
self.assertIn(trade1.order_id, trade_ids)
|
||||||
|
self.assertIn(trade2.order_id, trade_ids)
|
||||||
|
self.assertNotIn(trade3.order_id, trade_ids)
|
||||||
|
|
||||||
|
for x in [trade1, trade2]:
|
||||||
|
self.close_trade(x)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.trading.active_management.ActiveManagement.check_trends",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
def test_ams_max_open_trades_per_symbol_violated(self, check_trends):
|
||||||
|
self.strategy.risk_model.max_open_trades_per_symbol = 2
|
||||||
|
self.strategy.risk_model.save()
|
||||||
|
|
||||||
|
trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(trade1)
|
||||||
|
|
||||||
|
trade2 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(trade2)
|
||||||
|
|
||||||
|
trade3 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(trade3)
|
||||||
|
|
||||||
|
trade4 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0)
|
||||||
|
self.open_trade(trade4)
|
||||||
|
|
||||||
|
trade5 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0)
|
||||||
|
self.open_trade(trade5)
|
||||||
|
|
||||||
|
trade6 = self.create_complex_trade("buy", 10, "EUR_JPY", 1.5, 1.0)
|
||||||
|
self.open_trade(trade6)
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 6)
|
||||||
|
|
||||||
|
self.ams.run_checks()
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"close": [
|
||||||
|
{
|
||||||
|
"id": trade3.order_id,
|
||||||
|
"check": "max_open_trades_per_symbol",
|
||||||
|
"extra": {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": trade6.order_id,
|
||||||
|
"check": "max_open_trades_per_symbol",
|
||||||
|
"extra": {},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(self.ams.actions, expected)
|
||||||
|
|
||||||
|
self.ams.execute_actions()
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 4)
|
||||||
|
|
||||||
|
trade_ids = [trade["id"] for trade in trades]
|
||||||
|
self.assertIn(trade1.order_id, trade_ids)
|
||||||
|
self.assertIn(trade2.order_id, trade_ids)
|
||||||
|
self.assertNotIn(trade3.order_id, trade_ids)
|
||||||
|
|
||||||
|
self.assertIn(trade4.order_id, trade_ids)
|
||||||
|
self.assertIn(trade5.order_id, trade_ids)
|
||||||
|
self.assertNotIn(trade6.order_id, trade_ids)
|
||||||
|
|
||||||
def test_ams_max_loss_violated(self):
|
def test_ams_max_loss_violated(self):
|
||||||
pass
|
trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(trade1)
|
||||||
|
|
||||||
def test_ams_max_risk_violated(self):
|
self.account.initial_balance = self.account.client.get_balance() * 2
|
||||||
pass
|
self.account.save()
|
||||||
|
|
||||||
|
self.ams.run_checks()
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"close": [
|
||||||
|
{
|
||||||
|
"id": None,
|
||||||
|
"check": "max_loss",
|
||||||
|
"extra": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(self.ams.actions, expected)
|
||||||
|
|
||||||
|
self.ams.execute_actions()
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 0)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"core.trading.active_management.ActiveManagement.check_position_size",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
def test_ams_max_risk_violated(self, check_position_size):
|
||||||
|
self.strategy.risk_model.max_risk_percent = 0.001
|
||||||
|
self.strategy.risk_model.save()
|
||||||
|
|
||||||
|
trade1 = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(trade1)
|
||||||
|
|
||||||
|
trade2 = self.create_complex_trade("buy", 1000, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(trade2)
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 2)
|
||||||
|
|
||||||
|
self.ams.run_checks()
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"close": [
|
||||||
|
{
|
||||||
|
"id": trade2.order_id,
|
||||||
|
"check": "max_risk",
|
||||||
|
"extra": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(self.ams.actions, expected)
|
||||||
|
|
||||||
|
self.ams.execute_actions()
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 1)
|
||||||
|
|
||||||
|
trade_ids = [trade["id"] for trade in trades]
|
||||||
|
self.assertIn(trade1.order_id, trade_ids)
|
||||||
|
self.assertNotIn(trade2.order_id, trade_ids)
|
||||||
|
|
||||||
|
self.close_trade(trade1)
|
||||||
|
|
||||||
|
|
||||||
class LiveTradingTestCase(
|
class LiveTradingTestCase(
|
||||||
|
@ -290,6 +511,9 @@ class LiveTradingTestCase(
|
||||||
# Check the opened trade
|
# Check the opened trade
|
||||||
self.assertEqual(posted["type"], "MARKET_ORDER")
|
self.assertEqual(posted["type"], "MARKET_ORDER")
|
||||||
self.assertEqual(posted["symbol"], trade.symbol)
|
self.assertEqual(posted["symbol"], trade.symbol)
|
||||||
|
if trade.direction == "sell":
|
||||||
|
self.assertEqual(posted["units"], str(0 - trade.amount))
|
||||||
|
else:
|
||||||
self.assertEqual(posted["units"], str(trade.amount))
|
self.assertEqual(posted["units"], str(trade.amount))
|
||||||
self.assertEqual(posted["timeInForce"], "FOK")
|
self.assertEqual(posted["timeInForce"], "FOK")
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,11 @@ class ActiveManagement(object):
|
||||||
|
|
||||||
def bulk_close_trades(self, trade_ids):
|
def bulk_close_trades(self, trade_ids):
|
||||||
for trade_id in trade_ids:
|
for trade_id in trade_ids:
|
||||||
|
if trade_id is not None:
|
||||||
self.close_trade(trade_id)
|
self.close_trade(trade_id)
|
||||||
|
else:
|
||||||
|
self.strategy.account.client.close_all_positions()
|
||||||
|
return
|
||||||
|
|
||||||
def bulk_notify(self, action, action_cast_list):
|
def bulk_notify(self, action, action_cast_list):
|
||||||
msg = ""
|
msg = ""
|
||||||
|
@ -461,6 +465,7 @@ class ActiveManagement(object):
|
||||||
converted_trades = convert_trades(self.get_trades())
|
converted_trades = convert_trades(self.get_trades())
|
||||||
trades_copy = self.get_sorted_trades_copy(converted_trades, reverse=False)
|
trades_copy = self.get_sorted_trades_copy(converted_trades, reverse=False)
|
||||||
market.convert_trades_to_usd(self.strategy.account, trades_copy)
|
market.convert_trades_to_usd(self.strategy.account, trades_copy)
|
||||||
|
|
||||||
for trade in reversed(trades_copy):
|
for trade in reversed(trades_copy):
|
||||||
try:
|
try:
|
||||||
self.check_trading_time(trade)
|
self.check_trading_time(trade)
|
||||||
|
|
|
@ -46,6 +46,7 @@ def within_callback_price_deviation(strategy, price, current_price):
|
||||||
|
|
||||||
def within_trends(strategy, symbol, direction):
|
def within_trends(strategy, symbol, direction):
|
||||||
if strategy.trend_signals.exists():
|
if strategy.trend_signals.exists():
|
||||||
|
if len(strategy.trend_signals.all()) > 0:
|
||||||
if strategy.trends is None:
|
if strategy.trends is None:
|
||||||
log.debug("Refusing to trade with no trend signals received")
|
log.debug("Refusing to trade with no trend signals received")
|
||||||
sendmsg(
|
sendmsg(
|
||||||
|
@ -74,6 +75,5 @@ def within_trends(strategy, symbol, direction):
|
||||||
else:
|
else:
|
||||||
log.debug(f"Trend check passed for {symbol} - {direction}")
|
log.debug(f"Trend check passed for {symbol} - {direction}")
|
||||||
return True
|
return True
|
||||||
else:
|
|
||||||
log.debug("No trend signals configured")
|
log.debug("No trend signals configured")
|
||||||
return True
|
return True
|
||||||
|
|
Loading…
Reference in New Issue