Wrap API calls in helper and validate response
This commit is contained in:
parent
f22fcfdaaa
commit
c15ae379f5
|
@ -1,3 +1,5 @@
|
|||
from pydantic import ValidationError
|
||||
|
||||
from core.util import logs
|
||||
|
||||
|
||||
|
@ -8,11 +10,36 @@ class BaseExchange(object):
|
|||
self.log = logs.get_logger(name)
|
||||
self.client = None
|
||||
|
||||
self.set_schema()
|
||||
self.connect()
|
||||
|
||||
def set_schema(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def connect(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, method, *args, **kwargs) -> (bool, dict):
|
||||
|
||||
if hasattr(self.client, method):
|
||||
try:
|
||||
response = getattr(self.client, method)(*args, **kwargs)
|
||||
if isinstance(response, list):
|
||||
response = {"itemlist": response}
|
||||
if method not in self.schema:
|
||||
self.log.error(f"Method cannot be validated: {method}")
|
||||
self.log.debug(f"Response: {response}")
|
||||
return (False, f"Method cannot be validated: {method}")
|
||||
return (True, self.schema[method](**response).dict())
|
||||
except ValidationError as e:
|
||||
self.log.error(f"Could not validate response: {e}")
|
||||
return (False, e)
|
||||
except Exception as e:
|
||||
self.log.error(f"Error calling {method}: {e}")
|
||||
return (False, e)
|
||||
else:
|
||||
return (False, "No such method")
|
||||
|
||||
def get_account(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -8,9 +8,20 @@ from alpaca.trading.requests import (
|
|||
)
|
||||
|
||||
from core.exchanges import BaseExchange
|
||||
from core.lib.schemas import alpaca_s
|
||||
|
||||
ALPACA_SCHEMA_MAPPING = {
|
||||
"get_account": alpaca_s.GetAccount,
|
||||
"get_all_assets": alpaca_s.GetAllAssets,
|
||||
"get_all_positions": alpaca_s.GetAllPositions,
|
||||
"get_open_position": alpaca_s.GetOpenPosition,
|
||||
}
|
||||
|
||||
|
||||
class AlpacaExchange(BaseExchange):
|
||||
def set_schema(self):
|
||||
self.schema = ALPACA_SCHEMA_MAPPING
|
||||
|
||||
def connect(self):
|
||||
self.client = TradingClient(
|
||||
self.account.api_key,
|
||||
|
@ -20,32 +31,31 @@ class AlpacaExchange(BaseExchange):
|
|||
)
|
||||
|
||||
def get_account(self):
|
||||
return self.client.get_account()
|
||||
return self.call("get_account")
|
||||
|
||||
def get_supported_assets(self):
|
||||
try:
|
||||
request = GetAssetsRequest(status="active", asset_class="crypto")
|
||||
assets = self.client.get_all_assets(filter=request)
|
||||
success, assets = self.call("get_all_assets", filter=request)
|
||||
# assets = self.client.get_all_assets(filter=request)
|
||||
if not success:
|
||||
return (success, assets)
|
||||
assets = assets["itemlist"]
|
||||
asset_list = [x["symbol"] for x in assets if "symbol" in x]
|
||||
print("Supported symbols", asset_list)
|
||||
except APIError as e:
|
||||
self.log.error(f"Could not get asset list: {e}")
|
||||
# return False
|
||||
return asset_list
|
||||
|
||||
return (True, asset_list)
|
||||
|
||||
def get_balance(self):
|
||||
try:
|
||||
account_info = self.client.get_account()
|
||||
except APIError as e:
|
||||
self.log.error(f"Could not get account balance: {e}")
|
||||
return False
|
||||
success, account_info = self.call("get_account")
|
||||
if not success:
|
||||
return (success, account_info)
|
||||
equity = account_info["equity"]
|
||||
try:
|
||||
balance = float(equity)
|
||||
except ValueError:
|
||||
return False
|
||||
return (False, "Invalid balance")
|
||||
|
||||
return balance
|
||||
return (True, balance)
|
||||
|
||||
def get_market_value(self, symbol):
|
||||
try:
|
||||
|
@ -89,6 +99,8 @@ class AlpacaExchange(BaseExchange):
|
|||
order = self.client.submit_order(order_data=market_order_data)
|
||||
except APIError as e:
|
||||
self.log.error(f"Error placing market order: {e}")
|
||||
trade.status = "error"
|
||||
trade.save()
|
||||
return (False, e)
|
||||
elif trade.type == "limit":
|
||||
if not trade.price:
|
||||
|
@ -99,6 +111,8 @@ class AlpacaExchange(BaseExchange):
|
|||
order = self.client.submit_order(order_data=limit_order_data)
|
||||
except APIError as e:
|
||||
self.log.error(f"Error placing limit order: {e}")
|
||||
trade.status = "error"
|
||||
trade.save()
|
||||
return (False, e)
|
||||
|
||||
else:
|
||||
|
@ -120,19 +134,19 @@ class AlpacaExchange(BaseExchange):
|
|||
pass
|
||||
|
||||
def get_position_info(self, asset_id):
|
||||
try:
|
||||
position = self.client.get_open_position(asset_id)
|
||||
except APIError as e:
|
||||
return (False, e)
|
||||
success, position = self.call("get_open_position", asset_id)
|
||||
if not success:
|
||||
return (success, position)
|
||||
return (True, position)
|
||||
|
||||
def get_all_positions(self):
|
||||
items = []
|
||||
positions = self.client.get_all_positions()
|
||||
success, positions = self.call("get_all_positions")
|
||||
if not success:
|
||||
return (success, positions)
|
||||
|
||||
for item in positions:
|
||||
item = dict(item)
|
||||
for item in positions["itemlist"]:
|
||||
item["account_id"] = self.account.id
|
||||
item["unrealized_pl"] = float(item["unrealized_pl"])
|
||||
items.append(item)
|
||||
return items
|
||||
return (True, items)
|
||||
|
|
|
@ -2,9 +2,15 @@ from oandapyV20 import API
|
|||
from oandapyV20.endpoints import accounts, orders, positions, trades
|
||||
|
||||
from core.exchanges import BaseExchange
|
||||
from core.lib.schemas import oanda_s
|
||||
|
||||
OANDA_SCHEMA_MAPPING = {}
|
||||
|
||||
|
||||
class OANDAExchange(BaseExchange):
|
||||
def set_schema(self):
|
||||
self.schema = OANDA_SCHEMA_MAPPING
|
||||
|
||||
def connect(self):
|
||||
self.client = API(access_token=self.account.api_secret)
|
||||
self.account_id = self.account.api_key
|
||||
|
@ -53,4 +59,4 @@ class OANDAExchange(BaseExchange):
|
|||
def get_all_positions(self):
|
||||
r = positions.OpenPositions(accountID=self.account_id)
|
||||
self.client.request(r)
|
||||
return r.response["positions"]
|
||||
return (True, [])
|
||||
|
|
|
@ -24,9 +24,9 @@ def get_market_value(account, symbol):
|
|||
|
||||
|
||||
def execute_strategy(callback, strategy):
|
||||
cash_balance = get_balance(strategy.account)
|
||||
success, cash_balance = strategy.account.client.get_balance()
|
||||
log.debug(f"Cash balance: {cash_balance}")
|
||||
if not cash_balance:
|
||||
if not success:
|
||||
return None
|
||||
|
||||
user = strategy.user
|
||||
|
@ -41,7 +41,7 @@ def execute_strategy(callback, strategy):
|
|||
quote = "usd" # TODO: MASSIVE HACK
|
||||
symbol = f"{base.upper()}/{quote.upper()}"
|
||||
|
||||
if symbol not in account.supported_assets:
|
||||
if symbol not in account.supported_symbols:
|
||||
log.error(f"Symbol not supported by account: {symbol}")
|
||||
return False
|
||||
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
class Asset(BaseModel):
|
||||
id: str
|
||||
class_: str = Field(..., alias="class")
|
||||
exchange: str
|
||||
symbol: str
|
||||
name: str
|
||||
status: str
|
||||
tradable: bool
|
||||
marginable: bool
|
||||
maintenance_margin_requirement: int
|
||||
shortable: bool
|
||||
easy_to_borrow: bool
|
||||
fractionable: bool
|
||||
min_order_size: str
|
||||
min_trade_increment: str
|
||||
price_increment: str
|
||||
|
||||
# get_all_assets
|
||||
class GetAllAssets(BaseModel):
|
||||
itemlist: list[Asset]
|
||||
|
||||
# get_open_position
|
||||
class GetOpenPosition(BaseModel):
|
||||
asset_id: str
|
||||
symbol: str
|
||||
exchange: str
|
||||
asset_class: str = Field(..., alias="asset_class")
|
||||
asset_marginable: bool
|
||||
qty: str
|
||||
avg_entry_price: str
|
||||
side: str
|
||||
market_value: str
|
||||
cost_basis: str
|
||||
unrealized_pl: str
|
||||
unrealized_plpc: str
|
||||
unrealized_intraday_pl: str
|
||||
unrealized_intraday_plpc: str
|
||||
current_price: str
|
||||
lastday_price: str
|
||||
change_today: str
|
||||
qty_available: str
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
asset_id: str
|
||||
symbol: str
|
||||
exchange: str
|
||||
asset_class: str
|
||||
asset_marginable: bool
|
||||
qty: str
|
||||
avg_entry_price: str
|
||||
side: str
|
||||
market_value: str
|
||||
cost_basis: str
|
||||
unrealized_pl: str
|
||||
unrealized_plpc: str
|
||||
unrealized_intraday_pl: str
|
||||
unrealized_intraday_plpc: str
|
||||
current_price: str
|
||||
lastday_price: str
|
||||
change_today: str
|
||||
qty_available: str
|
||||
|
||||
# get_all_positions
|
||||
class GetAllPositions(BaseModel):
|
||||
itemlist: list[Position]
|
||||
|
||||
# get_account
|
||||
class GetAccount(BaseModel):
|
||||
id: str
|
||||
account_number: str
|
||||
status: str
|
||||
crypto_status: str
|
||||
currency: str
|
||||
buying_power: str
|
||||
regt_buying_power: str
|
||||
daytrading_buying_power: str
|
||||
effective_buying_power: str
|
||||
non_marginable_buying_power: str
|
||||
bod_dtbp: str
|
||||
cash: str
|
||||
accrued_fees: str
|
||||
pending_transfer_in: str
|
||||
portfolio_value: str
|
||||
pattern_day_trader: bool
|
||||
trading_blocked: bool
|
||||
transfers_blocked: bool
|
||||
account_blocked: bool
|
||||
created_at: str
|
||||
trade_suspended_by_user: bool
|
||||
multiplier: str
|
||||
shorting_enabled: bool
|
||||
equity: str
|
||||
last_equity: str
|
||||
long_market_value: str
|
||||
short_market_value: str
|
||||
position_market_value: str
|
||||
initial_margin: str
|
||||
maintenance_margin: str
|
||||
last_maintenance_margin: str
|
||||
sma: str
|
||||
daytrade_count: int
|
||||
balance_asof: str
|
|
@ -0,0 +1,21 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DrakdooMarket(BaseModel):
|
||||
exchange: str
|
||||
item: str
|
||||
currency: str
|
||||
contract: str
|
||||
|
||||
|
||||
class DrakdooTimestamp(BaseModel):
|
||||
sent: int
|
||||
trade: int
|
||||
|
||||
|
||||
class DrakdooCallback(BaseModel):
|
||||
title: str
|
||||
message: str
|
||||
period: str
|
||||
market: DrakdooMarket
|
||||
timestamp: DrakdooTimestamp
|
|
@ -0,0 +1 @@
|
|||
from pydantic import BaseModel
|
|
@ -1,125 +0,0 @@
|
|||
from serde import Model, fields
|
||||
|
||||
# {
|
||||
# "id": "92f0b26b-4c98-4553-9c74-cdafc7e037db",
|
||||
# "clientOrderId": "ccxt_26adcbf445674f01af38a66a15e6f5b5",
|
||||
# "timestamp": 1666096856515,
|
||||
# "datetime": "2022-10-18T12:40:56.515477181Z",
|
||||
# "lastTradeTimeStamp": null,
|
||||
# "status": "open",
|
||||
# "symbol": "BTC/USD",
|
||||
# "type": "market",
|
||||
# "timeInForce": "gtc",
|
||||
# "postOnly": null,
|
||||
# "side": "buy",
|
||||
# "price": null,
|
||||
# "stopPrice": null,
|
||||
# "cost": null,
|
||||
# "average": null,
|
||||
# "amount": 1.1,
|
||||
# "filled": 0.0,
|
||||
# "remaining": 1.1,
|
||||
# "trades": [],
|
||||
# "fee": null,
|
||||
# "info": {
|
||||
# "id": "92f0b26b-4c98-4553-9c74-cdafc7e037db",
|
||||
# "client_order_id": "ccxt_26adcbf445674f01af38a66a15e6f5b5",
|
||||
# "created_at": "2022-10-18T12:40:56.516095561Z",
|
||||
# "updated_at": "2022-10-18T12:40:56.516173841Z",
|
||||
# "submitted_at": "2022-10-18T12:40:56.515477181Z",
|
||||
# "filled_at": null,
|
||||
# "expired_at": null,
|
||||
# "canceled_at": null,
|
||||
# "failed_at": null,
|
||||
# "replaced_at": null,
|
||||
# "replaced_by": null,
|
||||
# "replaces": null,
|
||||
# "asset_id": "276e2673-764b-4ab6-a611-caf665ca6340",
|
||||
# "symbol": "BTC/USD",
|
||||
# "asset_class": "crypto",
|
||||
# "notional": null,
|
||||
# "qty": "1.1",
|
||||
# "filled_qty": "0",
|
||||
# "filled_avg_price": null,
|
||||
# "order_class": "",
|
||||
# "order_type": "market",
|
||||
# "type": "market",
|
||||
# "side": "buy",
|
||||
# "time_in_force": "gtc",
|
||||
# "limit_price": null,
|
||||
# "stop_price": null,
|
||||
# "status": "pending_new",
|
||||
# "extended_hours": false,
|
||||
# "legs": null,
|
||||
# "trail_percent": null,
|
||||
# "trail_price": null,
|
||||
# "hwm": null,
|
||||
# "subtag": null,
|
||||
# "source": null
|
||||
# },
|
||||
# "fees": [],
|
||||
# "lastTradeTimestamp": null
|
||||
# }
|
||||
|
||||
|
||||
class CCXTInfo(Model):
|
||||
id = fields.Uuid()
|
||||
client_order_id = fields.Str()
|
||||
created_at = fields.Str()
|
||||
updated_at = fields.Str()
|
||||
submitted_at = fields.Str()
|
||||
filled_at = fields.Optional(fields.Str())
|
||||
expired_at = fields.Optional(fields.Str())
|
||||
canceled_at = fields.Optional(fields.Str())
|
||||
failed_at = fields.Optional(fields.Str())
|
||||
replaced_at = fields.Optional(fields.Str())
|
||||
replaced_by = fields.Optional(fields.Str())
|
||||
replaces = fields.Optional(fields.Str())
|
||||
asset_id = fields.Uuid()
|
||||
symbol = fields.Str()
|
||||
asset_class = fields.Str()
|
||||
notional = fields.Optional(fields.Str())
|
||||
qty = fields.Str()
|
||||
filled_qty = fields.Str()
|
||||
filled_avg_price = fields.Optional(fields.Str())
|
||||
order_class = fields.Str()
|
||||
order_type = fields.Str()
|
||||
type = fields.Str()
|
||||
side = fields.Str()
|
||||
time_in_force = fields.Str()
|
||||
limit_price = fields.Optional(fields.Str())
|
||||
stop_price = fields.Optional(fields.Str())
|
||||
status = fields.Str()
|
||||
extended_hours = fields.Bool()
|
||||
legs = fields.Optional(fields.List(fields.Nested("CCXTInfo")))
|
||||
trail_percent = fields.Optional(fields.Str())
|
||||
trail_price = fields.Optional(fields.Str())
|
||||
hwm = fields.Optional(fields.Str())
|
||||
subtag = fields.Optional(fields.Str())
|
||||
source = fields.Optional(fields.Str())
|
||||
|
||||
|
||||
class CCXTRoot(Model):
|
||||
id = fields.Uuid()
|
||||
clientOrderId = fields.Str()
|
||||
timestamp = fields.Int()
|
||||
datetime = fields.Str()
|
||||
lastTradeTimeStamp = fields.Optional(fields.Str())
|
||||
status = fields.Str()
|
||||
symbol = fields.Str()
|
||||
type = fields.Str()
|
||||
timeInForce = fields.Str()
|
||||
postOnly = fields.Optional(fields.Str())
|
||||
side = fields.Str()
|
||||
price = fields.Optional(fields.Float())
|
||||
stopPrice = fields.Optional(fields.Float())
|
||||
cost = fields.Optional(fields.Float())
|
||||
average = fields.Optional(fields.Float())
|
||||
amount = fields.Float()
|
||||
filled = fields.Float()
|
||||
remaining = fields.Float()
|
||||
trades = fields.Optional(fields.List(fields.Dict()))
|
||||
fee = fields.Optional(fields.Float())
|
||||
info = fields.Nested(CCXTInfo)
|
||||
fees = fields.Optional(fields.List(fields.Dict()))
|
||||
lastTradeTimestamp = fields.Optional(fields.Str())
|
|
@ -91,8 +91,8 @@ class Account(models.Model):
|
|||
"""
|
||||
client = self.get_client()
|
||||
if client:
|
||||
supported_symbols = client.get_supported_assets()
|
||||
if supported_symbols:
|
||||
success, supported_symbols = client.get_supported_assets()
|
||||
if success:
|
||||
self.supported_symbols = supported_symbols
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
{% include 'partials/notify.html' %}
|
||||
|
||||
<h1 class="title">Live information</h1>
|
||||
<table class="table is-fullwidth is-hoverable">
|
||||
<thead>
|
||||
|
|
|
@ -39,7 +39,15 @@ class AccountInfo(LoginRequiredMixin, View):
|
|||
}
|
||||
return render(request, template_name, context)
|
||||
|
||||
live_info = dict(account.client.get_account())
|
||||
success, live_info = account.client.get_account()
|
||||
if not success:
|
||||
context = {
|
||||
"message": "Could not get account info",
|
||||
"class": "danger",
|
||||
"window_content": self.window_content,
|
||||
}
|
||||
return render(request, template_name, context)
|
||||
live_info = live_info
|
||||
account_info = account.__dict__
|
||||
account_info = {
|
||||
k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL
|
||||
|
|
|
@ -4,13 +4,13 @@ import orjson
|
|||
from django.conf import settings
|
||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||
from django.http import HttpResponse, HttpResponseBadRequest
|
||||
from pydantic import ValidationError
|
||||
from rest_framework.parsers import JSONParser
|
||||
from rest_framework.views import APIView
|
||||
from serde import ValidationError
|
||||
|
||||
from core.forms import HookForm
|
||||
from core.lib import market
|
||||
from core.lib.serde import drakdoo_s
|
||||
from core.lib.schemas.drakdoo_s import DrakdooCallback
|
||||
from core.models import Callback, Hook
|
||||
from core.util import logs
|
||||
from core.views import ObjectCreate, ObjectDelete, ObjectList, ObjectUpdate
|
||||
|
@ -44,7 +44,7 @@ class HookAPI(APIView):
|
|||
|
||||
# Try validating the JSON
|
||||
try:
|
||||
hook_resp = drakdoo_s.BaseDrakdoo.from_dict(request.data)
|
||||
hook_resp = DrakdooCallback(**request.data)
|
||||
except ValidationError as e:
|
||||
log.error(f"HookAPI POST: {e}")
|
||||
return HttpResponseBadRequest(e)
|
||||
|
|
|
@ -17,8 +17,10 @@ def get_positions(user, account_id=None):
|
|||
items = []
|
||||
accounts = Account.objects.filter(user=user)
|
||||
for account in accounts:
|
||||
positions = account.client.get_all_positions()
|
||||
print("positions", positions)
|
||||
success, positions = account.client.get_all_positions()
|
||||
if not success:
|
||||
items.append({"name": account.name, "status": "error"})
|
||||
continue
|
||||
|
||||
for item in positions:
|
||||
items.append(item)
|
||||
|
|
|
@ -15,6 +15,6 @@ django-debug-toolbar-template-profiler
|
|||
orjson
|
||||
django-otp
|
||||
qrcode
|
||||
serde[ext]
|
||||
pydantic
|
||||
alpaca-py
|
||||
oandapyV20
|
||||
|
|
Loading…
Reference in New Issue