Simplify schema and error handling

Mark Veidemanis 2 years ago
parent 04a87c1da6
commit b36791d56b
Signed by: m
GPG Key ID: 5ACFCEED46C0904F

@ -8,10 +8,56 @@ from core.util import logs
STRICT_VALIDATION = False
# Raise exception if the conversion schema is not found
STRICT_CONVERSTION = False
STRICT_CONVERSION = False
# TODO: Set them to True when all message types are implemented
class NoSchema(Exception):
"""
Raised when:
- The schema for the message type is not found
- The conversion schema is not found
- There is no schema library for the exchange
"""
pass
class NoSuchMethod(Exception):
"""
Exchange library has no such method.
"""
pass
class GenericAPIError(Exception):
"""
Generic API error.
"""
pass
class ExchangeError(Exception):
"""
Exchange error.
"""
pass
def is_camel_case(s):
return s != s.lower() and s != s.upper() and "_" not in s
def snake_to_camel(word):
if is_camel_case(word):
return word
return "".join(x.capitalize() or "_" for x in word.split("_"))
class BaseExchange(object):
def __init__(self, account):
name = self.__class__.__name__
@ -20,7 +66,6 @@ class BaseExchange(object):
self.log = logs.get_logger(self.name)
self.client = None
self.set_schema()
self.connect()
def set_schema(self):
@ -29,53 +74,91 @@ class BaseExchange(object):
def connect(self):
raise NotImplementedError
def convert_spec(self, response, msg_type):
@property
def schema(self):
"""
Get the schema library for the exchange.
"""
# Does the schemas library have a library for this exchange name?
if hasattr(schemas, f"{self.name}_s"):
schema_instance = getattr(schemas, f"{self.name}_s")
else:
raise Exception(f"No schema for {self.name} in schema mapping")
# Does the message type have a conversion spec for this message type?
if hasattr(schema_instance, f"{msg_type}_schema"):
schema = getattr(schema_instance, f"{msg_type}_schema")
raise NoSchema(f"No schema for {self.name} in schema mapping")
return schema_instance
def get_schema(self, method, convert=False):
if isinstance(method, str):
to_camel = snake_to_camel(method)
else:
to_camel = snake_to_camel(method.__class__.__name__)
if convert:
to_camel += "Schema"
# if hasattr(self.schema, method):
# schema = getattr(self.schema, method)
if hasattr(self.schema, to_camel):
schema = getattr(self.schema, to_camel)
else:
# Let us know so we can implement it, but don't do anything with it
self.log.error(f"No schema for message: {msg_type} - {response}")
if STRICT_CONVERSION:
raise Exception(f"No schema for {msg_type} in schema mapping")
raise NoSchema
return schema
def call_method(self, method, *args, **kwargs):
"""
Get a method from the exchange library.
"""
if hasattr(self.client, method):
response = getattr(self.client, method)(*args, **kwargs)
if isinstance(response, list):
response = {"itemlist": response}
return response
else:
raise NoSuchMethod
def convert_spec(self, response, method):
"""
Convert an API response to the requested spec.
:raises NoSchema: If the conversion schema is not found
"""
schema = self.get_schema(method, convert=True)
# Use glom to convert the response to the schema
converted = glom(response, schema)
print(f"[{self.name}] Converted of {msg_type}: {converted}")
print(f"[{self.name}] Converted of {method}: {converted}")
return converted
def call(self, method, *args, **kwargs) -> (bool, dict):
if hasattr(self.client, method):
def call(self, method, *args, **kwargs):
"""
Call the exchange API and validate the response
:raises NoSchema: If the method is not in the schema mapping
:raises ValidationError: If the response cannot be validated
"""
try:
response = self.call_method(method, *args, **kwargs)
try:
response = getattr(self.client, method)(*args, **kwargs)
if isinstance(response, list):
response = {"itemlist": response}
if method not in self.schema:
self.log.error(f"Method cannot be validated: {method}")
self.log.debug(f"Response: {response}")
if STRICT_VALIDATION:
return (False, f"Method cannot be validated: {method}")
return (True, response)
schema = self.get_schema(method)
# Return a dict of the validated response
response_valid = self.schema[method](**response).dict()
# Convert the response to a format that we can use
response_converted = self.convert_spec(response_valid, method)
return (True, response_converted)
except ValidationError as e:
self.log.error(f"Could not validate response: {e}")
return (False, e)
except Exception as e:
self.log.error(f"Error calling {method}: {e}")
return (False, e)
else:
return (False, "No such method")
response_valid = schema(**response).dict()
except NoSchema:
self.log.error(f"Method cannot be validated: {method}")
self.log.debug(f"Response: {response}")
if STRICT_VALIDATION:
raise
# Return the response as is
response_valid = response
# Convert the response to a format that we can use
response_converted = self.convert_spec(response_valid, method)
# return (True, response_converted)
return response_converted
except ValidationError as e:
self.log.error(f"Could not validate response: {e}")
raise
except NoSuchMethod:
self.log.error(f"Method not found: {method}")
raise
except Exception as e:
self.log.error(f"Error calling method: {e}")
raise GenericAPIError(e)
def get_account(self):
raise NotImplementedError

@ -1,3 +1,5 @@
import functools
from alpaca.common.exceptions import APIError
from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide, TimeInForce
@ -7,21 +9,23 @@ from alpaca.trading.requests import (
MarketOrderRequest,
)
from core.exchanges import BaseExchange
from core.lib.schemas import alpaca_s
from core.exchanges import BaseExchange, ExchangeError, GenericAPIError
ALPACA_SCHEMA_MAPPING = {
"get_account": alpaca_s.GetAccount,
"get_all_assets": alpaca_s.GetAllAssets,
"get_all_positions": alpaca_s.GetAllPositions,
"get_open_position": alpaca_s.GetOpenPosition,
}
def handle_errors(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
return_value = func(self, *args, **kwargs)
if isinstance(return_value, tuple):
if return_value[0] is False:
print("Error: ", return_value[1])
return return_value
return return_value[1]
class AlpacaExchange(BaseExchange):
def set_schema(self):
self.schema = ALPACA_SCHEMA_MAPPING
return wrapper
class AlpacaExchange(BaseExchange):
def connect(self):
self.client = TradingClient(
self.account.api_key,
@ -30,39 +34,37 @@ class AlpacaExchange(BaseExchange):
raw_data=True,
)
@handle_errors
def get_account(self):
return self.call("get_account")
# return self.call("get_account")
market = self.get_market_value("NONEXISTENT")
print("MARTKET", market)
def get_supported_assets(self):
request = GetAssetsRequest(status="active", asset_class="crypto")
success, assets = self.call("get_all_assets", filter=request)
# assets = self.client.get_all_assets(filter=request)
if not success:
return (success, assets)
assets = self.call("get_all_assets", filter=request)
assets = assets["itemlist"]
asset_list = [x["symbol"] for x in assets if "symbol" in x]
print("Supported symbols", asset_list)
return (True, asset_list)
return asset_list
def get_balance(self):
success, account_info = self.call("get_account")
if not success:
return (success, account_info)
account_info = self.call("get_account")
equity = account_info["equity"]
try:
balance = float(equity)
except ValueError:
return (False, "Invalid balance")
raise GenericAPIError(f"Balance is not a float: {equity}")
return (True, balance)
return balance
def get_market_value(self, symbol): # TODO: pydantic
try:
position = self.client.get_position(symbol)
except APIError as e:
self.log.error(f"Could not get market value for {symbol}: {e}")
return False
raise GenericAPIError(e)
return float(position["market_value"])
def post_trade(self, trade): # TODO: pydantic
@ -72,7 +74,7 @@ class AlpacaExchange(BaseExchange):
elif trade.direction == "sell":
direction = OrderSide.SELL
else:
raise Exception("Unknown direction")
raise ExchangeError("Unknown direction")
cast = {
"symbol": trade.symbol,
@ -84,7 +86,7 @@ class AlpacaExchange(BaseExchange):
if trade.amount_usd is not None:
cast["notional"] = trade.amount_usd
if not trade.amount and not trade.amount_usd:
return (False, "No amount specified")
raise ExchangeError("No amount specified")
if trade.take_profit:
cast["take_profit"] = {"limit_price": trade.take_profit}
if trade.stop_loss:
@ -101,10 +103,10 @@ class AlpacaExchange(BaseExchange):
self.log.error(f"Error placing market order: {e}")
trade.status = "error"
trade.save()
return (False, e)
raise GenericAPIError(e)
elif trade.type == "limit":
if not trade.price:
return (False, "Limit order with no price")
raise ExchangeError("No price specified for limit order")
cast["limit_price"] = trade.price
limit_order_data = LimitOrderRequest(**cast)
try:
@ -113,16 +115,16 @@ class AlpacaExchange(BaseExchange):
self.log.error(f"Error placing limit order: {e}")
trade.status = "error"
trade.save()
return (False, e)
raise GenericAPIError(e)
else:
raise Exception("Unknown trade type")
raise ExchangeError("Unknown trade type")
trade.response = order
trade.status = "posted"
trade.order_id = order["id"]
trade.client_order_id = order["client_order_id"]
trade.save()
return (True, order)
return order
def get_trade(self, trade_id):
pass
@ -134,20 +136,16 @@ class AlpacaExchange(BaseExchange):
pass
def get_position_info(self, symbol):
success, position = self.call("get_open_position", symbol)
if not success:
return (success, position)
return (True, position)
position = self.call("get_open_position", symbol)
return position
def get_all_positions(self):
items = []
success, response = self.call("get_all_positions")
if not success:
return (success, response)
response = self.call("get_all_positions")
for item in response["itemlist"]:
item["account"] = self.account.name
item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item)
return (True, items)
return items

@ -1,35 +1,16 @@
from oandapyV20 import API
from oandapyV20.endpoints import accounts, orders, positions, trades
from pydantic import ValidationError
from oandapyV20.endpoints import accounts, positions
from core.exchanges import BaseExchange
from core.lib.schemas import oanda_s
OANDA_SCHEMA_MAPPING = {"OpenPositions": oanda_s.OpenPositions}
class OANDAExchange(BaseExchange):
def call(self, method, request):
def call_method(self, request):
self.client.request(request)
response = request.response
if isinstance(response, list):
response = {"itemlist": response}
if method not in self.schema:
self.log.error(f"Method cannot be validated: {method}")
self.log.debug(f"Response: {response}")
return (False, f"Method cannot be validated: {method}")
try:
# Return a dict of the validated response
response_valid = self.schema[method](**response).dict()
# Convert the response to a format that we can use
response_converted = self.convert_spec(response_valid, method)
return (True, response_converted)
except ValidationError as e:
self.log.error(f"Could not validate response: {e}")
return (False, e)
def set_schema(self):
self.schema = OANDA_SCHEMA_MAPPING
return response
def connect(self):
self.client = API(access_token=self.account.api_secret)
@ -51,9 +32,9 @@ class OANDAExchange(BaseExchange):
def post_trade(self, trade):
raise NotImplementedError
r = orders.OrderCreate(accountID, data=data)
self.client.request(r)
return r.response
# r = orders.OrderCreate(accountID, data=data)
# self.client.request(r)
# return r.response
def get_trade(self, trade_id):
r = accounts.TradeDetails(accountID=self.account_id, tradeID=trade_id)
@ -62,11 +43,11 @@ class OANDAExchange(BaseExchange):
def update_trade(self, trade):
raise NotImplementedError
r = orders.OrderReplace(
accountID=self.account_id, orderID=trade.order_id, data=data
)
self.client.request(r)
return r.response
# r = orders.OrderReplace(
# accountID=self.account_id, orderID=trade.order_id, data=data
# )
# self.client.request(r)
# return r.response
def cancel_trade(self, trade_id):
raise NotImplementedError
@ -79,9 +60,7 @@ class OANDAExchange(BaseExchange):
def get_all_positions(self):
items = []
r = positions.OpenPositions(accountID=self.account_id)
success, response = self.call("OpenPositions", r)
if not success:
return (success, response)
response = self.call(r)
print("Positions", response)
for item in response["itemlist"]:
@ -89,4 +68,4 @@ class OANDAExchange(BaseExchange):
item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item)
return (True, items)
return items

@ -97,7 +97,7 @@ class GetAllPositions(BaseModel):
]
}
get_all_positions_schema = {
GetAllPositionsSchema = {
"itemlist": (
"itemlist",
[
@ -113,6 +113,7 @@ get_all_positions_schema = {
)
}
# get_account
class GetAccount(BaseModel):
id: str
@ -151,4 +152,4 @@ class GetAccount(BaseModel):
balance_asof: str
get_account_schema = {"": ""}
GetAccountSchema = {"": ""}

@ -153,7 +153,7 @@ def parse_side(x):
return "unknown"
OpenPositions_schema = {
OpenPositionsSchema = {
"itemlist": (
"positions",
[

@ -5,7 +5,6 @@ from django.db import models
from core.exchanges.alpaca import AlpacaExchange
from core.exchanges.oanda import OANDAExchange
from core.lib import trades
from core.lib.customers import get_or_create, update_customer_fields
from core.util import logs
@ -100,7 +99,7 @@ class Account(models.Model):
if self.exchange in EXCHANGE_MAP:
return EXCHANGE_MAP[self.exchange](self)
else:
raise Exception("Exchange not supported")
raise Exception(f"Exchange not supported : {self.exchange}")
@property
def client(self):

@ -111,7 +111,10 @@ class HookList(LoginRequiredMixin, ObjectList):
title = "Hooks"
title_singular = "Hook"
page_title = "List of active URL endpoints for receiving hooks."
page_subtitle = "Add URLs here to receive Drakdoo callbacks. Make then unique!"
page_subtitle = (
"Add URLs here to receive Drakdoo callbacks. "
"Make then unique and hard to guess!"
)
list_url_name = "hooks"
list_url_args = ["type"]

@ -6,7 +6,6 @@ from django.shortcuts import render
from django.views import View
from rest_framework.parsers import FormParser
from core.lib import trades
from core.models import Account
from core.util import logs
@ -17,10 +16,7 @@ def get_positions(user, account_id=None):
items = []
accounts = Account.objects.filter(user=user)
for account in accounts:
success, positions = account.client.get_all_positions()
if not success:
items.append({"name": account.name, "status": "error"})
continue
positions = account.client.get_all_positions()
for item in positions:
items.append(item)
@ -71,11 +67,8 @@ class PositionAction(LoginRequiredMixin, View):
unique = str(uuid.uuid4())[:8]
account = Account.get_by_id(account_id, request.user)
success, info = account.client.get_position_info(symbol)
info = account.client.get_position_info(symbol)
print("ACCT INFO", info)
if not success:
message = "Position does not exist"
message_class = "danger"
items = get_positions(request.user, account_id)
if type == "page":
@ -87,9 +80,6 @@ class PositionAction(LoginRequiredMixin, View):
"items": items,
"type": type,
}
if success:
context["items"] = info
else:
context["message"] = message
context["class"] = message_class
context["items"] = info
return render(request, template_name, context)

Loading…
Cancel
Save