Add more hooks to active management

master
Mark Veidemanis 1 year ago
parent dd3b3521d9
commit 1dbb3fcf79
Signed by: m
GPG Key ID: 5ACFCEED46C0904F

@ -30,6 +30,7 @@ from core.views import (
limits, limits,
notifications, notifications,
ordersettings, ordersettings,
policies,
positions, positions,
profit, profit,
risk, risk,
@ -306,4 +307,25 @@ urlpatterns = [
ordersettings.OrderSettingsDelete.as_view(), ordersettings.OrderSettingsDelete.as_view(),
name="ordersettings_delete", name="ordersettings_delete",
), ),
# Active Management Policies
path(
"ams/<str:type>/",
policies.ActiveManagementPolicyList.as_view(),
name="ams",
),
path(
"ams/<str:type>/create/",
policies.ActiveManagementPolicyCreate.as_view(),
name="ams_create",
),
path(
"ams/<str:type>/update/<str:pk>/",
policies.ActiveManagementPolicyUpdate.as_view(),
name="ams_update",
),
path(
"ams/<str:type>/delete/<str:pk>/",
policies.ActiveManagementPolicyDelete.as_view(),
name="ams_delete",
),
] + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) ] + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT)

@ -33,11 +33,12 @@ def get_pair(account, base, quote, invert=False):
:param invert: Invert the pair :param invert: Invert the pair
:return: currency symbol, e.g. BTC_USD, BTC/USD, etc. :return: currency symbol, e.g. BTC_USD, BTC/USD, etc.
""" """
# Currently we only have two exchanges with different pair separators
if account.exchange == "alpaca": if account.exchange == "alpaca":
separator = "/" separator = "/"
elif account.exchange == "oanda": elif account.exchange == "oanda":
separator = "_" separator = "_"
else:
separator = "_"
# Flip the pair if needed # Flip the pair if needed
if invert: if invert:
@ -50,6 +51,16 @@ def get_pair(account, base, quote, invert=False):
return symbol return symbol
def get_symbol_price(account, price_index, symbol):
try:
prices = account.client.get_currencies([symbol])
except GenericAPIError as e:
log.error(f"Error getting currencies and inverted currencies: {e}")
return None
price = D(prices["prices"][0][price_index][0]["price"])
return price
def to_currency(direction, account, amount, from_currency, to_currency): def to_currency(direction, account, amount, from_currency, to_currency):
""" """
Convert an amount from one currency to another. Convert an amount from one currency to another.
@ -79,12 +90,7 @@ def to_currency(direction, account, amount, from_currency, to_currency):
if not symbol: if not symbol:
log.error(f"Could not find symbol for {from_currency} -> {to_currency}") log.error(f"Could not find symbol for {from_currency} -> {to_currency}")
raise Exception("Could not find symbol") raise Exception("Could not find symbol")
try: price = get_symbol_price(account, price_index, symbol)
prices = account.client.get_currencies([symbol])
except GenericAPIError as e:
log.error(f"Error getting currencies and inverted currencies: {e}")
return None
price = D(prices["prices"][0][price_index][0]["price"])
# If we had to flip base and quote, we need to use the reciprocal of the price # If we had to flip base and quote, we need to use the reciprocal of the price
if inverted: if inverted:

@ -6,6 +6,7 @@ from mixins.restrictions import RestrictedFormMixin
from .models import ( # AssetRestriction, from .models import ( # AssetRestriction,
Account, Account,
ActiveManagementPolicy,
AssetGroup, AssetGroup,
AssetRule, AssetRule,
Hook, Hook,
@ -132,6 +133,7 @@ class StrategyForm(RestrictedFormMixin, ModelForm):
"trend_signals", "trend_signals",
"signal_trading_enabled", "signal_trading_enabled",
"active_management_enabled", "active_management_enabled",
"active_management_policy",
"enabled", "enabled",
) )
@ -148,6 +150,7 @@ class StrategyForm(RestrictedFormMixin, ModelForm):
"trend_signals": "Callbacks received to these signals will limit the trading direction of the given symbol to the callback direction until further notice.", "trend_signals": "Callbacks received to these signals will limit the trading direction of the given symbol to the callback direction until further notice.",
"signal_trading_enabled": "Whether the strategy will place trades based on signals.", "signal_trading_enabled": "Whether the strategy will place trades based on signals.",
"active_management_enabled": "Whether the strategy will amend/remove trades on the account that violate the rules.", "active_management_enabled": "Whether the strategy will amend/remove trades on the account that violate the rules.",
"active_management_policy": "The policy to use for active management.",
"enabled": "Whether the strategy is enabled.", "enabled": "Whether the strategy is enabled.",
} }
@ -174,9 +177,9 @@ class StrategyForm(RestrictedFormMixin, ModelForm):
) )
def clean(self): def clean(self):
super(StrategyForm, self).clean() cleaned_data = super(StrategyForm, self).clean()
entry_signals = self.cleaned_data.get("entry_signals") entry_signals = cleaned_data.get("entry_signals")
exit_signals = self.cleaned_data.get("exit_signals") exit_signals = cleaned_data.get("exit_signals")
for entry in entry_signals.all(): for entry in entry_signals.all():
if entry in exit_signals.all(): if entry in exit_signals.all():
self._errors["entry_signals"] = self.error_class( self._errors["entry_signals"] = self.error_class(
@ -213,6 +216,14 @@ class StrategyForm(RestrictedFormMixin, ModelForm):
"You cannot have entry and exit signals that are the same direction. At least one must be opposing." "You cannot have entry and exit signals that are the same direction. At least one must be opposing."
] ]
) )
if cleaned_data.get("active_management_enabled"):
if not cleaned_data.get("active_management_policy"):
self.add_error(
"active_management_policy",
"You must select an active management policy if active management is enabled.",
)
return
return cleaned_data
class TradeForm(RestrictedFormMixin, ModelForm): class TradeForm(RestrictedFormMixin, ModelForm):
@ -381,3 +392,34 @@ class OrderSettingsForm(RestrictedFormMixin, ModelForm):
"trailing_stop_loss_percent": "The trailing stop loss will be set at this percentage above/below the entry price. A trailing stop loss will follow the price as it moves in your favor.", "trailing_stop_loss_percent": "The trailing stop loss will be set at this percentage above/below the entry price. A trailing stop loss will follow the price as it moves in your favor.",
"trade_size_percent": "Percentage of the account balance to use for each trade.", "trade_size_percent": "Percentage of the account balance to use for each trade.",
} }
class ActiveManagementPolicyForm(RestrictedFormMixin, ModelForm):
class Meta:
model = ActiveManagementPolicy
fields = (
"name",
"when_trading_time_violated",
"when_trends_violated",
"when_position_size_violated",
"when_protection_violated",
"when_asset_groups_violated",
"when_max_open_trades_violated",
"when_max_open_trades_per_symbol_violated",
"when_max_loss_violated",
"when_max_risk_violated",
"when_crossfilter_violated",
)
help_texts = {
"name": "Name of the active management policy. Informational only.",
"when_trading_time_violated": "The action to take when the trading time is violated.",
"when_trends_violated": "The action to take a trade against the trend is discovered.",
"when_position_size_violated": "The action to take when a trade exceeding the position size is discovered.",
"when_protection_violated": "The action to take when a trade violating/lacking defined TP/SL/TSL is discovered.",
"when_asset_groups_violated": "The action to take when a trade violating the asset group rules is discovered.",
"when_max_open_trades_violated": "The action to take when a trade puts the account above the maximum open trades.",
"when_max_open_trades_per_symbol_violated": "The action to take when a trade puts the account above the maximum open trades per symbol.",
"when_max_loss_violated": "The action to take when a trade puts the account above the maximum loss.",
"when_max_risk_violated": "The action to take when a trade exposes the account to more than the maximum risk.",
"when_crossfilter_violated": "The action to take when a trade is deemed to conflict with another -- e.g. a buy and sell on the same asset.",
}

@ -46,6 +46,20 @@ class OpenPositions(BaseModel):
lastTransactionID: str lastTransactionID: str
def parse_time(x):
"""
Parse the time from the Oanda API.
"""
if "openTime" in x:
ts_split = x["openTime"].split(".")
else:
ts_split = x["trade"]["openTime"].split(".")
microseconds = ts_split[1].replace("Z", "")
microseconds_6 = microseconds[:6]
new_ts = ts_split[0] + "." + microseconds_6 + "Z"
return new_ts
def prevent_hedging(x): def prevent_hedging(x):
""" """
Our implementation breaks if a position has both. Our implementation breaks if a position has both.
@ -522,7 +536,7 @@ OpenTradesSchema = {
"id": "id", "id": "id",
"symbol": "instrument", "symbol": "instrument",
"price": "price", "price": "price",
"openTime": "openTime", "openTime": parse_time,
"initialUnits": "initialUnits", "initialUnits": "initialUnits",
"initialMarginRequired": "initialMarginRequired", "initialMarginRequired": "initialMarginRequired",
"state": "state", "state": "state",
@ -680,7 +694,7 @@ TradeDetailsSchema = {
"id": "trade.id", "id": "trade.id",
"symbol": "trade.instrument", "symbol": "trade.instrument",
"price": "trade.price", "price": "trade.price",
"openTime": "trade.openTime", "openTime": parse_time,
"initialUnits": "trade.initialUnits", "initialUnits": "trade.initialUnits",
"initialMarginRequired": "trade.initialMarginRequired", "initialMarginRequired": "trade.initialMarginRequired",
"state": "trade.state", "state": "trade.state",

@ -0,0 +1,30 @@
# Generated by Django 4.1.7 on 2023-02-17 11:50
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('core', '0070_strategy_active_management_enabled_and_more'),
]
operations = [
migrations.AlterField(
model_name='account',
name='exchange',
field=models.CharField(choices=[('alpaca', 'Alpaca'), ('oanda', 'OANDA'), ('fake', 'Fake')], max_length=255),
),
migrations.CreateModel(
name='ActiveManagementPolicy',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=255)),
('description', models.TextField(blank=True, null=True)),
('when_trading_time_violated', models.CharField(choices=[('none', 'None'), ('close', 'Close violating trades'), ('notify', 'Notify only')], default='none', max_length=255)),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
),
]

@ -0,0 +1,58 @@
# Generated by Django 4.1.7 on 2023-02-17 11:58
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0071_alter_account_exchange_activemanagementpolicy'),
]
operations = [
migrations.AddField(
model_name='activemanagementpolicy',
name='when_asset_groups_violated',
field=models.CharField(choices=[('none', 'None'), ('close', 'Close violating trades'), ('notify', 'Notify only')], default='none', max_length=255),
),
migrations.AddField(
model_name='activemanagementpolicy',
name='when_crossfilter_violated',
field=models.CharField(choices=[('none', 'None'), ('close', 'Close violating trades'), ('notify', 'Notify only')], default='none', max_length=255),
),
migrations.AddField(
model_name='activemanagementpolicy',
name='when_max_loss_violated',
field=models.CharField(choices=[('none', 'None'), ('close', 'Close violating trades'), ('notify', 'Notify only')], default='none', max_length=255),
),
migrations.AddField(
model_name='activemanagementpolicy',
name='when_max_open_trades_per_symbol_violated',
field=models.CharField(choices=[('none', 'None'), ('close', 'Close violating trades'), ('notify', 'Notify only')], default='none', max_length=255),
),
migrations.AddField(
model_name='activemanagementpolicy',
name='when_max_open_trades_violated',
field=models.CharField(choices=[('none', 'None'), ('close', 'Close violating trades'), ('notify', 'Notify only')], default='none', max_length=255),
),
migrations.AddField(
model_name='activemanagementpolicy',
name='when_max_risk_violated',
field=models.CharField(choices=[('none', 'None'), ('close', 'Close violating trades'), ('notify', 'Notify only')], default='none', max_length=255),
),
migrations.AddField(
model_name='activemanagementpolicy',
name='when_position_size_violated',
field=models.CharField(choices=[('none', 'None'), ('close', 'Close violating trades'), ('notify', 'Notify only'), ('adjust', 'Adjust violating trades')], default='none', max_length=255),
),
migrations.AddField(
model_name='activemanagementpolicy',
name='when_protection_violated',
field=models.CharField(choices=[('none', 'None'), ('close', 'Close violating trades'), ('notify', 'Notify only'), ('adjust', 'Adjust violating trades')], default='none', max_length=255),
),
migrations.AddField(
model_name='activemanagementpolicy',
name='when_trends_violated',
field=models.CharField(choices=[('none', 'None'), ('close', 'Close violating trades'), ('notify', 'Notify only')], default='none', max_length=255),
),
]

@ -0,0 +1,19 @@
# Generated by Django 4.1.7 on 2023-02-17 13:16
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('core', '0072_activemanagementpolicy_when_asset_groups_violated_and_more'),
]
operations = [
migrations.AddField(
model_name='strategy',
name='active_management_policy',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.PROTECT, to='core.activemanagementpolicy'),
),
]

@ -65,6 +65,19 @@ MAPPING_CHOICES = (
(3, "Bearish"), (3, "Bearish"),
) )
CLOSE_NOTIFY_CHOICES = (
("none", "None"),
("close", "Close violating trades"),
("notify", "Notify only"),
)
ADJUST_CLOSE_NOTIFY_CHOICES = (
("none", "None"),
("close", "Close violating trades"),
("notify", "Notify only"),
("adjust", "Adjust violating trades"),
)
class Plan(models.Model): class Plan(models.Model):
name = models.CharField(max_length=255, unique=True) name = models.CharField(max_length=255, unique=True)
@ -395,6 +408,12 @@ class Strategy(models.Model):
"core.OrderSettings", "core.OrderSettings",
on_delete=models.PROTECT, on_delete=models.PROTECT,
) )
active_management_policy = models.ForeignKey(
"core.ActiveManagementPolicy",
on_delete=models.PROTECT,
null=True,
blank=True,
)
class Meta: class Meta:
verbose_name_plural = "strategies" verbose_name_plural = "strategies"
@ -493,3 +512,42 @@ class OrderSettings(models.Model):
def __str__(self): def __str__(self):
return self.name return self.name
class ActiveManagementPolicy(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
name = models.CharField(max_length=255)
description = models.TextField(null=True, blank=True)
when_trading_time_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_trends_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_position_size_violated = models.CharField(
choices=ADJUST_CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_protection_violated = models.CharField(
choices=ADJUST_CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_asset_groups_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_max_open_trades_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_max_open_trades_per_symbol_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_max_loss_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_max_risk_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
when_crossfilter_violated = models.CharField(
choices=CLOSE_NOTIFY_CHOICES, max_length=255, default="none"
)
def __str__(self):
return self.name

@ -267,6 +267,9 @@
<a class="navbar-item" href="{% url 'assetgroups' type='page' %}"> <a class="navbar-item" href="{% url 'assetgroups' type='page' %}">
Asset Groups Asset Groups
</a> </a>
<a class="navbar-item" href="{% url 'ams' type='page' %}">
Active Management
</a>
</div> </div>
</div> </div>
<div class="navbar-item has-dropdown is-hoverable"> <div class="navbar-item has-dropdown is-hoverable">

@ -0,0 +1,61 @@
{% load cache %}
{% load cachalot cache %}
{% get_last_invalidation 'core.AssetManagementPolicy' as last %}
{% include 'mixins/partials/notify.html' %}
{% cache 600 objects_active_management request.user.id object_list type last %}
<table
class="table is-fullwidth is-hoverable"
hx-target="#{{ context_object_name }}-table"
id="{{ context_object_name }}-table"
hx-swap="outerHTML"
hx-trigger="{{ context_object_name_singular }}Event from:body"
hx-get="{{ list_url }}">
<thead>
<th>id</th>
<th>user</th>
<th>name</th>
<th>description</th>
<th>actions</th>
</thead>
{% for item in object_list %}
<tr>
<td>{{ item.id }}</td>
<td>{{ item.user }}</td>
<td>{{ item.name }}</td>
<td>{{ item.description|truncatechars:80 }}</td>
<td>
<div class="buttons">
<button
hx-headers='{"X-CSRFToken": "{{ csrf_token }}"}'
hx-get="{% url 'ams_update' type=type pk=item.id %}"
hx-trigger="click"
hx-target="#{{ type }}s-here"
hx-swap="innerHTML"
class="button">
<span class="icon-text">
<span class="icon">
<i class="fa-solid fa-pencil"></i>
</span>
</span>
</button>
<button
hx-headers='{"X-CSRFToken": "{{ csrf_token }}"}'
hx-delete="{% url 'ams_delete' type=type pk=item.id %}"
hx-trigger="click"
hx-target="#modals-here"
hx-swap="innerHTML"
hx-confirm="Are you sure you wish to delete {{ item.name }}?"
class="button">
<span class="icon-text">
<span class="icon">
<i class="fa-solid fa-xmark"></i>
</span>
</span>
</button>
</div>
</td>
</tr>
{% endfor %}
</table>
{% endcache %}

@ -1,7 +1,17 @@
from datetime import time
from os import getenv from os import getenv
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from core.models import Account, User from core.models import (
Account,
Hook,
OrderSettings,
RiskModel,
Signal,
Strategy,
TradingTime,
User,
)
# Create patch mixin to mock out the Elastic client # Create patch mixin to mock out the Elastic client
@ -31,6 +41,21 @@ class ElasticMock:
cls.patcher.stop() cls.patcher.stop()
class SymbolPriceMock:
@classmethod
def setUpClass(cls):
super(SymbolPriceMock, cls).setUpClass()
cls.patcher = patch("core.exchanges.common.get_symbol_price")
patcher = cls.patcher.start()
patcher.return_value = 1
@classmethod
def tearDownClass(cls):
super(SymbolPriceMock, cls).tearDownClass()
cls.patcher.stop()
class LiveBase: class LiveBase:
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
@ -92,3 +117,40 @@ If you have done this, please see the following line for more information:
def setUp(self): def setUp(self):
if self.fail: if self.fail:
self.skipTest("Live tests aborted") self.skipTest("Live tests aborted")
class StrategyMixin:
def setUp(self):
super().setUp()
self.time_8 = time(8, 0, 0)
self.time_16 = time(16, 0, 0)
self.order_settings = OrderSettings.objects.create(
user=self.user, name="Default"
)
self.trading_time_now = TradingTime.objects.create(
user=self.user,
name="Test Trading Time",
start_day=1, # Monday
start_time=self.time_8,
end_day=1, # Monday
end_time=self.time_16,
)
self.risk_model = RiskModel.objects.create(
user=self.user,
name="Test Risk Model",
max_loss_percent=50,
max_risk_percent=10,
max_open_trades=10,
max_open_trades_per_symbol=5,
)
self.strategy = Strategy.objects.create(
user=self.user,
name="Test Strategy",
account=self.account,
order_settings=self.order_settings,
risk_model=self.risk_model,
active_management_enabled=True,
)
self.strategy.trading_times.set([self.trading_time_now])
self.strategy.save()

@ -0,0 +1,211 @@
from django.test import TestCase
from core.tests.helpers import StrategyMixin, SymbolPriceMock
from core.trading.active_management import ActiveManagement
from core.models import User, Account, ActiveManagementPolicy, Hook, Signal
from unittest.mock import Mock, patch
from core.lib.schemas.oanda_s import parse_time
class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase):
def setUp(self):
self.user = User.objects.create_user(
username="testuser", email="test@example.com", password="test"
)
self.account = Account.objects.create(
user=self.user,
name="Test Account",
exchange="fake",
currency="USD",
)
self.account.supported_symbols = ["EUR_USD", "EUR_XXX", "USD_EUR", "XXX_EUR"]
self.account.save()
super().setUp()
self.active_management_policy = ActiveManagementPolicy.objects.create(
user=self.user,
name="Test Policy",
when_trading_time_violated="close",
when_trends_violated="close",
when_position_size_violated="close",
when_protection_violated="close",
when_asset_groups_violated="close",
when_max_open_trades_violated="close",
when_max_open_trades_per_symbol_violated="close",
when_max_loss_violated="close",
when_max_risk_violated="close",
when_crossfilter_violated="close",
)
self.strategy.active_management_policy = self.active_management_policy
self.strategy.save()
self.ams = ActiveManagement(self.strategy)
self.trades = [
{
"id": "20083",
"symbol": "EUR_USD",
"price": "1.06331",
"openTime": "2023-02-13T11:38:06.302917985Z", # Monday at 11:38
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
"currentUnits": "10",
"realizedPL": "0.0000",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008",
"marginUsed": "0.2966",
"takeProfitOrder": None,
"stopLossOrder": None,
"trailingStopLossOrder": None,
"trailingStopValue": None,
"side": "long",
},
{
"id": "20083",
"symbol": "EUR_USD",
"price": "1.06331",
"openTime": "2023-02-13T11:38:06.302917985Z", # Monday at 11:38
"initialUnits": "10",
"initialMarginRequired": "0.2966",
"state": "OPEN",
"currentUnits": "10",
"realizedPL": "0.0000",
"financing": "0.0000",
"dividendAdjustment": "0.0000",
"unrealizedPL": "-0.0008",
"marginUsed": "0.2966",
"takeProfitOrder": None,
"stopLossOrder": None,
"trailingStopLossOrder": None,
"trailingStopValue": None,
"side": "long",
}
]
# Run parse_time on all items in trades
for trade in self.trades:
trade["openTime"] = parse_time(trade)
self.ams.get_trades = self.fake_get_trades
self.ams.get_balance = self.fake_get_balance
# self.ams.trades = self.trades
def fake_get_trades(self):
self.ams.trades = self.trades
return self.trades
def fake_get_balance(self):
return 10000
def fake_get_currencies(self, symbols):
pass
def test_get_trades(self):
trades = self.ams.get_trades()
self.assertEqual(trades, self.trades)
def test_get_balance(self):
balance = self.ams.get_balance()
self.assertEqual(balance, 10000)
def check_violation(self, violation, calls, expected_action, expected_trades):
"""
Check that the violation was called with the expected action and trades.
Matches the first argument of the call to the violation name.
:param: violation: type of the violation to check against
:param: calls: list of calls to the violation
:param: expected_action: expected action to be called, close, notify, etc.
:param: expected_trades: list of expected trades to be passed to the violation
"""
calls = list(calls)
violation_calls = []
for call in calls:
if call[0][0] == violation:
violation_calls.append(call)
self.assertEqual(len(violation_calls), len(expected_trades))
for call in violation_calls:
# Ensure the correct action has been called, like close
self.assertEqual(call[0][1], expected_action)
# Ensure the correct trade has been passed to the violation
self.assertIn(call[0][2], expected_trades)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_run_checks(self, handle_violation):
self.ams.run_checks()
self.assertEqual(handle_violation.call_count, 0)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trading_time_violated(self, handle_violation):
self.trades[0]["openTime"] = "2023-02-17T11:38:06.302917Z" # Friday
self.ams.run_checks()
self.check_violation("trading_time", handle_violation.call_args_list, "close", [self.trades[0]])
def create_hook_signal(self):
hook = Hook.objects.create(
user=self.user,
name="Test Hook",
)
signal = Signal.objects.create(
user=self.user,
name="Test Signal",
hook=hook,
type="trend",
)
return signal
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated(self, handle_violation):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.strategy.save()
self.ams.run_checks()
self.check_violation("trends", handle_violation.call_args_list, "close", self.trades)
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated_none(self, handle_violation):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "buy"}
self.strategy.save()
self.ams.run_checks()
self.check_violation("trends", handle_violation.call_args_list, "close", [])
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_trends_violated_partial(self, handle_violation):
signal = self.create_hook_signal()
self.strategy.trend_signals.set([signal])
self.strategy.trends = {"EUR_USD": "sell"}
self.strategy.save()
# Change the side of the first trade to match the trends
self.trades[0]["side"] = "short"
self.ams.run_checks()
self.check_violation("trends", handle_violation.call_args_list, "close", [self.trades[1]])
@patch("core.trading.active_management.ActiveManagement.handle_violation")
def test_position_size_violated(self, handle_violation):
self.trades[0]["currentUnits"] = "100000"
self.ams.run_checks()
self.check_violation("position_size", handle_violation.call_args_list, "close", [self.trades[0]])
def test_protection_violated(self):
pass
def test_asset_groups_violated(self):
pass
def test_max_open_trades_violated(self):
pass
def test_max_open_trades_per_symbol_violated(self):
pass
def test_max_loss_violated(self):
pass
def test_max_risk_violated(self):
pass
def test_crossfilter_violated(self):
pass

@ -13,13 +13,12 @@ from core.models import (
TradingTime, TradingTime,
User, User,
) )
from core.tests.helpers import StrategyMixin
from core.trading import checks from core.trading import checks
class ChecksTestCase(TestCase): class ChecksTestCase(StrategyMixin, TestCase):
def setUp(self): def setUp(self):
self.time_8 = time(8, 0, 0)
self.time_16 = time(16, 0, 0)
self.user = User.objects.create_user( self.user = User.objects.create_user(
username="testuser", email="test@example.com", password="test" username="testuser", email="test@example.com", password="test"
) )
@ -28,36 +27,7 @@ class ChecksTestCase(TestCase):
name="Test Account", name="Test Account",
exchange="fake", exchange="fake",
) )
self.order_settings = OrderSettings.objects.create( super().setUp()
user=self.user, name="Default"
)
self.trading_time_now = TradingTime.objects.create(
user=self.user,
name="Test Trading Time",
start_day=1, # Monday
start_time=self.time_8,
end_day=1, # Monday
end_time=self.time_16,
)
self.risk_model = RiskModel.objects.create(
user=self.user,
name="Test Risk Model",
max_loss_percent=50,
max_risk_percent=10,
max_open_trades=10,
max_open_trades_per_symbol=5,
)
self.strategy = Strategy.objects.create(
user=self.user,
name="Test Strategy",
account=self.account,
order_settings=self.order_settings,
risk_model=self.risk_model,
active_management_enabled=True,
)
self.strategy.trading_times.set([self.trading_time_now])
self.strategy.save()
@freezegun.freeze_time("2023-02-13T09:00:00") # Monday at 09:00 @freezegun.freeze_time("2023-02-13T09:00:00") # Monday at 09:00
def test_within_trading_times_pass(self): def test_within_trading_times_pass(self):

@ -1,9 +1,78 @@
from datetime import datetime
from decimal import Decimal as D
from core.exchanges.convert import side_to_direction
from core.trading import checks
from core.trading.market import get_base_quote, get_trade_size_in_base
class ActiveManagement(object): class ActiveManagement(object):
def __init__(self, strategy): def __init__(self, strategy):
self.strategy = strategy self.strategy = strategy
self.policy = strategy.active_management_policy
self.trades = []
self.balance = None
def get_trades(self):
if not self.trades:
self.trades = self.strategy.account.client.get_all_open_trades()
return self.trades
def get_balance(self):
if self.balance is None:
self.balance = self.strategy.account.client.get_balance()
else:
return self.balance
def handle_violation(self, check_type, action, trade):
print("VIOLATION", check_type, action, trade)
def check_trading_time(self, trade):
open_ts = trade["openTime"]
open_ts_as_date = datetime.strptime(open_ts, "%Y-%m-%dT%H:%M:%S.%fZ")
trading_time_pass = checks.within_trading_times(self.strategy, open_ts_as_date)
if not trading_time_pass:
self.handle_violation(
"trading_time", self.policy.when_trading_time_violated, trade
)
def check_trends(self, trade):
direction = side_to_direction(trade["side"])
symbol = trade["symbol"]
trends_pass = checks.within_trends(self.strategy, symbol, direction)
if not trends_pass:
print("VIOLATION", "trends", self.policy.when_trends_violated, trade)
self.handle_violation("trends", self.policy.when_trends_violated, trade)
def check_position_size(self, trade):
balance = self.get_balance()
print("BALANCE", balance)
direction = side_to_direction(trade["side"])
symbol = trade["symbol"]
base, quote = get_base_quote(self.strategy.account.exchange, symbol)
expected_trade_size = get_trade_size_in_base(
direction, self.strategy.account, self.strategy, balance, base
)
print("TRADE SIZE", expected_trade_size)
deviation = D(0.05) # 5%
actual_trade_size = D(trade["currentUnits"])
# Ensure the trade size not above the expected trade size by more than 5%
max_trade_size = expected_trade_size + (deviation * expected_trade_size)
within_max_trade_size = actual_trade_size <= max_trade_size
if not within_max_trade_size:
self.handle_violation(
"position_size", self.policy.when_position_size_violated, trade
)
def run_checks(self): def run_checks(self):
pass for trade in self.get_trades():
self.check_trading_time(trade)
self.check_trends(trade)
self.check_position_size(trade)
# Trading Time # Trading Time
# Max loss # Max loss
# Trends # Trends

@ -53,7 +53,9 @@ def within_max_loss(strategy):
def within_trends(strategy, symbol, direction): def within_trends(strategy, symbol, direction):
print("WITHIN TRENDS", symbol, direction)
if strategy.trend_signals.exists(): if strategy.trend_signals.exists():
print("TREND SIGNALS EXIST")
if strategy.trends is None: if strategy.trends is None:
log.debug("Refusing to trade with no trend signals received") log.debug("Refusing to trade with no trend signals received")
sendmsg( sendmsg(
@ -82,6 +84,9 @@ def within_trends(strategy, symbol, direction):
else: else:
log.debug(f"Trend check passed for {symbol} - {direction}") log.debug(f"Trend check passed for {symbol} - {direction}")
return True return True
else:
log.debug("No trend signals configured")
return True
def within_position_size(strategy): def within_position_size(strategy):

@ -50,6 +50,8 @@ def get_base_quote(exchange, symbol):
separator = "/" separator = "/"
elif exchange == "oanda": elif exchange == "oanda":
separator = "_" separator = "_"
else:
separator = "_"
base, quote = symbol.split(separator) base, quote = symbol.split(separator)
return (base, quote) return (base, quote)

@ -0,0 +1,37 @@
from django.contrib.auth.mixins import LoginRequiredMixin
from mixins.views import ObjectCreate, ObjectDelete, ObjectList, ObjectUpdate
from core.forms import ActiveManagementPolicyForm
from core.models import ActiveManagementPolicy
from core.util import logs
log = logs.get_logger(__name__)
class ActiveManagementPolicyList(LoginRequiredMixin, ObjectList):
list_template = "partials/activemanagement-list.html"
model = ActiveManagementPolicy
page_title = "List of active management policies. Linked to strategies."
list_url_name = "ams"
list_url_args = ["type"]
submit_url_name = "ams_create"
class ActiveManagementPolicyCreate(LoginRequiredMixin, ObjectCreate):
model = ActiveManagementPolicy
form_class = ActiveManagementPolicyForm
submit_url_name = "ams_create"
class ActiveManagementPolicyUpdate(LoginRequiredMixin, ObjectUpdate):
model = ActiveManagementPolicy
form_class = ActiveManagementPolicyForm
submit_url_name = "ams_update"
class ActiveManagementPolicyDelete(LoginRequiredMixin, ObjectDelete):
model = ActiveManagementPolicy
Loading…
Cancel
Save