Simplify schema and error handling

This commit is contained in:
Mark Veidemanis 2022-11-04 07:20:55 +00:00
parent 04a87c1da6
commit b36791d56b
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
8 changed files with 181 additions and 128 deletions

View File

@ -8,10 +8,56 @@ from core.util import logs
STRICT_VALIDATION = False STRICT_VALIDATION = False
# Raise exception if the conversion schema is not found # Raise exception if the conversion schema is not found
STRICT_CONVERSTION = False STRICT_CONVERSION = False
# TODO: Set them to True when all message types are implemented # TODO: Set them to True when all message types are implemented
class NoSchema(Exception):
"""
Raised when:
- The schema for the message type is not found
- The conversion schema is not found
- There is no schema library for the exchange
"""
pass
class NoSuchMethod(Exception):
"""
Exchange library has no such method.
"""
pass
class GenericAPIError(Exception):
"""
Generic API error.
"""
pass
class ExchangeError(Exception):
"""
Exchange error.
"""
pass
def is_camel_case(s):
return s != s.lower() and s != s.upper() and "_" not in s
def snake_to_camel(word):
if is_camel_case(word):
return word
return "".join(x.capitalize() or "_" for x in word.split("_"))
class BaseExchange(object): class BaseExchange(object):
def __init__(self, account): def __init__(self, account):
name = self.__class__.__name__ name = self.__class__.__name__
@ -20,7 +66,6 @@ class BaseExchange(object):
self.log = logs.get_logger(self.name) self.log = logs.get_logger(self.name)
self.client = None self.client = None
self.set_schema()
self.connect() self.connect()
def set_schema(self): def set_schema(self):
@ -29,53 +74,91 @@ class BaseExchange(object):
def connect(self): def connect(self):
raise NotImplementedError raise NotImplementedError
def convert_spec(self, response, msg_type): @property
def schema(self):
"""
Get the schema library for the exchange.
"""
# Does the schemas library have a library for this exchange name? # Does the schemas library have a library for this exchange name?
if hasattr(schemas, f"{self.name}_s"): if hasattr(schemas, f"{self.name}_s"):
schema_instance = getattr(schemas, f"{self.name}_s") schema_instance = getattr(schemas, f"{self.name}_s")
else: else:
raise Exception(f"No schema for {self.name} in schema mapping") raise NoSchema(f"No schema for {self.name} in schema mapping")
# Does the message type have a conversion spec for this message type?
if hasattr(schema_instance, f"{msg_type}_schema"): return schema_instance
schema = getattr(schema_instance, f"{msg_type}_schema")
def get_schema(self, method, convert=False):
if isinstance(method, str):
to_camel = snake_to_camel(method)
else: else:
# Let us know so we can implement it, but don't do anything with it to_camel = snake_to_camel(method.__class__.__name__)
self.log.error(f"No schema for message: {msg_type} - {response}") if convert:
if STRICT_CONVERSION: to_camel += "Schema"
raise Exception(f"No schema for {msg_type} in schema mapping") # if hasattr(self.schema, method):
return response # schema = getattr(self.schema, method)
if hasattr(self.schema, to_camel):
# Use glom to convert the response to the schema schema = getattr(self.schema, to_camel)
converted = glom(response, schema) else:
print(f"[{self.name}] Converted of {msg_type}: {converted}") raise NoSchema
return converted return schema
def call(self, method, *args, **kwargs) -> (bool, dict):
def call_method(self, method, *args, **kwargs):
"""
Get a method from the exchange library.
"""
if hasattr(self.client, method): if hasattr(self.client, method):
try:
response = getattr(self.client, method)(*args, **kwargs) response = getattr(self.client, method)(*args, **kwargs)
if isinstance(response, list): if isinstance(response, list):
response = {"itemlist": response} response = {"itemlist": response}
if method not in self.schema: return response
else:
raise NoSuchMethod
def convert_spec(self, response, method):
"""
Convert an API response to the requested spec.
:raises NoSchema: If the conversion schema is not found
"""
schema = self.get_schema(method, convert=True)
# Use glom to convert the response to the schema
converted = glom(response, schema)
print(f"[{self.name}] Converted of {method}: {converted}")
return converted
def call(self, method, *args, **kwargs):
"""
Call the exchange API and validate the response
:raises NoSchema: If the method is not in the schema mapping
:raises ValidationError: If the response cannot be validated
"""
try:
response = self.call_method(method, *args, **kwargs)
try:
schema = self.get_schema(method)
# Return a dict of the validated response
response_valid = schema(**response).dict()
except NoSchema:
self.log.error(f"Method cannot be validated: {method}") self.log.error(f"Method cannot be validated: {method}")
self.log.debug(f"Response: {response}") self.log.debug(f"Response: {response}")
if STRICT_VALIDATION: if STRICT_VALIDATION:
return (False, f"Method cannot be validated: {method}") raise
return (True, response) # Return the response as is
# Return a dict of the validated response response_valid = response
response_valid = self.schema[method](**response).dict()
# Convert the response to a format that we can use # Convert the response to a format that we can use
response_converted = self.convert_spec(response_valid, method) response_converted = self.convert_spec(response_valid, method)
return (True, response_converted) # return (True, response_converted)
return response_converted
except ValidationError as e: except ValidationError as e:
self.log.error(f"Could not validate response: {e}") self.log.error(f"Could not validate response: {e}")
return (False, e) raise
except NoSuchMethod:
self.log.error(f"Method not found: {method}")
raise
except Exception as e: except Exception as e:
self.log.error(f"Error calling {method}: {e}") self.log.error(f"Error calling method: {e}")
return (False, e) raise GenericAPIError(e)
else:
return (False, "No such method")
def get_account(self): def get_account(self):
raise NotImplementedError raise NotImplementedError

View File

@ -1,3 +1,5 @@
import functools
from alpaca.common.exceptions import APIError from alpaca.common.exceptions import APIError
from alpaca.trading.client import TradingClient from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide, TimeInForce from alpaca.trading.enums import OrderSide, TimeInForce
@ -7,21 +9,23 @@ from alpaca.trading.requests import (
MarketOrderRequest, MarketOrderRequest,
) )
from core.exchanges import BaseExchange from core.exchanges import BaseExchange, ExchangeError, GenericAPIError
from core.lib.schemas import alpaca_s
ALPACA_SCHEMA_MAPPING = {
"get_account": alpaca_s.GetAccount, def handle_errors(func):
"get_all_assets": alpaca_s.GetAllAssets, @functools.wraps(func)
"get_all_positions": alpaca_s.GetAllPositions, def wrapper(self, *args, **kwargs):
"get_open_position": alpaca_s.GetOpenPosition, return_value = func(self, *args, **kwargs)
} if isinstance(return_value, tuple):
if return_value[0] is False:
print("Error: ", return_value[1])
return return_value
return return_value[1]
return wrapper
class AlpacaExchange(BaseExchange): class AlpacaExchange(BaseExchange):
def set_schema(self):
self.schema = ALPACA_SCHEMA_MAPPING
def connect(self): def connect(self):
self.client = TradingClient( self.client = TradingClient(
self.account.api_key, self.account.api_key,
@ -30,39 +34,37 @@ class AlpacaExchange(BaseExchange):
raw_data=True, raw_data=True,
) )
@handle_errors
def get_account(self): def get_account(self):
return self.call("get_account") # return self.call("get_account")
market = self.get_market_value("NONEXISTENT")
print("MARTKET", market)
def get_supported_assets(self): def get_supported_assets(self):
request = GetAssetsRequest(status="active", asset_class="crypto") request = GetAssetsRequest(status="active", asset_class="crypto")
success, assets = self.call("get_all_assets", filter=request) assets = self.call("get_all_assets", filter=request)
# assets = self.client.get_all_assets(filter=request)
if not success:
return (success, assets)
assets = assets["itemlist"] assets = assets["itemlist"]
asset_list = [x["symbol"] for x in assets if "symbol" in x] asset_list = [x["symbol"] for x in assets if "symbol" in x]
print("Supported symbols", asset_list) print("Supported symbols", asset_list)
return (True, asset_list) return asset_list
def get_balance(self): def get_balance(self):
success, account_info = self.call("get_account") account_info = self.call("get_account")
if not success:
return (success, account_info)
equity = account_info["equity"] equity = account_info["equity"]
try: try:
balance = float(equity) balance = float(equity)
except ValueError: except ValueError:
return (False, "Invalid balance") raise GenericAPIError(f"Balance is not a float: {equity}")
return (True, balance) return balance
def get_market_value(self, symbol): # TODO: pydantic def get_market_value(self, symbol): # TODO: pydantic
try: try:
position = self.client.get_position(symbol) position = self.client.get_position(symbol)
except APIError as e: except APIError as e:
self.log.error(f"Could not get market value for {symbol}: {e}") self.log.error(f"Could not get market value for {symbol}: {e}")
return False raise GenericAPIError(e)
return float(position["market_value"]) return float(position["market_value"])
def post_trade(self, trade): # TODO: pydantic def post_trade(self, trade): # TODO: pydantic
@ -72,7 +74,7 @@ class AlpacaExchange(BaseExchange):
elif trade.direction == "sell": elif trade.direction == "sell":
direction = OrderSide.SELL direction = OrderSide.SELL
else: else:
raise Exception("Unknown direction") raise ExchangeError("Unknown direction")
cast = { cast = {
"symbol": trade.symbol, "symbol": trade.symbol,
@ -84,7 +86,7 @@ class AlpacaExchange(BaseExchange):
if trade.amount_usd is not None: if trade.amount_usd is not None:
cast["notional"] = trade.amount_usd cast["notional"] = trade.amount_usd
if not trade.amount and not trade.amount_usd: if not trade.amount and not trade.amount_usd:
return (False, "No amount specified") raise ExchangeError("No amount specified")
if trade.take_profit: if trade.take_profit:
cast["take_profit"] = {"limit_price": trade.take_profit} cast["take_profit"] = {"limit_price": trade.take_profit}
if trade.stop_loss: if trade.stop_loss:
@ -101,10 +103,10 @@ class AlpacaExchange(BaseExchange):
self.log.error(f"Error placing market order: {e}") self.log.error(f"Error placing market order: {e}")
trade.status = "error" trade.status = "error"
trade.save() trade.save()
return (False, e) raise GenericAPIError(e)
elif trade.type == "limit": elif trade.type == "limit":
if not trade.price: if not trade.price:
return (False, "Limit order with no price") raise ExchangeError("No price specified for limit order")
cast["limit_price"] = trade.price cast["limit_price"] = trade.price
limit_order_data = LimitOrderRequest(**cast) limit_order_data = LimitOrderRequest(**cast)
try: try:
@ -113,16 +115,16 @@ class AlpacaExchange(BaseExchange):
self.log.error(f"Error placing limit order: {e}") self.log.error(f"Error placing limit order: {e}")
trade.status = "error" trade.status = "error"
trade.save() trade.save()
return (False, e) raise GenericAPIError(e)
else: else:
raise Exception("Unknown trade type") raise ExchangeError("Unknown trade type")
trade.response = order trade.response = order
trade.status = "posted" trade.status = "posted"
trade.order_id = order["id"] trade.order_id = order["id"]
trade.client_order_id = order["client_order_id"] trade.client_order_id = order["client_order_id"]
trade.save() trade.save()
return (True, order) return order
def get_trade(self, trade_id): def get_trade(self, trade_id):
pass pass
@ -134,20 +136,16 @@ class AlpacaExchange(BaseExchange):
pass pass
def get_position_info(self, symbol): def get_position_info(self, symbol):
success, position = self.call("get_open_position", symbol) position = self.call("get_open_position", symbol)
if not success: return position
return (success, position)
return (True, position)
def get_all_positions(self): def get_all_positions(self):
items = [] items = []
success, response = self.call("get_all_positions") response = self.call("get_all_positions")
if not success:
return (success, response)
for item in response["itemlist"]: for item in response["itemlist"]:
item["account"] = self.account.name item["account"] = self.account.name
item["account_id"] = self.account.id item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"]) item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item) items.append(item)
return (True, items) return items

View File

@ -1,35 +1,16 @@
from oandapyV20 import API from oandapyV20 import API
from oandapyV20.endpoints import accounts, orders, positions, trades from oandapyV20.endpoints import accounts, positions
from pydantic import ValidationError
from core.exchanges import BaseExchange from core.exchanges import BaseExchange
from core.lib.schemas import oanda_s
OANDA_SCHEMA_MAPPING = {"OpenPositions": oanda_s.OpenPositions}
class OANDAExchange(BaseExchange): class OANDAExchange(BaseExchange):
def call(self, method, request): def call_method(self, request):
self.client.request(request) self.client.request(request)
response = request.response response = request.response
if isinstance(response, list): if isinstance(response, list):
response = {"itemlist": response} response = {"itemlist": response}
if method not in self.schema: return response
self.log.error(f"Method cannot be validated: {method}")
self.log.debug(f"Response: {response}")
return (False, f"Method cannot be validated: {method}")
try:
# Return a dict of the validated response
response_valid = self.schema[method](**response).dict()
# Convert the response to a format that we can use
response_converted = self.convert_spec(response_valid, method)
return (True, response_converted)
except ValidationError as e:
self.log.error(f"Could not validate response: {e}")
return (False, e)
def set_schema(self):
self.schema = OANDA_SCHEMA_MAPPING
def connect(self): def connect(self):
self.client = API(access_token=self.account.api_secret) self.client = API(access_token=self.account.api_secret)
@ -51,9 +32,9 @@ class OANDAExchange(BaseExchange):
def post_trade(self, trade): def post_trade(self, trade):
raise NotImplementedError raise NotImplementedError
r = orders.OrderCreate(accountID, data=data) # r = orders.OrderCreate(accountID, data=data)
self.client.request(r) # self.client.request(r)
return r.response # return r.response
def get_trade(self, trade_id): def get_trade(self, trade_id):
r = accounts.TradeDetails(accountID=self.account_id, tradeID=trade_id) r = accounts.TradeDetails(accountID=self.account_id, tradeID=trade_id)
@ -62,11 +43,11 @@ class OANDAExchange(BaseExchange):
def update_trade(self, trade): def update_trade(self, trade):
raise NotImplementedError raise NotImplementedError
r = orders.OrderReplace( # r = orders.OrderReplace(
accountID=self.account_id, orderID=trade.order_id, data=data # accountID=self.account_id, orderID=trade.order_id, data=data
) # )
self.client.request(r) # self.client.request(r)
return r.response # return r.response
def cancel_trade(self, trade_id): def cancel_trade(self, trade_id):
raise NotImplementedError raise NotImplementedError
@ -79,9 +60,7 @@ class OANDAExchange(BaseExchange):
def get_all_positions(self): def get_all_positions(self):
items = [] items = []
r = positions.OpenPositions(accountID=self.account_id) r = positions.OpenPositions(accountID=self.account_id)
success, response = self.call("OpenPositions", r) response = self.call(r)
if not success:
return (success, response)
print("Positions", response) print("Positions", response)
for item in response["itemlist"]: for item in response["itemlist"]:
@ -89,4 +68,4 @@ class OANDAExchange(BaseExchange):
item["account_id"] = self.account.id item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"]) item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item) items.append(item)
return (True, items) return items

View File

@ -97,7 +97,7 @@ class GetAllPositions(BaseModel):
] ]
} }
get_all_positions_schema = { GetAllPositionsSchema = {
"itemlist": ( "itemlist": (
"itemlist", "itemlist",
[ [
@ -113,6 +113,7 @@ get_all_positions_schema = {
) )
} }
# get_account # get_account
class GetAccount(BaseModel): class GetAccount(BaseModel):
id: str id: str
@ -151,4 +152,4 @@ class GetAccount(BaseModel):
balance_asof: str balance_asof: str
get_account_schema = {"": ""} GetAccountSchema = {"": ""}

View File

@ -153,7 +153,7 @@ def parse_side(x):
return "unknown" return "unknown"
OpenPositions_schema = { OpenPositionsSchema = {
"itemlist": ( "itemlist": (
"positions", "positions",
[ [

View File

@ -5,7 +5,6 @@ from django.db import models
from core.exchanges.alpaca import AlpacaExchange from core.exchanges.alpaca import AlpacaExchange
from core.exchanges.oanda import OANDAExchange from core.exchanges.oanda import OANDAExchange
from core.lib import trades
from core.lib.customers import get_or_create, update_customer_fields from core.lib.customers import get_or_create, update_customer_fields
from core.util import logs from core.util import logs
@ -100,7 +99,7 @@ class Account(models.Model):
if self.exchange in EXCHANGE_MAP: if self.exchange in EXCHANGE_MAP:
return EXCHANGE_MAP[self.exchange](self) return EXCHANGE_MAP[self.exchange](self)
else: else:
raise Exception("Exchange not supported") raise Exception(f"Exchange not supported : {self.exchange}")
@property @property
def client(self): def client(self):

View File

@ -111,7 +111,10 @@ class HookList(LoginRequiredMixin, ObjectList):
title = "Hooks" title = "Hooks"
title_singular = "Hook" title_singular = "Hook"
page_title = "List of active URL endpoints for receiving hooks." page_title = "List of active URL endpoints for receiving hooks."
page_subtitle = "Add URLs here to receive Drakdoo callbacks. Make then unique!" page_subtitle = (
"Add URLs here to receive Drakdoo callbacks. "
"Make then unique and hard to guess!"
)
list_url_name = "hooks" list_url_name = "hooks"
list_url_args = ["type"] list_url_args = ["type"]

View File

@ -6,7 +6,6 @@ from django.shortcuts import render
from django.views import View from django.views import View
from rest_framework.parsers import FormParser from rest_framework.parsers import FormParser
from core.lib import trades
from core.models import Account from core.models import Account
from core.util import logs from core.util import logs
@ -17,10 +16,7 @@ def get_positions(user, account_id=None):
items = [] items = []
accounts = Account.objects.filter(user=user) accounts = Account.objects.filter(user=user)
for account in accounts: for account in accounts:
success, positions = account.client.get_all_positions() positions = account.client.get_all_positions()
if not success:
items.append({"name": account.name, "status": "error"})
continue
for item in positions: for item in positions:
items.append(item) items.append(item)
@ -71,11 +67,8 @@ class PositionAction(LoginRequiredMixin, View):
unique = str(uuid.uuid4())[:8] unique = str(uuid.uuid4())[:8]
account = Account.get_by_id(account_id, request.user) account = Account.get_by_id(account_id, request.user)
success, info = account.client.get_position_info(symbol) info = account.client.get_position_info(symbol)
print("ACCT INFO", info) print("ACCT INFO", info)
if not success:
message = "Position does not exist"
message_class = "danger"
items = get_positions(request.user, account_id) items = get_positions(request.user, account_id)
if type == "page": if type == "page":
@ -87,9 +80,6 @@ class PositionAction(LoginRequiredMixin, View):
"items": items, "items": items,
"type": type, "type": type,
} }
if success:
context["items"] = info context["items"] = info
else:
context["message"] = message
context["class"] = message_class
return render(request, template_name, context) return render(request, template_name, context)