Wrap API calls in helper and validate response

This commit is contained in:
Mark Veidemanis 2022-10-30 19:11:07 +00:00
parent f22fcfdaaa
commit c15ae379f5
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
15 changed files with 224 additions and 163 deletions

View File

@ -1,3 +1,5 @@
from pydantic import ValidationError
from core.util import logs
@ -8,11 +10,36 @@ class BaseExchange(object):
self.log = logs.get_logger(name)
self.client = None
self.set_schema()
self.connect()
def set_schema(self):
raise NotImplementedError
def connect(self):
raise NotImplementedError
def call(self, method, *args, **kwargs) -> (bool, dict):
if hasattr(self.client, method):
try:
response = getattr(self.client, method)(*args, **kwargs)
if isinstance(response, list):
response = {"itemlist": response}
if method not in self.schema:
self.log.error(f"Method cannot be validated: {method}")
self.log.debug(f"Response: {response}")
return (False, f"Method cannot be validated: {method}")
return (True, self.schema[method](**response).dict())
except ValidationError as e:
self.log.error(f"Could not validate response: {e}")
return (False, e)
except Exception as e:
self.log.error(f"Error calling {method}: {e}")
return (False, e)
else:
return (False, "No such method")
def get_account(self):
raise NotImplementedError

View File

@ -8,9 +8,20 @@ from alpaca.trading.requests import (
)
from core.exchanges import BaseExchange
from core.lib.schemas import alpaca_s
ALPACA_SCHEMA_MAPPING = {
"get_account": alpaca_s.GetAccount,
"get_all_assets": alpaca_s.GetAllAssets,
"get_all_positions": alpaca_s.GetAllPositions,
"get_open_position": alpaca_s.GetOpenPosition,
}
class AlpacaExchange(BaseExchange):
def set_schema(self):
self.schema = ALPACA_SCHEMA_MAPPING
def connect(self):
self.client = TradingClient(
self.account.api_key,
@ -20,32 +31,31 @@ class AlpacaExchange(BaseExchange):
)
def get_account(self):
return self.client.get_account()
return self.call("get_account")
def get_supported_assets(self):
try:
request = GetAssetsRequest(status="active", asset_class="crypto")
assets = self.client.get_all_assets(filter=request)
success, assets = self.call("get_all_assets", filter=request)
# assets = self.client.get_all_assets(filter=request)
if not success:
return (success, assets)
assets = assets["itemlist"]
asset_list = [x["symbol"] for x in assets if "symbol" in x]
print("Supported symbols", asset_list)
except APIError as e:
self.log.error(f"Could not get asset list: {e}")
# return False
return asset_list
return (True, asset_list)
def get_balance(self):
try:
account_info = self.client.get_account()
except APIError as e:
self.log.error(f"Could not get account balance: {e}")
return False
success, account_info = self.call("get_account")
if not success:
return (success, account_info)
equity = account_info["equity"]
try:
balance = float(equity)
except ValueError:
return False
return (False, "Invalid balance")
return balance
return (True, balance)
def get_market_value(self, symbol):
try:
@ -89,6 +99,8 @@ class AlpacaExchange(BaseExchange):
order = self.client.submit_order(order_data=market_order_data)
except APIError as e:
self.log.error(f"Error placing market order: {e}")
trade.status = "error"
trade.save()
return (False, e)
elif trade.type == "limit":
if not trade.price:
@ -99,6 +111,8 @@ class AlpacaExchange(BaseExchange):
order = self.client.submit_order(order_data=limit_order_data)
except APIError as e:
self.log.error(f"Error placing limit order: {e}")
trade.status = "error"
trade.save()
return (False, e)
else:
@ -120,19 +134,19 @@ class AlpacaExchange(BaseExchange):
pass
def get_position_info(self, asset_id):
try:
position = self.client.get_open_position(asset_id)
except APIError as e:
return (False, e)
success, position = self.call("get_open_position", asset_id)
if not success:
return (success, position)
return (True, position)
def get_all_positions(self):
items = []
positions = self.client.get_all_positions()
success, positions = self.call("get_all_positions")
if not success:
return (success, positions)
for item in positions:
item = dict(item)
for item in positions["itemlist"]:
item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item)
return items
return (True, items)

View File

@ -2,9 +2,15 @@ from oandapyV20 import API
from oandapyV20.endpoints import accounts, orders, positions, trades
from core.exchanges import BaseExchange
from core.lib.schemas import oanda_s
OANDA_SCHEMA_MAPPING = {}
class OANDAExchange(BaseExchange):
def set_schema(self):
self.schema = OANDA_SCHEMA_MAPPING
def connect(self):
self.client = API(access_token=self.account.api_secret)
self.account_id = self.account.api_key
@ -53,4 +59,4 @@ class OANDAExchange(BaseExchange):
def get_all_positions(self):
r = positions.OpenPositions(accountID=self.account_id)
self.client.request(r)
return r.response["positions"]
return (True, [])

View File

@ -24,9 +24,9 @@ def get_market_value(account, symbol):
def execute_strategy(callback, strategy):
cash_balance = get_balance(strategy.account)
success, cash_balance = strategy.account.client.get_balance()
log.debug(f"Cash balance: {cash_balance}")
if not cash_balance:
if not success:
return None
user = strategy.user
@ -41,7 +41,7 @@ def execute_strategy(callback, strategy):
quote = "usd" # TODO: MASSIVE HACK
symbol = f"{base.upper()}/{quote.upper()}"
if symbol not in account.supported_assets:
if symbol not in account.supported_symbols:
log.error(f"Symbol not supported by account: {symbol}")
return False

View File

@ -0,0 +1,105 @@
from pydantic import BaseModel, Field
class Asset(BaseModel):
id: str
class_: str = Field(..., alias="class")
exchange: str
symbol: str
name: str
status: str
tradable: bool
marginable: bool
maintenance_margin_requirement: int
shortable: bool
easy_to_borrow: bool
fractionable: bool
min_order_size: str
min_trade_increment: str
price_increment: str
# get_all_assets
class GetAllAssets(BaseModel):
itemlist: list[Asset]
# get_open_position
class GetOpenPosition(BaseModel):
asset_id: str
symbol: str
exchange: str
asset_class: str = Field(..., alias="asset_class")
asset_marginable: bool
qty: str
avg_entry_price: str
side: str
market_value: str
cost_basis: str
unrealized_pl: str
unrealized_plpc: str
unrealized_intraday_pl: str
unrealized_intraday_plpc: str
current_price: str
lastday_price: str
change_today: str
qty_available: str
class Position(BaseModel):
asset_id: str
symbol: str
exchange: str
asset_class: str
asset_marginable: bool
qty: str
avg_entry_price: str
side: str
market_value: str
cost_basis: str
unrealized_pl: str
unrealized_plpc: str
unrealized_intraday_pl: str
unrealized_intraday_plpc: str
current_price: str
lastday_price: str
change_today: str
qty_available: str
# get_all_positions
class GetAllPositions(BaseModel):
itemlist: list[Position]
# get_account
class GetAccount(BaseModel):
id: str
account_number: str
status: str
crypto_status: str
currency: str
buying_power: str
regt_buying_power: str
daytrading_buying_power: str
effective_buying_power: str
non_marginable_buying_power: str
bod_dtbp: str
cash: str
accrued_fees: str
pending_transfer_in: str
portfolio_value: str
pattern_day_trader: bool
trading_blocked: bool
transfers_blocked: bool
account_blocked: bool
created_at: str
trade_suspended_by_user: bool
multiplier: str
shorting_enabled: bool
equity: str
last_equity: str
long_market_value: str
short_market_value: str
position_market_value: str
initial_margin: str
maintenance_margin: str
last_maintenance_margin: str
sma: str
daytrade_count: int
balance_asof: str

View File

@ -0,0 +1,21 @@
from pydantic import BaseModel
class DrakdooMarket(BaseModel):
exchange: str
item: str
currency: str
contract: str
class DrakdooTimestamp(BaseModel):
sent: int
trade: int
class DrakdooCallback(BaseModel):
title: str
message: str
period: str
market: DrakdooMarket
timestamp: DrakdooTimestamp

View File

@ -0,0 +1 @@
from pydantic import BaseModel

View File

@ -1,125 +0,0 @@
from serde import Model, fields
# {
# "id": "92f0b26b-4c98-4553-9c74-cdafc7e037db",
# "clientOrderId": "ccxt_26adcbf445674f01af38a66a15e6f5b5",
# "timestamp": 1666096856515,
# "datetime": "2022-10-18T12:40:56.515477181Z",
# "lastTradeTimeStamp": null,
# "status": "open",
# "symbol": "BTC/USD",
# "type": "market",
# "timeInForce": "gtc",
# "postOnly": null,
# "side": "buy",
# "price": null,
# "stopPrice": null,
# "cost": null,
# "average": null,
# "amount": 1.1,
# "filled": 0.0,
# "remaining": 1.1,
# "trades": [],
# "fee": null,
# "info": {
# "id": "92f0b26b-4c98-4553-9c74-cdafc7e037db",
# "client_order_id": "ccxt_26adcbf445674f01af38a66a15e6f5b5",
# "created_at": "2022-10-18T12:40:56.516095561Z",
# "updated_at": "2022-10-18T12:40:56.516173841Z",
# "submitted_at": "2022-10-18T12:40:56.515477181Z",
# "filled_at": null,
# "expired_at": null,
# "canceled_at": null,
# "failed_at": null,
# "replaced_at": null,
# "replaced_by": null,
# "replaces": null,
# "asset_id": "276e2673-764b-4ab6-a611-caf665ca6340",
# "symbol": "BTC/USD",
# "asset_class": "crypto",
# "notional": null,
# "qty": "1.1",
# "filled_qty": "0",
# "filled_avg_price": null,
# "order_class": "",
# "order_type": "market",
# "type": "market",
# "side": "buy",
# "time_in_force": "gtc",
# "limit_price": null,
# "stop_price": null,
# "status": "pending_new",
# "extended_hours": false,
# "legs": null,
# "trail_percent": null,
# "trail_price": null,
# "hwm": null,
# "subtag": null,
# "source": null
# },
# "fees": [],
# "lastTradeTimestamp": null
# }
class CCXTInfo(Model):
id = fields.Uuid()
client_order_id = fields.Str()
created_at = fields.Str()
updated_at = fields.Str()
submitted_at = fields.Str()
filled_at = fields.Optional(fields.Str())
expired_at = fields.Optional(fields.Str())
canceled_at = fields.Optional(fields.Str())
failed_at = fields.Optional(fields.Str())
replaced_at = fields.Optional(fields.Str())
replaced_by = fields.Optional(fields.Str())
replaces = fields.Optional(fields.Str())
asset_id = fields.Uuid()
symbol = fields.Str()
asset_class = fields.Str()
notional = fields.Optional(fields.Str())
qty = fields.Str()
filled_qty = fields.Str()
filled_avg_price = fields.Optional(fields.Str())
order_class = fields.Str()
order_type = fields.Str()
type = fields.Str()
side = fields.Str()
time_in_force = fields.Str()
limit_price = fields.Optional(fields.Str())
stop_price = fields.Optional(fields.Str())
status = fields.Str()
extended_hours = fields.Bool()
legs = fields.Optional(fields.List(fields.Nested("CCXTInfo")))
trail_percent = fields.Optional(fields.Str())
trail_price = fields.Optional(fields.Str())
hwm = fields.Optional(fields.Str())
subtag = fields.Optional(fields.Str())
source = fields.Optional(fields.Str())
class CCXTRoot(Model):
id = fields.Uuid()
clientOrderId = fields.Str()
timestamp = fields.Int()
datetime = fields.Str()
lastTradeTimeStamp = fields.Optional(fields.Str())
status = fields.Str()
symbol = fields.Str()
type = fields.Str()
timeInForce = fields.Str()
postOnly = fields.Optional(fields.Str())
side = fields.Str()
price = fields.Optional(fields.Float())
stopPrice = fields.Optional(fields.Float())
cost = fields.Optional(fields.Float())
average = fields.Optional(fields.Float())
amount = fields.Float()
filled = fields.Float()
remaining = fields.Float()
trades = fields.Optional(fields.List(fields.Dict()))
fee = fields.Optional(fields.Float())
info = fields.Nested(CCXTInfo)
fees = fields.Optional(fields.List(fields.Dict()))
lastTradeTimestamp = fields.Optional(fields.Str())

View File

@ -91,8 +91,8 @@ class Account(models.Model):
"""
client = self.get_client()
if client:
supported_symbols = client.get_supported_assets()
if supported_symbols:
success, supported_symbols = client.get_supported_assets()
if success:
self.supported_symbols = supported_symbols
super().save(*args, **kwargs)

View File

@ -1,3 +1,5 @@
{% include 'partials/notify.html' %}
<h1 class="title">Live information</h1>
<table class="table is-fullwidth is-hoverable">
<thead>

View File

@ -39,7 +39,15 @@ class AccountInfo(LoginRequiredMixin, View):
}
return render(request, template_name, context)
live_info = dict(account.client.get_account())
success, live_info = account.client.get_account()
if not success:
context = {
"message": "Could not get account info",
"class": "danger",
"window_content": self.window_content,
}
return render(request, template_name, context)
live_info = live_info
account_info = account.__dict__
account_info = {
k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL

View File

@ -4,13 +4,13 @@ import orjson
from django.conf import settings
from django.contrib.auth.mixins import LoginRequiredMixin
from django.http import HttpResponse, HttpResponseBadRequest
from pydantic import ValidationError
from rest_framework.parsers import JSONParser
from rest_framework.views import APIView
from serde import ValidationError
from core.forms import HookForm
from core.lib import market
from core.lib.serde import drakdoo_s
from core.lib.schemas.drakdoo_s import DrakdooCallback
from core.models import Callback, Hook
from core.util import logs
from core.views import ObjectCreate, ObjectDelete, ObjectList, ObjectUpdate
@ -44,7 +44,7 @@ class HookAPI(APIView):
# Try validating the JSON
try:
hook_resp = drakdoo_s.BaseDrakdoo.from_dict(request.data)
hook_resp = DrakdooCallback(**request.data)
except ValidationError as e:
log.error(f"HookAPI POST: {e}")
return HttpResponseBadRequest(e)

View File

@ -17,8 +17,10 @@ def get_positions(user, account_id=None):
items = []
accounts = Account.objects.filter(user=user)
for account in accounts:
positions = account.client.get_all_positions()
print("positions", positions)
success, positions = account.client.get_all_positions()
if not success:
items.append({"name": account.name, "status": "error"})
continue
for item in positions:
items.append(item)

View File

@ -15,6 +15,6 @@ django-debug-toolbar-template-profiler
orjson
django-otp
qrcode
serde[ext]
pydantic
alpaca-py
oandapyV20