Implement adjusting positions and begin writing live tests for AMS
This commit is contained in:
parent
a840be3834
commit
9e22abe057
|
@ -205,7 +205,7 @@ class BaseExchange(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def close_trade(self, trade_id):
|
def close_trade(self, trade_id, units=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
@ -121,7 +121,7 @@ class AlpacaExchange(BaseExchange):
|
||||||
trade.save()
|
trade.save()
|
||||||
return order
|
return order
|
||||||
|
|
||||||
def close_trade(self, trade_id): # TODO
|
def close_trade(self, trade_id, units=None): # TODO
|
||||||
"""
|
"""
|
||||||
Close a trade
|
Close a trade
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -2,6 +2,9 @@ from oandapyV20 import API
|
||||||
from oandapyV20.endpoints import accounts, orders, positions, pricing, trades
|
from oandapyV20.endpoints import accounts, orders, positions, pricing, trades
|
||||||
|
|
||||||
from core.exchanges import BaseExchange, common
|
from core.exchanges import BaseExchange, common
|
||||||
|
from core.util import logs
|
||||||
|
|
||||||
|
log = logs.get_logger("oanda")
|
||||||
|
|
||||||
|
|
||||||
class OANDAExchange(BaseExchange):
|
class OANDAExchange(BaseExchange):
|
||||||
|
@ -98,12 +101,44 @@ class OANDAExchange(BaseExchange):
|
||||||
trade.save()
|
trade.save()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def close_trade(self, trade_id):
|
def get_trade_precision(self, symbol):
|
||||||
|
instruments = self.account.instruments
|
||||||
|
if not instruments:
|
||||||
|
log.error("No instruments found")
|
||||||
|
return None
|
||||||
|
# Extract the information for the symbol
|
||||||
|
instrument = self.extract_instrument(instruments, symbol)
|
||||||
|
if not instrument:
|
||||||
|
log.error(f"Symbol not found: {symbol}")
|
||||||
|
return None
|
||||||
|
# Get the required precision
|
||||||
|
try:
|
||||||
|
trade_precision = instrument["tradeUnitsPrecision"]
|
||||||
|
return trade_precision
|
||||||
|
except KeyError:
|
||||||
|
log.error(f"Precision not found for {symbol} from {instrument}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def close_trade(self, trade_id, units=None, symbol=None):
|
||||||
"""
|
"""
|
||||||
Close a trade.
|
Close a trade.
|
||||||
"""
|
"""
|
||||||
r = trades.TradeClose(accountID=self.account_id, tradeID=trade_id)
|
if not units:
|
||||||
return self.call(r)
|
r = trades.TradeClose(accountID=self.account_id, tradeID=trade_id)
|
||||||
|
return self.call(r)
|
||||||
|
else:
|
||||||
|
trade_precision = self.get_trade_precision(symbol)
|
||||||
|
if trade_precision is None:
|
||||||
|
log.error(f"Unable to get trade precision for {symbol}")
|
||||||
|
return None
|
||||||
|
units = round(units, trade_precision)
|
||||||
|
data = {
|
||||||
|
"units": str(units),
|
||||||
|
}
|
||||||
|
r = trades.TradeClose(
|
||||||
|
accountID=self.account_id, tradeID=trade_id, data=data
|
||||||
|
)
|
||||||
|
return self.call(r)
|
||||||
|
|
||||||
def get_trade(self, trade_id):
|
def get_trade(self, trade_id):
|
||||||
# OANDA is off by one...
|
# OANDA is off by one...
|
||||||
|
|
|
@ -1,15 +1,204 @@
|
||||||
|
from datetime import time
|
||||||
from decimal import Decimal as D
|
from decimal import Decimal as D
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from core.exchanges.convert import convert_trades
|
from core.exchanges.convert import convert_trades
|
||||||
from core.models import RiskModel, Trade
|
from core.models import (
|
||||||
from core.tests.helpers import ElasticMock, LiveBase
|
ActiveManagementPolicy,
|
||||||
|
Hook,
|
||||||
|
RiskModel,
|
||||||
|
Signal,
|
||||||
|
Trade,
|
||||||
|
TradingTime,
|
||||||
|
)
|
||||||
|
from core.tests.helpers import ElasticMock, LiveBase, StrategyMixin
|
||||||
from core.trading import market, risk
|
from core.trading import market, risk
|
||||||
|
from core.trading.active_management import ActiveManagement
|
||||||
|
|
||||||
|
|
||||||
class LiveTradingTestCase(ElasticMock, LiveBase, TestCase):
|
class ActiveManagementMixinTestCase(StrategyMixin):
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.trading_time_all = TradingTime.objects.create(
|
||||||
|
user=self.user,
|
||||||
|
name="All",
|
||||||
|
start_day=1,
|
||||||
|
start_time=time(0, 0, 0),
|
||||||
|
end_day=7,
|
||||||
|
end_time=time(23, 59, 59),
|
||||||
|
)
|
||||||
|
self.active_management_policy = ActiveManagementPolicy.objects.create(
|
||||||
|
user=self.user,
|
||||||
|
name="Test Policy",
|
||||||
|
when_trading_time_violated="close",
|
||||||
|
when_trends_violated="close",
|
||||||
|
when_position_size_violated="adjust",
|
||||||
|
when_protection_violated="adjust",
|
||||||
|
when_asset_groups_violated="close",
|
||||||
|
when_max_open_trades_violated="close",
|
||||||
|
when_max_open_trades_per_symbol_violated="close",
|
||||||
|
when_max_loss_violated="close",
|
||||||
|
when_max_risk_violated="close",
|
||||||
|
when_crossfilter_violated="close",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.strategy.active_management_policy = self.active_management_policy
|
||||||
|
self.strategy.trading_times.set([self.trading_time_all])
|
||||||
|
self.strategy.save()
|
||||||
|
self.ams = ActiveManagement(self.strategy)
|
||||||
|
|
||||||
|
def test_ams_success(self):
|
||||||
|
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(complex_trade)
|
||||||
|
|
||||||
|
self.ams.run_checks()
|
||||||
|
self.assertEqual(self.ams.actions, {})
|
||||||
|
|
||||||
|
self.close_trade(complex_trade)
|
||||||
|
|
||||||
|
def test_ams_trading_time_violated(self):
|
||||||
|
# Forex market is closed on Saturday and Sunday.
|
||||||
|
# All of these tests will fail then anyway, so it's the only time
|
||||||
|
# we can use for this.
|
||||||
|
trading_time_weekend = TradingTime.objects.create(
|
||||||
|
user=self.user,
|
||||||
|
name="Weekend",
|
||||||
|
start_day=6,
|
||||||
|
start_time=time(0, 0, 0),
|
||||||
|
end_day=7,
|
||||||
|
end_time=time(23, 59, 59),
|
||||||
|
)
|
||||||
|
self.strategy.trading_times.set([trading_time_weekend])
|
||||||
|
self.strategy.save()
|
||||||
|
|
||||||
|
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(complex_trade)
|
||||||
|
|
||||||
|
self.ams.run_checks()
|
||||||
|
expected_actions = {
|
||||||
|
"close": [
|
||||||
|
{"id": complex_trade.order_id, "check": "trading_time", "extra": {}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
self.assertEqual(self.ams.actions, expected_actions)
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 1)
|
||||||
|
|
||||||
|
self.ams.execute_actions()
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 0)
|
||||||
|
|
||||||
|
# Don't need to close the trade, it's already closed.
|
||||||
|
# Otherwise the test would fail.
|
||||||
|
|
||||||
|
def test_ams_trends_violated(self):
|
||||||
|
hook = Hook.objects.create(
|
||||||
|
user=self.user,
|
||||||
|
name="Test Hook",
|
||||||
|
)
|
||||||
|
signal = Signal.objects.create(
|
||||||
|
user=self.user,
|
||||||
|
name="Test Signal",
|
||||||
|
hook=hook,
|
||||||
|
type="trend",
|
||||||
|
)
|
||||||
|
self.strategy.trend_signals.set([signal])
|
||||||
|
self.strategy.trends = {"EUR_USD": "sell"}
|
||||||
|
self.strategy.save()
|
||||||
|
|
||||||
|
complex_trade = self.create_complex_trade("buy", 10, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(complex_trade)
|
||||||
|
|
||||||
|
self.ams.run_checks()
|
||||||
|
expected_actions = {
|
||||||
|
"close": [{"id": complex_trade.order_id, "check": "trends", "extra": {}}]
|
||||||
|
}
|
||||||
|
self.assertEqual(self.ams.actions, expected_actions)
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 1)
|
||||||
|
|
||||||
|
self.ams.execute_actions()
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 0)
|
||||||
|
|
||||||
|
# Don't need to close the trade, it's already closed.
|
||||||
|
# Otherwise the test would fail.
|
||||||
|
|
||||||
|
def test_ams_position_size_violated(self):
|
||||||
|
self.active_management_policy.when_position_size_violated = "close"
|
||||||
|
self.active_management_policy.save()
|
||||||
|
|
||||||
|
complex_trade = self.create_complex_trade("buy", 600, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(complex_trade)
|
||||||
|
|
||||||
|
self.ams.run_checks()
|
||||||
|
expected_actions = {
|
||||||
|
"close": [
|
||||||
|
{"id": complex_trade.order_id, "check": "position_size", "extra": {}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
self.assertEqual(len(self.ams.actions["close"]), 1)
|
||||||
|
del self.ams.actions["close"][0]["extra"]["size"] # We don't know the size
|
||||||
|
self.assertEqual(self.ams.actions, expected_actions)
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 1)
|
||||||
|
|
||||||
|
self.ams.execute_actions()
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 0)
|
||||||
|
|
||||||
|
def test_ams_position_size_violated_adjust(self):
|
||||||
|
complex_trade = self.create_complex_trade("buy", 600, "EUR_USD", 1.5, 1.0)
|
||||||
|
self.open_trade(complex_trade)
|
||||||
|
|
||||||
|
self.ams.run_checks()
|
||||||
|
# Convert to int to make a whole number
|
||||||
|
expected = round(self.ams.actions["adjust"][0]["extra"]["size"], 0)
|
||||||
|
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 1)
|
||||||
|
self.assertEqual(trades[0]["currentUnits"], "600")
|
||||||
|
|
||||||
|
self.ams.execute_actions()
|
||||||
|
trades = self.account.client.get_all_open_trades()
|
||||||
|
self.assertEqual(len(trades), 1)
|
||||||
|
self.assertEqual(trades[0]["currentUnits"], str(expected))
|
||||||
|
|
||||||
|
complex_trade.amount = expected
|
||||||
|
complex_trade.save()
|
||||||
|
|
||||||
|
self.close_trade(complex_trade)
|
||||||
|
|
||||||
|
def test_ams_protection_violated(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_ams_asset_groups_violated(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_ams_crossfilter_violated(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_ams_max_open_trades_violated(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_ams_max_open_trades_per_symbol_violated(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_ams_max_loss_violated(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_ams_max_risk_violated(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LiveTradingTestCase(
|
||||||
|
ElasticMock, ActiveManagementMixinTestCase, LiveBase, TestCase
|
||||||
|
):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(LiveTradingTestCase, self).setUp()
|
super(LiveTradingTestCase, self).setUp()
|
||||||
self.trade = Trade.objects.create(
|
self.trade = Trade.objects.create(
|
||||||
|
|
|
@ -95,7 +95,23 @@ class ActiveManagement(object):
|
||||||
sendmsg(self.strategy.user, msg, title=f"AMS: {action}")
|
sendmsg(self.strategy.user, msg, title=f"AMS: {action}")
|
||||||
|
|
||||||
def adjust_position_size(self, trade_id, new_size):
|
def adjust_position_size(self, trade_id, new_size):
|
||||||
pass # TODO
|
# Get old size
|
||||||
|
old_size = None
|
||||||
|
for trade in self.trades:
|
||||||
|
if trade["id"] == trade_id:
|
||||||
|
old_size = D(trade["currentUnits"])
|
||||||
|
symbol = trade["symbol"]
|
||||||
|
break
|
||||||
|
if old_size is None:
|
||||||
|
log.error(f"Could not find trade ID {trade_id} in active management")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Reduce only
|
||||||
|
assert old_size > new_size
|
||||||
|
difference = old_size - new_size
|
||||||
|
|
||||||
|
# Close the difference
|
||||||
|
self.strategy.account.client.close_trade(trade_id, difference, symbol)
|
||||||
|
|
||||||
def adjust_protection(self, trade_id, new_protection):
|
def adjust_protection(self, trade_id, new_protection):
|
||||||
pass # TODO
|
pass # TODO
|
||||||
|
|
Loading…
Reference in New Issue