Compare commits

...

4 Commits

5 changed files with 202 additions and 82 deletions

View File

@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from alpaca.common.exceptions import APIError from alpaca.common.exceptions import APIError
from glom import glom from glom import glom
from oandapyV20.exceptions import V20Error from oandapyV20.exceptions import V20Error
from pydantic.error_wrappers import ValidationError
from core.lib import schemas from core.lib import schemas
from core.util import logs from core.util import logs
@ -131,7 +132,12 @@ class BaseExchange(ABC):
def validate_response(self, response, method): def validate_response(self, response, method):
schema = self.get_schema(method) schema = self.get_schema(method)
# Return a dict of the validated response # Return a dict of the validated response
response_valid = schema(**response).dict() try:
response_valid = schema(**response).dict()
except ValidationError as e:
log.error(f"Error validating {method} response: {response}")
log.error(f"Errors: {e}")
raise GenericAPIError("Error validating response")
return response_valid return response_valid
def call(self, method, *args, **kwargs): def call(self, method, *args, **kwargs):
@ -198,6 +204,10 @@ class BaseExchange(ABC):
def post_trade(self, trade): def post_trade(self, trade):
pass pass
@abstractmethod
def close_trade(self, trade_id):
pass
@abstractmethod @abstractmethod
def get_trade(self, trade_id): def get_trade(self, trade_id):
pass pass

View File

@ -121,6 +121,11 @@ class AlpacaExchange(BaseExchange):
trade.save() trade.save()
return order return order
def close_trade(self, trade_id): # TODO
"""
Close a trade
"""
def get_trade(self, trade_id): def get_trade(self, trade_id):
pass # TODO pass # TODO

View File

@ -94,6 +94,13 @@ class OANDAExchange(BaseExchange):
trade.save() trade.save()
return response return response
def close_trade(self, trade_id):
"""
Close a trade.
"""
r = trades.TradeClose(accountID=self.account_id, tradeID=trade_id)
return self.call(r)
def get_trade(self, trade_id): def get_trade(self, trade_id):
# OANDA is off by one... # OANDA is off by one...
r = trades.TradeDetails(accountID=self.account_id, tradeID=trade_id) r = trades.TradeDetails(accountID=self.account_id, tradeID=trade_id)

View File

@ -386,41 +386,6 @@ AccountInstrumentsSchema = {
} }
class OrderTransaction(BaseModel):
id: str
accountID: str
userID: int
batchID: str
requestID: str
time: str
type: str
instrument: str
units: str
timeInForce: str
positionFill: str
reason: str
class OrderCreate(BaseModel):
orderCreateTransaction: OrderTransaction
OrderCreateSchema = {
"id": "orderCreateTransaction.id",
"accountID": "orderCreateTransaction.accountID",
"userID": "orderCreateTransaction.userID",
"batchID": "orderCreateTransaction.batchID",
"requestID": "orderCreateTransaction.requestID",
"time": "orderCreateTransaction.time",
"type": "orderCreateTransaction.type",
"symbol": "orderCreateTransaction.instrument",
"units": "orderCreateTransaction.units",
"timeInForce": "orderCreateTransaction.timeInForce",
"positionFill": "orderCreateTransaction.positionFill",
"reason": "orderCreateTransaction.reason",
}
class PriceBid(BaseModel): class PriceBid(BaseModel):
price: str price: str
liquidity: int liquidity: int
@ -476,11 +441,6 @@ PricingInfoSchema = {
} }
class LongPositionCloseout(BaseModel):
instrument: str
units: str
class Trade(BaseModel): class Trade(BaseModel):
tradeID: str tradeID: str
clientTradeID: str clientTradeID: str
@ -498,30 +458,52 @@ class Trade(BaseModel):
trailingStopLossOrder: dict | None trailingStopLossOrder: dict | None
{ class SideCarOrder(BaseModel):
"trades": [ id: str
{ createTime: str
"id": "14480", state: str
"instrument": "EUR_USD", price: str | None
"price": "1.06345", timeInForce: str
"openTime": "2022-12-22T08:57:10.459593310Z", gtdTime: str | None
"initialUnits": "100", clientExtensions: dict | None
"initialMarginRequired": "2.9226", tradeID: str
"state": "OPEN", clientTradeID: str | None
"currentUnits": "100", type: str
"realizedPL": "0.0000", time: str | None
"financing": "0.0000", priceBound: str | None
"dividendAdjustment": "0.0000", positionFill: str | None
"unrealizedPL": "-0.0158", reason: str | None
"marginUsed": "2.9228", orderFillTransactionID: str | None
} tradeOpenedID: str | None
], tradeReducedID: str | None
"lastTransactionID": "14480", tradeClosedIDs: list[str] | None
} cancellingTransactionID: str | None
replacesOrderID: str | None
replacedByOrderID: str | None
class OpenTradesTrade(BaseModel):
id: str
instrument: str
price: str
openTime: str
initialUnits: str
initialMarginRequired: str
state: str
currentUnits: str
realizedPL: str
financing: str
dividendAdjustment: str
unrealizedPL: str
marginUsed: str
takeProfitOrder: SideCarOrder | None
stopLossOrder: SideCarOrder | None
trailingStopLossOrder: SideCarOrder | None
trailingStopValue: dict | None
class OpenTrades(BaseModel): class OpenTrades(BaseModel):
trades: list[Trade] trades: list[OpenTradesTrade]
lastTransactionID: str lastTransactionID: str
@ -530,8 +512,8 @@ OpenTradesSchema = {
"trades", "trades",
[ [
{ {
"id": "tradeID", "id": "id",
"instrument": "instrument", "symbol": "instrument",
"price": "price", "price": "price",
"openTime": "openTime", "openTime": "openTime",
"initialUnits": "initialUnits", "initialUnits": "initialUnits",
@ -546,6 +528,7 @@ OpenTradesSchema = {
"takeProfitOrder": "takeProfitOrder", "takeProfitOrder": "takeProfitOrder",
"stopLossOrder": "stopLossOrder", "stopLossOrder": "stopLossOrder",
"trailingStopLossOrder": "trailingStopLossOrder", "trailingStopLossOrder": "trailingStopLossOrder",
"trailingStopValue": "trailingStopValue",
} }
], ],
), ),
@ -560,6 +543,48 @@ class HomeConversionFactors(BaseModel):
lossBaseHome: str lossBaseHome: str
class LongPositionCloseout(BaseModel):
instrument: str
units: str
class OrderTransaction(BaseModel):
id: str
accountID: str
userID: int
batchID: str
requestID: str
time: str
type: str
instrument: str | None
units: str | None
timeInForce: str | None
positionFill: str | None
reason: str
longPositionCloseout: LongPositionCloseout | None
longOrderFillTransaction: dict | None
class OrderCreate(BaseModel):
orderCreateTransaction: OrderTransaction
OrderCreateSchema = {
"id": "orderCreateTransaction.id",
"accountID": "orderCreateTransaction.accountID",
"userID": "orderCreateTransaction.userID",
"batchID": "orderCreateTransaction.batchID",
"requestID": "orderCreateTransaction.requestID",
"time": "orderCreateTransaction.time",
"type": "orderCreateTransaction.type",
"symbol": "orderCreateTransaction.instrument",
"units": "orderCreateTransaction.units",
"timeInForce": "orderCreateTransaction.timeInForce",
"positionFill": "orderCreateTransaction.positionFill",
"reason": "orderCreateTransaction.reason",
}
class LongOrderFillTransaction(BaseModel): class LongOrderFillTransaction(BaseModel):
id: str id: str
accountID: str accountID: str
@ -592,23 +617,6 @@ class LongOrderFillTransaction(BaseModel):
longPositionCloseout: LongPositionCloseout longPositionCloseout: LongPositionCloseout
class OrderTransaction(BaseModel):
id: str
accountID: str
userID: int
batchID: str
requestID: str
time: str
type: str
instrument: str | None
units: str | None
timeInForce: str | None
positionFill: str | None
reason: str
longPositionCloseout: LongPositionCloseout | None
longOrderFillTransaction: dict | None
class PositionClose(BaseModel): class PositionClose(BaseModel):
longOrderCreateTransaction: OrderTransaction | None longOrderCreateTransaction: OrderTransaction | None
longOrderFillTransaction: OrderTransaction | None longOrderFillTransaction: OrderTransaction | None
@ -678,3 +686,25 @@ TradeDetailsSchema = {
"clientExtensions": "trade.clientExtensions", "clientExtensions": "trade.clientExtensions",
"lastTransactionID": "lastTransactionID", "lastTransactionID": "lastTransactionID",
} }
class TradeClose(BaseModel):
orderCreateTransaction: OrderTransaction
TradeCloseSchema = {
"id": "orderCreateTransaction.id",
"accountID": "orderCreateTransaction.accountID",
"userID": "orderCreateTransaction.userID",
"batchID": "orderCreateTransaction.batchID",
"requestID": "orderCreateTransaction.requestID",
"time": "orderCreateTransaction.time",
"type": "orderCreateTransaction.type",
"symbol": "orderCreateTransaction.instrument",
"units": "orderCreateTransaction.units",
"timeInForce": "orderCreateTransaction.timeInForce",
"positionFill": "orderCreateTransaction.positionFill",
"reason": "orderCreateTransaction.reason",
"longPositionCloseout": "orderCreateTransaction.longPositionCloseout",
"longOrderFillTransaction": "orderCreateTransaction.longOrderFillTransaction",
}

View File

@ -1,12 +1,80 @@
from django.test import TestCase from django.test import TestCase
from core.models import Trade
from core.tests.helpers import ElasticMock, LiveBase from core.tests.helpers import ElasticMock, LiveBase
class LiveTradingTestCase(ElasticMock, LiveBase, TestCase): class LiveTradingTestCase(ElasticMock, LiveBase, TestCase):
def setUp(self):
super(LiveTradingTestCase, self).setUp()
self.trade = Trade.objects.create(
user=self.user,
account=self.account,
symbol="EUR_USD",
time_in_force="FOK",
type="market",
amount=10,
direction="buy",
)
def test_account_functional(self): def test_account_functional(self):
""" """
Test that the account is functional. Test that the account is functional.
""" """
balance = self.account.client.get_balance() balance = self.account.client.get_balance()
self.assertTrue(balance > 0) # We need some money to place trades
self.assertTrue(balance > 1000)
def open_trade(self):
posted = self.trade.post()
# Check the opened trade
self.assertEqual(posted["type"], "MARKET_ORDER")
self.assertEqual(posted["symbol"], "EUR_USD")
self.assertEqual(posted["units"], "10")
self.assertEqual(posted["timeInForce"], "FOK")
return posted
def close_trade(self):
# refresh the trade to get the trade id
self.trade.refresh_from_db()
closed = self.account.client.close_trade(self.trade.order_id)
# Check the feedback from closing the trade
self.assertEqual(closed["type"], "MARKET_ORDER")
self.assertEqual(closed["symbol"], "EUR_USD")
self.assertEqual(closed["units"], "-10")
self.assertEqual(closed["timeInForce"], "FOK")
self.assertEqual(closed["reason"], "TRADE_CLOSE")
return closed
def test_place_close_trade(self):
"""
Test placing a trade.
"""
self.open_trade()
self.close_trade()
def test_get_all_open_trades(self):
"""
Test getting all open trades.
"""
self.open_trade()
trades = self.account.client.get_all_open_trades()
self.trade.refresh_from_db()
found = False
for trade in trades["itemlist"]:
if trade["id"] == self.trade.order_id:
self.assertEqual(trade["symbol"], "EUR_USD")
self.assertEqual(trade["currentUnits"], "10")
self.assertEqual(trade["initialUnits"], "10")
self.assertEqual(trade["state"], "OPEN")
found = True
break
self.close_trade()
if not found:
self.fail("Could not find the trade in the list of open trades")