Simplify schema and error handling

This commit is contained in:
2022-11-04 07:20:55 +00:00
parent 04a87c1da6
commit b36791d56b
8 changed files with 181 additions and 128 deletions

View File

@@ -1,3 +1,5 @@
import functools
from alpaca.common.exceptions import APIError
from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide, TimeInForce
@@ -7,21 +9,23 @@ from alpaca.trading.requests import (
MarketOrderRequest,
)
from core.exchanges import BaseExchange
from core.lib.schemas import alpaca_s
from core.exchanges import BaseExchange, ExchangeError, GenericAPIError
ALPACA_SCHEMA_MAPPING = {
"get_account": alpaca_s.GetAccount,
"get_all_assets": alpaca_s.GetAllAssets,
"get_all_positions": alpaca_s.GetAllPositions,
"get_open_position": alpaca_s.GetOpenPosition,
}
def handle_errors(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
return_value = func(self, *args, **kwargs)
if isinstance(return_value, tuple):
if return_value[0] is False:
print("Error: ", return_value[1])
return return_value
return return_value[1]
return wrapper
class AlpacaExchange(BaseExchange):
def set_schema(self):
self.schema = ALPACA_SCHEMA_MAPPING
def connect(self):
self.client = TradingClient(
self.account.api_key,
@@ -30,39 +34,37 @@ class AlpacaExchange(BaseExchange):
raw_data=True,
)
@handle_errors
def get_account(self):
return self.call("get_account")
# return self.call("get_account")
market = self.get_market_value("NONEXISTENT")
print("MARTKET", market)
def get_supported_assets(self):
request = GetAssetsRequest(status="active", asset_class="crypto")
success, assets = self.call("get_all_assets", filter=request)
# assets = self.client.get_all_assets(filter=request)
if not success:
return (success, assets)
assets = self.call("get_all_assets", filter=request)
assets = assets["itemlist"]
asset_list = [x["symbol"] for x in assets if "symbol" in x]
print("Supported symbols", asset_list)
return (True, asset_list)
return asset_list
def get_balance(self):
success, account_info = self.call("get_account")
if not success:
return (success, account_info)
account_info = self.call("get_account")
equity = account_info["equity"]
try:
balance = float(equity)
except ValueError:
return (False, "Invalid balance")
raise GenericAPIError(f"Balance is not a float: {equity}")
return (True, balance)
return balance
def get_market_value(self, symbol): # TODO: pydantic
try:
position = self.client.get_position(symbol)
except APIError as e:
self.log.error(f"Could not get market value for {symbol}: {e}")
return False
raise GenericAPIError(e)
return float(position["market_value"])
def post_trade(self, trade): # TODO: pydantic
@@ -72,7 +74,7 @@ class AlpacaExchange(BaseExchange):
elif trade.direction == "sell":
direction = OrderSide.SELL
else:
raise Exception("Unknown direction")
raise ExchangeError("Unknown direction")
cast = {
"symbol": trade.symbol,
@@ -84,7 +86,7 @@ class AlpacaExchange(BaseExchange):
if trade.amount_usd is not None:
cast["notional"] = trade.amount_usd
if not trade.amount and not trade.amount_usd:
return (False, "No amount specified")
raise ExchangeError("No amount specified")
if trade.take_profit:
cast["take_profit"] = {"limit_price": trade.take_profit}
if trade.stop_loss:
@@ -101,10 +103,10 @@ class AlpacaExchange(BaseExchange):
self.log.error(f"Error placing market order: {e}")
trade.status = "error"
trade.save()
return (False, e)
raise GenericAPIError(e)
elif trade.type == "limit":
if not trade.price:
return (False, "Limit order with no price")
raise ExchangeError("No price specified for limit order")
cast["limit_price"] = trade.price
limit_order_data = LimitOrderRequest(**cast)
try:
@@ -113,16 +115,16 @@ class AlpacaExchange(BaseExchange):
self.log.error(f"Error placing limit order: {e}")
trade.status = "error"
trade.save()
return (False, e)
raise GenericAPIError(e)
else:
raise Exception("Unknown trade type")
raise ExchangeError("Unknown trade type")
trade.response = order
trade.status = "posted"
trade.order_id = order["id"]
trade.client_order_id = order["client_order_id"]
trade.save()
return (True, order)
return order
def get_trade(self, trade_id):
pass
@@ -134,20 +136,16 @@ class AlpacaExchange(BaseExchange):
pass
def get_position_info(self, symbol):
success, position = self.call("get_open_position", symbol)
if not success:
return (success, position)
return (True, position)
position = self.call("get_open_position", symbol)
return position
def get_all_positions(self):
items = []
success, response = self.call("get_all_positions")
if not success:
return (success, response)
response = self.call("get_all_positions")
for item in response["itemlist"]:
item["account"] = self.account.name
item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item)
return (True, items)
return items