Implement adjusting positions and begin writing live tests for AMS

This commit is contained in:
Mark Veidemanis 2023-02-20 07:20:03 +00:00
parent a840be3834
commit 9e22abe057
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
5 changed files with 249 additions and 9 deletions

View File

@ -205,7 +205,7 @@ class BaseExchange(ABC):
pass pass
@abstractmethod @abstractmethod
def close_trade(self, trade_id): def close_trade(self, trade_id, units=None):
pass pass
@abstractmethod @abstractmethod

View File

@ -121,7 +121,7 @@ class AlpacaExchange(BaseExchange):
trade.save() trade.save()
return order return order
def close_trade(self, trade_id): # TODO def close_trade(self, trade_id, units=None): # TODO
""" """
Close a trade Close a trade
""" """

View File

@ -2,6 +2,9 @@ from oandapyV20 import API
from oandapyV20.endpoints import accounts, orders, positions, pricing, trades from oandapyV20.endpoints import accounts, orders, positions, pricing, trades
from core.exchanges import BaseExchange, common from core.exchanges import BaseExchange, common
from core.util import logs
log = logs.get_logger("oanda")
class OANDAExchange(BaseExchange): class OANDAExchange(BaseExchange):
@ -98,12 +101,44 @@ class OANDAExchange(BaseExchange):
trade.save() trade.save()
return response return response
def close_trade(self, trade_id): def get_trade_precision(self, symbol):
instruments = self.account.instruments
if not instruments:
log.error("No instruments found")
return None
# Extract the information for the symbol
instrument = self.extract_instrument(instruments, symbol)
if not instrument:
log.error(f"Symbol not found: {symbol}")
return None
# Get the required precision
try:
trade_precision = instrument["tradeUnitsPrecision"]
return trade_precision
except KeyError:
log.error(f"Precision not found for {symbol} from {instrument}")
return None
def close_trade(self, trade_id, units=None, symbol=None):
""" """
Close a trade. Close a trade.
""" """
if not units:
r = trades.TradeClose(accountID=self.account_id, tradeID=trade_id) r = trades.TradeClose(accountID=self.account_id, tradeID=trade_id)
return self.call(r) return self.call(r)
else:
trade_precision = self.get_trade_precision(symbol)
if trade_precision is None:
log.error(f"Unable to get trade precision for {symbol}")
return None
units = round(units, trade_precision)
data = {
"units": str(units),
}
r = trades.TradeClose(
accountID=self.account_id, tradeID=trade_id, data=data
)
return self.call(r)
def get_trade(self, trade_id): def get_trade(self, trade_id):
# OANDA is off by one... # OANDA is off by one...

View File

@ -1,15 +1,204 @@
from datetime import time
from decimal import Decimal as D from decimal import Decimal as D
from unittest.mock import patch from unittest.mock import patch
from django.test import TestCase from django.test import TestCase
from core.exchanges.convert import convert_trades from core.exchanges.convert import convert_trades
from core.models import RiskModel, Trade from core.models import (
from core.tests.helpers import ElasticMock, LiveBase ActiveManagementPolicy,
Hook,
RiskModel,
Signal,
Trade,
TradingTime,
)
from core.tests.helpers import ElasticMock, LiveBase, StrategyMixin
from core.trading import market, risk from core.trading import market, risk
from core.trading.active_management import ActiveManagement
class LiveTradingTestCase(ElasticMock, LiveBase, TestCase): class ActiveManagementMixinTestCase(StrategyMixin):
def setUp(self):
super().setUp()
self.trading_time_all = TradingTime.objects.create(
user=self.user,
name="All",
start_day=1,
start_time=time(0, 0, 0),
end_day=7,
end_time=time(23, 59, 59),
)
self.active_management_policy = ActiveManagementPolicy.objects.create(
user=self.user,
name="Test Policy",
when_trading_time_violated="close",
when_trends_violated="close",
when_position_size_violated="adjust",
when_protection_violated="adjust",
when_asset_groups_violated="close",
when_max_open_trades_violated="close",
when_max_open_trades_per_symbol_violated="close",
when_max_loss_violated="close",
when_max_risk_violated="close",
when_crossfilter_violated="close",
)
self.strategy.active_management_policy = self.active_management_policy
self.strategy.trading_times.set([self.trading_time_all])
self.strategy.save()
self.ams = ActiveManagement(self.strategy)
def test_ams_success(self):
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(complex_trade)
self.ams.run_checks()
self.assertEqual(self.ams.actions, {})
self.close_trade(complex_trade)
def test_ams_trading_time_violated(self):
# Forex market is closed on Saturday and Sunday.
# All of these tests will fail then anyway, so it's the only time
# we can use for this.
trading_time_weekend = TradingTime.objects.create(
user=self.user,
name="Weekend",
start_day=6,
start_time=time(0, 0, 0),
end_day=7,
end_time=time(23, 59, 59),
)
self.strategy.trading_times.set([trading_time_weekend])
self.strategy.save()
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(complex_trade)
self.ams.run_checks()
expected_actions = {
"close": [
{"id": complex_trade.order_id, "check": "trading_time", "extra": {}}
]
}
self.assertEqual(self.ams.actions, expected_actions)
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 1)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 0)
# Don't need to close the trade, it's already closed.
# Otherwise the test would fail.
def test_ams_trends_violated(self):
hook = Hook.objects.create(
user=self.user,
name="Test Hook",
)
signal = Signal.objects.create(
user=self.user,
name="Test Signal",
hook=hook,
type="trend",
)
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.strategy.save()
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
self.open_trade(complex_trade)
self.ams.run_checks()
expected_actions = {
"close": [{"id": complex_trade.order_id, "check": "trends", "extra": {}}]
}
self.assertEqual(self.ams.actions, expected_actions)
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 1)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 0)
# Don't need to close the trade, it's already closed.
# Otherwise the test would fail.
def test_ams_position_size_violated(self):
self.active_management_policy.when_position_size_violated = "close"
self.active_management_policy.save()
complex_trade = self.create_complex_trade("buy", 600, "EUR_USD", 1.5, 1.0)
self.open_trade(complex_trade)
self.ams.run_checks()
expected_actions = {
"close": [
{"id": complex_trade.order_id, "check": "position_size", "extra": {}}
]
}
self.assertEqual(len(self.ams.actions["close"]), 1)
del self.ams.actions["close"][0]["extra"]["size"] # We don't know the size
self.assertEqual(self.ams.actions, expected_actions)
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 1)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 0)
def test_ams_position_size_violated_adjust(self):
complex_trade = self.create_complex_trade("buy", 600, "EUR_USD", 1.5, 1.0)
self.open_trade(complex_trade)
self.ams.run_checks()
# Convert to int to make a whole number
expected = round(self.ams.actions["adjust"][0]["extra"]["size"], 0)
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 1)
self.assertEqual(trades[0]["currentUnits"], "600")
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 1)
self.assertEqual(trades[0]["currentUnits"], str(expected))
complex_trade.amount = expected
complex_trade.save()
self.close_trade(complex_trade)
def test_ams_protection_violated(self):
pass
def test_ams_asset_groups_violated(self):
pass
def test_ams_crossfilter_violated(self):
pass
def test_ams_max_open_trades_violated(self):
pass
def test_ams_max_open_trades_per_symbol_violated(self):
pass
def test_ams_max_loss_violated(self):
pass
def test_ams_max_risk_violated(self):
pass
class LiveTradingTestCase(
ElasticMock, ActiveManagementMixinTestCase, LiveBase, TestCase
):
def setUp(self): def setUp(self):
super(LiveTradingTestCase, self).setUp() super(LiveTradingTestCase, self).setUp()
self.trade = Trade.objects.create( self.trade = Trade.objects.create(

View File

@ -95,7 +95,23 @@ class ActiveManagement(object):
sendmsg(self.strategy.user, msg, title=f"AMS: {action}") sendmsg(self.strategy.user, msg, title=f"AMS: {action}")
def adjust_position_size(self, trade_id, new_size): def adjust_position_size(self, trade_id, new_size):
pass # TODO # Get old size
old_size = None
for trade in self.trades:
if trade["id"] == trade_id:
old_size = D(trade["currentUnits"])
symbol = trade["symbol"]
break
if old_size is None:
log.error(f"Could not find trade ID {trade_id} in active management")
return
# Reduce only
assert old_size > new_size
difference = old_size - new_size
# Close the difference
self.strategy.account.client.close_trade(trade_id, difference, symbol)
def adjust_protection(self, trade_id, new_protection): def adjust_protection(self, trade_id, new_protection):
pass # TODO pass # TODO