Convert API responses with Glom

Mark Veidemanis 2 years ago
parent 396d838416
commit 5cb7d08614
Signed by: m
GPG Key ID: 5ACFCEED46C0904F

@ -114,7 +114,7 @@ urlpatterns = [
name="positions", name="positions",
), ),
path( path(
"positions/<str:type>/<str:account_id>/<str:asset_id>/", "positions/<str:type>/<str:account_id>/<str:symbol>/",
positions.PositionAction.as_view(), positions.PositionAction.as_view(),
name="position_action", name="position_action",
), ),

@ -1,13 +1,16 @@
from glom import glom
from pydantic import ValidationError from pydantic import ValidationError
from core.lib import schemas
from core.util import logs from core.util import logs
class BaseExchange(object): class BaseExchange(object):
def __init__(self, account): def __init__(self, account):
name = self.__class__.__name__ name = self.__class__.__name__
self.name = name.replace("Exchange", "").lower()
self.account = account self.account = account
self.log = logs.get_logger(name) self.log = logs.get_logger(self.name)
self.client = None self.client = None
self.set_schema() self.set_schema()
@ -19,6 +22,26 @@ class BaseExchange(object):
def connect(self): def connect(self):
raise NotImplementedError raise NotImplementedError
def convert_spec(self, response, msg_type):
# Does the schemas library have a library for this exchange name?
if hasattr(schemas, f"{self.name}_s"):
schema_instance = getattr(schemas, f"{self.name}_s")
else:
raise Exception(f"No schema for {self.name} in schema mapping")
# Does the message type have a conversion spec for this message type?
if hasattr(schema_instance, f"{msg_type}_schema"):
schema = getattr(schema_instance, f"{msg_type}_schema")
else:
# Let us know so we can implement it, but don't do anything with it
self.log.error(f"No schema for message: {msg_type} - {response}")
# raise Exception(f"No schema for {msg_type} in schema mapping")
return response
# Use glom to convert the response to the schema
converted = glom(response, schema)
print(f"[{self.name}] Converted of {msg_type}: {converted}")
return converted
def call(self, method, *args, **kwargs) -> (bool, dict): def call(self, method, *args, **kwargs) -> (bool, dict):
if hasattr(self.client, method): if hasattr(self.client, method):
@ -29,8 +52,13 @@ class BaseExchange(object):
if method not in self.schema: if method not in self.schema:
self.log.error(f"Method cannot be validated: {method}") self.log.error(f"Method cannot be validated: {method}")
self.log.debug(f"Response: {response}") self.log.debug(f"Response: {response}")
return (False, f"Method cannot be validated: {method}") # return (False, f"Method cannot be validated: {method}")
return (True, self.schema[method](**response).dict()) return (True, response)
# Return a dict of the validated response
response_valid = self.schema[method](**response).dict()
# Convert the response to a format that we can use
response_converted = self.convert_spec(response_valid, method)
return (True, response_converted)
except ValidationError as e: except ValidationError as e:
self.log.error(f"Could not validate response: {e}") self.log.error(f"Could not validate response: {e}")
return (False, e) return (False, e)
@ -64,7 +92,7 @@ class BaseExchange(object):
def cancel_trade(self, trade_id): def cancel_trade(self, trade_id):
raise NotImplementedError raise NotImplementedError
def get_position_info(self, asset_id): def get_position_info(self, symbol):
raise NotImplementedError raise NotImplementedError
def get_all_positions(self): def get_all_positions(self):

@ -133,19 +133,19 @@ class AlpacaExchange(BaseExchange):
def cancel_trade(self, trade_id): def cancel_trade(self, trade_id):
pass pass
def get_position_info(self, asset_id): def get_position_info(self, symbol):
success, position = self.call("get_open_position", asset_id) success, position = self.call("get_open_position", symbol)
if not success: if not success:
return (success, position) return (success, position)
return (True, position) return (True, position)
def get_all_positions(self): def get_all_positions(self):
items = [] items = []
success, positions = self.call("get_all_positions") success, response = self.call("get_all_positions")
if not success: if not success:
return (success, positions) return (success, response)
for item in positions["itemlist"]: for item in response["itemlist"]:
item["account_id"] = self.account.id item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"]) item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item) items.append(item)

@ -1,5 +1,6 @@
from oandapyV20 import API from oandapyV20 import API
from oandapyV20.endpoints import accounts, orders, positions, trades from oandapyV20.endpoints import accounts, orders, positions, trades
from pydantic import ValidationError
from core.exchanges import BaseExchange from core.exchanges import BaseExchange
from core.lib.schemas import oanda_s from core.lib.schemas import oanda_s
@ -18,7 +19,11 @@ class OANDAExchange(BaseExchange):
self.log.debug(f"Response: {response}") self.log.debug(f"Response: {response}")
return (False, f"Method cannot be validated: {method}") return (False, f"Method cannot be validated: {method}")
try: try:
return (True, self.schema[method](**response).dict()) # Return a dict of the validated response
response_valid = self.schema[method](**response).dict()
# Convert the response to a format that we can use
response_converted = self.convert_spec(response_valid, method)
return (True, response_converted)
except ValidationError as e: except ValidationError as e:
self.log.error(f"Could not validate response: {e}") self.log.error(f"Could not validate response: {e}")
return (False, e) return (False, e)
@ -66,16 +71,21 @@ class OANDAExchange(BaseExchange):
def cancel_trade(self, trade_id): def cancel_trade(self, trade_id):
raise NotImplementedError raise NotImplementedError
def get_position_info(self, asset_id): def get_position_info(self, symbol):
r = positions.PositionDetails(self.account_id, asset_id) r = positions.PositionDetails(self.account_id, symbol)
self.client.request(r) self.client.request(r)
return r.response return r.response
def get_all_positions(self): def get_all_positions(self):
items = []
r = positions.OpenPositions(accountID=self.account_id) r = positions.OpenPositions(accountID=self.account_id)
success, response = self.call("OpenPositions", r) success, response = self.call("OpenPositions", r)
if not success: if not success:
return (success, response) return (success, response)
print("Positions", response) print("Positions", response)
return (True, []) for item in response["itemlist"]:
item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item)
return (True, items)

@ -72,6 +72,44 @@ class GetAllPositions(BaseModel):
itemlist: list[Position] itemlist: list[Position]
{
"itemlist": [
{
"asset_id": "64bbff51-59d6-4b3c-9351-13ad85e3c752",
"symbol": "BTCUSD",
"exchange": "FTXU",
"asset_class": "crypto",
"asset_marginable": False,
"qty": "0.009975",
"avg_entry_price": "20714",
"side": "long",
"market_value": "204.297975",
"cost_basis": "206.62215",
"unrealized_pl": "-2.324175",
"unrealized_plpc": "-0.0112484310128416",
"unrealized_intraday_pl": "-0.269325",
"unrealized_intraday_plpc": "-0.001316559391457",
"current_price": "20481",
"lastday_price": "20508",
"change_today": "-0.001316559391457",
"qty_available": "0.009975",
}
]
}
get_all_positions_schema = {
"itemlist": (
"itemlist",
[
{
"symbol": "symbol",
"unrealized_pl": "unrealized_pl",
"price:": "current_price",
}
],
)
}
# get_account # get_account
class GetAccount(BaseModel): class GetAccount(BaseModel):
id: str id: str
@ -108,3 +146,6 @@ class GetAccount(BaseModel):
sma: str sma: str
daytrade_count: int daytrade_count: int
balance_asof: str balance_asof: str
get_account_schema = {"": ""}

@ -77,3 +77,47 @@ class Position(BaseModel):
class OpenPositions(BaseModel): class OpenPositions(BaseModel):
positions: list[Position] positions: list[Position]
lastTransactionID: str lastTransactionID: str
{
"positions": [
{
"instrument": "EUR_USD",
"long": {
"units": "1",
"averagePrice": "0.99361",
"pl": "-0.1014",
"resettablePL": "-0.1014",
"financing": "-0.0002",
"dividendAdjustment": "0.0000",
"guaranteedExecutionFees": "0.0000",
"tradeIDs": ["71"],
"unrealizedPL": "-0.0044",
},
"short": {
"units": "0",
"pl": "0.0932",
"resettablePL": "0.0932",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"guaranteedExecutionFees": "0.0000",
"unrealizedPL": "0.0000",
},
"pl": "-0.0082",
"resettablePL": "-0.0082",
"financing": "-0.0002",
"commission": "0.0000",
"dividendAdjustment": "0.0000",
"guaranteedExecutionFees": "0.0000",
"unrealizedPL": "-0.0044",
"marginUsed": "0.0287",
}
],
"lastTransactionID": "73",
}
OpenPositions_schema = {
"itemlist": (
"positions",
[{"symbol": "instrument", "unrealized_pl": "unrealizedPL"}],
)
}

@ -73,10 +73,10 @@ def close_trade(trade):
pass pass
def get_position_info(account, asset_id): def get_position_info(account, symbol):
trading_client = account.get_client() trading_client = account.get_client()
try: try:
position = trading_client.get_open_position(asset_id) position = trading_client.get_open_position(symbol)
except APIError as e: except APIError as e:
return (False, e) return (False, e)
return (True, position) return (True, position)

@ -18,3 +18,4 @@ qrcode
pydantic pydantic
alpaca-py alpaca-py
oandapyV20 oandapyV20
glom

Loading…
Cancel
Save