diff --git a/core/exchanges/__init__.py b/core/exchanges/__init__.py index 93bb8dc..64d2f6a 100644 --- a/core/exchanges/__init__.py +++ b/core/exchanges/__init__.py @@ -1,3 +1,5 @@ +from pydantic import ValidationError + from core.util import logs @@ -8,11 +10,36 @@ class BaseExchange(object): self.log = logs.get_logger(name) self.client = None + self.set_schema() self.connect() + def set_schema(self): + raise NotImplementedError + def connect(self): raise NotImplementedError + def call(self, method, *args, **kwargs) -> (bool, dict): + + if hasattr(self.client, method): + try: + response = getattr(self.client, method)(*args, **kwargs) + if isinstance(response, list): + response = {"itemlist": response} + if method not in self.schema: + self.log.error(f"Method cannot be validated: {method}") + self.log.debug(f"Response: {response}") + return (False, f"Method cannot be validated: {method}") + return (True, self.schema[method](**response).dict()) + except ValidationError as e: + self.log.error(f"Could not validate response: {e}") + return (False, e) + except Exception as e: + self.log.error(f"Error calling {method}: {e}") + return (False, e) + else: + return (False, "No such method") + def get_account(self): raise NotImplementedError diff --git a/core/exchanges/alpaca.py b/core/exchanges/alpaca.py index 37cbbcd..a6c2ffd 100644 --- a/core/exchanges/alpaca.py +++ b/core/exchanges/alpaca.py @@ -8,9 +8,20 @@ from alpaca.trading.requests import ( ) from core.exchanges import BaseExchange +from core.lib.schemas import alpaca_s + +ALPACA_SCHEMA_MAPPING = { + "get_account": alpaca_s.GetAccount, + "get_all_assets": alpaca_s.GetAllAssets, + "get_all_positions": alpaca_s.GetAllPositions, + "get_open_position": alpaca_s.GetOpenPosition, +} class AlpacaExchange(BaseExchange): + def set_schema(self): + self.schema = ALPACA_SCHEMA_MAPPING + def connect(self): self.client = TradingClient( self.account.api_key, @@ -20,32 +31,31 @@ class AlpacaExchange(BaseExchange): ) def get_account(self): - return self.client.get_account() + return self.call("get_account") def get_supported_assets(self): - try: - request = GetAssetsRequest(status="active", asset_class="crypto") - assets = self.client.get_all_assets(filter=request) - asset_list = [x["symbol"] for x in assets if "symbol" in x] - print("Supported symbols", asset_list) - except APIError as e: - self.log.error(f"Could not get asset list: {e}") - # return False - return asset_list + request = GetAssetsRequest(status="active", asset_class="crypto") + success, assets = self.call("get_all_assets", filter=request) + # assets = self.client.get_all_assets(filter=request) + if not success: + return (success, assets) + assets = assets["itemlist"] + asset_list = [x["symbol"] for x in assets if "symbol" in x] + print("Supported symbols", asset_list) + + return (True, asset_list) def get_balance(self): - try: - account_info = self.client.get_account() - except APIError as e: - self.log.error(f"Could not get account balance: {e}") - return False + success, account_info = self.call("get_account") + if not success: + return (success, account_info) equity = account_info["equity"] try: balance = float(equity) except ValueError: - return False + return (False, "Invalid balance") - return balance + return (True, balance) def get_market_value(self, symbol): try: @@ -89,6 +99,8 @@ class AlpacaExchange(BaseExchange): order = self.client.submit_order(order_data=market_order_data) except APIError as e: self.log.error(f"Error placing market order: {e}") + trade.status = "error" + trade.save() return (False, e) elif trade.type == "limit": if not trade.price: @@ -99,6 +111,8 @@ class AlpacaExchange(BaseExchange): order = self.client.submit_order(order_data=limit_order_data) except APIError as e: self.log.error(f"Error placing limit order: {e}") + trade.status = "error" + trade.save() return (False, e) else: @@ -120,19 +134,19 @@ class AlpacaExchange(BaseExchange): pass def get_position_info(self, asset_id): - try: - position = self.client.get_open_position(asset_id) - except APIError as e: - return (False, e) + success, position = self.call("get_open_position", asset_id) + if not success: + return (success, position) return (True, position) def get_all_positions(self): items = [] - positions = self.client.get_all_positions() + success, positions = self.call("get_all_positions") + if not success: + return (success, positions) - for item in positions: - item = dict(item) + for item in positions["itemlist"]: item["account_id"] = self.account.id item["unrealized_pl"] = float(item["unrealized_pl"]) items.append(item) - return items + return (True, items) diff --git a/core/exchanges/oanda.py b/core/exchanges/oanda.py index b8fce68..6f284d1 100644 --- a/core/exchanges/oanda.py +++ b/core/exchanges/oanda.py @@ -2,9 +2,15 @@ from oandapyV20 import API from oandapyV20.endpoints import accounts, orders, positions, trades from core.exchanges import BaseExchange +from core.lib.schemas import oanda_s + +OANDA_SCHEMA_MAPPING = {} class OANDAExchange(BaseExchange): + def set_schema(self): + self.schema = OANDA_SCHEMA_MAPPING + def connect(self): self.client = API(access_token=self.account.api_secret) self.account_id = self.account.api_key @@ -53,4 +59,4 @@ class OANDAExchange(BaseExchange): def get_all_positions(self): r = positions.OpenPositions(accountID=self.account_id) self.client.request(r) - return r.response["positions"] + return (True, []) diff --git a/core/lib/market.py b/core/lib/market.py index f187910..e804574 100644 --- a/core/lib/market.py +++ b/core/lib/market.py @@ -24,9 +24,9 @@ def get_market_value(account, symbol): def execute_strategy(callback, strategy): - cash_balance = get_balance(strategy.account) + success, cash_balance = strategy.account.client.get_balance() log.debug(f"Cash balance: {cash_balance}") - if not cash_balance: + if not success: return None user = strategy.user @@ -41,7 +41,7 @@ def execute_strategy(callback, strategy): quote = "usd" # TODO: MASSIVE HACK symbol = f"{base.upper()}/{quote.upper()}" - if symbol not in account.supported_assets: + if symbol not in account.supported_symbols: log.error(f"Symbol not supported by account: {symbol}") return False diff --git a/core/lib/serde/__init__.py b/core/lib/schemas/__init__.py similarity index 100% rename from core/lib/serde/__init__.py rename to core/lib/schemas/__init__.py diff --git a/core/lib/schemas/alpaca_s.py b/core/lib/schemas/alpaca_s.py new file mode 100644 index 0000000..a76e0d9 --- /dev/null +++ b/core/lib/schemas/alpaca_s.py @@ -0,0 +1,105 @@ +from pydantic import BaseModel, Field + +class Asset(BaseModel): + id: str + class_: str = Field(..., alias="class") + exchange: str + symbol: str + name: str + status: str + tradable: bool + marginable: bool + maintenance_margin_requirement: int + shortable: bool + easy_to_borrow: bool + fractionable: bool + min_order_size: str + min_trade_increment: str + price_increment: str + +# get_all_assets +class GetAllAssets(BaseModel): + itemlist: list[Asset] + +# get_open_position +class GetOpenPosition(BaseModel): + asset_id: str + symbol: str + exchange: str + asset_class: str = Field(..., alias="asset_class") + asset_marginable: bool + qty: str + avg_entry_price: str + side: str + market_value: str + cost_basis: str + unrealized_pl: str + unrealized_plpc: str + unrealized_intraday_pl: str + unrealized_intraday_plpc: str + current_price: str + lastday_price: str + change_today: str + qty_available: str + + +class Position(BaseModel): + asset_id: str + symbol: str + exchange: str + asset_class: str + asset_marginable: bool + qty: str + avg_entry_price: str + side: str + market_value: str + cost_basis: str + unrealized_pl: str + unrealized_plpc: str + unrealized_intraday_pl: str + unrealized_intraday_plpc: str + current_price: str + lastday_price: str + change_today: str + qty_available: str + +# get_all_positions +class GetAllPositions(BaseModel): + itemlist: list[Position] + +# get_account +class GetAccount(BaseModel): + id: str + account_number: str + status: str + crypto_status: str + currency: str + buying_power: str + regt_buying_power: str + daytrading_buying_power: str + effective_buying_power: str + non_marginable_buying_power: str + bod_dtbp: str + cash: str + accrued_fees: str + pending_transfer_in: str + portfolio_value: str + pattern_day_trader: bool + trading_blocked: bool + transfers_blocked: bool + account_blocked: bool + created_at: str + trade_suspended_by_user: bool + multiplier: str + shorting_enabled: bool + equity: str + last_equity: str + long_market_value: str + short_market_value: str + position_market_value: str + initial_margin: str + maintenance_margin: str + last_maintenance_margin: str + sma: str + daytrade_count: int + balance_asof: str diff --git a/core/lib/schemas/drakdoo_s.py b/core/lib/schemas/drakdoo_s.py new file mode 100644 index 0000000..34b51ea --- /dev/null +++ b/core/lib/schemas/drakdoo_s.py @@ -0,0 +1,21 @@ +from pydantic import BaseModel + + +class DrakdooMarket(BaseModel): + exchange: str + item: str + currency: str + contract: str + + +class DrakdooTimestamp(BaseModel): + sent: int + trade: int + + +class DrakdooCallback(BaseModel): + title: str + message: str + period: str + market: DrakdooMarket + timestamp: DrakdooTimestamp diff --git a/core/lib/schemas/oanda_s.py b/core/lib/schemas/oanda_s.py new file mode 100644 index 0000000..22145a4 --- /dev/null +++ b/core/lib/schemas/oanda_s.py @@ -0,0 +1 @@ +from pydantic import BaseModel diff --git a/core/lib/serde/ccxt_s.py b/core/lib/serde/ccxt_s.py deleted file mode 100644 index 9ea53d0..0000000 --- a/core/lib/serde/ccxt_s.py +++ /dev/null @@ -1,125 +0,0 @@ -from serde import Model, fields - -# { -# "id": "92f0b26b-4c98-4553-9c74-cdafc7e037db", -# "clientOrderId": "ccxt_26adcbf445674f01af38a66a15e6f5b5", -# "timestamp": 1666096856515, -# "datetime": "2022-10-18T12:40:56.515477181Z", -# "lastTradeTimeStamp": null, -# "status": "open", -# "symbol": "BTC/USD", -# "type": "market", -# "timeInForce": "gtc", -# "postOnly": null, -# "side": "buy", -# "price": null, -# "stopPrice": null, -# "cost": null, -# "average": null, -# "amount": 1.1, -# "filled": 0.0, -# "remaining": 1.1, -# "trades": [], -# "fee": null, -# "info": { -# "id": "92f0b26b-4c98-4553-9c74-cdafc7e037db", -# "client_order_id": "ccxt_26adcbf445674f01af38a66a15e6f5b5", -# "created_at": "2022-10-18T12:40:56.516095561Z", -# "updated_at": "2022-10-18T12:40:56.516173841Z", -# "submitted_at": "2022-10-18T12:40:56.515477181Z", -# "filled_at": null, -# "expired_at": null, -# "canceled_at": null, -# "failed_at": null, -# "replaced_at": null, -# "replaced_by": null, -# "replaces": null, -# "asset_id": "276e2673-764b-4ab6-a611-caf665ca6340", -# "symbol": "BTC/USD", -# "asset_class": "crypto", -# "notional": null, -# "qty": "1.1", -# "filled_qty": "0", -# "filled_avg_price": null, -# "order_class": "", -# "order_type": "market", -# "type": "market", -# "side": "buy", -# "time_in_force": "gtc", -# "limit_price": null, -# "stop_price": null, -# "status": "pending_new", -# "extended_hours": false, -# "legs": null, -# "trail_percent": null, -# "trail_price": null, -# "hwm": null, -# "subtag": null, -# "source": null -# }, -# "fees": [], -# "lastTradeTimestamp": null -# } - - -class CCXTInfo(Model): - id = fields.Uuid() - client_order_id = fields.Str() - created_at = fields.Str() - updated_at = fields.Str() - submitted_at = fields.Str() - filled_at = fields.Optional(fields.Str()) - expired_at = fields.Optional(fields.Str()) - canceled_at = fields.Optional(fields.Str()) - failed_at = fields.Optional(fields.Str()) - replaced_at = fields.Optional(fields.Str()) - replaced_by = fields.Optional(fields.Str()) - replaces = fields.Optional(fields.Str()) - asset_id = fields.Uuid() - symbol = fields.Str() - asset_class = fields.Str() - notional = fields.Optional(fields.Str()) - qty = fields.Str() - filled_qty = fields.Str() - filled_avg_price = fields.Optional(fields.Str()) - order_class = fields.Str() - order_type = fields.Str() - type = fields.Str() - side = fields.Str() - time_in_force = fields.Str() - limit_price = fields.Optional(fields.Str()) - stop_price = fields.Optional(fields.Str()) - status = fields.Str() - extended_hours = fields.Bool() - legs = fields.Optional(fields.List(fields.Nested("CCXTInfo"))) - trail_percent = fields.Optional(fields.Str()) - trail_price = fields.Optional(fields.Str()) - hwm = fields.Optional(fields.Str()) - subtag = fields.Optional(fields.Str()) - source = fields.Optional(fields.Str()) - - -class CCXTRoot(Model): - id = fields.Uuid() - clientOrderId = fields.Str() - timestamp = fields.Int() - datetime = fields.Str() - lastTradeTimeStamp = fields.Optional(fields.Str()) - status = fields.Str() - symbol = fields.Str() - type = fields.Str() - timeInForce = fields.Str() - postOnly = fields.Optional(fields.Str()) - side = fields.Str() - price = fields.Optional(fields.Float()) - stopPrice = fields.Optional(fields.Float()) - cost = fields.Optional(fields.Float()) - average = fields.Optional(fields.Float()) - amount = fields.Float() - filled = fields.Float() - remaining = fields.Float() - trades = fields.Optional(fields.List(fields.Dict())) - fee = fields.Optional(fields.Float()) - info = fields.Nested(CCXTInfo) - fees = fields.Optional(fields.List(fields.Dict())) - lastTradeTimestamp = fields.Optional(fields.Str()) diff --git a/core/models.py b/core/models.py index a4cd6b8..0ff1ac0 100644 --- a/core/models.py +++ b/core/models.py @@ -91,8 +91,8 @@ class Account(models.Model): """ client = self.get_client() if client: - supported_symbols = client.get_supported_assets() - if supported_symbols: + success, supported_symbols = client.get_supported_assets() + if success: self.supported_symbols = supported_symbols super().save(*args, **kwargs) diff --git a/core/templates/window-content/account-info.html b/core/templates/window-content/account-info.html index d072ed7..8641214 100644 --- a/core/templates/window-content/account-info.html +++ b/core/templates/window-content/account-info.html @@ -1,3 +1,5 @@ +{% include 'partials/notify.html' %} +

Live information

diff --git a/core/views/accounts.py b/core/views/accounts.py index da795c6..d651596 100644 --- a/core/views/accounts.py +++ b/core/views/accounts.py @@ -39,7 +39,15 @@ class AccountInfo(LoginRequiredMixin, View): } return render(request, template_name, context) - live_info = dict(account.client.get_account()) + success, live_info = account.client.get_account() + if not success: + context = { + "message": "Could not get account info", + "class": "danger", + "window_content": self.window_content, + } + return render(request, template_name, context) + live_info = live_info account_info = account.__dict__ account_info = { k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL diff --git a/core/views/hooks.py b/core/views/hooks.py index b5553d3..280a168 100644 --- a/core/views/hooks.py +++ b/core/views/hooks.py @@ -4,13 +4,13 @@ import orjson from django.conf import settings from django.contrib.auth.mixins import LoginRequiredMixin from django.http import HttpResponse, HttpResponseBadRequest +from pydantic import ValidationError from rest_framework.parsers import JSONParser from rest_framework.views import APIView -from serde import ValidationError from core.forms import HookForm from core.lib import market -from core.lib.serde import drakdoo_s +from core.lib.schemas.drakdoo_s import DrakdooCallback from core.models import Callback, Hook from core.util import logs from core.views import ObjectCreate, ObjectDelete, ObjectList, ObjectUpdate @@ -44,7 +44,7 @@ class HookAPI(APIView): # Try validating the JSON try: - hook_resp = drakdoo_s.BaseDrakdoo.from_dict(request.data) + hook_resp = DrakdooCallback(**request.data) except ValidationError as e: log.error(f"HookAPI POST: {e}") return HttpResponseBadRequest(e) diff --git a/core/views/positions.py b/core/views/positions.py index 4a04242..372bd5b 100644 --- a/core/views/positions.py +++ b/core/views/positions.py @@ -17,8 +17,10 @@ def get_positions(user, account_id=None): items = [] accounts = Account.objects.filter(user=user) for account in accounts: - positions = account.client.get_all_positions() - print("positions", positions) + success, positions = account.client.get_all_positions() + if not success: + items.append({"name": account.name, "status": "error"}) + continue for item in positions: items.append(item) diff --git a/docker/prod/requirements.prod.txt b/docker/prod/requirements.prod.txt index 41c8898..dd802da 100644 --- a/docker/prod/requirements.prod.txt +++ b/docker/prod/requirements.prod.txt @@ -15,6 +15,6 @@ django-debug-toolbar-template-profiler orjson django-otp qrcode -serde[ext] +pydantic alpaca-py oandapyV20