from glom import glom from pydantic import ValidationError from core.lib import schemas from core.util import logs # Return error if the schema for the message type is not found STRICT_VALIDATION = False # Raise exception if the conversion schema is not found STRICT_CONVERSION = False # TODO: Set them to True when all message types are implemented class NoSchema(Exception): """ Raised when: - The schema for the message type is not found - The conversion schema is not found - There is no schema library for the exchange """ pass class NoSuchMethod(Exception): """ Exchange library has no such method. """ pass class GenericAPIError(Exception): """ Generic API error. """ pass class ExchangeError(Exception): """ Exchange error. """ pass def is_camel_case(s): return s != s.lower() and s != s.upper() and "_" not in s def snake_to_camel(word): if is_camel_case(word): return word return "".join(x.capitalize() or "_" for x in word.split("_")) class BaseExchange(object): def __init__(self, account): name = self.__class__.__name__ self.name = name.replace("Exchange", "").lower() self.account = account self.log = logs.get_logger(self.name) self.client = None self.connect() def set_schema(self): raise NotImplementedError def connect(self): raise NotImplementedError @property def schema(self): """ Get the schema library for the exchange. """ # Does the schemas library have a library for this exchange name? if hasattr(schemas, f"{self.name}_s"): schema_instance = getattr(schemas, f"{self.name}_s") else: raise NoSchema(f"No schema for {self.name} in schema mapping") return schema_instance def get_schema(self, method, convert=False): if isinstance(method, str): to_camel = snake_to_camel(method) else: to_camel = snake_to_camel(method.__class__.__name__) if convert: to_camel += "Schema" # if hasattr(self.schema, method): # schema = getattr(self.schema, method) if hasattr(self.schema, to_camel): schema = getattr(self.schema, to_camel) else: raise NoSchema return schema def call_method(self, method, *args, **kwargs): """ Get a method from the exchange library. """ if hasattr(self.client, method): response = getattr(self.client, method)(*args, **kwargs) if isinstance(response, list): response = {"itemlist": response} return response else: raise NoSuchMethod def convert_spec(self, response, method): """ Convert an API response to the requested spec. :raises NoSchema: If the conversion schema is not found """ schema = self.get_schema(method, convert=True) # Use glom to convert the response to the schema converted = glom(response, schema) print(f"[{self.name}] Converted of {method}: {converted}") return converted def call(self, method, *args, **kwargs): """ Call the exchange API and validate the response :raises NoSchema: If the method is not in the schema mapping :raises ValidationError: If the response cannot be validated """ try: response = self.call_method(method, *args, **kwargs) try: schema = self.get_schema(method) # Return a dict of the validated response response_valid = schema(**response).dict() except NoSchema: self.log.error(f"Method cannot be validated: {method}") self.log.debug(f"Response: {response}") if STRICT_VALIDATION: raise # Return the response as is response_valid = response # Convert the response to a format that we can use response_converted = self.convert_spec(response_valid, method) # return (True, response_converted) return response_converted except ValidationError as e: self.log.error(f"Could not validate response: {e}") raise except NoSuchMethod: self.log.error(f"Method not found: {method}") raise except Exception as e: self.log.error(f"Error calling method: {e}") raise GenericAPIError(e) def get_account(self): raise NotImplementedError def get_supported_assets(self): raise NotImplementedError def get_balance(self): raise NotImplementedError def get_market_value(self, symbol): raise NotImplementedError def post_trade(self, trade): raise NotImplementedError def get_trade(self, trade_id): raise NotImplementedError def update_trade(self, trade): raise NotImplementedError def cancel_trade(self, trade_id): raise NotImplementedError def get_position_info(self, symbol): raise NotImplementedError def get_all_positions(self): raise NotImplementedError