From 313c7f79d0c27a448481f94015cbae9ab743a46b Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Sat, 11 Feb 2023 18:22:49 +0000 Subject: [PATCH] Write tests for asset filter --- .../0057_alter_assetgroup_account_and_more.py | 2 +- core/models.py | 4 ++- core/tests/trading/test_assetfilter.py | 28 +++++++++++++++++++ core/trading/assetfilter.py | 11 ++++++-- core/trading/market.py | 9 +++++- 5 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 core/tests/trading/test_assetfilter.py diff --git a/core/migrations/0057_alter_assetgroup_account_and_more.py b/core/migrations/0057_alter_assetgroup_account_and_more.py index b0afd61..9f3ba32 100644 --- a/core/migrations/0057_alter_assetgroup_account_and_more.py +++ b/core/migrations/0057_alter_assetgroup_account_and_more.py @@ -1,7 +1,7 @@ # Generated by Django 4.1.6 on 2023-02-11 18:17 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): diff --git a/core/models.py b/core/models.py index 09a50dc..d65d33e 100644 --- a/core/models.py +++ b/core/models.py @@ -412,7 +412,9 @@ class AssetGroup(models.Model): description = models.TextField(null=True, blank=True) # Account for checking pairs on children if specified - account = models.ForeignKey(Account, on_delete=models.PROTECT, null=True, blank=True) + account = models.ForeignKey( + Account, on_delete=models.PROTECT, null=True, blank=True + ) # Dict like {"RUB": True, "USD": False} allowed = models.JSONField(null=True, blank=True, default=dict) diff --git a/core/tests/trading/test_assetfilter.py b/core/tests/trading/test_assetfilter.py new file mode 100644 index 0000000..5664589 --- /dev/null +++ b/core/tests/trading/test_assetfilter.py @@ -0,0 +1,28 @@ +from django.test import TestCase + +from core.trading.assetfilter import get_allowed +from core.models import AssetGroup, User + +class AssetfilterTestCase(TestCase): + def setUp(self): + self.user = User.objects.create_user( + username="testuser", + email="test@example.com", + ) + self.group = AssetGroup.objects.create( + user=self.user, + name="Group1", + description="Test group", + ) + + + def test_get_allowed(self): + """ + Test that the asset filter works. + """ + self.group.allowed = {"EUR_USD": True, "EUR_GBP": False} + self.assertTrue(get_allowed(self.group, "EUR_USD", "buy")) + self.assertFalse(get_allowed(self.group, "EUR_GBP", "sell")) + + # Default true + self.assertTrue(get_allowed(self.group, "nonexistent", "sell")) \ No newline at end of file diff --git a/core/trading/assetfilter.py b/core/trading/assetfilter.py index 6ccda02..7ef4f97 100644 --- a/core/trading/assetfilter.py +++ b/core/trading/assetfilter.py @@ -1,7 +1,14 @@ -def get_allowed(strategy, symbol, direction): +def get_allowed(group, symbol, direction): """ Determine whether the trade is allowed according to the Asset Groups linked to the strategy. """ + # TODO: figure out what to do with direction - # asset_group = strategy. + allowed = group.allowed + if not isinstance(allowed, dict): + return False + if symbol not in allowed: + return True + + return allowed[symbol] diff --git a/core/trading/market.py b/core/trading/market.py index 38a5b16..345f623 100644 --- a/core/trading/market.py +++ b/core/trading/market.py @@ -326,8 +326,15 @@ def execute_strategy(callback, strategy, func): price_bound = round(price_bound, display_precision) # Callback now verified + + # Check against the asset groups if func == "entry": - allowed = assetfilter.get_allowed(strategy, symbol, direction) + if strategy.assetgroup is not None: + allowed = assetfilter.get_allowed(strategy, symbol, direction) + if not allowed: + log.debug(f"Asset trading not allowed for {strategy}: {symbol}") + return + if func == "exit": check_exit = crossfilter(account, symbol, direction, func) if check_exit is None: