diff --git a/core/tests/trading/test_assetfilter.py b/core/tests/trading/test_assetfilter.py index dcbdcad..d610d0b 100644 --- a/core/tests/trading/test_assetfilter.py +++ b/core/tests/trading/test_assetfilter.py @@ -16,16 +16,55 @@ class AssetfilterTestCase(TestCase): description="Test group", ) - def test_get_allowed(self): + def test_get_allowed_negative(self): """ - Test that the asset filter works. + Test that the asset filter works on negative aggregations. """ - self.group.allowed = {"EUR_USD": True, "EUR_GBP": False} - self.assertTrue(assetfilter.get_allowed(self.group, "EUR_USD", "buy")) - self.assertFalse(assetfilter.get_allowed(self.group, "EUR_GBP", "sell")) + # We have negative news about EUR + self.group.allowed = {"EUR": False} + self.group.save() - # Default true - self.assertTrue(assetfilter.get_allowed(self.group, "nonexistent", "sell")) + # This means that: + # * base == EUR: long is not allowed, short is allowed + # * quote == EUR: long is allowed, short is not allowed + # We are betting on it going down, so we short base, and long quote. + + # Test that short on base of EUR is allowed + self.assertTrue(assetfilter.get_allowed(self.group, "EUR", "USD", "short")) + + # Test that long on quote of EUR is allowed + self.assertTrue(assetfilter.get_allowed(self.group, "USD", "EUR", "long")) + + # Test that long on base of EUR is not allowed + self.assertFalse(assetfilter.get_allowed(self.group, "EUR", "USD", "long")) + + # Test that short on quote of EUR is not allowed + self.assertFalse(assetfilter.get_allowed(self.group, "USD", "EUR", "short")) + + def test_get_allowed_positive(self): + """ + Test that the asset filter works on positive aggregations. + """ + # We have positive news about EUR + self.group.allowed = {"EUR": True} + self.group.save() + + # This means that: + # * base == EUR: long is allowed, short is not allowed + # * quote == EUR: long is not allowed, short is allowed + # We are betting on it going up, so we long base, and short quote. + + # Test that long on base of EUR is allowed + self.assertTrue(assetfilter.get_allowed(self.group, "EUR", "USD", "long")) + + # Test that short on quote of EUR is allowed + self.assertTrue(assetfilter.get_allowed(self.group, "USD", "EUR", "short")) + + # Test that short on base of EUR is not allowed + self.assertFalse(assetfilter.get_allowed(self.group, "EUR", "USD", "short")) + + # Test that long on quote of EUR is not allowed + self.assertFalse(assetfilter.get_allowed(self.group, "USD", "EUR", "long")) def test_check_asset_aggregation(self): """ diff --git a/core/trading/assetfilter.py b/core/trading/assetfilter.py index db77667..1cdbee3 100644 --- a/core/trading/assetfilter.py +++ b/core/trading/assetfilter.py @@ -1,22 +1,44 @@ -def get_allowed(group, symbol, direction): +def get_allowed(group, base, quote, direction): """ - Determine whether the trade is allowed according to the Asset Groups - linked to the strategy. + Determine whether the trade is allowed according to the group. + See tests for examples. The logic requires trading knowledge. + :param group: The group to check + :param base: The base currency + :param quote: The quote currency + :param direction: The direction of the trade """ - # TODO: figure out what to do with direction allowed = group.allowed if not isinstance(allowed, dict): return False - if symbol not in allowed: - return True - return allowed[symbol] + # If our base has allowed == False, we can only short it, or long the quote + if base in allowed: + if not allowed[base]: + if direction == "long": + return False + else: + if direction == "short": + return False + + # If our quote has allowed == False, we can only long it, or short the base + if quote in allowed: + if not allowed[quote]: + if direction == "short": + return False + else: + if direction == "long": + return False + + return True def check_asset_aggregation(value, trigger_above, trigger_below): """ - Check if the value is within the bounds of the aggregation + Check if the value is within the bounds of the aggregation. + :param value: The value to check + :param trigger_above: Only trigger if the value is above this + :param trigger_below: Only trigger if the value is below this """ # If both are defined if trigger_above is not None and trigger_below is not None: diff --git a/core/trading/crossfilter.py b/core/trading/crossfilter.py index 8e3006a..12428bd 100644 --- a/core/trading/crossfilter.py +++ b/core/trading/crossfilter.py @@ -63,6 +63,22 @@ def check_conflicting_position( new_symbol: str, trade_side_opposite: str, ): + """ + Determine if we have a conflicting position open. + :param func: Whether we are checking entries, exits or trends + :param position: Position dict + :param open_base: Base currency of the open position + :param open_quote: Quote currency of the open position + :param open_side: Side of the open position + :param open_symbol: Symbol of the open position + :param open_units: Units of the open position + :param new_base: Base currency of the new position + :param new_quote: Quote currency of the new position + :param new_side: Side of the new position + :param new_symbol: Symbol of the new position + :param trade_side_opposite: Opposite side of the trade + :return: dict of action and opposing position, or None + """ if open_base == new_quote or open_quote == new_base: # If we have a long on GBP/AUD, we can only place shorts on XAU/GBP. if open_side != trade_side_opposite: @@ -85,10 +101,10 @@ def crossfilter(account, new_symbol, new_direction, func): Checks open positions for the account, rejecting the trade if there is one with an opposite direction to this one. :param account: Account object - :param symbol: Symbol - :param direction: Direction of the trade - :param func: Whether we are checking entries or exits - :return: dict of action and opposing position, or False + :param new_symbol: Symbol of the new position + :param new_direction: Direction of the new position + :param func: Whether we are checking entries, exits or trends + :return: dict of action and opposing position, False or None """ try: # Only get the data we need diff --git a/core/trading/market.py b/core/trading/market.py index e39d7f7..3b943e8 100644 --- a/core/trading/market.py +++ b/core/trading/market.py @@ -329,7 +329,7 @@ def execute_strategy(callback, strategy, func): # Check against the asset groups if func == "entry" and strategy.assetgroup is not None: - allowed = assetfilter.get_allowed(strategy, symbol, direction) + allowed = assetfilter.get_allowed(strategy, base, quote, direction) if not allowed: log.debug( f"Denied trading {symbol} due to asset filter {strategy.assetgroup}"