You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

545 lines
19 KiB
Python

import uuid
from datetime import timedelta
import stripe
from django.conf import settings
from django.contrib.auth.models import AbstractUser
from django.db import models
from core.exchanges.alpaca import AlpacaExchange
from core.exchanges.fake import FakeExchange
from core.exchanges.oanda import OANDAExchange
# from core.lib.customers import get_or_create, update_customer_fields
from core.lib import billing
from core.util import logs
log = logs.get_logger(__name__)
EXCHANGE_MAP = {"alpaca": AlpacaExchange, "oanda": OANDAExchange, "fake": FakeExchange}
TYPE_CHOICES = (
("market", "Market"),
("limit", "Limit"),
)
DIRECTION_CHOICES = (
("buy", "Buy"),
("sell", "Sell"),
)
TIF_CHOICES = (
("gtc", "GTC (Good Til Cancelled)"),
("gfd", "GFD (Good For Day)"),
("fok", "FOK (Fill Or Kill)"),
("ioc", "IOC (Immediate Or Cancel)"),
)
DAY_CHOICES = (
(1, "Monday"),
(2, "Tuesday"),
(3, "Wednesday"),
(4, "Thursday"),
(5, "Friday"),
(6, "Saturday"),
(7, "Sunday"),
)
SIGNAL_TYPE_CHOICES = (
("entry", "Entry"),
("exit", "Exit"),
("trend", "Trend"),
)
AGGREGATION_CHOICES = (
("none", "None"),
("avg_sentiment", "Average sentiment"),
)
STATUS_CHOICES = (
(0, "No data"),
(1, "No match"),
(2, "Bullish"),
(3, "Bearish"),
(4, "No aggregation"),
(5, "Not in bounds"),
(6, "Always allow"),
(7, "Always deny"),
)
MAPPING_CHOICES = (
(6, "Always allow"),
(7, "Always deny"),
(2, "Bullish"),
(3, "Bearish"),
)
CLOSE_NOTIFY_CHOICES = (
("none", "None"),
("close", "Close violating trades"),
("notify", "Notify only"),
)
ADJUST_CLOSE_NOTIFY_CHOICES = (
("none", "None"),
("close", "Close violating trades"),
("notify", "Notify only"),
("adjust", "Adjust violating trades"),
)
ADJUST_WITH_DIRECTION_CHOICES = (
("none", "None"),
("close", "Close violating trades"),
("notify", "Notify only"),
("adjust", "Increase and reduce"),
("adjust_up", "Increase only"),
("adjust_down", "Reduce only"),
)
# class Plan(models.Model):
# name = models.CharField(max_length=255, unique=True)
# description = models.CharField(max_length=1024, null=True, blank=True)
# cost = models.IntegerField()
# product_id = models.CharField(max_length=255, unique=True, null=True, blank=True)
# image = models.CharField(max_length=1024, null=True, blank=True)
# def __str__(self):
# return f"{self.name} (£{self.cost})"
class User(AbstractUser):
# Stripe customer ID
stripe_id = models.CharField(max_length=255, null=True, blank=True)
customer_id = models.UUIDField(default=uuid.uuid4, null=True, blank=True)
billing_provider_id = models.CharField(max_length=255, null=True, blank=True)
# last_payment = models.DateTimeField(null=True, blank=True)
# plans = models.ManyToManyField(Plan, blank=True)
email = models.EmailField(unique=True)
def delete(self, *args, **kwargs):
if settings.BILLING_ENABLED:
if self.stripe_id:
stripe.Customer.delete(self.stripe_id)
log.info(f"Deleted Stripe customer {self.stripe_id}")
if self.billing_provider_id:
billing.delete_customer(self)
log.info(f"Deleted Billing customer {self.billing_provider_id}")
super().delete(*args, **kwargs)
# Override save to update attributes in Lago
def save(self, *args, **kwargs):
if self.customer_id is None:
self.customer_id = uuid.uuid4()
if settings.BILLING_ENABLED:
if not self.stripe_id: # stripe ID not stored
self.stripe_id = billing.get_or_create(
self.email, self.first_name, self.last_name
)
if not self.billing_provider_id:
self.billing_provider_id = billing.create_or_update_customer(self)
billing.update_customer_fields(self)
super().save(*args, **kwargs)
def get_notification_settings(self):
return NotificationSettings.objects.get_or_create(user=self)[0]
class Account(models.Model):
EXCHANGE_CHOICES = (("alpaca", "Alpaca"), ("oanda", "OANDA"), ("fake", "Fake"))
user = models.ForeignKey(User, on_delete=models.CASCADE)
name = models.CharField(max_length=255)
exchange = models.CharField(choices=EXCHANGE_CHOICES, max_length=255)
api_key = models.CharField(max_length=255)
api_secret = models.CharField(max_length=255)
sandbox = models.BooleanField(default=False)
enabled = models.BooleanField(default=True)
supported_symbols = models.JSONField(default=list)
instruments = models.JSONField(default=list)
currency = models.CharField(max_length=255, null=True, blank=True)
initial_balance = models.FloatField(default=0)
def __str__(self):
name = f"{self.name} ({self.exchange})"
if self.sandbox:
name += " (sandbox)"
return name
def update_info(self, save=True):
client = self.get_client()
if client:
response = client.get_instruments()
supported_symbols = client.get_supported_assets(response)
currency = client.get_account()["currency"]
log.debug(f"Supported symbols for {self.name}: {supported_symbols}")
self.supported_symbols = supported_symbols
self.instruments = response
self.currency = currency
if save:
self.save()
def save(self, *args, **kwargs):
"""
Override the save function to update supported symbols.
"""
if self.exchange != "fake":
self.update_info(save=False)
super().save(*args, **kwargs)
def get_client(self):
if self.exchange in EXCHANGE_MAP:
return EXCHANGE_MAP[self.exchange](self)
else:
raise Exception(f"Exchange not supported : {self.exchange}")
@property
def client(self):
"""
Convenience property for one-off API calls.
"""
return self.get_client()
@property
def rawclient(self):
"""
Convenience property for one-off API calls.
"""
return self.get_client().client
@classmethod
def get_by_id(cls, account_id, user):
return cls.objects.get(id=account_id, user=user)
@classmethod
def get_by_id_no_user_check(cls, account_id):
return cls.objects.get(id=account_id)
class Hook(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
name = models.CharField(max_length=1024)
hook = models.CharField(max_length=255, unique=True) # hook URL
received = models.IntegerField(default=0)
def __str__(self):
return f"{self.name} ({self.hook})"
class Signal(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
name = models.CharField(max_length=1024)
signal = models.CharField(max_length=256) # signal name
hook = models.ForeignKey(Hook, on_delete=models.CASCADE)
direction = models.CharField(choices=DIRECTION_CHOICES, max_length=255)
received = models.IntegerField(default=0)
type = models.CharField(choices=SIGNAL_TYPE_CHOICES, max_length=255)
def __str__(self):
return f"{self.name} ({self.hook.name}) - {self.direction}"
class Trade(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
account = models.ForeignKey(Account, on_delete=models.CASCADE)
hook = models.ForeignKey(Hook, on_delete=models.CASCADE, null=True, blank=True)
signal = models.ForeignKey(Signal, on_delete=models.CASCADE, null=True, blank=True)
symbol = models.CharField(max_length=255)
time_in_force = models.CharField(choices=TIF_CHOICES, max_length=255, default="gtc")
type = models.CharField(choices=TYPE_CHOICES, max_length=255)
amount = models.FloatField(null=True, blank=True)
amount_usd = models.FloatField(null=True, blank=True)
price = models.FloatField(null=True, blank=True)
stop_loss = models.FloatField(null=True, blank=True)
trailing_stop_loss = models.FloatField(null=True, blank=True)
take_profit = models.FloatField(null=True, blank=True)
status = models.CharField(max_length=255, null=True, blank=True)
information = models.JSONField(null=True, blank=True)
direction = models.CharField(choices=DIRECTION_CHOICES, max_length=255)
# To populate from the trade
order_id = models.CharField(max_length=255, null=True, blank=True)
client_order_id = models.CharField(max_length=255, null=True, blank=True)
response = models.JSONField(null=True, blank=True)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._original = self
def post(self):
if self.status in ["rejected", "close"]:
log.debug(f"Trade {self.id} rejected. Not posting.")
log.debug(f"Trade {self.id} information: {self.information}")
else:
return self.account.client.post_trade(self)
def delete(self, *args, **kwargs):
# close the trade
super().delete(*args, **kwargs)
@classmethod
def get_by_id(cls, trade_id, user):
return cls.objects.get(id=trade_id, user=user)
@classmethod
def get_by_id_or_order(cls, trade_id, account_id, user):
try:
account = Account.objects.get(id=account_id, user=user)
except Account.DoesNotExist:
return None
try:
return cls.objects.get(id=trade_id, account=account, user=user)
except cls.DoesNotExist:
try:
return cls.objects.get(order_id=trade_id, account=account, user=user)
except cls.DoesNotExist:
return None
class Callback(models.Model):
hook = models.ForeignKey(Hook, on_delete=models.CASCADE)
signal = models.ForeignKey(Signal, on_delete=models.CASCADE)
title = models.CharField(max_length=1024, null=True, blank=True)
message = models.CharField(max_length=1024, null=True, blank=True)
period = models.CharField(max_length=255, null=True, blank=True)
sent = models.BigIntegerField(null=True, blank=True)
trade = models.BigIntegerField(null=True, blank=True)
exchange = models.CharField(max_length=255, null=True, blank=True)
base = models.CharField(max_length=255, null=True, blank=True)
quote = models.CharField(max_length=255, null=True, blank=True)
contract = models.CharField(max_length=255, null=True, blank=True)
price = models.FloatField(null=True, blank=True)
symbol = models.CharField(max_length=255)
class TradingTime(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
name = models.CharField(max_length=255)
description = models.TextField(null=True, blank=True)
start_day = models.IntegerField(choices=DAY_CHOICES)
end_day = models.IntegerField(choices=DAY_CHOICES)
start_time = models.TimeField()
end_time = models.TimeField()
def within_range(self, ts):
"""
Check if the specified time is within the configured trading times.
:param ts: Timestamp
:type ts: datetime
:return: whether or not the time is within the trading range
:rtype: bool
"""
start_day = self.start_day
end_day = self.end_day
# Check the day is between the start and end day
if not start_day <= ts.weekday() + 1 <= end_day:
return False
start_time = self.start_time
end_time = self.end_time
# Get what the start time would be this week
ts_monday = ts - timedelta(days=ts.weekday())
# Now we need to add our day of week to monday
# Let's set the offset now since it's off by one
offset_start = start_day - 1
# Datetime: monday=0, tuesday=1, us: monday=1, tuesday=2, so we need to subtract
# one from ours to not be off by one
offset_end = end_day - 1
# Now we can add the offset to the monday
start = ts_monday + timedelta(days=offset_start)
start = start.replace(
hour=start_time.hour,
minute=start_time.minute,
second=start_time.second,
microsecond=start_time.microsecond,
)
end = ts_monday + timedelta(days=offset_end)
end = end.replace(
hour=end_time.hour,
minute=end_time.minute,
second=end_time.second,
microsecond=end_time.microsecond,
)
# Check if the ts is between the start and end times
# ts must be more than start and less than end
return ts >= start and ts <= end
return True
def __str__(self):
return (
f"{self.name} ({self.get_start_day_display()} at {self.start_time} - "
f"{self.get_end_day_display()} at {self.end_time})"
)
class Strategy(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
name = models.CharField(max_length=255)
description = models.TextField(null=True, blank=True)
account = models.ForeignKey(Account, on_delete=models.CASCADE)
trading_times = models.ManyToManyField(TradingTime)
entry_signals = models.ManyToManyField(
Signal, related_name="entry_strategies", blank=True
)
exit_signals = models.ManyToManyField(
Signal, related_name="exit_strategies", blank=True
)
trend_signals = models.ManyToManyField(
Signal, related_name="trend_strategies", blank=True
)
enabled = models.BooleanField(default=False)
signal_trading_enabled = models.BooleanField(default=False)
active_management_enabled = models.BooleanField(default=False)
trends = models.JSONField(null=True, blank=True)
asset_group = models.ForeignKey(
"core.AssetGroup", on_delete=models.PROTECT, null=True, blank=True
)
risk_model = models.ForeignKey(
"core.RiskModel", on_delete=models.PROTECT, null=True, blank=True
)
order_settings = models.ForeignKey(
"core.OrderSettings",
on_delete=models.PROTECT,
)
active_management_policy = models.ForeignKey(
"core.ActiveManagementPolicy",
on_delete=models.PROTECT,
null=True,
blank=True,
)
class Meta:
verbose_name_plural = "strategies"
def __str__(self):
return self.name
class NotificationSettings(models.Model):
user = models.OneToOneField(User, on_delete=models.CASCADE)
ntfy_topic = models.CharField(max_length=255, null=True, blank=True)
ntfy_url = models.CharField(max_length=255, null=True, blank=True)
def __str__(self):
return f"Notification settings for {self.user}"
class RiskModel(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
name = models.CharField(max_length=255)
description = models.TextField(null=True, blank=True)
# Maximum amount of money to have lost from the initial balance to stop trading
max_loss_percent = models.FloatField(default=0.05)
# Maximum amount of money to risk on all open trades
max_risk_percent = models.FloatField(default=0.05)
# Maximum number of trades
max_open_trades = models.IntegerField(default=10)
# Maximum number of trades per symbol
max_open_trades_per_symbol = models.IntegerField(default=2)
price_slippage_percent = models.FloatField(default=2.5)
callback_price_deviation_percent = models.FloatField(default=0.5)
def __str__(self):
return self.name
class AssetGroup(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
name = models.CharField(max_length=255)
description = models.TextField(null=True, blank=True)
webhook_id = models.UUIDField(default=uuid.uuid4, editable=False, unique=True)
when_no_data = models.IntegerField(choices=MAPPING_CHOICES, default=7)
when_no_match = models.IntegerField(choices=MAPPING_CHOICES, default=6)
when_no_aggregation = models.IntegerField(choices=MAPPING_CHOICES, default=6)
when_not_in_bounds = models.IntegerField(choices=MAPPING_CHOICES, default=6)
when_bullish = models.IntegerField(choices=MAPPING_CHOICES, default=2)
when_bearish = models.IntegerField(choices=MAPPING_CHOICES, default=3)
def __str__(self):
return self.name
@property
def matches(self):
"""
Get the total number of matches for this group.
"""
asset_rule_total = AssetRule.objects.filter(group=self).count()
asset_rule_positive = AssetRule.objects.filter(group=self, status=2).count()
return f"{asset_rule_positive}/{asset_rule_total}"
class AssetRule(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
asset = models.CharField(max_length=64)
group = models.ForeignKey(AssetGroup, on_delete=models.CASCADE)
aggregation = models.CharField(
choices=AGGREGATION_CHOICES, max_length=255, default="none"
)
value = models.FloatField(null=True, blank=True)
original_status = models.IntegerField(choices=STATUS_CHOICES, default=0)
status = models.IntegerField(choices=STATUS_CHOICES, default=0)
trigger_below = models.FloatField(null=True, blank=True)
trigger_above = models.FloatField(null=True, blank=True)
# Ensure that the asset is unique per group
class Meta:
unique_together = ("asset", "group")
class OrderSettings(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
name = models.CharField(max_length=255)
description = models.TextField(null=True, blank=True)
order_type = models.CharField(
choices=TYPE_CHOICES, max_length=255, default="market"
)
time_in_force = models.CharField(choices=TIF_CHOICES, max_length=255, default="gtc")
take_profit_percent = models.FloatField(default=1.5)
stop_loss_percent = models.FloatField(default=1.0)
trailing_stop_loss_percent = models.FloatField(default=1.0, null=True, blank=True)
trade_size_percent = models.FloatField(default=0.5)
def __str__(self):
return self.name
class ActiveManagementPolicy(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
name = models.CharField(max_length=255)
description = models.TextField(null=True, blank=True)
when_trading_time_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_trends_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_position_size_violated = models.CharField(
choices=ADJUST_CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_protection_violated = models.CharField(
choices=ADJUST_CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_asset_groups_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_max_open_trades_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_max_open_trades_per_symbol_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_max_loss_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_max_risk_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_crossfilter_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
def __str__(self):
return self.name