Simplify schema and error handling
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import functools
|
||||
|
||||
from alpaca.common.exceptions import APIError
|
||||
from alpaca.trading.client import TradingClient
|
||||
from alpaca.trading.enums import OrderSide, TimeInForce
|
||||
@@ -7,21 +9,23 @@ from alpaca.trading.requests import (
|
||||
MarketOrderRequest,
|
||||
)
|
||||
|
||||
from core.exchanges import BaseExchange
|
||||
from core.lib.schemas import alpaca_s
|
||||
from core.exchanges import BaseExchange, ExchangeError, GenericAPIError
|
||||
|
||||
ALPACA_SCHEMA_MAPPING = {
|
||||
"get_account": alpaca_s.GetAccount,
|
||||
"get_all_assets": alpaca_s.GetAllAssets,
|
||||
"get_all_positions": alpaca_s.GetAllPositions,
|
||||
"get_open_position": alpaca_s.GetOpenPosition,
|
||||
}
|
||||
|
||||
def handle_errors(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
return_value = func(self, *args, **kwargs)
|
||||
if isinstance(return_value, tuple):
|
||||
if return_value[0] is False:
|
||||
print("Error: ", return_value[1])
|
||||
return return_value
|
||||
return return_value[1]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class AlpacaExchange(BaseExchange):
|
||||
def set_schema(self):
|
||||
self.schema = ALPACA_SCHEMA_MAPPING
|
||||
|
||||
def connect(self):
|
||||
self.client = TradingClient(
|
||||
self.account.api_key,
|
||||
@@ -30,39 +34,37 @@ class AlpacaExchange(BaseExchange):
|
||||
raw_data=True,
|
||||
)
|
||||
|
||||
@handle_errors
|
||||
def get_account(self):
|
||||
return self.call("get_account")
|
||||
# return self.call("get_account")
|
||||
market = self.get_market_value("NONEXISTENT")
|
||||
print("MARTKET", market)
|
||||
|
||||
def get_supported_assets(self):
|
||||
request = GetAssetsRequest(status="active", asset_class="crypto")
|
||||
success, assets = self.call("get_all_assets", filter=request)
|
||||
# assets = self.client.get_all_assets(filter=request)
|
||||
if not success:
|
||||
return (success, assets)
|
||||
assets = self.call("get_all_assets", filter=request)
|
||||
assets = assets["itemlist"]
|
||||
asset_list = [x["symbol"] for x in assets if "symbol" in x]
|
||||
print("Supported symbols", asset_list)
|
||||
|
||||
return (True, asset_list)
|
||||
return asset_list
|
||||
|
||||
def get_balance(self):
|
||||
success, account_info = self.call("get_account")
|
||||
if not success:
|
||||
return (success, account_info)
|
||||
account_info = self.call("get_account")
|
||||
equity = account_info["equity"]
|
||||
try:
|
||||
balance = float(equity)
|
||||
except ValueError:
|
||||
return (False, "Invalid balance")
|
||||
raise GenericAPIError(f"Balance is not a float: {equity}")
|
||||
|
||||
return (True, balance)
|
||||
return balance
|
||||
|
||||
def get_market_value(self, symbol): # TODO: pydantic
|
||||
try:
|
||||
position = self.client.get_position(symbol)
|
||||
except APIError as e:
|
||||
self.log.error(f"Could not get market value for {symbol}: {e}")
|
||||
return False
|
||||
raise GenericAPIError(e)
|
||||
return float(position["market_value"])
|
||||
|
||||
def post_trade(self, trade): # TODO: pydantic
|
||||
@@ -72,7 +74,7 @@ class AlpacaExchange(BaseExchange):
|
||||
elif trade.direction == "sell":
|
||||
direction = OrderSide.SELL
|
||||
else:
|
||||
raise Exception("Unknown direction")
|
||||
raise ExchangeError("Unknown direction")
|
||||
|
||||
cast = {
|
||||
"symbol": trade.symbol,
|
||||
@@ -84,7 +86,7 @@ class AlpacaExchange(BaseExchange):
|
||||
if trade.amount_usd is not None:
|
||||
cast["notional"] = trade.amount_usd
|
||||
if not trade.amount and not trade.amount_usd:
|
||||
return (False, "No amount specified")
|
||||
raise ExchangeError("No amount specified")
|
||||
if trade.take_profit:
|
||||
cast["take_profit"] = {"limit_price": trade.take_profit}
|
||||
if trade.stop_loss:
|
||||
@@ -101,10 +103,10 @@ class AlpacaExchange(BaseExchange):
|
||||
self.log.error(f"Error placing market order: {e}")
|
||||
trade.status = "error"
|
||||
trade.save()
|
||||
return (False, e)
|
||||
raise GenericAPIError(e)
|
||||
elif trade.type == "limit":
|
||||
if not trade.price:
|
||||
return (False, "Limit order with no price")
|
||||
raise ExchangeError("No price specified for limit order")
|
||||
cast["limit_price"] = trade.price
|
||||
limit_order_data = LimitOrderRequest(**cast)
|
||||
try:
|
||||
@@ -113,16 +115,16 @@ class AlpacaExchange(BaseExchange):
|
||||
self.log.error(f"Error placing limit order: {e}")
|
||||
trade.status = "error"
|
||||
trade.save()
|
||||
return (False, e)
|
||||
raise GenericAPIError(e)
|
||||
|
||||
else:
|
||||
raise Exception("Unknown trade type")
|
||||
raise ExchangeError("Unknown trade type")
|
||||
trade.response = order
|
||||
trade.status = "posted"
|
||||
trade.order_id = order["id"]
|
||||
trade.client_order_id = order["client_order_id"]
|
||||
trade.save()
|
||||
return (True, order)
|
||||
return order
|
||||
|
||||
def get_trade(self, trade_id):
|
||||
pass
|
||||
@@ -134,20 +136,16 @@ class AlpacaExchange(BaseExchange):
|
||||
pass
|
||||
|
||||
def get_position_info(self, symbol):
|
||||
success, position = self.call("get_open_position", symbol)
|
||||
if not success:
|
||||
return (success, position)
|
||||
return (True, position)
|
||||
position = self.call("get_open_position", symbol)
|
||||
return position
|
||||
|
||||
def get_all_positions(self):
|
||||
items = []
|
||||
success, response = self.call("get_all_positions")
|
||||
if not success:
|
||||
return (success, response)
|
||||
response = self.call("get_all_positions")
|
||||
|
||||
for item in response["itemlist"]:
|
||||
item["account"] = self.account.name
|
||||
item["account_id"] = self.account.id
|
||||
item["unrealized_pl"] = float(item["unrealized_pl"])
|
||||
items.append(item)
|
||||
return (True, items)
|
||||
return items
|
||||
|
||||
Reference in New Issue
Block a user