From 60979652d9f716fb334a7893098ea3268b965016 Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Fri, 4 Nov 2022 07:20:14 +0000 Subject: [PATCH] Implement more validation and conversion --- core/exchanges/__init__.py | 58 ++++---- core/exchanges/alpaca.py | 21 +-- core/exchanges/oanda.py | 5 +- core/lib/schemas/alpaca_s.py | 88 ++++++++---- core/lib/schemas/oanda_s.py | 268 +++++++++++++++++++++++++---------- core/models.py | 1 + core/views/accounts.py | 9 +- core/views/positions.py | 2 - 8 files changed, 294 insertions(+), 158 deletions(-) diff --git a/core/exchanges/__init__.py b/core/exchanges/__init__.py index 970705a..049e000 100644 --- a/core/exchanges/__init__.py +++ b/core/exchanges/__init__.py @@ -1,5 +1,4 @@ from glom import glom -from pydantic import ValidationError from core.lib import schemas from core.util import logs @@ -12,6 +11,8 @@ STRICT_CONVERSION = False # TODO: Set them to True when all message types are implemented +log = logs.get_logger("exchanges") + class NoSchema(Exception): """ @@ -63,14 +64,10 @@ class BaseExchange(object): name = self.__class__.__name__ self.name = name.replace("Exchange", "").lower() self.account = account - self.log = logs.get_logger(self.name) self.client = None self.connect() - def set_schema(self): - raise NotImplementedError - def connect(self): raise NotImplementedError @@ -83,7 +80,8 @@ class BaseExchange(object): if hasattr(schemas, f"{self.name}_s"): schema_instance = getattr(schemas, f"{self.name}_s") else: - raise NoSchema(f"No schema for {self.name} in schema mapping") + log.error(f"No schema library for {self.name}") + raise Exception(f"No schema library for exchange {self.name}") return schema_instance @@ -93,14 +91,13 @@ class BaseExchange(object): else: to_camel = snake_to_camel(method.__class__.__name__) if convert: - to_camel += "Schema" + to_camel = f"{to_camel}Schema" # if hasattr(self.schema, method): # schema = getattr(self.schema, method) if hasattr(self.schema, to_camel): schema = getattr(self.schema, to_camel) else: - self.log.error(f"Method cannot be validated: {to_camel}") - raise NoSchema(f"Method cannot be validated: {to_camel}") + raise NoSchema(f"Could not get schema: {to_camel}") return schema def call_method(self, method, *args, **kwargs): @@ -124,40 +121,37 @@ class BaseExchange(object): # Use glom to convert the response to the schema converted = glom(response, schema) - print(f"[{self.name}] Converted of {method}: {converted}") return converted + def validate_response(self, response, method): + schema = self.get_schema(method) + # Return a dict of the validated response + response_valid = schema(**response).dict() + return response_valid + def call(self, method, *args, **kwargs): """ Call the exchange API and validate the response :raises NoSchema: If the method is not in the schema mapping :raises ValidationError: If the response cannot be validated """ + response = self.call_method(method, *args, **kwargs) + try: + response_valid = self.validate_response(response, method) + except NoSchema as e: + log.error(f"{e} - {response}") + response_valid = response + # Convert the response to a format that we can use try: - response = self.call_method(method, *args, **kwargs) - try: - schema = self.get_schema(method) - # Return a dict of the validated response - response_valid = schema(**response).dict() - except NoSchema: - self.log.debug(f"No schema: {response}") - if STRICT_VALIDATION: - raise - # Return the response as is - response_valid = response - - # Convert the response to a format that we can use response_converted = self.convert_spec(response_valid, method) - # return (True, response_converted) - return response_converted - except ValidationError as e: - self.log.error(f"Could not validate response: {e}") - raise - except NoSuchMethod: - self.log.error(f"Method not found: {method}") - raise + except NoSchema as e: + log.error(f"{e} - {response}") + response_converted = response_valid + # return (True, response_converted) + return response_converted + # except Exception as e: - # self.log.error(f"Error calling method: {e}") + # log.error(f"Error calling method: {e}") # raise GenericAPIError(e) def get_account(self): diff --git a/core/exchanges/alpaca.py b/core/exchanges/alpaca.py index 1a7133f..07a3e24 100644 --- a/core/exchanges/alpaca.py +++ b/core/exchanges/alpaca.py @@ -1,5 +1,3 @@ -import functools - from alpaca.common.exceptions import APIError from alpaca.trading.client import TradingClient from alpaca.trading.enums import OrderSide, TimeInForce @@ -12,19 +10,6 @@ from alpaca.trading.requests import ( from core.exchanges import BaseExchange, ExchangeError, GenericAPIError -def handle_errors(func): - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - return_value = func(self, *args, **kwargs) - if isinstance(return_value, tuple): - if return_value[0] is False: - print("Error: ", return_value[1]) - return return_value - return return_value[1] - - return wrapper - - class AlpacaExchange(BaseExchange): def connect(self): self.client = TradingClient( @@ -34,18 +19,14 @@ class AlpacaExchange(BaseExchange): raw_data=True, ) - @handle_errors def get_account(self): - # return self.call("get_account") - market = self.get_market_value("NONEXISTENT") - print("MARTKET", market) + return self.call("get_account") def get_supported_assets(self): request = GetAssetsRequest(status="active", asset_class="crypto") assets = self.call("get_all_assets", filter=request) assets = assets["itemlist"] asset_list = [x["symbol"] for x in assets if "symbol" in x] - print("Supported symbols", asset_list) return asset_list diff --git a/core/exchanges/oanda.py b/core/exchanges/oanda.py index 3719912..574003d 100644 --- a/core/exchanges/oanda.py +++ b/core/exchanges/oanda.py @@ -21,7 +21,9 @@ class OANDAExchange(BaseExchange): return self.call(r) def get_supported_assets(self): - return False + r = accounts.AccountInstruments(accountID=self.account_id) + response = self.call(r) + return [x["name"] for x in response["itemlist"]] def get_balance(self): raise NotImplementedError @@ -59,7 +61,6 @@ class OANDAExchange(BaseExchange): r = positions.OpenPositions(accountID=self.account_id) response = self.call(r) - print("Positions", response) for item in response["itemlist"]: item["account"] = self.account.name item["account_id"] = self.account.id diff --git a/core/lib/schemas/alpaca_s.py b/core/lib/schemas/alpaca_s.py index 5745e22..49955f9 100644 --- a/core/lib/schemas/alpaca_s.py +++ b/core/lib/schemas/alpaca_s.py @@ -24,6 +24,32 @@ class GetAllAssets(BaseModel): itemlist: list[Asset] +GetAllAssetsSchema = { + "itemlist": ( + "itemlist", + [ + { + "id": "id", + "class": "class", + "exchange": "exchange", + "symbol": "symbol", + "name": "name", + "status": "status", + "tradable": "tradable", + "marginable": "marginable", + "maintenance_margin_requirement": "maintenanceMarginRequirement", + "shortable": "shortable", + "easy_to_borrow": "easyToBorrow", + "fractionable": "fractionable", + "min_order_size": "minOrderSize", + "min_trade_increment": "minTradeIncrement", + "price_increment": "priceIncrement", + } + ], + ) +} + + # get_open_position class GetOpenPosition(BaseModel): asset_id: str @@ -72,31 +98,6 @@ class GetAllPositions(BaseModel): itemlist: list[Position] -{ - "itemlist": [ - { - "asset_id": "64bbff51-59d6-4b3c-9351-13ad85e3c752", - "symbol": "BTCUSD", - "exchange": "FTXU", - "asset_class": "crypto", - "asset_marginable": False, - "qty": "0.009975", - "avg_entry_price": "20714", - "side": "long", - "market_value": "204.297975", - "cost_basis": "206.62215", - "unrealized_pl": "-2.324175", - "unrealized_plpc": "-0.0112484310128416", - "unrealized_intraday_pl": "-0.269325", - "unrealized_intraday_plpc": "-0.001316559391457", - "current_price": "20481", - "lastday_price": "20508", - "change_today": "-0.001316559391457", - "qty_available": "0.009975", - } - ] -} - GetAllPositionsSchema = { "itemlist": ( "itemlist", @@ -152,4 +153,39 @@ class GetAccount(BaseModel): balance_asof: str -GetAccountSchema = {"": ""} +GetAccountSchema = { + "id": "id", + "account_number": "account_number", + "status": "status", + "crypto_status": "crypto_status", + "currency": "currency", + "buying_power": "buying_power", + "regt_buying_power": "regt_buying_power", + "daytrading_buying_power": "daytrading_buying_power", + "effective_buying_power": "effective_buying_power", + "non_marginable_buying_power": "non_marginable_buying_power", + "bod_dtbp": "bod_dtbp", + "cash": "cash", + "accrued_fees": "accrued_fees", + "pending_transfer_in": "pending_transfer_in", + "portfolio_value": "portfolio_value", + "pattern_day_trader": "pattern_day_trader", + "trading_blocked": "trading_blocked", + "transfers_blocked": "transfers_blocked", + "account_blocked": "account_blocked", + "created_at": "created_at", + "trade_suspended_by_user": "trade_suspended_by_user", + "multiplier": "multiplier", + "shorting_enabled": "shorting_enabled", + "equity": "equity", + "last_equity": "last_equity", + "long_market_value": "long_market_value", + "short_market_value": "short_market_value", + "position_market_value": "position_market_value", + "initial_margin": "initial_margin", + "maintenance_margin": "maintenance_margin", + "last_maintenance_margin": "last_maintenance_margin", + "sma": "sma", + "daytrade_count": "daytrade_count", + "balance_asof": "balance_asof", +} diff --git a/core/lib/schemas/oanda_s.py b/core/lib/schemas/oanda_s.py index 5aa8583..0897783 100644 --- a/core/lib/schemas/oanda_s.py +++ b/core/lib/schemas/oanda_s.py @@ -1,42 +1,5 @@ from pydantic import BaseModel -a = { - "positions": [ - { - "instrument": "EUR_USD", - "long": { - "units": "1", - "averagePrice": "0.99361", - "pl": "-0.1014", - "resettablePL": "-0.1014", - "financing": "0.0000", - "dividendAdjustment": "0.0000", - "guaranteedExecutionFees": "0.0000", - "tradeIDs": ["71"], - "unrealizedPL": "-0.0002", - }, - "short": { - "units": "0", - "pl": "0.0932", - "resettablePL": "0.0932", - "financing": "0.0000", - "dividendAdjustment": "0.0000", - "guaranteedExecutionFees": "0.0000", - "unrealizedPL": "0.0000", - }, - "pl": "-0.0082", - "resettablePL": "-0.0082", - "financing": "0.0000", - "commission": "0.0000", - "dividendAdjustment": "0.0000", - "guaranteedExecutionFees": "0.0000", - "unrealizedPL": "-0.0002", - "marginUsed": "0.0286", - } - ], - "lastTransactionID": "71", -} - class PositionLong(BaseModel): units: str @@ -79,44 +42,6 @@ class OpenPositions(BaseModel): lastTransactionID: str -{ - "positions": [ - { - "instrument": "EUR_USD", - "long": { - "units": "1", - "averagePrice": "0.99361", - "pl": "-0.1014", - "resettablePL": "-0.1014", - "financing": "-0.0002", - "dividendAdjustment": "0.0000", - "guaranteedExecutionFees": "0.0000", - "tradeIDs": ["71"], - "unrealizedPL": "-0.0044", - }, - "short": { - "units": "0", - "pl": "0.0932", - "resettablePL": "0.0932", - "financing": "0.0000", - "dividendAdjustment": "0.0000", - "guaranteedExecutionFees": "0.0000", - "unrealizedPL": "0.0000", - }, - "pl": "-0.0082", - "resettablePL": "-0.0082", - "financing": "-0.0002", - "commission": "0.0000", - "dividendAdjustment": "0.0000", - "guaranteedExecutionFees": "0.0000", - "unrealizedPL": "-0.0044", - "marginUsed": "0.0287", - } - ], - "lastTransactionID": "73", -} - - def parse_prices(x): if float(x["long"]["units"]) > 0: return x["long"]["averagePrice"] @@ -168,3 +93,196 @@ OpenPositionsSchema = { ], ) } + + +class AccountDetailsNested(BaseModel): + guaranteedStopLossOrderMode: str + hedgingEnabled: bool + id: str + createdTime: str + currency: str + createdByUserID: int + alias: str + marginRate: str + lastTransactionID: str + balance: str + openTradeCount: int + openPositionCount: int + pendingOrderCount: int + pl: str + resettablePL: str + resettablePLTime: str + financing: str + commission: str + dividendAdjustment: str + guaranteedExecutionFees: str + orders: list # Order + positions: list # Position + trades: list # Trade + unrealizedPL: str + NAV: str + marginUsed: str + marginAvailable: str + positionValue: str + marginCloseoutUnrealizedPL: str + marginCloseoutNAV: str + marginCloseoutMarginUsed: str + marginCloseoutPositionValue: str + marginCloseoutPercent: str + withdrawalLimit: str + marginCallMarginUsed: str + marginCallPercent: str + + +class AccountDetails(BaseModel): + account: AccountDetailsNested + lastTransactionID: str + + +AccountDetailsSchema = { + "guaranteedSLOM": "account.guaranteedStopLossOrderMode", + "hedgingEnabled": "account.hedgingEnabled", + "id": "account.id", + "created_at": "account.createdTime", + "currency": "account.currency", + "createdByUserID": "account.createdByUserID", + "alias": "account.alias", + "marginRate": "account.marginRate", + "lastTransactionID": "account.lastTransactionID", + "balance": "account.balance", + "openTradeCount": "account.openTradeCount", + "openPositionCount": "account.openPositionCount", + "pendingOrderCount": "account.pendingOrderCount", + "pl": "account.pl", + "resettablePL": "account.resettablePL", + "resettablePLTime": "account.resettablePLTime", + "financing": "account.financing", + "commission": "account.commission", + "dividendAdjustment": "account.dividendAdjustment", + "guaranteedExecutionFees": "account.guaranteedExecutionFees", + # "orders": "account.orders", + # "positions": "account.positions", + # "trades": "account.trades", + "unrealizedPL": "account.unrealizedPL", + "NAV": "account.NAV", + "marginUsed": "account.marginUsed", + "marginAvailable": "account.marginAvailable", + "positionValue": "account.positionValue", + "marginCloseoutUnrealizedPL": "account.marginCloseoutUnrealizedPL", + "marginCloseoutNAV": "account.marginCloseoutNAV", + "marginCloseoutMarginUsed": "account.marginCloseoutMarginUsed", + "marginCloseoutPositionValue": "account.marginCloseoutPositionValue", + "marginCloseoutPercent": "account.marginCloseoutPercent", + "withdrawalLimit": "account.withdrawalLimit", + "marginCallMarginUsed": "account.marginCallMarginUsed", + "marginCallPercent": "account.marginCallPercent", +} + + +class PositionDetailsNested(BaseModel): + instrument: str + long: PositionLong + short: PositionShort + pl: str + resettablePL: str + financing: str + commission: str + dividendAdjustment: str + guaranteedExecutionFees: str + unrealizedPL: str + marginUsed: str + + +class PositionDetails(BaseModel): + position: PositionDetailsNested + lastTransactionID: str + + +PositionDetailsSchema = { + "symbol": "position.instrument", + "long": "position.long", + "short": "position.short", + "pl": "position.pl", + "resettablePL": "position.resettablePL", + "financing": "position.financing", + "commission": "position.commission", + "dividendAdjustment": "position.dividendAdjustment", + "guaranteedExecutionFees": "position.guaranteedExecutionFees", + "unrealizedPL": "position.unrealizedPL", + "marginUsed": "position.marginUsed", + "price": lambda x: parse_prices(x["position"]), + "units": lambda x: parse_units(x["position"]), + "side": lambda x: parse_side(x["position"]), + "value": lambda x: parse_value(x["position"]), +} + + +class InstrumentTag(BaseModel): + type: str + name: str + + +class InstrumentFinancingDaysOfWeek(BaseModel): + dayOfWeek: str + daysCharged: int + + +class InstrumentFinancing(BaseModel): + longRate: str + shortRate: str + financingDaysOfWeek: list[InstrumentFinancingDaysOfWeek] + + +class InstrumentGuaranteedRestriction(BaseModel): + volume: str + priceRange: str + + +class Instrument(BaseModel): + name: str + type: str + displayName: str + pipLocation: int + displayPrecision: int + tradeUnitsPrecision: int + minimumTradeSize: str + maximumTrailingStopDistance: str + minimumTrailingStopDistance: str + maximumPositionSize: str + maximumOrderUnits: str + marginRate: str + guaranteedStopLossOrderMode: str + tags: list[InstrumentTag] + financing: InstrumentFinancing + guaranteedStopLossOrderLevelRestriction: InstrumentGuaranteedRestriction | None + + +class AccountInstruments(BaseModel): + instruments: list[Instrument] + + +AccountInstrumentsSchema = { + "itemlist": ( + "instruments", + [ + { + "name": "name", + "type": "type", + "displayName": "displayName", + "pipLocation": "pipLocation", + "displayPrecision": "displayPrecision", + "tradeUnitsPrecision": "tradeUnitsPrecision", + "minimumTradeSize": "minimumTradeSize", + "maximumTrailingStopDistance": "maximumTrailingStopDistance", + "minimumTrailingStopDistance": "minimumTrailingStopDistance", + "maximumPositionSize": "maximumPositionSize", + "maximumOrderUnits": "maximumOrderUnits", + "marginRate": "marginRate", + "guaranteedSLOM": "guaranteedStopLossOrderMode", + "tags": "tags", + "financing": "financing", + "guaranteedSLOLR": "guaranteedStopLossOrderLevelRestriction", + } + ], + ) +} diff --git a/core/models.py b/core/models.py index 2765c17..dd58864 100644 --- a/core/models.py +++ b/core/models.py @@ -91,6 +91,7 @@ class Account(models.Model): client = self.get_client() if client: supported_symbols = client.get_supported_assets() + log.debug(f"Supported symbols for {self.name}: {supported_symbols}") self.supported_symbols = supported_symbols super().save(*args, **kwargs) diff --git a/core/views/accounts.py b/core/views/accounts.py index 2bdf6ab..41575c5 100644 --- a/core/views/accounts.py +++ b/core/views/accounts.py @@ -14,7 +14,13 @@ log = logs.get_logger(__name__) class AccountInfo(LoginRequiredMixin, View): - VIEWABLE_FIELDS_MODEL = ["name", "exchange", "api_key", "sandbox"] + VIEWABLE_FIELDS_MODEL = [ + "name", + "exchange", + "api_key", + "sandbox", + "supported_symbols", + ] allowed_types = ["modal", "widget", "window", "page"] window_content = "window-content/account-info.html" @@ -45,6 +51,7 @@ class AccountInfo(LoginRequiredMixin, View): account_info = { k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL } + account_info["supported_symbols"] = ", ".join(account_info["supported_symbols"]) if type == "page": type = "modal" diff --git a/core/views/positions.py b/core/views/positions.py index b47b6b9..9f70b95 100644 --- a/core/views/positions.py +++ b/core/views/positions.py @@ -36,7 +36,6 @@ class Positions(LoginRequiredMixin, View): template_name = f"wm/{type}.html" unique = str(uuid.uuid4())[:8] items = get_positions(request.user, account_id) - print("ITEMS", items) if type == "page": type = "modal" context = { @@ -68,7 +67,6 @@ class PositionAction(LoginRequiredMixin, View): account = Account.get_by_id(account_id, request.user) info = account.client.get_position_info(symbol) - print("ACCT INFO", info) if type == "page": type = "modal"