Write tests for asset filter
This commit is contained in:
parent
ce0b75ae2d
commit
313c7f79d0
|
@ -1,7 +1,7 @@
|
||||||
# Generated by Django 4.1.6 on 2023-02-11 18:17
|
# Generated by Django 4.1.6 on 2023-02-11 18:17
|
||||||
|
|
||||||
from django.db import migrations, models
|
|
||||||
import django.db.models.deletion
|
import django.db.models.deletion
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
|
|
|
@ -412,7 +412,9 @@ class AssetGroup(models.Model):
|
||||||
description = models.TextField(null=True, blank=True)
|
description = models.TextField(null=True, blank=True)
|
||||||
|
|
||||||
# Account for checking pairs on children if specified
|
# Account for checking pairs on children if specified
|
||||||
account = models.ForeignKey(Account, on_delete=models.PROTECT, null=True, blank=True)
|
account = models.ForeignKey(
|
||||||
|
Account, on_delete=models.PROTECT, null=True, blank=True
|
||||||
|
)
|
||||||
|
|
||||||
# Dict like {"RUB": True, "USD": False}
|
# Dict like {"RUB": True, "USD": False}
|
||||||
allowed = models.JSONField(null=True, blank=True, default=dict)
|
allowed = models.JSONField(null=True, blank=True, default=dict)
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from core.trading.assetfilter import get_allowed
|
||||||
|
from core.models import AssetGroup, User
|
||||||
|
|
||||||
|
class AssetfilterTestCase(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.user = User.objects.create_user(
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
)
|
||||||
|
self.group = AssetGroup.objects.create(
|
||||||
|
user=self.user,
|
||||||
|
name="Group1",
|
||||||
|
description="Test group",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_allowed(self):
|
||||||
|
"""
|
||||||
|
Test that the asset filter works.
|
||||||
|
"""
|
||||||
|
self.group.allowed = {"EUR_USD": True, "EUR_GBP": False}
|
||||||
|
self.assertTrue(get_allowed(self.group, "EUR_USD", "buy"))
|
||||||
|
self.assertFalse(get_allowed(self.group, "EUR_GBP", "sell"))
|
||||||
|
|
||||||
|
# Default true
|
||||||
|
self.assertTrue(get_allowed(self.group, "nonexistent", "sell"))
|
|
@ -1,7 +1,14 @@
|
||||||
def get_allowed(strategy, symbol, direction):
|
def get_allowed(group, symbol, direction):
|
||||||
"""
|
"""
|
||||||
Determine whether the trade is allowed according to the Asset Groups
|
Determine whether the trade is allowed according to the Asset Groups
|
||||||
linked to the strategy.
|
linked to the strategy.
|
||||||
"""
|
"""
|
||||||
|
# TODO: figure out what to do with direction
|
||||||
|
|
||||||
# asset_group = strategy.
|
allowed = group.allowed
|
||||||
|
if not isinstance(allowed, dict):
|
||||||
|
return False
|
||||||
|
if symbol not in allowed:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return allowed[symbol]
|
||||||
|
|
|
@ -326,8 +326,15 @@ def execute_strategy(callback, strategy, func):
|
||||||
price_bound = round(price_bound, display_precision)
|
price_bound = round(price_bound, display_precision)
|
||||||
|
|
||||||
# Callback now verified
|
# Callback now verified
|
||||||
|
|
||||||
|
# Check against the asset groups
|
||||||
if func == "entry":
|
if func == "entry":
|
||||||
allowed = assetfilter.get_allowed(strategy, symbol, direction)
|
if strategy.assetgroup is not None:
|
||||||
|
allowed = assetfilter.get_allowed(strategy, symbol, direction)
|
||||||
|
if not allowed:
|
||||||
|
log.debug(f"Asset trading not allowed for {strategy}: {symbol}")
|
||||||
|
return
|
||||||
|
|
||||||
if func == "exit":
|
if func == "exit":
|
||||||
check_exit = crossfilter(account, symbol, direction, func)
|
check_exit = crossfilter(account, symbol, direction, func)
|
||||||
if check_exit is None:
|
if check_exit is None:
|
||||||
|
|
Loading…
Reference in New Issue