Implement stub functions for OANDA

Mark Veidemanis 2 years ago
parent f6fa9bdbb6
commit f22fcfdaaa
Signed by: m
GPG Key ID: 5ACFCEED46C0904F

@ -13,6 +13,9 @@ class BaseExchange(object):
def connect(self):
raise NotImplementedError
def get_account(self):
raise NotImplementedError
def get_supported_assets(self):
raise NotImplementedError
@ -39,4 +42,3 @@ class BaseExchange(object):
def get_all_positions(self):
raise NotImplementedError

@ -1,18 +1,27 @@
from core.exchanges import BaseExchange
from alpaca.trading.client import TradingClient
from alpaca.trading.requests import GetAssetsRequest
from alpaca.common.exceptions import APIError
from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide, TimeInForce
from alpaca.trading.requests import LimitOrderRequest, MarketOrderRequest
from alpaca.trading.requests import (
GetAssetsRequest,
LimitOrderRequest,
MarketOrderRequest,
)
from core.exchanges import BaseExchange
class AlpacaExchange(BaseExchange):
def connect(self):
self.client = TradingClient(
self.account.api_key, self.account.api_secret, paper=self.account.sandbox, raw_data=True
self.account.api_key,
self.account.api_secret,
paper=self.account.sandbox,
raw_data=True,
)
def get_account(self):
return self.client.get_account()
def get_supported_assets(self):
try:
request = GetAssetsRequest(status="active", asset_class="crypto")
@ -20,7 +29,7 @@ class AlpacaExchange(BaseExchange):
asset_list = [x["symbol"] for x in assets if "symbol" in x]
print("Supported symbols", asset_list)
except APIError as e:
log.error(f"Could not get asset list: {e}")
self.log.error(f"Could not get asset list: {e}")
# return False
return asset_list
@ -55,7 +64,11 @@ class AlpacaExchange(BaseExchange):
else:
raise Exception("Unknown direction")
cast = {"symbol": trade.symbol, "side": direction, "time_in_force": TimeInForce.IOC}
cast = {
"symbol": trade.symbol,
"side": direction,
"time_in_force": TimeInForce.IOC,
}
if trade.amount is not None:
cast["qty"] = trade.amount
if trade.amount_usd is not None:
@ -75,7 +88,7 @@ class AlpacaExchange(BaseExchange):
try:
order = self.client.submit_order(order_data=market_order_data)
except APIError as e:
log.error(f"Error placing market order: {e}")
self.log.error(f"Error placing market order: {e}")
return (False, e)
elif trade.type == "limit":
if not trade.price:
@ -85,7 +98,7 @@ class AlpacaExchange(BaseExchange):
try:
order = self.client.submit_order(order_data=limit_order_data)
except APIError as e:
log.error(f"Error placing limit order: {e}")
self.log.error(f"Error placing limit order: {e}")
return (False, e)
else:
@ -122,4 +135,4 @@ class AlpacaExchange(BaseExchange):
item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item)
return items
return items

@ -1,15 +1,21 @@
from core.exchanges import BaseExchange
from oandapyV20 import API
from oandapyV20.endpoints import trades
from oandapyV20.endpoints import positions
from oandapyV20.endpoints import accounts, orders, positions, trades
from core.exchanges import BaseExchange
class OANDAExchange(BaseExchange):
def connect(self):
self.client = API(access_token=self.account.api_secret)
self.account_id = self.account.api_key
def get_account(self):
r = accounts.AccountDetails(self.account_id)
self.client.request(r)
return r.response
def get_supported_assets(self):
raise NotImplementedError
return False
def get_balance(self):
raise NotImplementedError
@ -19,18 +25,32 @@ class OANDAExchange(BaseExchange):
def post_trade(self, trade):
raise NotImplementedError
r = orders.OrderCreate(accountID, data=data)
self.client.request(r)
return r.response
def get_trade(self, trade_id):
raise NotImplementedError
r = accounts.TradeDetails(accountID=self.account_id, tradeID=trade_id)
self.client.request(r)
return r.response
def update_trade(self, trade):
raise NotImplementedError
r = orders.OrderReplace(
accountID=self.account_id, orderID=trade.order_id, data=data
)
self.client.request(r)
return r.response
def cancel_trade(self, trade_id):
raise NotImplementedError
def get_position_info(self, asset_id):
raise NotImplementedError
r = positions.PositionDetails(self.account_id, asset_id)
self.client.request(r)
return r.response
def get_all_positions(self):
pass
r = positions.OpenPositions(accountID=self.account_id)
self.client.request(r)
return r.response["positions"]

@ -39,7 +39,7 @@ class AccountInfo(LoginRequiredMixin, View):
}
return render(request, template_name, context)
live_info = dict(account.rawclient.get_account())
live_info = dict(account.client.get_account())
account_info = account.__dict__
account_info = {
k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL

@ -18,6 +18,7 @@ def get_positions(user, account_id=None):
accounts = Account.objects.filter(user=user)
for account in accounts:
positions = account.client.get_all_positions()
print("positions", positions)
for item in positions:
items.append(item)

Loading…
Cancel
Save