diff --git a/core/exchanges/common.py b/core/exchanges/common.py index 76c0f76..db1dc87 100644 --- a/core/exchanges/common.py +++ b/core/exchanges/common.py @@ -98,5 +98,4 @@ def to_currency(direction, account, amount, from_currency, to_currency): # Convert the amount to the destination currency converted = D(amount) * price - return converted diff --git a/core/forms.py b/core/forms.py index 2ab1273..4f54b2c 100644 --- a/core/forms.py +++ b/core/forms.py @@ -419,7 +419,7 @@ class ActiveManagementPolicyForm(RestrictedFormMixin, ModelForm): "when_asset_groups_violated": "The action to take when a trade violating the asset group rules is discovered.", "when_max_open_trades_violated": "The action to take when a trade puts the account above the maximum open trades.", "when_max_open_trades_per_symbol_violated": "The action to take when a trade puts the account above the maximum open trades per symbol.", - "when_max_loss_violated": "The action to take when a trade puts the account above the maximum loss.", + "when_max_loss_violated": "The action to take when the account exceeds its maximum loss. NOTE: The close action will close all trades.", "when_max_risk_violated": "The action to take when a trade exposes the account to more than the maximum risk.", "when_crossfilter_violated": "The action to take when a trade is deemed to conflict with another -- e.g. a buy and sell on the same asset.", } diff --git a/core/tests/helpers.py b/core/tests/helpers.py index 198fb58..40da4cc 100644 --- a/core/tests/helpers.py +++ b/core/tests/helpers.py @@ -1,17 +1,9 @@ from datetime import time +from decimal import Decimal as D from os import getenv from unittest.mock import Mock, patch -from core.models import ( - Account, - Hook, - OrderSettings, - RiskModel, - Signal, - Strategy, - TradingTime, - User, -) +from core.models import Account, OrderSettings, RiskModel, Strategy, TradingTime, User # Create patch mixin to mock out the Elastic client @@ -48,7 +40,7 @@ class SymbolPriceMock: cls.patcher = patch("core.exchanges.common.get_symbol_price") patcher = cls.patcher.start() - patcher.return_value = 1 + patcher.return_value = D(1) @classmethod def tearDownClass(cls): @@ -141,7 +133,7 @@ class StrategyMixin: max_loss_percent=50, max_risk_percent=10, max_open_trades=10, - max_open_trades_per_symbol=2, + max_open_trades_per_symbol=5, ) self.strategy = Strategy.objects.create( diff --git a/core/tests/trading/test_active_management.py b/core/tests/trading/test_active_management.py index 727651c..cc92179 100644 --- a/core/tests/trading/test_active_management.py +++ b/core/tests/trading/test_active_management.py @@ -1,5 +1,5 @@ from decimal import Decimal as D -from unittest.mock import Mock, patch +from unittest.mock import patch from django.test import TestCase @@ -28,6 +28,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): name="Test Account", exchange="fake", currency="USD", + initial_balance=100000, ) self.account.supported_symbols = ["EUR_USD", "EUR_XXX", "USD_EUR", "XXX_EUR"] self.account.save() @@ -75,7 +76,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): "id": "20084", "symbol": "EUR_USD", "price": "1.06331", - "openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:38 + "openTime": "2023-02-13T11:39:06.302917985Z", # Monday at 11:39 "initialUnits": "10", "initialMarginRequired": "0.2966", "state": "OPEN", @@ -95,6 +96,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): # Run parse_time on all items in trades for trade in self.trades: trade["openTime"] = parse_time(trade) + + self.balance = 100000 + self.balance_usd = 120000 self.ams.get_trades = self.fake_get_trades self.ams.get_balance = self.fake_get_balance # self.ams.trades = self.trades @@ -123,12 +127,25 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): trade["openTime"] = parse_time(trade) self.trades.append(trade) + def amend_tp_sl_flip_side(self): + """ + Amend the take profit and stop loss orders to be the opposite side. + This lets the protection tests pass, so we can only test one violation + per test. + """ + for trade in self.trades: + if trade["side"] == "short": + trade["stopLossOrder"]["price"] = "1.07386" + trade["takeProfitOrder"]["price"] = "1.04728" + def fake_get_trades(self): self.ams.trades = self.trades return self.trades - def fake_get_balance(self): - return 10000 + def fake_get_balance(self, return_usd=None): + if return_usd: + return self.balance_usd + return self.balance def fake_get_currencies(self, symbols): pass @@ -139,7 +156,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): def test_get_balance(self): balance = self.ams.get_balance() - self.assertEqual(balance, 10000) + self.assertEqual(balance, self.balance) def check_violation( self, violation, calls, expected_action, expected_trades, expected_args=None @@ -153,6 +170,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): :param: expected_trades: list of expected trades to be passed to the violation :param: expected_args: optional, expected args to be passed to the violation """ + self.assertEqual(len(calls), len(expected_trades)) calls = list(calls) violation_calls = [] for call in calls: @@ -160,11 +178,18 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): violation_calls.append(call) self.assertEqual(len(violation_calls), len(expected_trades)) - expected_trades = convert_trades(expected_trades) + if all(expected_trades): + expected_trades = convert_trades(expected_trades) for call in violation_calls: # Ensure the correct action has been called, like close self.assertEqual(call[0][1], expected_action) # Ensure the correct trade has been passed to the violation + trade = call[0][2] + if trade: + for field in list(trade.keys()): + if "_usd" in field: + if field in trade.keys(): + del trade[field] self.assertIn(call[0][2], expected_trades) if expected_args: self.assertEqual(call[0][3], expected_args) @@ -172,7 +197,6 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_run_checks(self, handle_violation): self.ams.run_checks() - print("handle_violation.call_count", handle_violation.call_args_list) self.assertEqual(handle_violation.call_count, 0) @patch("core.trading.active_management.ActiveManagement.handle_violation") @@ -201,7 +225,9 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): signal = self.create_hook_signal() self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "sell"} + self.amend_tp_sl_flip_side() self.strategy.save() + self.ams.run_checks() self.check_violation( "trends", handle_violation.call_args_list, "close", self.trades @@ -216,8 +242,14 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.ams.run_checks() self.check_violation("trends", handle_violation.call_args_list, "close", []) + # Mock crossfilter here since we want to allow this conflict in order to test that + # trends only close trades that are in the wrong direction + @patch( + "core.trading.active_management.ActiveManagement.check_crossfilter", + return_value=None, + ) @patch("core.trading.active_management.ActiveManagement.handle_violation") - def test_trends_violated_partial(self, handle_violation): + def test_trends_violated_partial(self, handle_violation, check_crossfilter): signal = self.create_hook_signal() self.strategy.trend_signals.set([signal]) self.strategy.trends = {"EUR_USD": "sell"} @@ -225,6 +257,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): # Change the side of the first trade to match the trends self.trades[0]["side"] = "short" + self.amend_tp_sl_flip_side() self.ams.run_checks() self.check_violation( @@ -241,7 +274,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): handle_violation.call_args_list, "close", [self.trades[0]], - {"size": 50}, + {"size": 500}, ) @patch("core.trading.active_management.ActiveManagement.handle_violation") @@ -270,7 +303,6 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): self.trades[0]["takeProfitOrder"] = None self.trades[0]["stopLossOrder"] = None self.ams.run_checks() - print("CALLS", handle_violation.call_args_list) self.assertEqual(handle_violation.call_count, 0) @@ -301,6 +333,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): def test_asset_groups_violated_invert(self, handle_violation): self.trades[0]["side"] = "short" self.trades[1]["side"] = "short" + self.amend_tp_sl_flip_side() asset_group = AssetGroup.objects.create( user=self.user, name="Test Asset Group", @@ -325,6 +358,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_crossfilter_violated_side(self, handle_violation): self.trades[1]["side"] = "short" + self.amend_tp_sl_flip_side() self.ams.run_checks() self.check_violation( @@ -336,16 +370,19 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_crossfilter_violated_side_multiple(self, handle_violation): - self.add_trade("20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z") - self.add_trade("20086", "EUR_USD", "short", "2023-02-14T12:39:06.302917985Z") - self.add_trade("20087", "EUR_USD", "short", "2023-02-10T12:39:06.302917985Z") + self.add_trade( + "20085", "EUR_USD", "short", "2023-02-13T12:39:06.302917985Z" + ) # 2: + self.add_trade("20086", "EUR_USD", "short", "2023-02-13T13:39:07.302917985Z") + self.add_trade("20087", "EUR_USD", "short", "2023-02-13T14:39:06.302917985Z") + self.amend_tp_sl_flip_side() self.ams.run_checks() self.check_violation( "crossfilter", handle_violation.call_args_list, "close", - self.trades[0:4], # Only close newer trades + self.trades[2:], # Only close newer trades ) @patch("core.trading.active_management.ActiveManagement.handle_violation") @@ -363,16 +400,18 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_crossfilter_violated_symbol_multiple(self, handle_violation): - self.add_trade("20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z") - self.add_trade("20086", "USD_EUR", "long", "2023-02-14T12:39:06.302917985Z") - self.add_trade("20087", "USD_EUR", "long", "2023-02-10T12:39:06.302917985Z") + self.add_trade( + "20085", "USD_EUR", "long", "2023-02-13T12:39:06.302917985Z" + ) # 2: + self.add_trade("20086", "USD_EUR", "long", "2023-02-13T13:39:06.302917985Z") + self.add_trade("20087", "USD_EUR", "long", "2023-02-13T14:39:06.302917985Z") self.ams.run_checks() self.check_violation( "crossfilter", handle_violation.call_args_list, "close", - self.trades[0:4], # Only close newer trades + self.trades[2:], # Only close newer trades ) @patch("core.trading.active_management.ActiveManagement.handle_violation") @@ -380,7 +419,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): for x in range(9): self.add_trade( str(x), - "EUR_USD", + f"EUR_USD{x}", # Vary symbol to prevent max open trades per symbol "long", f"2023-02-13T12:39:1{x}.302917985Z", ) @@ -395,7 +434,7 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): @patch("core.trading.active_management.ActiveManagement.handle_violation") def test_max_open_trades_per_symbol_violated(self, handle_violation): - for x in range(2): + for x in range(4): self.add_trade( str(x), "EUR_USD", @@ -408,11 +447,102 @@ class ActiveManagementTestCase(StrategyMixin, SymbolPriceMock, TestCase): "max_open_trades_per_symbol", handle_violation.call_args_list, "close", - self.trades[2:], # Only close newer trades + self.trades[5:], # Only close newer trades ) - def test_max_loss_violated(self): - pass + # Mock position size as we have no way of checking the balance at the start of the + # trade. + # TODO: Fix this when we have a way of checking the balance at the start of the + # trade. + # Max risk is also mocked as this puts us over the limit, due to the low account + # size. + @patch( + "core.trading.active_management.ActiveManagement.check_max_risk", + return_value=None, + ) + @patch( + "core.trading.active_management.ActiveManagement.check_position_size", + return_value=None, + ) + @patch("core.trading.active_management.ActiveManagement.handle_violation") + def test_max_loss_violated( + self, handle_violation, check_position_size, check_max_risk + ): + self.balance = D("1") + self.balance_usd = D("0.69") + self.ams.run_checks() - def test_max_risk_violated(self): - pass + self.check_violation( + "max_loss", + handle_violation.call_args_list, + "close", + [None], + ) + + @patch( + "core.trading.active_management.ActiveManagement.check_position_size", + return_value=None, + ) + @patch( + "core.trading.active_management.ActiveManagement.check_protection", + return_value=None, + ) + @patch("core.trading.active_management.ActiveManagement.handle_violation") + def test_max_risk_violated( + self, handle_violation, check_protection, check_position_size + ): + self.add_trade( + "20085", + "EUR_USD", + "long", + "2023-02-13T15:39:19.302917985Z", + ) + self.trades[2]["stopLossOrder"]["price"] = "0.001" + self.trades[2]["currentUnits"] = "13000" + + self.ams.run_checks() + + self.check_violation( + "max_risk", + handle_violation.call_args_list, + "close", + [self.trades[2]], + ) + + @patch( + "core.trading.active_management.ActiveManagement.check_position_size", + return_value=None, + ) + @patch( + "core.trading.active_management.ActiveManagement.check_protection", + return_value=None, + ) + @patch("core.trading.active_management.ActiveManagement.handle_violation") + def test_max_risk_violated_multiple( + self, handle_violation, check_protection, check_position_size + ): + self.add_trade( + "20085", + "EUR_USD", + "long", + "2023-02-13T15:39:19.302917985Z", + ) + self.add_trade( + "20086", + "EUR_USD", + "long", + "2023-02-13T15:45:19.302917985Z", + ) + self.trades[2]["stopLossOrder"]["price"] = "0.001" + self.trades[2]["currentUnits"] = "13000" + self.trades[3]["stopLossOrder"]["price"] = "0.001" + self.trades[3]["currentUnits"] = "13000" + + self.ams.run_checks() + + self.check_violation( + "max_risk", + handle_violation.call_args_list, + "close", + [self.trades[2], self.trades[3]], + ) diff --git a/core/tests/trading/test_checks.py b/core/tests/trading/test_checks.py index 04e7012..07d4b14 100644 --- a/core/tests/trading/test_checks.py +++ b/core/tests/trading/test_checks.py @@ -1,18 +1,7 @@ -from datetime import time - import freezegun from django.test import TestCase -from core.models import ( - Account, - Hook, - OrderSettings, - RiskModel, - Signal, - Strategy, - TradingTime, - User, -) +from core.models import Account, Hook, Signal, User from core.tests.helpers import StrategyMixin from core.trading import checks diff --git a/core/tests/trading/test_risk.py b/core/tests/trading/test_risk.py index d45d2af..ef2fb6a 100644 --- a/core/tests/trading/test_risk.py +++ b/core/tests/trading/test_risk.py @@ -1,9 +1,8 @@ from django.test import TestCase +import core.trading.market # noqa # to avoid messy circular import from core.exchanges import convert from core.models import RiskModel, User -import core.trading.market # to avoid messy circular import - from core.trading import risk diff --git a/core/trading/active_management.py b/core/trading/active_management.py index c78e7b9..da2c5a2 100644 --- a/core/trading/active_management.py +++ b/core/trading/active_management.py @@ -1,14 +1,13 @@ +from copy import deepcopy from datetime import datetime from decimal import Decimal as D -import core.trading.market # to avoid messy circular import from core.exchanges.convert import ( convert_trades, - side_to_direction, sl_percent_to_price, tp_percent_to_price, ) -from core.trading import assetfilter, checks, risk +from core.trading import assetfilter, checks, market, risk from core.trading.crossfilter import crossfilter from core.trading.market import get_base_quote, get_trade_size_in_base @@ -20,17 +19,28 @@ class ActiveManagement(object): self.trades = [] self.balance = None + self.balance_usd = None def get_trades(self): if not self.trades: self.trades = self.strategy.account.client.get_all_open_trades() return self.trades - def get_balance(self): - if self.balance is None: - self.balance = self.strategy.account.client.get_balance() + def get_balance(self, return_usd=False): + if return_usd: + if self.balance_usd is None: + self.balance_usd = self.strategy.account.client.get_balance( + return_usd=True + ) + else: + return self.balance_usd else: - return self.balance + if self.balance is None: + self.balance = self.strategy.account.client.get_balance( + return_usd=False + ) + else: + return self.balance def handle_violation(self, check_type, action, trade, **kwargs): print("VIOLATION", check_type, action, trade, kwargs) @@ -54,8 +64,10 @@ class ActiveManagement(object): def check_position_size(self, trade): """ Check the position size is within the allowed deviation. - WARNING: This uses the current balance, not the balance at the time of the trade. - WARNING: This uses the current symbol prices, not those at the time of the trade. + WARNING: This uses the current balance, not the balance at the time of the + trade. + WARNING: This uses the current symbol prices, not those at the time of the + trade. This should normally be run every 5 seconds, so this is fine. """ # TODO: add the trade value to the balance @@ -86,10 +98,12 @@ class ActiveManagement(object): def check_protection(self, trade): deviation = D(0.05) # 5% + # fmt: off matches = { "stop_loss_percent": self.strategy.order_settings.stop_loss_percent, "take_profit_percent": self.strategy.order_settings.take_profit_percent, - "trailing_stop_percent": self.strategy.order_settings.trailing_stop_loss_percent, + "trailing_stop_percent": + self.strategy.order_settings.trailing_stop_loss_percent, } violations = {} @@ -155,7 +169,7 @@ class ActiveManagement(object): ) def get_sorted_trades_copy(self, trades, reverse=True): - trades_copy = trades.copy() + trades_copy = deepcopy(trades) # sort by open time, newest first trades_copy.sort( key=lambda x: datetime.strptime(x["open_time"], "%Y-%m-%dT%H:%M:%S.%fZ"), @@ -170,7 +184,8 @@ class ActiveManagement(object): iterations = 0 finished = [] - # Recursively run crossfilter on the newest-first list until we have no more conflicts + # Recursively run crossfilter on the newest-first list until we have no more + # conflicts while not len(finished) == len(trades): iterations += 1 if iterations > 10000: @@ -207,17 +222,7 @@ class ActiveManagement(object): close_trades.append(trade) if not close_trades: return - # For each conflicting symbol, identify the oldest trades - # removed_trades = [] - # for symbol in conflict: - # newest_trade = max(conflict, key=lambda x: datetime.strptime(x["open_time"], "%Y-%m-%dT%H:%M:%S.%fZ")) - # removed_trades.append(newest_trade) - # print("KEEP TRADES", keep_trade_ids) - # close_trades = [] - # for x in keep_trade_ids: - # for position in conflict[x]: - # if position["id"] not in keep_trade_ids[x]: - # close_trades.append(position) + if close_trades: for trade in close_trades: self.handle_violation( @@ -225,57 +230,90 @@ class ActiveManagement(object): ) def check_max_open_trades(self, trades): - if self.strategy.risk_model is not None: - max_open_pass = risk.check_max_open_trades(self.strategy.risk_model, trades) - if not max_open_pass: - trades_copy = self.get_sorted_trades_copy(trades, reverse=False) - print("TRADES COPY", [x["id"] for x in trades_copy]) - print("MAX", self.strategy.risk_model.max_open_trades) - trades_over_limit = trades_copy[ - self.strategy.risk_model.max_open_trades : - ] - for trade in trades_over_limit: - self.handle_violation( - "max_open_trades", - self.policy.when_max_open_trades_violated, - trade, - ) - print("TRADES OVER LIMNIT", trades_over_limit) + if self.strategy.risk_model is None: + return + max_open_pass = risk.check_max_open_trades(self.strategy.risk_model, trades) + if not max_open_pass: + trades_copy = self.get_sorted_trades_copy(trades, reverse=False) + # fmt: off + trades_over_limit = trades_copy[self.strategy.risk_model.max_open_trades:] + for trade in trades_over_limit: + self.handle_violation( + "max_open_trades", + self.policy.when_max_open_trades_violated, + trade, + ) def check_max_open_trades_per_symbol(self, trades): - if self.strategy.risk_model is not None: - max_open_pass = risk.check_max_open_trades_per_symbol( - self.strategy.risk_model, trades, return_symbols=True - ) - print("max_open_pass", max_open_pass) - max_open_pass = list(max_open_pass) - print("MAX OPEN PASS", max_open_pass) - if max_open_pass: - trades_copy = self.get_sorted_trades_copy(trades, reverse=False) - trades_over_limit = [] - for symbol in max_open_pass: - print("SYMBOL", symbol) - print("TRADES", trades) - symbol_trades = [x for x in trades_copy if x["symbol"] == symbol] - exceeding_limit = symbol_trades[ - self.strategy.risk_model.max_open_trades_per_symbol : - ] - for x in exceeding_limit: - trades_over_limit.append(x) - - for trade in trades_over_limit: - self.handle_violation( - "max_open_trades_per_symbol", - self.policy.when_max_open_trades_violated, - trade, - ) - print("TRADES OVER LIMNIT", trades_over_limit) + if self.strategy.risk_model is None: + return + max_open_pass = risk.check_max_open_trades_per_symbol( + self.strategy.risk_model, trades, return_symbols=True + ) + max_open_pass = list(max_open_pass) + if max_open_pass: + trades_copy = self.get_sorted_trades_copy(trades, reverse=False) + trades_over_limit = [] + for symbol in max_open_pass: + symbol_trades = [x for x in trades_copy if x["symbol"] == symbol] + # fmt: off + exceeding_limit = symbol_trades[ + self.strategy.risk_model.max_open_trades_per_symbol: + ] + for x in exceeding_limit: + trades_over_limit.append(x) + + for trade in trades_over_limit: + self.handle_violation( + "max_open_trades_per_symbol", + self.policy.when_max_open_trades_violated, + trade, + ) def check_max_loss(self): - check_passed = risk.check_max_loss(self.strategy.risk_model, self.strategy.account.initial_balance, self.get_balance()) + if self.strategy.risk_model is None: + return + check_passed = risk.check_max_loss( + self.strategy.risk_model, + self.strategy.account.initial_balance, + self.get_balance(), + ) + if not check_passed: + self.handle_violation( + "max_loss", self.policy.when_max_loss_violated, None # Close all trades + ) def check_max_risk(self, trades): - pass + if self.strategy.risk_model is None: + return + close_trades = [] + + trades_copy = self.get_sorted_trades_copy(trades, reverse=False) + market.convert_trades_to_usd(self.strategy.account, trades_copy) + + iterations = 0 + finished = False + while not finished: + iterations += 1 + if iterations > 10000: + raise Exception("Too many iterations") + + check_passed = risk.check_max_risk( + self.strategy.risk_model, + self.get_balance(return_usd=True), + trades_copy, + ) + if check_passed: + finished = True + else: + # Add the newest trade to close_trades and remove it from trades_copy + close_trades.append(trades_copy[-1]) + trades_copy = trades_copy[:-1] + if close_trades: + for trade in close_trades: + self.handle_violation( + "max_risk", self.policy.when_max_risk_violated, trade + ) def run_checks(self): converted_trades = convert_trades(self.get_trades()) diff --git a/core/trading/market.py b/core/trading/market.py index 3c5960f..2420edc 100644 --- a/core/trading/market.py +++ b/core/trading/market.py @@ -20,7 +20,7 @@ def convert_trades_to_usd(account, trades): :return: List of trades, with amount_usd added """ for trade in trades: - amount = trade["amount"] + amount = D(trade["amount"]) symbol = trade["symbol"] side = trade["side"] direction = side_to_direction(side) diff --git a/core/trading/risk.py b/core/trading/risk.py index 1762957..3774fa4 100644 --- a/core/trading/risk.py +++ b/core/trading/risk.py @@ -26,6 +26,7 @@ def check_max_risk(risk_model, account_balance_usd, account_trades): # Calculate the max risk of the account in USD max_risk_usd = account_balance_usd * (max_risk_percent / D(100)) total_risk = 0 + for trade in account_trades: max_tmp = [] # Need to calculate the max risk in base account currency @@ -36,7 +37,8 @@ def check_max_risk(risk_model, account_balance_usd, account_trades): if "trailing_stop_loss_usd" in trade: max_tmp.append(trade["trailing_stop_loss_usd"]) if max_tmp: - total_risk += max(max_tmp) + max_risk = max(max_tmp) + total_risk += max_risk allowed = total_risk < max_risk_usd return allowed @@ -59,7 +61,6 @@ def check_max_open_trades_per_symbol(risk_model, account_trades, return_symbols= if symbol not in symbol_map: symbol_map[symbol] = 0 symbol_map[symbol] += 1 - print("Symbol map: ", symbol_map) violating_symbols = [] for symbol, count in symbol_map.items(): if count >= risk_model.max_open_trades_per_symbol: