Implement stub functions for OANDA

Mark Veidemanis 2 years ago
parent f6fa9bdbb6
commit f22fcfdaaa
Signed by: m
GPG Key ID: 5ACFCEED46C0904F

@ -13,6 +13,9 @@ class BaseExchange(object):
def connect(self): def connect(self):
raise NotImplementedError raise NotImplementedError
def get_account(self):
raise NotImplementedError
def get_supported_assets(self): def get_supported_assets(self):
raise NotImplementedError raise NotImplementedError
@ -39,4 +42,3 @@ class BaseExchange(object):
def get_all_positions(self): def get_all_positions(self):
raise NotImplementedError raise NotImplementedError

@ -1,18 +1,27 @@
from core.exchanges import BaseExchange
from alpaca.trading.client import TradingClient
from alpaca.trading.requests import GetAssetsRequest
from alpaca.common.exceptions import APIError from alpaca.common.exceptions import APIError
from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide, TimeInForce from alpaca.trading.enums import OrderSide, TimeInForce
from alpaca.trading.requests import LimitOrderRequest, MarketOrderRequest from alpaca.trading.requests import (
GetAssetsRequest,
LimitOrderRequest,
MarketOrderRequest,
)
from core.exchanges import BaseExchange
class AlpacaExchange(BaseExchange): class AlpacaExchange(BaseExchange):
def connect(self): def connect(self):
self.client = TradingClient( self.client = TradingClient(
self.account.api_key, self.account.api_secret, paper=self.account.sandbox, raw_data=True self.account.api_key,
self.account.api_secret,
paper=self.account.sandbox,
raw_data=True,
) )
def get_account(self):
return self.client.get_account()
def get_supported_assets(self): def get_supported_assets(self):
try: try:
request = GetAssetsRequest(status="active", asset_class="crypto") request = GetAssetsRequest(status="active", asset_class="crypto")
@ -20,7 +29,7 @@ class AlpacaExchange(BaseExchange):
asset_list = [x["symbol"] for x in assets if "symbol" in x] asset_list = [x["symbol"] for x in assets if "symbol" in x]
print("Supported symbols", asset_list) print("Supported symbols", asset_list)
except APIError as e: except APIError as e:
log.error(f"Could not get asset list: {e}") self.log.error(f"Could not get asset list: {e}")
# return False # return False
return asset_list return asset_list
@ -55,7 +64,11 @@ class AlpacaExchange(BaseExchange):
else: else:
raise Exception("Unknown direction") raise Exception("Unknown direction")
cast = {"symbol": trade.symbol, "side": direction, "time_in_force": TimeInForce.IOC} cast = {
"symbol": trade.symbol,
"side": direction,
"time_in_force": TimeInForce.IOC,
}
if trade.amount is not None: if trade.amount is not None:
cast["qty"] = trade.amount cast["qty"] = trade.amount
if trade.amount_usd is not None: if trade.amount_usd is not None:
@ -75,7 +88,7 @@ class AlpacaExchange(BaseExchange):
try: try:
order = self.client.submit_order(order_data=market_order_data) order = self.client.submit_order(order_data=market_order_data)
except APIError as e: except APIError as e:
log.error(f"Error placing market order: {e}") self.log.error(f"Error placing market order: {e}")
return (False, e) return (False, e)
elif trade.type == "limit": elif trade.type == "limit":
if not trade.price: if not trade.price:
@ -85,7 +98,7 @@ class AlpacaExchange(BaseExchange):
try: try:
order = self.client.submit_order(order_data=limit_order_data) order = self.client.submit_order(order_data=limit_order_data)
except APIError as e: except APIError as e:
log.error(f"Error placing limit order: {e}") self.log.error(f"Error placing limit order: {e}")
return (False, e) return (False, e)
else: else:

@ -1,15 +1,21 @@
from core.exchanges import BaseExchange
from oandapyV20 import API from oandapyV20 import API
from oandapyV20.endpoints import trades from oandapyV20.endpoints import accounts, orders, positions, trades
from oandapyV20.endpoints import positions
from core.exchanges import BaseExchange
class OANDAExchange(BaseExchange): class OANDAExchange(BaseExchange):
def connect(self): def connect(self):
self.client = API(access_token=self.account.api_secret) self.client = API(access_token=self.account.api_secret)
self.account_id = self.account.api_key self.account_id = self.account.api_key
def get_account(self):
r = accounts.AccountDetails(self.account_id)
self.client.request(r)
return r.response
def get_supported_assets(self): def get_supported_assets(self):
raise NotImplementedError return False
def get_balance(self): def get_balance(self):
raise NotImplementedError raise NotImplementedError
@ -19,18 +25,32 @@ class OANDAExchange(BaseExchange):
def post_trade(self, trade): def post_trade(self, trade):
raise NotImplementedError raise NotImplementedError
r = orders.OrderCreate(accountID, data=data)
self.client.request(r)
return r.response
def get_trade(self, trade_id): def get_trade(self, trade_id):
raise NotImplementedError r = accounts.TradeDetails(accountID=self.account_id, tradeID=trade_id)
self.client.request(r)
return r.response
def update_trade(self, trade): def update_trade(self, trade):
raise NotImplementedError raise NotImplementedError
r = orders.OrderReplace(
accountID=self.account_id, orderID=trade.order_id, data=data
)
self.client.request(r)
return r.response
def cancel_trade(self, trade_id): def cancel_trade(self, trade_id):
raise NotImplementedError raise NotImplementedError
def get_position_info(self, asset_id): def get_position_info(self, asset_id):
raise NotImplementedError r = positions.PositionDetails(self.account_id, asset_id)
self.client.request(r)
return r.response
def get_all_positions(self): def get_all_positions(self):
pass r = positions.OpenPositions(accountID=self.account_id)
self.client.request(r)
return r.response["positions"]

@ -39,7 +39,7 @@ class AccountInfo(LoginRequiredMixin, View):
} }
return render(request, template_name, context) return render(request, template_name, context)
live_info = dict(account.rawclient.get_account()) live_info = dict(account.client.get_account())
account_info = account.__dict__ account_info = account.__dict__
account_info = { account_info = {
k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL

@ -18,6 +18,7 @@ def get_positions(user, account_id=None):
accounts = Account.objects.filter(user=user) accounts = Account.objects.filter(user=user)
for account in accounts: for account in accounts:
positions = account.client.get_all_positions() positions = account.client.get_all_positions()
print("positions", positions)
for item in positions: for item in positions:
items.append(item) items.append(item)

Loading…
Cancel
Save