You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

152 lines
4.9 KiB
Python

import functools
from alpaca.common.exceptions import APIError
from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide, TimeInForce
from alpaca.trading.requests import (
GetAssetsRequest,
LimitOrderRequest,
MarketOrderRequest,
)
from core.exchanges import BaseExchange, ExchangeError, GenericAPIError
def handle_errors(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
return_value = func(self, *args, **kwargs)
if isinstance(return_value, tuple):
if return_value[0] is False:
print("Error: ", return_value[1])
return return_value
return return_value[1]
return wrapper
class AlpacaExchange(BaseExchange):
def connect(self):
self.client = TradingClient(
self.account.api_key,
self.account.api_secret,
paper=self.account.sandbox,
raw_data=True,
)
@handle_errors
def get_account(self):
# return self.call("get_account")
market = self.get_market_value("NONEXISTENT")
print("MARTKET", market)
def get_supported_assets(self):
request = GetAssetsRequest(status="active", asset_class="crypto")
assets = self.call("get_all_assets", filter=request)
assets = assets["itemlist"]
asset_list = [x["symbol"] for x in assets if "symbol" in x]
print("Supported symbols", asset_list)
return asset_list
def get_balance(self):
account_info = self.call("get_account")
equity = account_info["equity"]
try:
balance = float(equity)
except ValueError:
raise GenericAPIError(f"Balance is not a float: {equity}")
return balance
def get_market_value(self, symbol): # TODO: pydantic
try:
position = self.client.get_position(symbol)
except APIError as e:
self.log.error(f"Could not get market value for {symbol}: {e}")
raise GenericAPIError(e)
return float(position["market_value"])
def post_trade(self, trade): # TODO: pydantic
# the trade is not placed yet
if trade.direction == "buy":
direction = OrderSide.BUY
elif trade.direction == "sell":
direction = OrderSide.SELL
else:
raise ExchangeError("Unknown direction")
cast = {
"symbol": trade.symbol,
"side": direction,
"time_in_force": TimeInForce.IOC,
}
if trade.amount is not None:
cast["qty"] = trade.amount
if trade.amount_usd is not None:
cast["notional"] = trade.amount_usd
if not trade.amount and not trade.amount_usd:
raise ExchangeError("No amount specified")
if trade.take_profit:
cast["take_profit"] = {"limit_price": trade.take_profit}
if trade.stop_loss:
stop_limit_price = trade.stop_loss - (trade.stop_loss * 0.005)
cast["stop_loss"] = {
"stop_price": trade.stop_loss,
"limit_price": stop_limit_price,
}
if trade.type == "market":
market_order_data = MarketOrderRequest(**cast)
try:
order = self.client.submit_order(order_data=market_order_data)
except APIError as e:
self.log.error(f"Error placing market order: {e}")
trade.status = "error"
trade.save()
raise GenericAPIError(e)
elif trade.type == "limit":
if not trade.price:
raise ExchangeError("No price specified for limit order")
cast["limit_price"] = trade.price
limit_order_data = LimitOrderRequest(**cast)
try:
order = self.client.submit_order(order_data=limit_order_data)
except APIError as e:
self.log.error(f"Error placing limit order: {e}")
trade.status = "error"
trade.save()
raise GenericAPIError(e)
else:
raise ExchangeError("Unknown trade type")
trade.response = order
trade.status = "posted"
trade.order_id = order["id"]
trade.client_order_id = order["client_order_id"]
trade.save()
return order
def get_trade(self, trade_id):
pass
def update_trade(self, trade):
pass
def cancel_trade(self, trade_id):
pass
def get_position_info(self, symbol):
position = self.call("get_open_position", symbol)
return position
def get_all_positions(self):
items = []
response = self.call("get_all_positions")
for item in response["itemlist"]:
item["account"] = self.account.name
item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item)
return items