from __future__ import annotations from asgiref.sync import async_to_sync from django.test import TestCase from core.commands.base import CommandContext from core.commands.engine import process_inbound_message from core.gateway.commands import ( GatewayCommandContext, GatewayCommandRoute, dispatch_gateway_command, ) from core.models import ( ChatSession, CommandChannelBinding, CommandProfile, CommandSecurityPolicy, GatewayCommandEvent, Message, Person, PersonIdentifier, User, UserXmppOmemoState, ) from core.security.command_policy import CommandSecurityContext, evaluate_command_policy class CommandSecurityPolicyTests(TestCase): def setUp(self): self.user = User.objects.create_user( username="policy-user", email="policy-user@example.com", password="x", ) self.person = Person.objects.create(user=self.user, name="Policy Person") self.identifier = PersonIdentifier.objects.create( user=self.user, person=self.person, service="xmpp", identifier="policy-user@zm.is", ) self.session = ChatSession.objects.create( user=self.user, identifier=self.identifier, ) def test_command_profile_scope_denies_disallowed_service(self): profile = CommandProfile.objects.create( user=self.user, slug="bp", name="Business Plan", enabled=True, trigger_token="#bp#", reply_required=False, exact_match_only=True, ) CommandChannelBinding.objects.create( profile=profile, direction="ingress", service="xmpp", channel_identifier="policy-user@zm.is", enabled=True, ) CommandSecurityPolicy.objects.create( user=self.user, scope_key="command.bp", enabled=True, allowed_services=["whatsapp"], ) msg = Message.objects.create( user=self.user, session=self.session, sender_uuid="", text="#bp#", ts=1000, source_service="xmpp", source_chat_id="policy-user@zm.is", message_meta={}, ) results = async_to_sync(process_inbound_message)( CommandContext( service="xmpp", channel_identifier="policy-user@zm.is", message_id=str(msg.id), user_id=self.user.id, message_text="#bp#", payload={}, ) ) self.assertEqual(1, len(results)) self.assertEqual("skipped", results[0].status) self.assertTrue(str(results[0].error).startswith("policy_denied:service_not_allowed")) def test_gateway_scope_can_require_trusted_omemo_key(self): CommandSecurityPolicy.objects.create( user=self.user, scope_key="gateway.tasks", enabled=True, require_omemo=True, require_trusted_omemo_fingerprint=True, ) UserXmppOmemoState.objects.create( user=self.user, status="detected", latest_client_key="sid:abc", last_sender_jid="policy-user@zm.is/phone", last_target_jid="jews.zm.is", ) outputs: list[str] = [] async def _tasks_handler(_ctx, emit): emit("ok") return True handled = async_to_sync(dispatch_gateway_command)( context=GatewayCommandContext( user=self.user, source_message=None, service="xmpp", channel_identifier="policy-user@zm.is", sender_identifier="policy-user@zm.is/phone", message_text=".tasks list", message_meta={"xmpp": {"omemo_status": "detected", "omemo_client_key": "sid:abc"}}, payload={}, ), routes=[ GatewayCommandRoute( name="tasks", scope_key="gateway.tasks", matcher=lambda text: str(text).startswith(".tasks"), handler=_tasks_handler, ) ], emit=lambda value: outputs.append(str(value)), ) self.assertTrue(handled) self.assertEqual(["ok"], outputs) event = GatewayCommandEvent.objects.order_by("-created_at").first() self.assertIsNotNone(event) self.assertEqual("ok", event.status if event else "") def test_gateway_scope_blocks_when_omemo_required_but_missing(self): CommandSecurityPolicy.objects.create( user=self.user, scope_key="gateway.tasks", enabled=True, require_omemo=True, ) outputs: list[str] = [] async def _tasks_handler(_ctx, emit): emit("unexpected") return True handled = async_to_sync(dispatch_gateway_command)( context=GatewayCommandContext( user=self.user, source_message=None, service="xmpp", channel_identifier="policy-user@zm.is", sender_identifier="policy-user@zm.is/phone", message_text=".tasks list", message_meta={"xmpp": {"omemo_status": "no_omemo"}}, payload={}, ), routes=[ GatewayCommandRoute( name="tasks", scope_key="gateway.tasks", matcher=lambda text: str(text).startswith(".tasks"), handler=_tasks_handler, ) ], emit=lambda value: outputs.append(str(value)), ) self.assertTrue(handled) self.assertTrue(outputs) self.assertIn("blocked by policy", outputs[0].lower()) event = GatewayCommandEvent.objects.order_by("-created_at").first() self.assertIsNotNone(event) self.assertEqual("blocked", event.status if event else "") def test_global_scope_override_can_force_scope_disabled(self): CommandSecurityPolicy.objects.create( user=self.user, scope_key="gateway.tasks", enabled=True, ) CommandSecurityPolicy.objects.create( user=self.user, scope_key="global.override", settings={"scope_enabled": "off"}, ) decision = evaluate_command_policy( user=self.user, scope_key="gateway.tasks", context=CommandSecurityContext( service="xmpp", channel_identifier="policy-user@zm.is", message_meta={}, payload={}, ), ) self.assertFalse(decision.allowed) self.assertEqual("policy_disabled", decision.code) def test_global_scope_override_allowed_services_applies_to_all_scopes(self): CommandSecurityPolicy.objects.create( user=self.user, scope_key="global.override", allowed_services=["xmpp"], ) decision = evaluate_command_policy( user=self.user, scope_key="tasks.commands", context=CommandSecurityContext( service="whatsapp", channel_identifier="12035550123", message_meta={}, payload={}, ), ) self.assertFalse(decision.allowed) self.assertEqual("service_not_allowed", decision.code)