Compare commits

...

2 Commits

11 changed files with 287 additions and 35 deletions

View File

@ -102,6 +102,11 @@ urlpatterns = [
trades.TradeDelete.as_view(), trades.TradeDelete.as_view(),
name="trade_delete", name="trade_delete",
), ),
path(
"trades/action/delete_all/",
trades.TradeDeleteAll.as_view(),
name="trade_delete_all",
),
path("positions/<str:type>/", positions.Positions.as_view(), name="positions"), path("positions/<str:type>/", positions.Positions.as_view(), name="positions"),
path( path(
"positions/<str:type>/<str:account_id>/", "positions/<str:type>/<str:account_id>/",

View File

@ -0,0 +1,42 @@
from core.util import logs
class BaseExchange(object):
def __init__(self, account):
name = self.__class__.__name__
self.account = account
self.log = logs.get_logger(name)
self.client = None
self.connect()
def connect(self):
raise NotImplementedError
def get_supported_assets(self):
raise NotImplementedError
def get_balance(self):
raise NotImplementedError
def get_market_value(self, symbol):
raise NotImplementedError
def post_trade(self, trade):
raise NotImplementedError
def get_trade(self, trade_id):
raise NotImplementedError
def update_trade(self, trade):
raise NotImplementedError
def cancel_trade(self, trade_id):
raise NotImplementedError
def get_position_info(self, asset_id):
raise NotImplementedError
def get_all_positions(self):
raise NotImplementedError

125
core/exchanges/alpaca.py Normal file
View File

@ -0,0 +1,125 @@
from core.exchanges import BaseExchange
from alpaca.trading.client import TradingClient
from alpaca.trading.requests import GetAssetsRequest
from alpaca.common.exceptions import APIError
from alpaca.trading.enums import OrderSide, TimeInForce
from alpaca.trading.requests import LimitOrderRequest, MarketOrderRequest
class AlpacaExchange(BaseExchange):
def connect(self):
self.client = TradingClient(
self.account.api_key, self.account.api_secret, paper=self.account.sandbox, raw_data=True
)
def get_supported_assets(self):
try:
request = GetAssetsRequest(status="active", asset_class="crypto")
assets = self.client.get_all_assets(filter=request)
asset_list = [x["symbol"] for x in assets if "symbol" in x]
print("Supported symbols", asset_list)
except APIError as e:
log.error(f"Could not get asset list: {e}")
# return False
return asset_list
def get_balance(self):
try:
account_info = self.client.get_account()
except APIError as e:
self.log.error(f"Could not get account balance: {e}")
return False
equity = account_info["equity"]
try:
balance = float(equity)
except ValueError:
return False
return balance
def get_market_value(self, symbol):
try:
position = self.client.get_position(symbol)
except APIError as e:
self.log.error(f"Could not get market value for {symbol}: {e}")
return False
return float(position["market_value"])
def post_trade(self, trade):
# the trade is not placed yet
if trade.direction == "buy":
direction = OrderSide.BUY
elif trade.direction == "sell":
direction = OrderSide.SELL
else:
raise Exception("Unknown direction")
cast = {"symbol": trade.symbol, "side": direction, "time_in_force": TimeInForce.IOC}
if trade.amount is not None:
cast["qty"] = trade.amount
if trade.amount_usd is not None:
cast["notional"] = trade.amount_usd
if not trade.amount and not trade.amount_usd:
return (False, "No amount specified")
if trade.take_profit:
cast["take_profit"] = {"limit_price": trade.take_profit}
if trade.stop_loss:
stop_limit_price = trade.stop_loss - (trade.stop_loss * 0.005)
cast["stop_loss"] = {
"stop_price": trade.stop_loss,
"limit_price": stop_limit_price,
}
if trade.type == "market":
market_order_data = MarketOrderRequest(**cast)
try:
order = self.client.submit_order(order_data=market_order_data)
except APIError as e:
log.error(f"Error placing market order: {e}")
return (False, e)
elif trade.type == "limit":
if not trade.price:
return (False, "Limit order with no price")
cast["limit_price"] = trade.price
limit_order_data = LimitOrderRequest(**cast)
try:
order = self.client.submit_order(order_data=limit_order_data)
except APIError as e:
log.error(f"Error placing limit order: {e}")
return (False, e)
else:
raise Exception("Unknown trade type")
trade.response = order
trade.status = "posted"
trade.order_id = order["id"]
trade.client_order_id = order["client_order_id"]
trade.save()
return (True, order)
def get_trade(self, trade_id):
pass
def update_trade(self, trade):
pass
def cancel_trade(self, trade_id):
pass
def get_position_info(self, asset_id):
try:
position = self.client.get_open_position(asset_id)
except APIError as e:
return (False, e)
return (True, position)
def get_all_positions(self):
items = []
positions = self.client.get_all_positions()
for item in positions:
item = dict(item)
item["account_id"] = self.account.id
item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item)
return items

36
core/exchanges/oanda.py Normal file
View File

@ -0,0 +1,36 @@
from core.exchanges import BaseExchange
from oandapyV20 import API
from oandapyV20.endpoints import trades
from oandapyV20.endpoints import positions
class OANDAExchange(BaseExchange):
def connect(self):
self.client = API(access_token=self.account.api_secret)
self.account_id = self.account.api_key
def get_supported_assets(self):
raise NotImplementedError
def get_balance(self):
raise NotImplementedError
def get_market_value(self, symbol):
raise NotImplementedError
def post_trade(self, trade):
raise NotImplementedError
def get_trade(self, trade_id):
raise NotImplementedError
def update_trade(self, trade):
raise NotImplementedError
def cancel_trade(self, trade_id):
raise NotImplementedError
def get_position_info(self, asset_id):
raise NotImplementedError
def get_all_positions(self):
pass

View File

@ -1,16 +1,16 @@
import stripe import stripe
from alpaca.common.exceptions import APIError
from alpaca.trading.client import TradingClient
from alpaca.trading.requests import GetAssetsRequest
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
from django.db import models from django.db import models
from core.exchanges.alpaca import AlpacaExchange
from core.exchanges.oanda import OANDAExchange
from core.lib import trades from core.lib import trades
from core.lib.customers import get_or_create, update_customer_fields from core.lib.customers import get_or_create, update_customer_fields
from core.util import logs from core.util import logs
log = logs.get_logger(__name__) log = logs.get_logger(__name__)
EXCHANGE_MAP = {"alpaca": AlpacaExchange, "oanda": OANDAExchange}
class Plan(models.Model): class Plan(models.Model):
@ -70,7 +70,7 @@ class User(AbstractUser):
class Account(models.Model): class Account(models.Model):
EXCHANGE_CHOICES = (("alpaca", "Alpaca"),) EXCHANGE_CHOICES = (("alpaca", "Alpaca"), ("oanda", "OANDA"))
user = models.ForeignKey(User, on_delete=models.CASCADE) user = models.ForeignKey(User, on_delete=models.CASCADE)
name = models.CharField(max_length=255) name = models.CharField(max_length=255)
exchange = models.CharField(choices=EXCHANGE_CHOICES, max_length=255) exchange = models.CharField(choices=EXCHANGE_CHOICES, max_length=255)
@ -89,23 +89,18 @@ class Account(models.Model):
""" """
Override the save function to update supported symbols. Override the save function to update supported symbols.
""" """
try: client = self.get_client()
request = GetAssetsRequest(status="active", asset_class="crypto") if client:
assets = self.client.get_all_assets(filter=request) supported_symbols = client.get_supported_assets()
asset_list = [x["symbol"] for x in assets if "symbol" in x] if supported_symbols:
self.supported_symbols = asset_list self.supported_symbols = supported_symbols
print("Supported symbols", self.supported_symbols)
except APIError as e:
log.error(f"Could not get asset list: {e}")
# return False
super().save(*args, **kwargs) super().save(*args, **kwargs)
def get_client(self): def get_client(self):
trading_client = TradingClient( if self.exchange in EXCHANGE_MAP:
self.api_key, self.api_secret, paper=self.sandbox, raw_data=True return EXCHANGE_MAP[self.exchange](self)
) else:
return trading_client raise Exception("Exchange not supported")
@property @property
def client(self): def client(self):
@ -114,6 +109,13 @@ class Account(models.Model):
""" """
return self.get_client() return self.get_client()
@property
def rawclient(self):
"""
Convenience property for one-off API calls.
"""
return self.get_client().client
@classmethod @classmethod
def get_by_id(cls, account_id, user): def get_by_id(cls, account_id, user):
return cls.objects.get(id=account_id, user=user) return cls.objects.get(id=account_id, user=user)
@ -174,7 +176,7 @@ class Trade(models.Model):
self._original = self self._original = self
def post(self): def post(self):
return trades.post_trade(self) return self.account.client.post_trade(self)
def delete(self, *args, **kwargs): def delete(self, *args, **kwargs):
# close the trade # close the trade

View File

@ -5,9 +5,9 @@
{% if page_subtitle is not None %} {% if page_subtitle is not None %}
<h1 class="subtitle">{{ page_subtitle }}</h1> <h1 class="subtitle">{{ page_subtitle }}</h1>
{% endif %} {% endif %}
<div class="buttons">
{% if submit_url is not None %} {% if submit_url is not None %}
<div class="buttons">
<button <button
hx-headers='{"X-CSRFToken": "{{ csrf_token }}"}' hx-headers='{"X-CSRFToken": "{{ csrf_token }}"}'
hx-get="{{ submit_url }}" hx-get="{{ submit_url }}"
@ -21,8 +21,25 @@
<span>{{ title_singular }}</span> <span>{{ title_singular }}</span>
</span> </span>
</button> </button>
</div> {% endif %}
{% endif %} {% if delete_all_url is not None %}
<button
hx-headers='{"X-CSRFToken": "{{ csrf_token }}"}'
hx-delete="{{ delete_all_url }}"
hx-trigger="click"
hx-target="#modals-here"
hx-swap="innerHTML"
hx-confirm="Are you sure you wish to delete all {{ context_object_name }}?"
class="button is-info">
<span class="icon-text">
<span class="icon">
<i class="fa-solid fa-plus"></i>
</span>
<span>Delete all {{ context_object_name }} </span>
</span>
</button>
{% endif %}
</div>
{% include list_template %} {% include list_template %}

View File

@ -32,6 +32,8 @@ class ObjectList(ListView):
submit_url_name = None submit_url_name = None
delete_all_url_name = None
# copied from BaseListView # copied from BaseListView
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
self.object_list = self.get_queryset() self.object_list = self.get_queryset()
@ -80,6 +82,8 @@ class ObjectList(ListView):
context["list_url"] = list_url context["list_url"] = list_url
context["context_object_name"] = self.context_object_name context["context_object_name"] = self.context_object_name
context["context_object_name_singular"] = self.context_object_name_singular context["context_object_name_singular"] = self.context_object_name_singular
if self.delete_all_url_name:
context["delete_all_url"] = reverse(self.delete_all_url_name)
# Return partials for HTMX # Return partials for HTMX
if self.request.htmx: if self.request.htmx:
@ -102,6 +106,9 @@ class ObjectCreate(CreateView):
request = None request = None
def post_save(self, obj):
pass
def form_valid(self, form): def form_valid(self, form):
obj = form.save(commit=False) obj = form.save(commit=False)
if self.request is None: if self.request is None:
@ -109,6 +116,7 @@ class ObjectCreate(CreateView):
obj.user = self.request.user obj.user = self.request.user
obj.save() obj.save()
form.save_m2m() form.save_m2m()
self.post_save(obj)
context = {"message": "Object created", "class": "success"} context = {"message": "Object created", "class": "success"}
response = self.render_to_response(context) response = self.render_to_response(context)
response["HX-Trigger"] = f"{self.context_object_name_singular}Event" response["HX-Trigger"] = f"{self.context_object_name_singular}Event"
@ -170,12 +178,16 @@ class ObjectUpdate(UpdateView):
request = None request = None
def post_save(self, obj):
pass
def form_valid(self, form): def form_valid(self, form):
obj = form.save(commit=False) obj = form.save(commit=False)
if self.request is None: if self.request is None:
raise Exception("Request is None") raise Exception("Request is None")
obj.save() obj.save()
form.save_m2m() form.save_m2m()
self.post_save(obj)
context = {"message": "Object updated", "class": "success"} context = {"message": "Object updated", "class": "success"}
response = self.render_to_response(context) response = self.render_to_response(context)
response["HX-Trigger"] = f"{self.context_object_name_singular}Event" response["HX-Trigger"] = f"{self.context_object_name_singular}Event"

View File

@ -39,7 +39,7 @@ class AccountInfo(LoginRequiredMixin, View):
} }
return render(request, template_name, context) return render(request, template_name, context)
live_info = dict(account.client.get_account()) live_info = dict(account.rawclient.get_account())
account_info = account.__dict__ account_info = account.__dict__
account_info = { account_info = {
k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL k: v for k, v in account_info.items() if k in self.VIEWABLE_FIELDS_MODEL

View File

@ -20,17 +20,7 @@ def get_positions(user, account_id=None):
positions = account.client.get_all_positions() positions = account.client.get_all_positions()
for item in positions: for item in positions:
item = dict(item)
item["account_id"] = account.id
item["unrealized_pl"] = float(item["unrealized_pl"])
items.append(item) items.append(item)
# try:
# parsed = ccxt_s.CCXTRoot.from_dict(order)
# except ValidationError as e:
# log.error(f"Error creating trade: {e}")
# return False
# self.status = parsed.status
# self.response = order
return items return items
@ -77,7 +67,7 @@ class PositionAction(LoginRequiredMixin, View):
unique = str(uuid.uuid4())[:8] unique = str(uuid.uuid4())[:8]
account = Account.get_by_id(account_id, request.user) account = Account.get_by_id(account_id, request.user)
success, info = trades.get_position_info(account, asset_id) success, info = account.client.get_position_info(asset_id)
if not success: if not success:
message = "Position does not exist" message = "Position does not exist"
message_class = "danger" message_class = "danger"

View File

@ -1,4 +1,6 @@
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.shortcuts import render
from django.views import View
from core.forms import TradeForm from core.forms import TradeForm
from core.models import Trade from core.models import Trade
@ -25,6 +27,8 @@ class TradeList(LoginRequiredMixin, ObjectList):
submit_url_name = "trade_create" submit_url_name = "trade_create"
delete_all_url_name = "trade_delete_all"
class TradeCreate(LoginRequiredMixin, ObjectCreate): class TradeCreate(LoginRequiredMixin, ObjectCreate):
model = Trade model = Trade
@ -37,6 +41,10 @@ class TradeCreate(LoginRequiredMixin, ObjectCreate):
submit_url_name = "trade_create" submit_url_name = "trade_create"
def post_save(self, obj):
obj.post()
log.debug(f"Posting trade {obj}")
class TradeUpdate(LoginRequiredMixin, ObjectUpdate): class TradeUpdate(LoginRequiredMixin, ObjectUpdate):
model = Trade model = Trade
@ -57,3 +65,17 @@ class TradeDelete(LoginRequiredMixin, ObjectDelete):
list_url_name = "trades" list_url_name = "trades"
list_url_args = ["type"] list_url_args = ["type"]
class TradeDeleteAll(LoginRequiredMixin, View):
template_name = "partials/notify.html"
def delete(self, request):
"""
Delete all trades by the current user
"""
Trade.objects.filter(user=request.user).delete()
context = {"message": "All trades deleted", "class": "success"}
response = render(request, self.template_name, context)
response["HX-Trigger"] = f"{self.context_object_name_singular}Event"
return response

View File

@ -17,3 +17,4 @@ django-otp
qrcode qrcode
serde[ext] serde[ext]
alpaca-py alpaca-py
oandapyV20