Implement adjusting positions and begin writing live tests for AMS

This commit is contained in:
2023-02-20 07:20:03 +00:00
parent a840be3834
commit 9e22abe057
5 changed files with 249 additions and 9 deletions

View File

@@ -205,7 +205,7 @@ class BaseExchange(ABC):
pass
@abstractmethod
def close_trade(self, trade_id):
def close_trade(self, trade_id, units=None):
pass
@abstractmethod

View File

@@ -121,7 +121,7 @@ class AlpacaExchange(BaseExchange):
trade.save()
return order
def close_trade(self, trade_id): # TODO
def close_trade(self, trade_id, units=None): # TODO
"""
Close a trade
"""

View File

@@ -2,6 +2,9 @@ from oandapyV20 import API
from oandapyV20.endpoints import accounts, orders, positions, pricing, trades
from core.exchanges import BaseExchange, common
from core.util import logs
log = logs.get_logger("oanda")
class OANDAExchange(BaseExchange):
@@ -98,12 +101,44 @@ class OANDAExchange(BaseExchange):
trade.save()
return response
def close_trade(self, trade_id):
def get_trade_precision(self, symbol):
instruments = self.account.instruments
if not instruments:
log.error("No instruments found")
return None
# Extract the information for the symbol
instrument = self.extract_instrument(instruments, symbol)
if not instrument:
log.error(f"Symbol not found: {symbol}")
return None
# Get the required precision
try:
trade_precision = instrument["tradeUnitsPrecision"]
return trade_precision
except KeyError:
log.error(f"Precision not found for {symbol} from {instrument}")
return None
def close_trade(self, trade_id, units=None, symbol=None):
"""
Close a trade.
"""
r = trades.TradeClose(accountID=self.account_id, tradeID=trade_id)
return self.call(r)
if not units:
r = trades.TradeClose(accountID=self.account_id, tradeID=trade_id)
return self.call(r)
else:
trade_precision = self.get_trade_precision(symbol)
if trade_precision is None:
log.error(f"Unable to get trade precision for {symbol}")
return None
units = round(units, trade_precision)
data = {
"units": str(units),
}
r = trades.TradeClose(
accountID=self.account_id, tradeID=trade_id, data=data
)
return self.call(r)
def get_trade(self, trade_id):
# OANDA is off by one...