Wrap API calls in helper and validate response

This commit is contained in:
Mark Veidemanis 2022-10-30 19:11:07 +00:00
parent f22fcfdaaa
commit c15ae379f5
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
15 changed files with 224 additions and 163 deletions

View File

@ -1,3 +1,5 @@
from pydantic import ValidationError
from core.util import logs from core.util import logs
@ -8,11 +10,36 @@ class BaseExchange(object):
self.log = logs.get_logger(name) self.log = logs.get_logger(name)
self.client = None self.client = None
self.set_schema()
self.connect() self.connect()
def set_schema(self):
raise NotImplementedError
def connect(self): def connect(self):
raise NotImplementedError raise NotImplementedError
def call(self, method, *args, **kwargs) -> (bool, dict):
if hasattr(self.client, method):
try:
response = getattr(self.client, method)(*args, **kwargs)
if isinstance(response, list):
response = {"itemlist": response}
if method not in self.schema:
self.log.error(f"Method cannot be validated: {method}")
self.log.debug(f"Response: {response}")
return (False, f"Method cannot be validated: {method}")
return (True, self.schema[method](**response).dict())
except ValidationError as e:
self.log.error(f"Could not validate response: {e}")
return (False, e)
except Exception as e:
self.log.error(f"Error calling {method}: {e}")
return (False, e)
else:
return (False, "No such method")
def get_account(self): def get_account(self):
raise NotImplementedError raise NotImplementedError

View File

@ -8,9 +8,20 @@ from alpaca.trading.requests import (
) )
from core.exchanges import BaseExchange from core.exchanges import BaseExchange
from core.lib.schemas import alpaca_s
ALPACA_SCHEMA_MAPPING = {
"get_account": alpaca_s.GetAccount,
"get_all_assets": alpaca_s.GetAllAssets,
"get_all_positions": alpaca_s.GetAllPositions,
"get_open_position": alpaca_s.GetOpenPosition,
}
class AlpacaExchange(BaseExchange): class AlpacaExchange(BaseExchange):
def set_schema(self):
self.schema = ALPACA_SCHEMA_MAPPING
def connect(self): def connect(self):
self.client = TradingClient( self.client = TradingClient(
self.account.api_key, self.account.api_key,
@ -20,32 +31,31 @@ class AlpacaExchange(BaseExchange):
) )
def get_account(self): def get_account(self):
return self.client.get_account() return self.call("get_account")
def get_supported_assets(self): def get_supported_assets(self):
try:
request = GetAssetsRequest(status="active", asset_class="crypto") request = GetAssetsRequest(status="active", asset_class="crypto")
assets = self.client.get_all_assets(filter=request) success, assets = self.call("get_all_assets", filter=request)
# assets = self.client.get_all_assets(filter=request)
if not success:
return (success, assets)
assets = assets["itemlist"]
asset_list = [x["symbol"] for x in assets if "symbol" in x] asset_list = [x["symbol"] for x in assets if "symbol" in x]
print("Supported symbols", asset_list) print("Supported symbols", asset_list)
except APIError as e:
self.log.error(f"Could not get asset list: {e}") return (True, asset_list)
# return False
return asset_list
def get_balance(self): def get_balance(self):
try: success, account_info = self.call("get_account")
account_info = self.client.get_account() if not success:
except APIError as e: return (success, account_info)
self.log.error(f"Could not get account balance: {e}")
return False
equity = account_info["equity"] equity = account_info["equity"]
try: try:
balance = float(equity) balance = float(equity)
except ValueError: except ValueError:
return False return (False, "Invalid balance")
return balance return (True, balance)
def get_market_value(self, symbol): def get_market_value(self, symbol):
try: try:
@ -89,6 +99,8 @@ class AlpacaExchange(BaseExchange):
order = self.client.submit_order(order_data=market_order_data) order = self.client.submit_order(order_data=market_order_data)
except APIError as e: except APIError as e:
self.log.error(f"Error placing market order: {e}") self.log.error(f"Error placing market order: {e}")
trade.status = "error"
trade.save()
return (False, e) return (False, e)
elif trade.type == "limit": elif trade.type == "limit":
if not trade.price: if not trade.price:
@ -99,6 +111,8 @@ class AlpacaExchange(BaseExchange):
order = self.client.submit_order(order_data=limit_order_data) order = self.client.submit_order(order_data=limit_order_data)
except APIError as e: except APIError as e:
self.log.error(f"Error placing limit order: {e}") self.log.error(f"Error placing limit order: {e}")
trade.status = "error"
trade.save()
return (False, e) return (False, e)
else: else:
@ -120,19 +134,19 @@ class AlpacaExchange(BaseExchange):
pass pass
def get_position_info(self, asset_id): def get_position_info(self, asset_id):
try: success, position = self.call("get_open_position", asset_id)
position = self.client.get_open_position(asset_id) if not success:
except APIError as e: return (success, position)
return (False, e)
return (True, position) return (True, position)
def get_all_positions(self): def get_all_positions(self):
items = [] items = []
positions = self.client.get_all_positions() success, positions = self.call("get_all_positions")
if not success:
return (success, positions)
for item in positions: for item in positions["itemlist"]:
item = dict(item)
item["account_id"] = self.account.id item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"]) item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item) items.append(item)
return items return (True, items)

View File

@ -2,9 +2,15 @@ from oandapyV20 import API
from oandapyV20.endpoints import accounts, orders, positions, trades from oandapyV20.endpoints import accounts, orders, positions, trades
from core.exchanges import BaseExchange from core.exchanges import BaseExchange
from core.lib.schemas import oanda_s
OANDA_SCHEMA_MAPPING = {}
class OANDAExchange(BaseExchange): class OANDAExchange(BaseExchange):
def set_schema(self):
self.schema = OANDA_SCHEMA_MAPPING
def connect(self): def connect(self):
self.client = API(access_token=self.account.api_secret) self.client = API(access_token=self.account.api_secret)
self.account_id = self.account.api_key self.account_id = self.account.api_key
@ -53,4 +59,4 @@ class OANDAExchange(BaseExchange):
def get_all_positions(self): def get_all_positions(self):
r = positions.OpenPositions(accountID=self.account_id) r = positions.OpenPositions(accountID=self.account_id)
self.client.request(r) self.client.request(r)
return r.response["positions"] return (True, [])

View File

@ -24,9 +24,9 @@ def get_market_value(account, symbol):
def execute_strategy(callback, strategy): def execute_strategy(callback, strategy):
cash_balance = get_balance(strategy.account) success, cash_balance = strategy.account.client.get_balance()
log.debug(f"Cash balance: {cash_balance}") log.debug(f"Cash balance: {cash_balance}")
if not cash_balance: if not success:
return None return None
user = strategy.user user = strategy.user
@ -41,7 +41,7 @@ def execute_strategy(callback, strategy):
quote = "usd" # TODO: MASSIVE HACK quote = "usd" # TODO: MASSIVE HACK
symbol = f"{base.upper()}/{quote.upper()}" symbol = f"{base.upper()}/{quote.upper()}"
if symbol not in account.supported_assets: if symbol not in account.supported_symbols:
log.error(f"Symbol not supported by account: {symbol}") log.error(f"Symbol not supported by account: {symbol}")
return False return False

View File

@ -0,0 +1,105 @@
from pydantic import BaseModel, Field
class Asset(BaseModel):
id: str
class_: str = Field(..., alias="class")
exchange: str
symbol: str
name: str
status: str
tradable: bool
marginable: bool
maintenance_margin_requirement: int
shortable: bool
easy_to_borrow: bool
fractionable: bool
min_order_size: str
min_trade_increment: str
price_increment: str
# get_all_assets
class GetAllAssets(BaseModel):
itemlist: list[Asset]
# get_open_position
class GetOpenPosition(BaseModel):
asset_id: str
symbol: str
exchange: str
asset_class: str = Field(..., alias="asset_class")
asset_marginable: bool
qty: str
avg_entry_price: str
side: str
market_value: str
cost_basis: str
unrealized_pl: str
unrealized_plpc: str
unrealized_intraday_pl: str
unrealized_intraday_plpc: str
current_price: str
lastday_price: str
change_today: str
qty_available: str
class Position(BaseModel):
asset_id: str
symbol: str
exchange: str
asset_class: str
asset_marginable: bool
qty: str
avg_entry_price: str
side: str
market_value: str
cost_basis: str
unrealized_pl: str
unrealized_plpc: str
unrealized_intraday_pl: str
unrealized_intraday_plpc: str
current_price: str
lastday_price: str
change_today: str
qty_available: str
# get_all_positions
class GetAllPositions(BaseModel):
itemlist: list[Position]
# get_account
class GetAccount(BaseModel):
id: str
account_number: str
status: str
crypto_status: str
currency: str
buying_power: str
regt_buying_power: str
daytrading_buying_power: str
effective_buying_power: str
non_marginable_buying_power: str
bod_dtbp: str
cash: str
accrued_fees: str
pending_transfer_in: str
portfolio_value: str
pattern_day_trader: bool
trading_blocked: bool
transfers_blocked: bool
account_blocked: bool
created_at: str
trade_suspended_by_user: bool
multiplier: str
shorting_enabled: bool
equity: str
last_equity: str
long_market_value: str
short_market_value: str
position_market_value: str
initial_margin: str
maintenance_margin: str
last_maintenance_margin: str
sma: str
daytrade_count: int
balance_asof: str

View File

@ -0,0 +1,21 @@
from pydantic import BaseModel
class DrakdooMarket(BaseModel):
exchange: str
item: str
currency: str
contract: str
class DrakdooTimestamp(BaseModel):
sent: int
trade: int
class DrakdooCallback(BaseModel):
title: str
message: str
period: str
market: DrakdooMarket
timestamp: DrakdooTimestamp

View File

@ -0,0 +1 @@
from pydantic import BaseModel

View File

@ -1,125 +0,0 @@
from serde import Model, fields
# {
# "id": "92f0b26b-4c98-4553-9c74-cdafc7e037db",
# "clientOrderId": "ccxt_26adcbf445674f01af38a66a15e6f5b5",
# "timestamp": 1666096856515,
# "datetime": "2022-10-18T12:40:56.515477181Z",
# "lastTradeTimeStamp": null,
# "status": "open",
# "symbol": "BTC/USD",
# "type": "market",
# "timeInForce": "gtc",
# "postOnly": null,
# "side": "buy",
# "price": null,
# "stopPrice": null,
# "cost": null,
# "average": null,
# "amount": 1.1,
# "filled": 0.0,
# "remaining": 1.1,
# "trades": [],
# "fee": null,
# "info": {
# "id": "92f0b26b-4c98-4553-9c74-cdafc7e037db",
# "client_order_id": "ccxt_26adcbf445674f01af38a66a15e6f5b5",
# "created_at": "2022-10-18T12:40:56.516095561Z",
# "updated_at": "2022-10-18T12:40:56.516173841Z",
# "submitted_at": "2022-10-18T12:40:56.515477181Z",
# "filled_at": null,
# "expired_at": null,
# "canceled_at": null,
# "failed_at": null,
# "replaced_at": null,
# "replaced_by": null,
# "replaces": null,
# "asset_id": "276e2673-764b-4ab6-a611-caf665ca6340",
# "symbol": "BTC/USD",
# "asset_class": "crypto",
# "notional": null,
# "qty": "1.1",
# "filled_qty": "0",
# "filled_avg_price": null,
# "order_class": "",
# "order_type": "market",
# "type": "market",
# "side": "buy",
# "time_in_force": "gtc",
# "limit_price": null,
# "stop_price": null,
# "status": "pending_new",
# "extended_hours": false,
# "legs": null,
# "trail_percent": null,
# "trail_price": null,
# "hwm": null,
# "subtag": null,
# "source": null
# },
# "fees": [],
# "lastTradeTimestamp": null
# }
class CCXTInfo(Model):
id = fields.Uuid()
client_order_id = fields.Str()
created_at = fields.Str()
updated_at = fields.Str()
submitted_at = fields.Str()
filled_at = fields.Optional(fields.Str())
expired_at = fields.Optional(fields.Str())
canceled_at = fields.Optional(fields.Str())
failed_at = fields.Optional(fields.Str())
replaced_at = fields.Optional(fields.Str())
replaced_by = fields.Optional(fields.Str())
replaces = fields.Optional(fields.Str())
asset_id = fields.Uuid()
symbol = fields.Str()
asset_class = fields.Str()
notional = fields.Optional(fields.Str())
qty = fields.Str()
filled_qty = fields.Str()
filled_avg_price = fields.Optional(fields.Str())
order_class = fields.Str()
order_type = fields.Str()
type = fields.Str()
side = fields.Str()
time_in_force = fields.Str()
limit_price = fields.Optional(fields.Str())
stop_price = fields.Optional(fields.Str())
status = fields.Str()
extended_hours = fields.Bool()
legs = fields.Optional(fields.List(fields.Nested("CCXTInfo")))
trail_percent = fields.Optional(fields.Str())
trail_price = fields.Optional(fields.Str())
hwm = fields.Optional(fields.Str())
subtag = fields.Optional(fields.Str())
source = fields.Optional(fields.Str())
class CCXTRoot(Model):
id = fields.Uuid()
clientOrderId = fields.Str()
timestamp = fields.Int()
datetime = fields.Str()
lastTradeTimeStamp = fields.Optional(fields.Str())
status = fields.Str()
symbol = fields.Str()
type = fields.Str()
timeInForce = fields.Str()
postOnly = fields.Optional(fields.Str())
side = fields.Str()
price = fields.Optional(fields.Float())
stopPrice = fields.Optional(fields.Float())
cost = fields.Optional(fields.Float())
average = fields.Optional(fields.Float())
amount = fields.Float()
filled = fields.Float()
remaining = fields.Float()
trades = fields.Optional(fields.List(fields.Dict()))
fee = fields.Optional(fields.Float())
info = fields.Nested(CCXTInfo)
fees = fields.Optional(fields.List(fields.Dict()))
lastTradeTimestamp = fields.Optional(fields.Str())

View File

@ -91,8 +91,8 @@ class Account(models.Model):
""" """
client = self.get_client() client = self.get_client()
if client: if client:
supported_symbols = client.get_supported_assets() success, supported_symbols = client.get_supported_assets()
if supported_symbols: if success:
self.supported_symbols = supported_symbols self.supported_symbols = supported_symbols
super().save(*args, **kwargs) super().save(*args, **kwargs)

View File

@ -1,3 +1,5 @@
{% include 'partials/notify.html' %}
<h1 class="title">Live information</h1> <h1 class="title">Live information</h1>
<table class="table is-fullwidth is-hoverable"> <table class="table is-fullwidth is-hoverable">
<thead> <thead>

View File

@ -39,7 +39,15 @@ class AccountInfo(LoginRequiredMixin, View):
} }
return render(request, template_name, context) return render(request, template_name, context)
live_info = dict(account.client.get_account()) success, live_info = account.client.get_account()
if not success:
context = {
"message": "Could not get account info",
"class": "danger",
"window_content": self.window_content,
}
return render(request, template_name, context)
live_info = live_info
account_info = account.__dict__ account_info = account.__dict__
account_info = { account_info = {
k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL

View File

@ -4,13 +4,13 @@ import orjson
from django.conf import settings from django.conf import settings
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.http import HttpResponse, HttpResponseBadRequest from django.http import HttpResponse, HttpResponseBadRequest
from pydantic import ValidationError
from rest_framework.parsers import JSONParser from rest_framework.parsers import JSONParser
from rest_framework.views import APIView from rest_framework.views import APIView
from serde import ValidationError
from core.forms import HookForm from core.forms import HookForm
from core.lib import market from core.lib import market
from core.lib.serde import drakdoo_s from core.lib.schemas.drakdoo_s import DrakdooCallback
from core.models import Callback, Hook from core.models import Callback, Hook
from core.util import logs from core.util import logs
from core.views import ObjectCreate, ObjectDelete, ObjectList, ObjectUpdate from core.views import ObjectCreate, ObjectDelete, ObjectList, ObjectUpdate
@ -44,7 +44,7 @@ class HookAPI(APIView):
# Try validating the JSON # Try validating the JSON
try: try:
hook_resp = drakdoo_s.BaseDrakdoo.from_dict(request.data) hook_resp = DrakdooCallback(**request.data)
except ValidationError as e: except ValidationError as e:
log.error(f"HookAPI POST: {e}") log.error(f"HookAPI POST: {e}")
return HttpResponseBadRequest(e) return HttpResponseBadRequest(e)

View File

@ -17,8 +17,10 @@ def get_positions(user, account_id=None):
items = [] items = []
accounts = Account.objects.filter(user=user) accounts = Account.objects.filter(user=user)
for account in accounts: for account in accounts:
positions = account.client.get_all_positions() success, positions = account.client.get_all_positions()
print("positions", positions) if not success:
items.append({"name": account.name, "status": "error"})
continue
for item in positions: for item in positions:
items.append(item) items.append(item)

View File

@ -15,6 +15,6 @@ django-debug-toolbar-template-profiler
orjson orjson
django-otp django-otp
qrcode qrcode
serde[ext] pydantic
alpaca-py alpaca-py
oandapyV20 oandapyV20