Implement more validation and conversion

This commit is contained in:
Mark Veidemanis 2022-11-04 07:20:14 +00:00
parent d34ac39d68
commit 60979652d9
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
8 changed files with 294 additions and 158 deletions

View File

@ -1,5 +1,4 @@
from glom import glom from glom import glom
from pydantic import ValidationError
from core.lib import schemas from core.lib import schemas
from core.util import logs from core.util import logs
@ -12,6 +11,8 @@ STRICT_CONVERSION = False
# TODO: Set them to True when all message types are implemented # TODO: Set them to True when all message types are implemented
log = logs.get_logger("exchanges")
class NoSchema(Exception): class NoSchema(Exception):
""" """
@ -63,14 +64,10 @@ class BaseExchange(object):
name = self.__class__.__name__ name = self.__class__.__name__
self.name = name.replace("Exchange", "").lower() self.name = name.replace("Exchange", "").lower()
self.account = account self.account = account
self.log = logs.get_logger(self.name)
self.client = None self.client = None
self.connect() self.connect()
def set_schema(self):
raise NotImplementedError
def connect(self): def connect(self):
raise NotImplementedError raise NotImplementedError
@ -83,7 +80,8 @@ class BaseExchange(object):
if hasattr(schemas, f"{self.name}_s"): if hasattr(schemas, f"{self.name}_s"):
schema_instance = getattr(schemas, f"{self.name}_s") schema_instance = getattr(schemas, f"{self.name}_s")
else: else:
raise NoSchema(f"No schema for {self.name} in schema mapping") log.error(f"No schema library for {self.name}")
raise Exception(f"No schema library for exchange {self.name}")
return schema_instance return schema_instance
@ -93,14 +91,13 @@ class BaseExchange(object):
else: else:
to_camel = snake_to_camel(method.__class__.__name__) to_camel = snake_to_camel(method.__class__.__name__)
if convert: if convert:
to_camel += "Schema" to_camel = f"{to_camel}Schema"
# if hasattr(self.schema, method): # if hasattr(self.schema, method):
# schema = getattr(self.schema, method) # schema = getattr(self.schema, method)
if hasattr(self.schema, to_camel): if hasattr(self.schema, to_camel):
schema = getattr(self.schema, to_camel) schema = getattr(self.schema, to_camel)
else: else:
self.log.error(f"Method cannot be validated: {to_camel}") raise NoSchema(f"Could not get schema: {to_camel}")
raise NoSchema(f"Method cannot be validated: {to_camel}")
return schema return schema
def call_method(self, method, *args, **kwargs): def call_method(self, method, *args, **kwargs):
@ -124,40 +121,37 @@ class BaseExchange(object):
# Use glom to convert the response to the schema # Use glom to convert the response to the schema
converted = glom(response, schema) converted = glom(response, schema)
print(f"[{self.name}] Converted of {method}: {converted}")
return converted return converted
def validate_response(self, response, method):
schema = self.get_schema(method)
# Return a dict of the validated response
response_valid = schema(**response).dict()
return response_valid
def call(self, method, *args, **kwargs): def call(self, method, *args, **kwargs):
""" """
Call the exchange API and validate the response Call the exchange API and validate the response
:raises NoSchema: If the method is not in the schema mapping :raises NoSchema: If the method is not in the schema mapping
:raises ValidationError: If the response cannot be validated :raises ValidationError: If the response cannot be validated
""" """
response = self.call_method(method, *args, **kwargs)
try:
response_valid = self.validate_response(response, method)
except NoSchema as e:
log.error(f"{e} - {response}")
response_valid = response
# Convert the response to a format that we can use
try: try:
response = self.call_method(method, *args, **kwargs)
try:
schema = self.get_schema(method)
# Return a dict of the validated response
response_valid = schema(**response).dict()
except NoSchema:
self.log.debug(f"No schema: {response}")
if STRICT_VALIDATION:
raise
# Return the response as is
response_valid = response
# Convert the response to a format that we can use
response_converted = self.convert_spec(response_valid, method) response_converted = self.convert_spec(response_valid, method)
# return (True, response_converted) except NoSchema as e:
return response_converted log.error(f"{e} - {response}")
except ValidationError as e: response_converted = response_valid
self.log.error(f"Could not validate response: {e}") # return (True, response_converted)
raise return response_converted
except NoSuchMethod:
self.log.error(f"Method not found: {method}")
raise
# except Exception as e: # except Exception as e:
# self.log.error(f"Error calling method: {e}") # log.error(f"Error calling method: {e}")
# raise GenericAPIError(e) # raise GenericAPIError(e)
def get_account(self): def get_account(self):

View File

@ -1,5 +1,3 @@
import functools
from alpaca.common.exceptions import APIError from alpaca.common.exceptions import APIError
from alpaca.trading.client import TradingClient from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide, TimeInForce from alpaca.trading.enums import OrderSide, TimeInForce
@ -12,19 +10,6 @@ from alpaca.trading.requests import (
from core.exchanges import BaseExchange, ExchangeError, GenericAPIError from core.exchanges import BaseExchange, ExchangeError, GenericAPIError
def handle_errors(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
return_value = func(self, *args, **kwargs)
if isinstance(return_value, tuple):
if return_value[0] is False:
print("Error: ", return_value[1])
return return_value
return return_value[1]
return wrapper
class AlpacaExchange(BaseExchange): class AlpacaExchange(BaseExchange):
def connect(self): def connect(self):
self.client = TradingClient( self.client = TradingClient(
@ -34,18 +19,14 @@ class AlpacaExchange(BaseExchange):
raw_data=True, raw_data=True,
) )
@handle_errors
def get_account(self): def get_account(self):
# return self.call("get_account") return self.call("get_account")
market = self.get_market_value("NONEXISTENT")
print("MARTKET", market)
def get_supported_assets(self): def get_supported_assets(self):
request = GetAssetsRequest(status="active", asset_class="crypto") request = GetAssetsRequest(status="active", asset_class="crypto")
assets = self.call("get_all_assets", filter=request) assets = self.call("get_all_assets", filter=request)
assets = assets["itemlist"] assets = assets["itemlist"]
asset_list = [x["symbol"] for x in assets if "symbol" in x] asset_list = [x["symbol"] for x in assets if "symbol" in x]
print("Supported symbols", asset_list)
return asset_list return asset_list

View File

@ -21,7 +21,9 @@ class OANDAExchange(BaseExchange):
return self.call(r) return self.call(r)
def get_supported_assets(self): def get_supported_assets(self):
return False r = accounts.AccountInstruments(accountID=self.account_id)
response = self.call(r)
return [x["name"] for x in response["itemlist"]]
def get_balance(self): def get_balance(self):
raise NotImplementedError raise NotImplementedError
@ -59,7 +61,6 @@ class OANDAExchange(BaseExchange):
r = positions.OpenPositions(accountID=self.account_id) r = positions.OpenPositions(accountID=self.account_id)
response = self.call(r) response = self.call(r)
print("Positions", response)
for item in response["itemlist"]: for item in response["itemlist"]:
item["account"] = self.account.name item["account"] = self.account.name
item["account_id"] = self.account.id item["account_id"] = self.account.id

View File

@ -24,6 +24,32 @@ class GetAllAssets(BaseModel):
itemlist: list[Asset] itemlist: list[Asset]
GetAllAssetsSchema = {
"itemlist": (
"itemlist",
[
{
"id": "id",
"class": "class",
"exchange": "exchange",
"symbol": "symbol",
"name": "name",
"status": "status",
"tradable": "tradable",
"marginable": "marginable",
"maintenance_margin_requirement": "maintenanceMarginRequirement",
"shortable": "shortable",
"easy_to_borrow": "easyToBorrow",
"fractionable": "fractionable",
"min_order_size": "minOrderSize",
"min_trade_increment": "minTradeIncrement",
"price_increment": "priceIncrement",
}
],
)
}
# get_open_position # get_open_position
class GetOpenPosition(BaseModel): class GetOpenPosition(BaseModel):
asset_id: str asset_id: str
@ -72,31 +98,6 @@ class GetAllPositions(BaseModel):
itemlist: list[Position] itemlist: list[Position]
{
"itemlist": [
{
"asset_id": "64bbff51-59d6-4b3c-9351-13ad85e3c752",
"symbol": "BTCUSD",
"exchange": "FTXU",
"asset_class": "crypto",
"asset_marginable": False,
"qty": "0.009975",
"avg_entry_price": "20714",
"side": "long",
"market_value": "204.297975",
"cost_basis": "206.62215",
"unrealized_pl": "-2.324175",
"unrealized_plpc": "-0.0112484310128416",
"unrealized_intraday_pl": "-0.269325",
"unrealized_intraday_plpc": "-0.001316559391457",
"current_price": "20481",
"lastday_price": "20508",
"change_today": "-0.001316559391457",
"qty_available": "0.009975",
}
]
}
GetAllPositionsSchema = { GetAllPositionsSchema = {
"itemlist": ( "itemlist": (
"itemlist", "itemlist",
@ -152,4 +153,39 @@ class GetAccount(BaseModel):
balance_asof: str balance_asof: str
GetAccountSchema = {"": ""} GetAccountSchema = {
"id": "id",
"account_number": "account_number",
"status": "status",
"crypto_status": "crypto_status",
"currency": "currency",
"buying_power": "buying_power",
"regt_buying_power": "regt_buying_power",
"daytrading_buying_power": "daytrading_buying_power",
"effective_buying_power": "effective_buying_power",
"non_marginable_buying_power": "non_marginable_buying_power",
"bod_dtbp": "bod_dtbp",
"cash": "cash",
"accrued_fees": "accrued_fees",
"pending_transfer_in": "pending_transfer_in",
"portfolio_value": "portfolio_value",
"pattern_day_trader": "pattern_day_trader",
"trading_blocked": "trading_blocked",
"transfers_blocked": "transfers_blocked",
"account_blocked": "account_blocked",
"created_at": "created_at",
"trade_suspended_by_user": "trade_suspended_by_user",
"multiplier": "multiplier",
"shorting_enabled": "shorting_enabled",
"equity": "equity",
"last_equity": "last_equity",
"long_market_value": "long_market_value",
"short_market_value": "short_market_value",
"position_market_value": "position_market_value",
"initial_margin": "initial_margin",
"maintenance_margin": "maintenance_margin",
"last_maintenance_margin": "last_maintenance_margin",
"sma": "sma",
"daytrade_count": "daytrade_count",
"balance_asof": "balance_asof",
}

View File

@ -1,42 +1,5 @@
from pydantic import BaseModel from pydantic import BaseModel
a = {
"positions": [
{
"instrument": "EUR_USD",
"long": {
"units": "1",
"averagePrice": "0.99361",
"pl": "-0.1014",
"resettablePL": "-0.1014",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"guaranteedExecutionFees": "0.0000",
"tradeIDs": ["71"],
"unrealizedPL": "-0.0002",
},
"short": {
"units": "0",
"pl": "0.0932",
"resettablePL": "0.0932",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"guaranteedExecutionFees": "0.0000",
"unrealizedPL": "0.0000",
},
"pl": "-0.0082",
"resettablePL": "-0.0082",
"financing": "0.0000",
"commission": "0.0000",
"dividendAdjustment": "0.0000",
"guaranteedExecutionFees": "0.0000",
"unrealizedPL": "-0.0002",
"marginUsed": "0.0286",
}
],
"lastTransactionID": "71",
}
class PositionLong(BaseModel): class PositionLong(BaseModel):
units: str units: str
@ -79,44 +42,6 @@ class OpenPositions(BaseModel):
lastTransactionID: str lastTransactionID: str
{
"positions": [
{
"instrument": "EUR_USD",
"long": {
"units": "1",
"averagePrice": "0.99361",
"pl": "-0.1014",
"resettablePL": "-0.1014",
"financing": "-0.0002",
"dividendAdjustment": "0.0000",
"guaranteedExecutionFees": "0.0000",
"tradeIDs": ["71"],
"unrealizedPL": "-0.0044",
},
"short": {
"units": "0",
"pl": "0.0932",
"resettablePL": "0.0932",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"guaranteedExecutionFees": "0.0000",
"unrealizedPL": "0.0000",
},
"pl": "-0.0082",
"resettablePL": "-0.0082",
"financing": "-0.0002",
"commission": "0.0000",
"dividendAdjustment": "0.0000",
"guaranteedExecutionFees": "0.0000",
"unrealizedPL": "-0.0044",
"marginUsed": "0.0287",
}
],
"lastTransactionID": "73",
}
def parse_prices(x): def parse_prices(x):
if float(x["long"]["units"]) > 0: if float(x["long"]["units"]) > 0:
return x["long"]["averagePrice"] return x["long"]["averagePrice"]
@ -168,3 +93,196 @@ OpenPositionsSchema = {
], ],
) )
} }
class AccountDetailsNested(BaseModel):
guaranteedStopLossOrderMode: str
hedgingEnabled: bool
id: str
createdTime: str
currency: str
createdByUserID: int
alias: str
marginRate: str
lastTransactionID: str
balance: str
openTradeCount: int
openPositionCount: int
pendingOrderCount: int
pl: str
resettablePL: str
resettablePLTime: str
financing: str
commission: str
dividendAdjustment: str
guaranteedExecutionFees: str
orders: list # Order
positions: list # Position
trades: list # Trade
unrealizedPL: str
NAV: str
marginUsed: str
marginAvailable: str
positionValue: str
marginCloseoutUnrealizedPL: str
marginCloseoutNAV: str
marginCloseoutMarginUsed: str
marginCloseoutPositionValue: str
marginCloseoutPercent: str
withdrawalLimit: str
marginCallMarginUsed: str
marginCallPercent: str
class AccountDetails(BaseModel):
account: AccountDetailsNested
lastTransactionID: str
AccountDetailsSchema = {
"guaranteedSLOM": "account.guaranteedStopLossOrderMode",
"hedgingEnabled": "account.hedgingEnabled",
"id": "account.id",
"created_at": "account.createdTime",
"currency": "account.currency",
"createdByUserID": "account.createdByUserID",
"alias": "account.alias",
"marginRate": "account.marginRate",
"lastTransactionID": "account.lastTransactionID",
"balance": "account.balance",
"openTradeCount": "account.openTradeCount",
"openPositionCount": "account.openPositionCount",
"pendingOrderCount": "account.pendingOrderCount",
"pl": "account.pl",
"resettablePL": "account.resettablePL",
"resettablePLTime": "account.resettablePLTime",
"financing": "account.financing",
"commission": "account.commission",
"dividendAdjustment": "account.dividendAdjustment",
"guaranteedExecutionFees": "account.guaranteedExecutionFees",
# "orders": "account.orders",
# "positions": "account.positions",
# "trades": "account.trades",
"unrealizedPL": "account.unrealizedPL",
"NAV": "account.NAV",
"marginUsed": "account.marginUsed",
"marginAvailable": "account.marginAvailable",
"positionValue": "account.positionValue",
"marginCloseoutUnrealizedPL": "account.marginCloseoutUnrealizedPL",
"marginCloseoutNAV": "account.marginCloseoutNAV",
"marginCloseoutMarginUsed": "account.marginCloseoutMarginUsed",
"marginCloseoutPositionValue": "account.marginCloseoutPositionValue",
"marginCloseoutPercent": "account.marginCloseoutPercent",
"withdrawalLimit": "account.withdrawalLimit",
"marginCallMarginUsed": "account.marginCallMarginUsed",
"marginCallPercent": "account.marginCallPercent",
}
class PositionDetailsNested(BaseModel):
instrument: str
long: PositionLong
short: PositionShort
pl: str
resettablePL: str
financing: str
commission: str
dividendAdjustment: str
guaranteedExecutionFees: str
unrealizedPL: str
marginUsed: str
class PositionDetails(BaseModel):
position: PositionDetailsNested
lastTransactionID: str
PositionDetailsSchema = {
"symbol": "position.instrument",
"long": "position.long",
"short": "position.short",
"pl": "position.pl",
"resettablePL": "position.resettablePL",
"financing": "position.financing",
"commission": "position.commission",
"dividendAdjustment": "position.dividendAdjustment",
"guaranteedExecutionFees": "position.guaranteedExecutionFees",
"unrealizedPL": "position.unrealizedPL",
"marginUsed": "position.marginUsed",
"price": lambda x: parse_prices(x["position"]),
"units": lambda x: parse_units(x["position"]),
"side": lambda x: parse_side(x["position"]),
"value": lambda x: parse_value(x["position"]),
}
class InstrumentTag(BaseModel):
type: str
name: str
class InstrumentFinancingDaysOfWeek(BaseModel):
dayOfWeek: str
daysCharged: int
class InstrumentFinancing(BaseModel):
longRate: str
shortRate: str
financingDaysOfWeek: list[InstrumentFinancingDaysOfWeek]
class InstrumentGuaranteedRestriction(BaseModel):
volume: str
priceRange: str
class Instrument(BaseModel):
name: str
type: str
displayName: str
pipLocation: int
displayPrecision: int
tradeUnitsPrecision: int
minimumTradeSize: str
maximumTrailingStopDistance: str
minimumTrailingStopDistance: str
maximumPositionSize: str
maximumOrderUnits: str
marginRate: str
guaranteedStopLossOrderMode: str
tags: list[InstrumentTag]
financing: InstrumentFinancing
guaranteedStopLossOrderLevelRestriction: InstrumentGuaranteedRestriction | None
class AccountInstruments(BaseModel):
instruments: list[Instrument]
AccountInstrumentsSchema = {
"itemlist": (
"instruments",
[
{
"name": "name",
"type": "type",
"displayName": "displayName",
"pipLocation": "pipLocation",
"displayPrecision": "displayPrecision",
"tradeUnitsPrecision": "tradeUnitsPrecision",
"minimumTradeSize": "minimumTradeSize",
"maximumTrailingStopDistance": "maximumTrailingStopDistance",
"minimumTrailingStopDistance": "minimumTrailingStopDistance",
"maximumPositionSize": "maximumPositionSize",
"maximumOrderUnits": "maximumOrderUnits",
"marginRate": "marginRate",
"guaranteedSLOM": "guaranteedStopLossOrderMode",
"tags": "tags",
"financing": "financing",
"guaranteedSLOLR": "guaranteedStopLossOrderLevelRestriction",
}
],
)
}

View File

@ -91,6 +91,7 @@ class Account(models.Model):
client = self.get_client() client = self.get_client()
if client: if client:
supported_symbols = client.get_supported_assets() supported_symbols = client.get_supported_assets()
log.debug(f"Supported symbols for {self.name}: {supported_symbols}")
self.supported_symbols = supported_symbols self.supported_symbols = supported_symbols
super().save(*args, **kwargs) super().save(*args, **kwargs)

View File

@ -14,7 +14,13 @@ log = logs.get_logger(__name__)
class AccountInfo(LoginRequiredMixin, View): class AccountInfo(LoginRequiredMixin, View):
VIEWABLE_FIELDS_MODEL = ["name", "exchange", "api_key", "sandbox"] VIEWABLE_FIELDS_MODEL = [
"name",
"exchange",
"api_key",
"sandbox",
"supported_symbols",
]
allowed_types = ["modal", "widget", "window", "page"] allowed_types = ["modal", "widget", "window", "page"]
window_content = "window-content/account-info.html" window_content = "window-content/account-info.html"
@ -45,6 +51,7 @@ class AccountInfo(LoginRequiredMixin, View):
account_info = { account_info = {
k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL
} }
account_info["supported_symbols"] = ", ".join(account_info["supported_symbols"])
if type == "page": if type == "page":
type = "modal" type = "modal"

View File

@ -36,7 +36,6 @@ class Positions(LoginRequiredMixin, View):
template_name = f"wm/{type}.html" template_name = f"wm/{type}.html"
unique = str(uuid.uuid4())[:8] unique = str(uuid.uuid4())[:8]
items = get_positions(request.user, account_id) items = get_positions(request.user, account_id)
print("ITEMS", items)
if type == "page": if type == "page":
type = "modal" type = "modal"
context = { context = {
@ -68,7 +67,6 @@ class PositionAction(LoginRequiredMixin, View):
account = Account.get_by_id(account_id, request.user) account = Account.get_by_id(account_id, request.user)
info = account.client.get_position_info(symbol) info = account.client.get_position_info(symbol)
print("ACCT INFO", info)
if type == "page": if type == "page":
type = "modal" type = "modal"