Implement stub functions for OANDA

This commit is contained in:
Mark Veidemanis 2022-10-30 11:21:48 +00:00
parent f6fa9bdbb6
commit f22fcfdaaa
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
5 changed files with 55 additions and 19 deletions

View File

@ -13,6 +13,9 @@ class BaseExchange(object):
def connect(self):
raise NotImplementedError
def get_account(self):
raise NotImplementedError
def get_supported_assets(self):
raise NotImplementedError
@ -39,4 +42,3 @@ class BaseExchange(object):
def get_all_positions(self):
raise NotImplementedError

View File

@ -1,18 +1,27 @@
from core.exchanges import BaseExchange
from alpaca.trading.client import TradingClient
from alpaca.trading.requests import GetAssetsRequest
from alpaca.common.exceptions import APIError
from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide, TimeInForce
from alpaca.trading.requests import LimitOrderRequest, MarketOrderRequest
from alpaca.trading.requests import (
GetAssetsRequest,
LimitOrderRequest,
MarketOrderRequest,
)
from core.exchanges import BaseExchange
class AlpacaExchange(BaseExchange):
def connect(self):
self.client = TradingClient(
self.account.api_key, self.account.api_secret, paper=self.account.sandbox, raw_data=True
self.account.api_key,
self.account.api_secret,
paper=self.account.sandbox,
raw_data=True,
)
def get_account(self):
return self.client.get_account()
def get_supported_assets(self):
try:
request = GetAssetsRequest(status="active", asset_class="crypto")
@ -20,7 +29,7 @@ class AlpacaExchange(BaseExchange):
asset_list = [x["symbol"] for x in assets if "symbol" in x]
print("Supported symbols", asset_list)
except APIError as e:
log.error(f"Could not get asset list: {e}")
self.log.error(f"Could not get asset list: {e}")
# return False
return asset_list
@ -55,7 +64,11 @@ class AlpacaExchange(BaseExchange):
else:
raise Exception("Unknown direction")
cast = {"symbol": trade.symbol, "side": direction, "time_in_force": TimeInForce.IOC}
cast = {
"symbol": trade.symbol,
"side": direction,
"time_in_force": TimeInForce.IOC,
}
if trade.amount is not None:
cast["qty"] = trade.amount
if trade.amount_usd is not None:
@ -75,7 +88,7 @@ class AlpacaExchange(BaseExchange):
try:
order = self.client.submit_order(order_data=market_order_data)
except APIError as e:
log.error(f"Error placing market order: {e}")
self.log.error(f"Error placing market order: {e}")
return (False, e)
elif trade.type == "limit":
if not trade.price:
@ -85,7 +98,7 @@ class AlpacaExchange(BaseExchange):
try:
order = self.client.submit_order(order_data=limit_order_data)
except APIError as e:
log.error(f"Error placing limit order: {e}")
self.log.error(f"Error placing limit order: {e}")
return (False, e)
else:
@ -122,4 +135,4 @@ class AlpacaExchange(BaseExchange):
item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item)
return items
return items

View File

@ -1,15 +1,21 @@
from core.exchanges import BaseExchange
from oandapyV20 import API
from oandapyV20.endpoints import trades
from oandapyV20.endpoints import positions
from oandapyV20.endpoints import accounts, orders, positions, trades
from core.exchanges import BaseExchange
class OANDAExchange(BaseExchange):
def connect(self):
self.client = API(access_token=self.account.api_secret)
self.account_id = self.account.api_key
def get_account(self):
r = accounts.AccountDetails(self.account_id)
self.client.request(r)
return r.response
def get_supported_assets(self):
raise NotImplementedError
return False
def get_balance(self):
raise NotImplementedError
@ -19,18 +25,32 @@ class OANDAExchange(BaseExchange):
def post_trade(self, trade):
raise NotImplementedError
r = orders.OrderCreate(accountID, data=data)
self.client.request(r)
return r.response
def get_trade(self, trade_id):
raise NotImplementedError
r = accounts.TradeDetails(accountID=self.account_id, tradeID=trade_id)
self.client.request(r)
return r.response
def update_trade(self, trade):
raise NotImplementedError
r = orders.OrderReplace(
accountID=self.account_id, orderID=trade.order_id, data=data
)
self.client.request(r)
return r.response
def cancel_trade(self, trade_id):
raise NotImplementedError
def get_position_info(self, asset_id):
raise NotImplementedError
r = positions.PositionDetails(self.account_id, asset_id)
self.client.request(r)
return r.response
def get_all_positions(self):
pass
r = positions.OpenPositions(accountID=self.account_id)
self.client.request(r)
return r.response["positions"]

View File

@ -39,7 +39,7 @@ class AccountInfo(LoginRequiredMixin, View):
}
return render(request, template_name, context)
live_info = dict(account.rawclient.get_account())
live_info = dict(account.client.get_account())
account_info = account.__dict__
account_info = {
k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL

View File

@ -18,6 +18,7 @@ def get_positions(user, account_id=None):
accounts = Account.objects.filter(user=user)
for account in accounts:
positions = account.client.get_all_positions()
print("positions", positions)
for item in positions:
items.append(item)