Write tests for asset filter

This commit is contained in:
Mark Veidemanis 2023-02-11 18:22:49 +00:00
parent ce0b75ae2d
commit 313c7f79d0
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
5 changed files with 49 additions and 5 deletions

View File

@ -1,7 +1,7 @@
# Generated by Django 4.1.6 on 2023-02-11 18:17
from django.db import migrations, models
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):

View File

@ -412,7 +412,9 @@ class AssetGroup(models.Model):
description = models.TextField(null=True, blank=True)
# Account for checking pairs on children if specified
account = models.ForeignKey(Account, on_delete=models.PROTECT, null=True, blank=True)
account = models.ForeignKey(
Account, on_delete=models.PROTECT, null=True, blank=True
)
# Dict like {"RUB": True, "USD": False}
allowed = models.JSONField(null=True, blank=True, default=dict)

View File

@ -0,0 +1,28 @@
from django.test import TestCase
from core.trading.assetfilter import get_allowed
from core.models import AssetGroup, User
class AssetfilterTestCase(TestCase):
def setUp(self):
self.user = User.objects.create_user(
username="testuser",
email="test@example.com",
)
self.group = AssetGroup.objects.create(
user=self.user,
name="Group1",
description="Test group",
)
def test_get_allowed(self):
"""
Test that the asset filter works.
"""
self.group.allowed = {"EUR_USD": True, "EUR_GBP": False}
self.assertTrue(get_allowed(self.group, "EUR_USD", "buy"))
self.assertFalse(get_allowed(self.group, "EUR_GBP", "sell"))
# Default true
self.assertTrue(get_allowed(self.group, "nonexistent", "sell"))

View File

@ -1,7 +1,14 @@
def get_allowed(strategy, symbol, direction):
def get_allowed(group, symbol, direction):
"""
Determine whether the trade is allowed according to the Asset Groups
linked to the strategy.
"""
# TODO: figure out what to do with direction
# asset_group = strategy.
allowed = group.allowed
if not isinstance(allowed, dict):
return False
if symbol not in allowed:
return True
return allowed[symbol]

View File

@ -326,8 +326,15 @@ def execute_strategy(callback, strategy, func):
price_bound = round(price_bound, display_precision)
# Callback now verified
# Check against the asset groups
if func == "entry":
allowed = assetfilter.get_allowed(strategy, symbol, direction)
if strategy.assetgroup is not None:
allowed = assetfilter.get_allowed(strategy, symbol, direction)
if not allowed:
log.debug(f"Asset trading not allowed for {strategy}: {symbol}")
return
if func == "exit":
check_exit = crossfilter(account, symbol, direction, func)
if check_exit is None: