From b36791d56b3f513a66c0a02fed508d102df482ef Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Fri, 4 Nov 2022 07:20:55 +0000 Subject: [PATCH] Simplify schema and error handling --- core/exchanges/__init__.py | 155 +++++++++++++++++++++++++++-------- core/exchanges/alpaca.py | 74 ++++++++--------- core/exchanges/oanda.py | 47 +++-------- core/lib/schemas/alpaca_s.py | 5 +- core/lib/schemas/oanda_s.py | 2 +- core/models.py | 3 +- core/views/hooks.py | 5 +- core/views/positions.py | 18 +--- 8 files changed, 181 insertions(+), 128 deletions(-) diff --git a/core/exchanges/__init__.py b/core/exchanges/__init__.py index ad1f6a1..1af7dba 100644 --- a/core/exchanges/__init__.py +++ b/core/exchanges/__init__.py @@ -8,10 +8,56 @@ from core.util import logs STRICT_VALIDATION = False # Raise exception if the conversion schema is not found -STRICT_CONVERSTION = False +STRICT_CONVERSION = False # TODO: Set them to True when all message types are implemented + +class NoSchema(Exception): + """ + Raised when: + - The schema for the message type is not found + - The conversion schema is not found + - There is no schema library for the exchange + """ + + pass + + +class NoSuchMethod(Exception): + """ + Exchange library has no such method. + """ + + pass + + +class GenericAPIError(Exception): + """ + Generic API error. + """ + + pass + + +class ExchangeError(Exception): + """ + Exchange error. + """ + + pass + + +def is_camel_case(s): + return s != s.lower() and s != s.upper() and "_" not in s + + +def snake_to_camel(word): + if is_camel_case(word): + return word + return "".join(x.capitalize() or "_" for x in word.split("_")) + + class BaseExchange(object): def __init__(self, account): name = self.__class__.__name__ @@ -20,7 +66,6 @@ class BaseExchange(object): self.log = logs.get_logger(self.name) self.client = None - self.set_schema() self.connect() def set_schema(self): @@ -29,53 +74,91 @@ class BaseExchange(object): def connect(self): raise NotImplementedError - def convert_spec(self, response, msg_type): + @property + def schema(self): + """ + Get the schema library for the exchange. + """ # Does the schemas library have a library for this exchange name? if hasattr(schemas, f"{self.name}_s"): schema_instance = getattr(schemas, f"{self.name}_s") else: - raise Exception(f"No schema for {self.name} in schema mapping") - # Does the message type have a conversion spec for this message type? - if hasattr(schema_instance, f"{msg_type}_schema"): - schema = getattr(schema_instance, f"{msg_type}_schema") + raise NoSchema(f"No schema for {self.name} in schema mapping") + + return schema_instance + + def get_schema(self, method, convert=False): + if isinstance(method, str): + to_camel = snake_to_camel(method) + else: + to_camel = snake_to_camel(method.__class__.__name__) + if convert: + to_camel += "Schema" + # if hasattr(self.schema, method): + # schema = getattr(self.schema, method) + if hasattr(self.schema, to_camel): + schema = getattr(self.schema, to_camel) else: - # Let us know so we can implement it, but don't do anything with it - self.log.error(f"No schema for message: {msg_type} - {response}") - if STRICT_CONVERSION: - raise Exception(f"No schema for {msg_type} in schema mapping") + raise NoSchema + return schema + + def call_method(self, method, *args, **kwargs): + """ + Get a method from the exchange library. + """ + if hasattr(self.client, method): + response = getattr(self.client, method)(*args, **kwargs) + if isinstance(response, list): + response = {"itemlist": response} return response + else: + raise NoSuchMethod + + def convert_spec(self, response, method): + """ + Convert an API response to the requested spec. + :raises NoSchema: If the conversion schema is not found + """ + schema = self.get_schema(method, convert=True) # Use glom to convert the response to the schema converted = glom(response, schema) - print(f"[{self.name}] Converted of {msg_type}: {converted}") + print(f"[{self.name}] Converted of {method}: {converted}") return converted - def call(self, method, *args, **kwargs) -> (bool, dict): - - if hasattr(self.client, method): + def call(self, method, *args, **kwargs): + """ + Call the exchange API and validate the response + :raises NoSchema: If the method is not in the schema mapping + :raises ValidationError: If the response cannot be validated + """ + try: + response = self.call_method(method, *args, **kwargs) try: - response = getattr(self.client, method)(*args, **kwargs) - if isinstance(response, list): - response = {"itemlist": response} - if method not in self.schema: - self.log.error(f"Method cannot be validated: {method}") - self.log.debug(f"Response: {response}") - if STRICT_VALIDATION: - return (False, f"Method cannot be validated: {method}") - return (True, response) + schema = self.get_schema(method) # Return a dict of the validated response - response_valid = self.schema[method](**response).dict() - # Convert the response to a format that we can use - response_converted = self.convert_spec(response_valid, method) - return (True, response_converted) - except ValidationError as e: - self.log.error(f"Could not validate response: {e}") - return (False, e) - except Exception as e: - self.log.error(f"Error calling {method}: {e}") - return (False, e) - else: - return (False, "No such method") + response_valid = schema(**response).dict() + except NoSchema: + self.log.error(f"Method cannot be validated: {method}") + self.log.debug(f"Response: {response}") + if STRICT_VALIDATION: + raise + # Return the response as is + response_valid = response + + # Convert the response to a format that we can use + response_converted = self.convert_spec(response_valid, method) + # return (True, response_converted) + return response_converted + except ValidationError as e: + self.log.error(f"Could not validate response: {e}") + raise + except NoSuchMethod: + self.log.error(f"Method not found: {method}") + raise + except Exception as e: + self.log.error(f"Error calling method: {e}") + raise GenericAPIError(e) def get_account(self): raise NotImplementedError diff --git a/core/exchanges/alpaca.py b/core/exchanges/alpaca.py index ed39f28..1a7133f 100644 --- a/core/exchanges/alpaca.py +++ b/core/exchanges/alpaca.py @@ -1,3 +1,5 @@ +import functools + from alpaca.common.exceptions import APIError from alpaca.trading.client import TradingClient from alpaca.trading.enums import OrderSide, TimeInForce @@ -7,21 +9,23 @@ from alpaca.trading.requests import ( MarketOrderRequest, ) -from core.exchanges import BaseExchange -from core.lib.schemas import alpaca_s +from core.exchanges import BaseExchange, ExchangeError, GenericAPIError -ALPACA_SCHEMA_MAPPING = { - "get_account": alpaca_s.GetAccount, - "get_all_assets": alpaca_s.GetAllAssets, - "get_all_positions": alpaca_s.GetAllPositions, - "get_open_position": alpaca_s.GetOpenPosition, -} +def handle_errors(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + return_value = func(self, *args, **kwargs) + if isinstance(return_value, tuple): + if return_value[0] is False: + print("Error: ", return_value[1]) + return return_value + return return_value[1] -class AlpacaExchange(BaseExchange): - def set_schema(self): - self.schema = ALPACA_SCHEMA_MAPPING + return wrapper + +class AlpacaExchange(BaseExchange): def connect(self): self.client = TradingClient( self.account.api_key, @@ -30,39 +34,37 @@ class AlpacaExchange(BaseExchange): raw_data=True, ) + @handle_errors def get_account(self): - return self.call("get_account") + # return self.call("get_account") + market = self.get_market_value("NONEXISTENT") + print("MARTKET", market) def get_supported_assets(self): request = GetAssetsRequest(status="active", asset_class="crypto") - success, assets = self.call("get_all_assets", filter=request) - # assets = self.client.get_all_assets(filter=request) - if not success: - return (success, assets) + assets = self.call("get_all_assets", filter=request) assets = assets["itemlist"] asset_list = [x["symbol"] for x in assets if "symbol" in x] print("Supported symbols", asset_list) - return (True, asset_list) + return asset_list def get_balance(self): - success, account_info = self.call("get_account") - if not success: - return (success, account_info) + account_info = self.call("get_account") equity = account_info["equity"] try: balance = float(equity) except ValueError: - return (False, "Invalid balance") + raise GenericAPIError(f"Balance is not a float: {equity}") - return (True, balance) + return balance def get_market_value(self, symbol): # TODO: pydantic try: position = self.client.get_position(symbol) except APIError as e: self.log.error(f"Could not get market value for {symbol}: {e}") - return False + raise GenericAPIError(e) return float(position["market_value"]) def post_trade(self, trade): # TODO: pydantic @@ -72,7 +74,7 @@ class AlpacaExchange(BaseExchange): elif trade.direction == "sell": direction = OrderSide.SELL else: - raise Exception("Unknown direction") + raise ExchangeError("Unknown direction") cast = { "symbol": trade.symbol, @@ -84,7 +86,7 @@ class AlpacaExchange(BaseExchange): if trade.amount_usd is not None: cast["notional"] = trade.amount_usd if not trade.amount and not trade.amount_usd: - return (False, "No amount specified") + raise ExchangeError("No amount specified") if trade.take_profit: cast["take_profit"] = {"limit_price": trade.take_profit} if trade.stop_loss: @@ -101,10 +103,10 @@ class AlpacaExchange(BaseExchange): self.log.error(f"Error placing market order: {e}") trade.status = "error" trade.save() - return (False, e) + raise GenericAPIError(e) elif trade.type == "limit": if not trade.price: - return (False, "Limit order with no price") + raise ExchangeError("No price specified for limit order") cast["limit_price"] = trade.price limit_order_data = LimitOrderRequest(**cast) try: @@ -113,16 +115,16 @@ class AlpacaExchange(BaseExchange): self.log.error(f"Error placing limit order: {e}") trade.status = "error" trade.save() - return (False, e) + raise GenericAPIError(e) else: - raise Exception("Unknown trade type") + raise ExchangeError("Unknown trade type") trade.response = order trade.status = "posted" trade.order_id = order["id"] trade.client_order_id = order["client_order_id"] trade.save() - return (True, order) + return order def get_trade(self, trade_id): pass @@ -134,20 +136,16 @@ class AlpacaExchange(BaseExchange): pass def get_position_info(self, symbol): - success, position = self.call("get_open_position", symbol) - if not success: - return (success, position) - return (True, position) + position = self.call("get_open_position", symbol) + return position def get_all_positions(self): items = [] - success, response = self.call("get_all_positions") - if not success: - return (success, response) + response = self.call("get_all_positions") for item in response["itemlist"]: item["account"] = self.account.name item["account_id"] = self.account.id item["unrealized_pl"] = float(item["unrealized_pl"]) items.append(item) - return (True, items) + return items diff --git a/core/exchanges/oanda.py b/core/exchanges/oanda.py index d42b2c8..bbf12c0 100644 --- a/core/exchanges/oanda.py +++ b/core/exchanges/oanda.py @@ -1,35 +1,16 @@ from oandapyV20 import API -from oandapyV20.endpoints import accounts, orders, positions, trades -from pydantic import ValidationError +from oandapyV20.endpoints import accounts, positions from core.exchanges import BaseExchange -from core.lib.schemas import oanda_s - -OANDA_SCHEMA_MAPPING = {"OpenPositions": oanda_s.OpenPositions} class OANDAExchange(BaseExchange): - def call(self, method, request): + def call_method(self, request): self.client.request(request) response = request.response if isinstance(response, list): response = {"itemlist": response} - if method not in self.schema: - self.log.error(f"Method cannot be validated: {method}") - self.log.debug(f"Response: {response}") - return (False, f"Method cannot be validated: {method}") - try: - # Return a dict of the validated response - response_valid = self.schema[method](**response).dict() - # Convert the response to a format that we can use - response_converted = self.convert_spec(response_valid, method) - return (True, response_converted) - except ValidationError as e: - self.log.error(f"Could not validate response: {e}") - return (False, e) - - def set_schema(self): - self.schema = OANDA_SCHEMA_MAPPING + return response def connect(self): self.client = API(access_token=self.account.api_secret) @@ -51,9 +32,9 @@ class OANDAExchange(BaseExchange): def post_trade(self, trade): raise NotImplementedError - r = orders.OrderCreate(accountID, data=data) - self.client.request(r) - return r.response + # r = orders.OrderCreate(accountID, data=data) + # self.client.request(r) + # return r.response def get_trade(self, trade_id): r = accounts.TradeDetails(accountID=self.account_id, tradeID=trade_id) @@ -62,11 +43,11 @@ class OANDAExchange(BaseExchange): def update_trade(self, trade): raise NotImplementedError - r = orders.OrderReplace( - accountID=self.account_id, orderID=trade.order_id, data=data - ) - self.client.request(r) - return r.response + # r = orders.OrderReplace( + # accountID=self.account_id, orderID=trade.order_id, data=data + # ) + # self.client.request(r) + # return r.response def cancel_trade(self, trade_id): raise NotImplementedError @@ -79,9 +60,7 @@ class OANDAExchange(BaseExchange): def get_all_positions(self): items = [] r = positions.OpenPositions(accountID=self.account_id) - success, response = self.call("OpenPositions", r) - if not success: - return (success, response) + response = self.call(r) print("Positions", response) for item in response["itemlist"]: @@ -89,4 +68,4 @@ class OANDAExchange(BaseExchange): item["account_id"] = self.account.id item["unrealized_pl"] = float(item["unrealized_pl"]) items.append(item) - return (True, items) + return items diff --git a/core/lib/schemas/alpaca_s.py b/core/lib/schemas/alpaca_s.py index 08d7de4..5745e22 100644 --- a/core/lib/schemas/alpaca_s.py +++ b/core/lib/schemas/alpaca_s.py @@ -97,7 +97,7 @@ class GetAllPositions(BaseModel): ] } -get_all_positions_schema = { +GetAllPositionsSchema = { "itemlist": ( "itemlist", [ @@ -113,6 +113,7 @@ get_all_positions_schema = { ) } + # get_account class GetAccount(BaseModel): id: str @@ -151,4 +152,4 @@ class GetAccount(BaseModel): balance_asof: str -get_account_schema = {"": ""} +GetAccountSchema = {"": ""} diff --git a/core/lib/schemas/oanda_s.py b/core/lib/schemas/oanda_s.py index d421d84..5aa8583 100644 --- a/core/lib/schemas/oanda_s.py +++ b/core/lib/schemas/oanda_s.py @@ -153,7 +153,7 @@ def parse_side(x): return "unknown" -OpenPositions_schema = { +OpenPositionsSchema = { "itemlist": ( "positions", [ diff --git a/core/models.py b/core/models.py index 0ff1ac0..da55226 100644 --- a/core/models.py +++ b/core/models.py @@ -5,7 +5,6 @@ from django.db import models from core.exchanges.alpaca import AlpacaExchange from core.exchanges.oanda import OANDAExchange -from core.lib import trades from core.lib.customers import get_or_create, update_customer_fields from core.util import logs @@ -100,7 +99,7 @@ class Account(models.Model): if self.exchange in EXCHANGE_MAP: return EXCHANGE_MAP[self.exchange](self) else: - raise Exception("Exchange not supported") + raise Exception(f"Exchange not supported : {self.exchange}") @property def client(self): diff --git a/core/views/hooks.py b/core/views/hooks.py index e92d1d9..37a162d 100644 --- a/core/views/hooks.py +++ b/core/views/hooks.py @@ -111,7 +111,10 @@ class HookList(LoginRequiredMixin, ObjectList): title = "Hooks" title_singular = "Hook" page_title = "List of active URL endpoints for receiving hooks." - page_subtitle = "Add URLs here to receive Drakdoo callbacks. Make then unique!" + page_subtitle = ( + "Add URLs here to receive Drakdoo callbacks. " + "Make then unique and hard to guess!" + ) list_url_name = "hooks" list_url_args = ["type"] diff --git a/core/views/positions.py b/core/views/positions.py index 831ed9f..b781aad 100644 --- a/core/views/positions.py +++ b/core/views/positions.py @@ -6,7 +6,6 @@ from django.shortcuts import render from django.views import View from rest_framework.parsers import FormParser -from core.lib import trades from core.models import Account from core.util import logs @@ -17,10 +16,7 @@ def get_positions(user, account_id=None): items = [] accounts = Account.objects.filter(user=user) for account in accounts: - success, positions = account.client.get_all_positions() - if not success: - items.append({"name": account.name, "status": "error"}) - continue + positions = account.client.get_all_positions() for item in positions: items.append(item) @@ -71,11 +67,8 @@ class PositionAction(LoginRequiredMixin, View): unique = str(uuid.uuid4())[:8] account = Account.get_by_id(account_id, request.user) - success, info = account.client.get_position_info(symbol) + info = account.client.get_position_info(symbol) print("ACCT INFO", info) - if not success: - message = "Position does not exist" - message_class = "danger" items = get_positions(request.user, account_id) if type == "page": @@ -87,9 +80,6 @@ class PositionAction(LoginRequiredMixin, View): "items": items, "type": type, } - if success: - context["items"] = info - else: - context["message"] = message - context["class"] = message_class + context["items"] = info + return render(request, template_name, context)