diff --git a/core/exchanges/__init__.py b/core/exchanges/__init__.py index c144ad1..93bb8dc 100644 --- a/core/exchanges/__init__.py +++ b/core/exchanges/__init__.py @@ -13,6 +13,9 @@ class BaseExchange(object): def connect(self): raise NotImplementedError + def get_account(self): + raise NotImplementedError + def get_supported_assets(self): raise NotImplementedError @@ -39,4 +42,3 @@ class BaseExchange(object): def get_all_positions(self): raise NotImplementedError - \ No newline at end of file diff --git a/core/exchanges/alpaca.py b/core/exchanges/alpaca.py index d9ff496..37cbbcd 100644 --- a/core/exchanges/alpaca.py +++ b/core/exchanges/alpaca.py @@ -1,18 +1,27 @@ -from core.exchanges import BaseExchange -from alpaca.trading.client import TradingClient -from alpaca.trading.requests import GetAssetsRequest from alpaca.common.exceptions import APIError +from alpaca.trading.client import TradingClient from alpaca.trading.enums import OrderSide, TimeInForce -from alpaca.trading.requests import LimitOrderRequest, MarketOrderRequest +from alpaca.trading.requests import ( + GetAssetsRequest, + LimitOrderRequest, + MarketOrderRequest, +) +from core.exchanges import BaseExchange class AlpacaExchange(BaseExchange): def connect(self): self.client = TradingClient( - self.account.api_key, self.account.api_secret, paper=self.account.sandbox, raw_data=True + self.account.api_key, + self.account.api_secret, + paper=self.account.sandbox, + raw_data=True, ) + def get_account(self): + return self.client.get_account() + def get_supported_assets(self): try: request = GetAssetsRequest(status="active", asset_class="crypto") @@ -20,7 +29,7 @@ class AlpacaExchange(BaseExchange): asset_list = [x["symbol"] for x in assets if "symbol" in x] print("Supported symbols", asset_list) except APIError as e: - log.error(f"Could not get asset list: {e}") + self.log.error(f"Could not get asset list: {e}") # return False return asset_list @@ -55,7 +64,11 @@ class AlpacaExchange(BaseExchange): else: raise Exception("Unknown direction") - cast = {"symbol": trade.symbol, "side": direction, "time_in_force": TimeInForce.IOC} + cast = { + "symbol": trade.symbol, + "side": direction, + "time_in_force": TimeInForce.IOC, + } if trade.amount is not None: cast["qty"] = trade.amount if trade.amount_usd is not None: @@ -75,7 +88,7 @@ class AlpacaExchange(BaseExchange): try: order = self.client.submit_order(order_data=market_order_data) except APIError as e: - log.error(f"Error placing market order: {e}") + self.log.error(f"Error placing market order: {e}") return (False, e) elif trade.type == "limit": if not trade.price: @@ -85,7 +98,7 @@ class AlpacaExchange(BaseExchange): try: order = self.client.submit_order(order_data=limit_order_data) except APIError as e: - log.error(f"Error placing limit order: {e}") + self.log.error(f"Error placing limit order: {e}") return (False, e) else: @@ -122,4 +135,4 @@ class AlpacaExchange(BaseExchange): item["account_id"] = self.account.id item["unrealized_pl"] = float(item["unrealized_pl"]) items.append(item) - return items \ No newline at end of file + return items diff --git a/core/exchanges/oanda.py b/core/exchanges/oanda.py index ce8bfa2..b8fce68 100644 --- a/core/exchanges/oanda.py +++ b/core/exchanges/oanda.py @@ -1,15 +1,21 @@ -from core.exchanges import BaseExchange from oandapyV20 import API -from oandapyV20.endpoints import trades -from oandapyV20.endpoints import positions +from oandapyV20.endpoints import accounts, orders, positions, trades + +from core.exchanges import BaseExchange + class OANDAExchange(BaseExchange): def connect(self): self.client = API(access_token=self.account.api_secret) self.account_id = self.account.api_key + def get_account(self): + r = accounts.AccountDetails(self.account_id) + self.client.request(r) + return r.response + def get_supported_assets(self): - raise NotImplementedError + return False def get_balance(self): raise NotImplementedError @@ -19,18 +25,32 @@ class OANDAExchange(BaseExchange): def post_trade(self, trade): raise NotImplementedError + r = orders.OrderCreate(accountID, data=data) + self.client.request(r) + return r.response def get_trade(self, trade_id): - raise NotImplementedError + r = accounts.TradeDetails(accountID=self.account_id, tradeID=trade_id) + self.client.request(r) + return r.response def update_trade(self, trade): raise NotImplementedError + r = orders.OrderReplace( + accountID=self.account_id, orderID=trade.order_id, data=data + ) + self.client.request(r) + return r.response def cancel_trade(self, trade_id): raise NotImplementedError def get_position_info(self, asset_id): - raise NotImplementedError + r = positions.PositionDetails(self.account_id, asset_id) + self.client.request(r) + return r.response def get_all_positions(self): - pass + r = positions.OpenPositions(accountID=self.account_id) + self.client.request(r) + return r.response["positions"] diff --git a/core/views/accounts.py b/core/views/accounts.py index 12b74c7..da795c6 100644 --- a/core/views/accounts.py +++ b/core/views/accounts.py @@ -39,7 +39,7 @@ class AccountInfo(LoginRequiredMixin, View): } return render(request, template_name, context) - live_info = dict(account.rawclient.get_account()) + live_info = dict(account.client.get_account()) account_info = account.__dict__ account_info = { k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL diff --git a/core/views/positions.py b/core/views/positions.py index a00576c..4a04242 100644 --- a/core/views/positions.py +++ b/core/views/positions.py @@ -18,6 +18,7 @@ def get_positions(user, account_id=None): accounts = Account.objects.filter(user=user) for account in accounts: positions = account.client.get_all_positions() + print("positions", positions) for item in positions: items.append(item)