Simplify schema and error handling
This commit is contained in:
parent
04a87c1da6
commit
b36791d56b
|
@ -8,10 +8,56 @@ from core.util import logs
|
|||
STRICT_VALIDATION = False
|
||||
|
||||
# Raise exception if the conversion schema is not found
|
||||
STRICT_CONVERSTION = False
|
||||
STRICT_CONVERSION = False
|
||||
|
||||
# TODO: Set them to True when all message types are implemented
|
||||
|
||||
|
||||
class NoSchema(Exception):
|
||||
"""
|
||||
Raised when:
|
||||
- The schema for the message type is not found
|
||||
- The conversion schema is not found
|
||||
- There is no schema library for the exchange
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NoSuchMethod(Exception):
|
||||
"""
|
||||
Exchange library has no such method.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GenericAPIError(Exception):
|
||||
"""
|
||||
Generic API error.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExchangeError(Exception):
|
||||
"""
|
||||
Exchange error.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def is_camel_case(s):
|
||||
return s != s.lower() and s != s.upper() and "_" not in s
|
||||
|
||||
|
||||
def snake_to_camel(word):
|
||||
if is_camel_case(word):
|
||||
return word
|
||||
return "".join(x.capitalize() or "_" for x in word.split("_"))
|
||||
|
||||
|
||||
class BaseExchange(object):
|
||||
def __init__(self, account):
|
||||
name = self.__class__.__name__
|
||||
|
@ -20,7 +66,6 @@ class BaseExchange(object):
|
|||
self.log = logs.get_logger(self.name)
|
||||
self.client = None
|
||||
|
||||
self.set_schema()
|
||||
self.connect()
|
||||
|
||||
def set_schema(self):
|
||||
|
@ -29,53 +74,91 @@ class BaseExchange(object):
|
|||
def connect(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_spec(self, response, msg_type):
|
||||
@property
|
||||
def schema(self):
|
||||
"""
|
||||
Get the schema library for the exchange.
|
||||
"""
|
||||
# Does the schemas library have a library for this exchange name?
|
||||
if hasattr(schemas, f"{self.name}_s"):
|
||||
schema_instance = getattr(schemas, f"{self.name}_s")
|
||||
else:
|
||||
raise Exception(f"No schema for {self.name} in schema mapping")
|
||||
# Does the message type have a conversion spec for this message type?
|
||||
if hasattr(schema_instance, f"{msg_type}_schema"):
|
||||
schema = getattr(schema_instance, f"{msg_type}_schema")
|
||||
raise NoSchema(f"No schema for {self.name} in schema mapping")
|
||||
|
||||
return schema_instance
|
||||
|
||||
def get_schema(self, method, convert=False):
|
||||
if isinstance(method, str):
|
||||
to_camel = snake_to_camel(method)
|
||||
else:
|
||||
# Let us know so we can implement it, but don't do anything with it
|
||||
self.log.error(f"No schema for message: {msg_type} - {response}")
|
||||
if STRICT_CONVERSION:
|
||||
raise Exception(f"No schema for {msg_type} in schema mapping")
|
||||
return response
|
||||
|
||||
# Use glom to convert the response to the schema
|
||||
converted = glom(response, schema)
|
||||
print(f"[{self.name}] Converted of {msg_type}: {converted}")
|
||||
return converted
|
||||
|
||||
def call(self, method, *args, **kwargs) -> (bool, dict):
|
||||
to_camel = snake_to_camel(method.__class__.__name__)
|
||||
if convert:
|
||||
to_camel += "Schema"
|
||||
# if hasattr(self.schema, method):
|
||||
# schema = getattr(self.schema, method)
|
||||
if hasattr(self.schema, to_camel):
|
||||
schema = getattr(self.schema, to_camel)
|
||||
else:
|
||||
raise NoSchema
|
||||
return schema
|
||||
|
||||
def call_method(self, method, *args, **kwargs):
|
||||
"""
|
||||
Get a method from the exchange library.
|
||||
"""
|
||||
if hasattr(self.client, method):
|
||||
try:
|
||||
response = getattr(self.client, method)(*args, **kwargs)
|
||||
if isinstance(response, list):
|
||||
response = {"itemlist": response}
|
||||
if method not in self.schema:
|
||||
return response
|
||||
else:
|
||||
raise NoSuchMethod
|
||||
|
||||
def convert_spec(self, response, method):
|
||||
"""
|
||||
Convert an API response to the requested spec.
|
||||
:raises NoSchema: If the conversion schema is not found
|
||||
"""
|
||||
schema = self.get_schema(method, convert=True)
|
||||
|
||||
# Use glom to convert the response to the schema
|
||||
converted = glom(response, schema)
|
||||
print(f"[{self.name}] Converted of {method}: {converted}")
|
||||
return converted
|
||||
|
||||
def call(self, method, *args, **kwargs):
|
||||
"""
|
||||
Call the exchange API and validate the response
|
||||
:raises NoSchema: If the method is not in the schema mapping
|
||||
:raises ValidationError: If the response cannot be validated
|
||||
"""
|
||||
try:
|
||||
response = self.call_method(method, *args, **kwargs)
|
||||
try:
|
||||
schema = self.get_schema(method)
|
||||
# Return a dict of the validated response
|
||||
response_valid = schema(**response).dict()
|
||||
except NoSchema:
|
||||
self.log.error(f"Method cannot be validated: {method}")
|
||||
self.log.debug(f"Response: {response}")
|
||||
if STRICT_VALIDATION:
|
||||
return (False, f"Method cannot be validated: {method}")
|
||||
return (True, response)
|
||||
# Return a dict of the validated response
|
||||
response_valid = self.schema[method](**response).dict()
|
||||
raise
|
||||
# Return the response as is
|
||||
response_valid = response
|
||||
|
||||
# Convert the response to a format that we can use
|
||||
response_converted = self.convert_spec(response_valid, method)
|
||||
return (True, response_converted)
|
||||
# return (True, response_converted)
|
||||
return response_converted
|
||||
except ValidationError as e:
|
||||
self.log.error(f"Could not validate response: {e}")
|
||||
return (False, e)
|
||||
raise
|
||||
except NoSuchMethod:
|
||||
self.log.error(f"Method not found: {method}")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.log.error(f"Error calling {method}: {e}")
|
||||
return (False, e)
|
||||
else:
|
||||
return (False, "No such method")
|
||||
self.log.error(f"Error calling method: {e}")
|
||||
raise GenericAPIError(e)
|
||||
|
||||
def get_account(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import functools
|
||||
|
||||
from alpaca.common.exceptions import APIError
|
||||
from alpaca.trading.client import TradingClient
|
||||
from alpaca.trading.enums import OrderSide, TimeInForce
|
||||
|
@ -7,21 +9,23 @@ from alpaca.trading.requests import (
|
|||
MarketOrderRequest,
|
||||
)
|
||||
|
||||
from core.exchanges import BaseExchange
|
||||
from core.lib.schemas import alpaca_s
|
||||
from core.exchanges import BaseExchange, ExchangeError, GenericAPIError
|
||||
|
||||
ALPACA_SCHEMA_MAPPING = {
|
||||
"get_account": alpaca_s.GetAccount,
|
||||
"get_all_assets": alpaca_s.GetAllAssets,
|
||||
"get_all_positions": alpaca_s.GetAllPositions,
|
||||
"get_open_position": alpaca_s.GetOpenPosition,
|
||||
}
|
||||
|
||||
def handle_errors(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
return_value = func(self, *args, **kwargs)
|
||||
if isinstance(return_value, tuple):
|
||||
if return_value[0] is False:
|
||||
print("Error: ", return_value[1])
|
||||
return return_value
|
||||
return return_value[1]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class AlpacaExchange(BaseExchange):
|
||||
def set_schema(self):
|
||||
self.schema = ALPACA_SCHEMA_MAPPING
|
||||
|
||||
def connect(self):
|
||||
self.client = TradingClient(
|
||||
self.account.api_key,
|
||||
|
@ -30,39 +34,37 @@ class AlpacaExchange(BaseExchange):
|
|||
raw_data=True,
|
||||
)
|
||||
|
||||
@handle_errors
|
||||
def get_account(self):
|
||||
return self.call("get_account")
|
||||
# return self.call("get_account")
|
||||
market = self.get_market_value("NONEXISTENT")
|
||||
print("MARTKET", market)
|
||||
|
||||
def get_supported_assets(self):
|
||||
request = GetAssetsRequest(status="active", asset_class="crypto")
|
||||
success, assets = self.call("get_all_assets", filter=request)
|
||||
# assets = self.client.get_all_assets(filter=request)
|
||||
if not success:
|
||||
return (success, assets)
|
||||
assets = self.call("get_all_assets", filter=request)
|
||||
assets = assets["itemlist"]
|
||||
asset_list = [x["symbol"] for x in assets if "symbol" in x]
|
||||
print("Supported symbols", asset_list)
|
||||
|
||||
return (True, asset_list)
|
||||
return asset_list
|
||||
|
||||
def get_balance(self):
|
||||
success, account_info = self.call("get_account")
|
||||
if not success:
|
||||
return (success, account_info)
|
||||
account_info = self.call("get_account")
|
||||
equity = account_info["equity"]
|
||||
try:
|
||||
balance = float(equity)
|
||||
except ValueError:
|
||||
return (False, "Invalid balance")
|
||||
raise GenericAPIError(f"Balance is not a float: {equity}")
|
||||
|
||||
return (True, balance)
|
||||
return balance
|
||||
|
||||
def get_market_value(self, symbol): # TODO: pydantic
|
||||
try:
|
||||
position = self.client.get_position(symbol)
|
||||
except APIError as e:
|
||||
self.log.error(f"Could not get market value for {symbol}: {e}")
|
||||
return False
|
||||
raise GenericAPIError(e)
|
||||
return float(position["market_value"])
|
||||
|
||||
def post_trade(self, trade): # TODO: pydantic
|
||||
|
@ -72,7 +74,7 @@ class AlpacaExchange(BaseExchange):
|
|||
elif trade.direction == "sell":
|
||||
direction = OrderSide.SELL
|
||||
else:
|
||||
raise Exception("Unknown direction")
|
||||
raise ExchangeError("Unknown direction")
|
||||
|
||||
cast = {
|
||||
"symbol": trade.symbol,
|
||||
|
@ -84,7 +86,7 @@ class AlpacaExchange(BaseExchange):
|
|||
if trade.amount_usd is not None:
|
||||
cast["notional"] = trade.amount_usd
|
||||
if not trade.amount and not trade.amount_usd:
|
||||
return (False, "No amount specified")
|
||||
raise ExchangeError("No amount specified")
|
||||
if trade.take_profit:
|
||||
cast["take_profit"] = {"limit_price": trade.take_profit}
|
||||
if trade.stop_loss:
|
||||
|
@ -101,10 +103,10 @@ class AlpacaExchange(BaseExchange):
|
|||
self.log.error(f"Error placing market order: {e}")
|
||||
trade.status = "error"
|
||||
trade.save()
|
||||
return (False, e)
|
||||
raise GenericAPIError(e)
|
||||
elif trade.type == "limit":
|
||||
if not trade.price:
|
||||
return (False, "Limit order with no price")
|
||||
raise ExchangeError("No price specified for limit order")
|
||||
cast["limit_price"] = trade.price
|
||||
limit_order_data = LimitOrderRequest(**cast)
|
||||
try:
|
||||
|
@ -113,16 +115,16 @@ class AlpacaExchange(BaseExchange):
|
|||
self.log.error(f"Error placing limit order: {e}")
|
||||
trade.status = "error"
|
||||
trade.save()
|
||||
return (False, e)
|
||||
raise GenericAPIError(e)
|
||||
|
||||
else:
|
||||
raise Exception("Unknown trade type")
|
||||
raise ExchangeError("Unknown trade type")
|
||||
trade.response = order
|
||||
trade.status = "posted"
|
||||
trade.order_id = order["id"]
|
||||
trade.client_order_id = order["client_order_id"]
|
||||
trade.save()
|
||||
return (True, order)
|
||||
return order
|
||||
|
||||
def get_trade(self, trade_id):
|
||||
pass
|
||||
|
@ -134,20 +136,16 @@ class AlpacaExchange(BaseExchange):
|
|||
pass
|
||||
|
||||
def get_position_info(self, symbol):
|
||||
success, position = self.call("get_open_position", symbol)
|
||||
if not success:
|
||||
return (success, position)
|
||||
return (True, position)
|
||||
position = self.call("get_open_position", symbol)
|
||||
return position
|
||||
|
||||
def get_all_positions(self):
|
||||
items = []
|
||||
success, response = self.call("get_all_positions")
|
||||
if not success:
|
||||
return (success, response)
|
||||
response = self.call("get_all_positions")
|
||||
|
||||
for item in response["itemlist"]:
|
||||
item["account"] = self.account.name
|
||||
item["account_id"] = self.account.id
|
||||
item["unrealized_pl"] = float(item["unrealized_pl"])
|
||||
items.append(item)
|
||||
return (True, items)
|
||||
return items
|
||||
|
|
|
@ -1,35 +1,16 @@
|
|||
from oandapyV20 import API
|
||||
from oandapyV20.endpoints import accounts, orders, positions, trades
|
||||
from pydantic import ValidationError
|
||||
from oandapyV20.endpoints import accounts, positions
|
||||
|
||||
from core.exchanges import BaseExchange
|
||||
from core.lib.schemas import oanda_s
|
||||
|
||||
OANDA_SCHEMA_MAPPING = {"OpenPositions": oanda_s.OpenPositions}
|
||||
|
||||
|
||||
class OANDAExchange(BaseExchange):
|
||||
def call(self, method, request):
|
||||
def call_method(self, request):
|
||||
self.client.request(request)
|
||||
response = request.response
|
||||
if isinstance(response, list):
|
||||
response = {"itemlist": response}
|
||||
if method not in self.schema:
|
||||
self.log.error(f"Method cannot be validated: {method}")
|
||||
self.log.debug(f"Response: {response}")
|
||||
return (False, f"Method cannot be validated: {method}")
|
||||
try:
|
||||
# Return a dict of the validated response
|
||||
response_valid = self.schema[method](**response).dict()
|
||||
# Convert the response to a format that we can use
|
||||
response_converted = self.convert_spec(response_valid, method)
|
||||
return (True, response_converted)
|
||||
except ValidationError as e:
|
||||
self.log.error(f"Could not validate response: {e}")
|
||||
return (False, e)
|
||||
|
||||
def set_schema(self):
|
||||
self.schema = OANDA_SCHEMA_MAPPING
|
||||
return response
|
||||
|
||||
def connect(self):
|
||||
self.client = API(access_token=self.account.api_secret)
|
||||
|
@ -51,9 +32,9 @@ class OANDAExchange(BaseExchange):
|
|||
|
||||
def post_trade(self, trade):
|
||||
raise NotImplementedError
|
||||
r = orders.OrderCreate(accountID, data=data)
|
||||
self.client.request(r)
|
||||
return r.response
|
||||
# r = orders.OrderCreate(accountID, data=data)
|
||||
# self.client.request(r)
|
||||
# return r.response
|
||||
|
||||
def get_trade(self, trade_id):
|
||||
r = accounts.TradeDetails(accountID=self.account_id, tradeID=trade_id)
|
||||
|
@ -62,11 +43,11 @@ class OANDAExchange(BaseExchange):
|
|||
|
||||
def update_trade(self, trade):
|
||||
raise NotImplementedError
|
||||
r = orders.OrderReplace(
|
||||
accountID=self.account_id, orderID=trade.order_id, data=data
|
||||
)
|
||||
self.client.request(r)
|
||||
return r.response
|
||||
# r = orders.OrderReplace(
|
||||
# accountID=self.account_id, orderID=trade.order_id, data=data
|
||||
# )
|
||||
# self.client.request(r)
|
||||
# return r.response
|
||||
|
||||
def cancel_trade(self, trade_id):
|
||||
raise NotImplementedError
|
||||
|
@ -79,9 +60,7 @@ class OANDAExchange(BaseExchange):
|
|||
def get_all_positions(self):
|
||||
items = []
|
||||
r = positions.OpenPositions(accountID=self.account_id)
|
||||
success, response = self.call("OpenPositions", r)
|
||||
if not success:
|
||||
return (success, response)
|
||||
response = self.call(r)
|
||||
|
||||
print("Positions", response)
|
||||
for item in response["itemlist"]:
|
||||
|
@ -89,4 +68,4 @@ class OANDAExchange(BaseExchange):
|
|||
item["account_id"] = self.account.id
|
||||
item["unrealized_pl"] = float(item["unrealized_pl"])
|
||||
items.append(item)
|
||||
return (True, items)
|
||||
return items
|
||||
|
|
|
@ -97,7 +97,7 @@ class GetAllPositions(BaseModel):
|
|||
]
|
||||
}
|
||||
|
||||
get_all_positions_schema = {
|
||||
GetAllPositionsSchema = {
|
||||
"itemlist": (
|
||||
"itemlist",
|
||||
[
|
||||
|
@ -113,6 +113,7 @@ get_all_positions_schema = {
|
|||
)
|
||||
}
|
||||
|
||||
|
||||
# get_account
|
||||
class GetAccount(BaseModel):
|
||||
id: str
|
||||
|
@ -151,4 +152,4 @@ class GetAccount(BaseModel):
|
|||
balance_asof: str
|
||||
|
||||
|
||||
get_account_schema = {"": ""}
|
||||
GetAccountSchema = {"": ""}
|
||||
|
|
|
@ -153,7 +153,7 @@ def parse_side(x):
|
|||
return "unknown"
|
||||
|
||||
|
||||
OpenPositions_schema = {
|
||||
OpenPositionsSchema = {
|
||||
"itemlist": (
|
||||
"positions",
|
||||
[
|
||||
|
|
|
@ -5,7 +5,6 @@ from django.db import models
|
|||
|
||||
from core.exchanges.alpaca import AlpacaExchange
|
||||
from core.exchanges.oanda import OANDAExchange
|
||||
from core.lib import trades
|
||||
from core.lib.customers import get_or_create, update_customer_fields
|
||||
from core.util import logs
|
||||
|
||||
|
@ -100,7 +99,7 @@ class Account(models.Model):
|
|||
if self.exchange in EXCHANGE_MAP:
|
||||
return EXCHANGE_MAP[self.exchange](self)
|
||||
else:
|
||||
raise Exception("Exchange not supported")
|
||||
raise Exception(f"Exchange not supported : {self.exchange}")
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
|
|
|
@ -111,7 +111,10 @@ class HookList(LoginRequiredMixin, ObjectList):
|
|||
title = "Hooks"
|
||||
title_singular = "Hook"
|
||||
page_title = "List of active URL endpoints for receiving hooks."
|
||||
page_subtitle = "Add URLs here to receive Drakdoo callbacks. Make then unique!"
|
||||
page_subtitle = (
|
||||
"Add URLs here to receive Drakdoo callbacks. "
|
||||
"Make then unique and hard to guess!"
|
||||
)
|
||||
|
||||
list_url_name = "hooks"
|
||||
list_url_args = ["type"]
|
||||
|
|
|
@ -6,7 +6,6 @@ from django.shortcuts import render
|
|||
from django.views import View
|
||||
from rest_framework.parsers import FormParser
|
||||
|
||||
from core.lib import trades
|
||||
from core.models import Account
|
||||
from core.util import logs
|
||||
|
||||
|
@ -17,10 +16,7 @@ def get_positions(user, account_id=None):
|
|||
items = []
|
||||
accounts = Account.objects.filter(user=user)
|
||||
for account in accounts:
|
||||
success, positions = account.client.get_all_positions()
|
||||
if not success:
|
||||
items.append({"name": account.name, "status": "error"})
|
||||
continue
|
||||
positions = account.client.get_all_positions()
|
||||
|
||||
for item in positions:
|
||||
items.append(item)
|
||||
|
@ -71,11 +67,8 @@ class PositionAction(LoginRequiredMixin, View):
|
|||
unique = str(uuid.uuid4())[:8]
|
||||
|
||||
account = Account.get_by_id(account_id, request.user)
|
||||
success, info = account.client.get_position_info(symbol)
|
||||
info = account.client.get_position_info(symbol)
|
||||
print("ACCT INFO", info)
|
||||
if not success:
|
||||
message = "Position does not exist"
|
||||
message_class = "danger"
|
||||
|
||||
items = get_positions(request.user, account_id)
|
||||
if type == "page":
|
||||
|
@ -87,9 +80,6 @@ class PositionAction(LoginRequiredMixin, View):
|
|||
"items": items,
|
||||
"type": type,
|
||||
}
|
||||
if success:
|
||||
context["items"] = info
|
||||
else:
|
||||
context["message"] = message
|
||||
context["class"] = message_class
|
||||
|
||||
return render(request, template_name, context)
|
||||
|
|
Loading…
Reference in New Issue