Implement more validation and conversion
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from glom import glom
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.lib import schemas
|
||||
from core.util import logs
|
||||
@@ -12,6 +11,8 @@ STRICT_CONVERSION = False
|
||||
|
||||
# TODO: Set them to True when all message types are implemented
|
||||
|
||||
log = logs.get_logger("exchanges")
|
||||
|
||||
|
||||
class NoSchema(Exception):
|
||||
"""
|
||||
@@ -63,14 +64,10 @@ class BaseExchange(object):
|
||||
name = self.__class__.__name__
|
||||
self.name = name.replace("Exchange", "").lower()
|
||||
self.account = account
|
||||
self.log = logs.get_logger(self.name)
|
||||
self.client = None
|
||||
|
||||
self.connect()
|
||||
|
||||
def set_schema(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def connect(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -83,7 +80,8 @@ class BaseExchange(object):
|
||||
if hasattr(schemas, f"{self.name}_s"):
|
||||
schema_instance = getattr(schemas, f"{self.name}_s")
|
||||
else:
|
||||
raise NoSchema(f"No schema for {self.name} in schema mapping")
|
||||
log.error(f"No schema library for {self.name}")
|
||||
raise Exception(f"No schema library for exchange {self.name}")
|
||||
|
||||
return schema_instance
|
||||
|
||||
@@ -93,14 +91,13 @@ class BaseExchange(object):
|
||||
else:
|
||||
to_camel = snake_to_camel(method.__class__.__name__)
|
||||
if convert:
|
||||
to_camel += "Schema"
|
||||
to_camel = f"{to_camel}Schema"
|
||||
# if hasattr(self.schema, method):
|
||||
# schema = getattr(self.schema, method)
|
||||
if hasattr(self.schema, to_camel):
|
||||
schema = getattr(self.schema, to_camel)
|
||||
else:
|
||||
self.log.error(f"Method cannot be validated: {to_camel}")
|
||||
raise NoSchema(f"Method cannot be validated: {to_camel}")
|
||||
raise NoSchema(f"Could not get schema: {to_camel}")
|
||||
return schema
|
||||
|
||||
def call_method(self, method, *args, **kwargs):
|
||||
@@ -124,40 +121,37 @@ class BaseExchange(object):
|
||||
|
||||
# Use glom to convert the response to the schema
|
||||
converted = glom(response, schema)
|
||||
print(f"[{self.name}] Converted of {method}: {converted}")
|
||||
return converted
|
||||
|
||||
def validate_response(self, response, method):
|
||||
schema = self.get_schema(method)
|
||||
# Return a dict of the validated response
|
||||
response_valid = schema(**response).dict()
|
||||
return response_valid
|
||||
|
||||
def call(self, method, *args, **kwargs):
|
||||
"""
|
||||
Call the exchange API and validate the response
|
||||
:raises NoSchema: If the method is not in the schema mapping
|
||||
:raises ValidationError: If the response cannot be validated
|
||||
"""
|
||||
response = self.call_method(method, *args, **kwargs)
|
||||
try:
|
||||
response_valid = self.validate_response(response, method)
|
||||
except NoSchema as e:
|
||||
log.error(f"{e} - {response}")
|
||||
response_valid = response
|
||||
# Convert the response to a format that we can use
|
||||
try:
|
||||
response = self.call_method(method, *args, **kwargs)
|
||||
try:
|
||||
schema = self.get_schema(method)
|
||||
# Return a dict of the validated response
|
||||
response_valid = schema(**response).dict()
|
||||
except NoSchema:
|
||||
self.log.debug(f"No schema: {response}")
|
||||
if STRICT_VALIDATION:
|
||||
raise
|
||||
# Return the response as is
|
||||
response_valid = response
|
||||
|
||||
# Convert the response to a format that we can use
|
||||
response_converted = self.convert_spec(response_valid, method)
|
||||
# return (True, response_converted)
|
||||
return response_converted
|
||||
except ValidationError as e:
|
||||
self.log.error(f"Could not validate response: {e}")
|
||||
raise
|
||||
except NoSuchMethod:
|
||||
self.log.error(f"Method not found: {method}")
|
||||
raise
|
||||
except NoSchema as e:
|
||||
log.error(f"{e} - {response}")
|
||||
response_converted = response_valid
|
||||
# return (True, response_converted)
|
||||
return response_converted
|
||||
|
||||
# except Exception as e:
|
||||
# self.log.error(f"Error calling method: {e}")
|
||||
# log.error(f"Error calling method: {e}")
|
||||
# raise GenericAPIError(e)
|
||||
|
||||
def get_account(self):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import functools
|
||||
|
||||
from alpaca.common.exceptions import APIError
|
||||
from alpaca.trading.client import TradingClient
|
||||
from alpaca.trading.enums import OrderSide, TimeInForce
|
||||
@@ -12,19 +10,6 @@ from alpaca.trading.requests import (
|
||||
from core.exchanges import BaseExchange, ExchangeError, GenericAPIError
|
||||
|
||||
|
||||
def handle_errors(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
return_value = func(self, *args, **kwargs)
|
||||
if isinstance(return_value, tuple):
|
||||
if return_value[0] is False:
|
||||
print("Error: ", return_value[1])
|
||||
return return_value
|
||||
return return_value[1]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class AlpacaExchange(BaseExchange):
|
||||
def connect(self):
|
||||
self.client = TradingClient(
|
||||
@@ -34,18 +19,14 @@ class AlpacaExchange(BaseExchange):
|
||||
raw_data=True,
|
||||
)
|
||||
|
||||
@handle_errors
|
||||
def get_account(self):
|
||||
# return self.call("get_account")
|
||||
market = self.get_market_value("NONEXISTENT")
|
||||
print("MARTKET", market)
|
||||
return self.call("get_account")
|
||||
|
||||
def get_supported_assets(self):
|
||||
request = GetAssetsRequest(status="active", asset_class="crypto")
|
||||
assets = self.call("get_all_assets", filter=request)
|
||||
assets = assets["itemlist"]
|
||||
asset_list = [x["symbol"] for x in assets if "symbol" in x]
|
||||
print("Supported symbols", asset_list)
|
||||
|
||||
return asset_list
|
||||
|
||||
|
||||
@@ -21,7 +21,9 @@ class OANDAExchange(BaseExchange):
|
||||
return self.call(r)
|
||||
|
||||
def get_supported_assets(self):
|
||||
return False
|
||||
r = accounts.AccountInstruments(accountID=self.account_id)
|
||||
response = self.call(r)
|
||||
return [x["name"] for x in response["itemlist"]]
|
||||
|
||||
def get_balance(self):
|
||||
raise NotImplementedError
|
||||
@@ -59,7 +61,6 @@ class OANDAExchange(BaseExchange):
|
||||
r = positions.OpenPositions(accountID=self.account_id)
|
||||
response = self.call(r)
|
||||
|
||||
print("Positions", response)
|
||||
for item in response["itemlist"]:
|
||||
item["account"] = self.account.name
|
||||
item["account_id"] = self.account.id
|
||||
|
||||
Reference in New Issue
Block a user