229 lines
7.6 KiB
Python
229 lines
7.6 KiB
Python
from __future__ import annotations
|
|
|
|
from asgiref.sync import async_to_sync
|
|
from django.test import TestCase
|
|
|
|
from core.commands.base import CommandContext
|
|
from core.commands.engine import process_inbound_message
|
|
from core.gateway.commands import (
|
|
GatewayCommandContext,
|
|
GatewayCommandRoute,
|
|
dispatch_gateway_command,
|
|
)
|
|
from core.models import (
|
|
ChatSession,
|
|
CommandChannelBinding,
|
|
CommandProfile,
|
|
CommandSecurityPolicy,
|
|
GatewayCommandEvent,
|
|
Message,
|
|
Person,
|
|
PersonIdentifier,
|
|
User,
|
|
UserXmppOmemoState,
|
|
)
|
|
from core.security.command_policy import CommandSecurityContext, evaluate_command_policy
|
|
|
|
|
|
class CommandSecurityPolicyTests(TestCase):
|
|
def setUp(self):
|
|
self.user = User.objects.create_user(
|
|
username="policy-user",
|
|
email="policy-user@example.com",
|
|
password="x",
|
|
)
|
|
self.person = Person.objects.create(user=self.user, name="Policy Person")
|
|
self.identifier = PersonIdentifier.objects.create(
|
|
user=self.user,
|
|
person=self.person,
|
|
service="xmpp",
|
|
identifier="policy-user@zm.is",
|
|
)
|
|
self.session = ChatSession.objects.create(
|
|
user=self.user,
|
|
identifier=self.identifier,
|
|
)
|
|
|
|
def test_command_profile_scope_denies_disallowed_service(self):
|
|
profile = CommandProfile.objects.create(
|
|
user=self.user,
|
|
slug="bp",
|
|
name="Business Plan",
|
|
enabled=True,
|
|
trigger_token="#bp#",
|
|
reply_required=False,
|
|
exact_match_only=True,
|
|
)
|
|
CommandChannelBinding.objects.create(
|
|
profile=profile,
|
|
direction="ingress",
|
|
service="xmpp",
|
|
channel_identifier="policy-user@zm.is",
|
|
enabled=True,
|
|
)
|
|
CommandSecurityPolicy.objects.create(
|
|
user=self.user,
|
|
scope_key="command.bp",
|
|
enabled=True,
|
|
allowed_services=["whatsapp"],
|
|
)
|
|
msg = Message.objects.create(
|
|
user=self.user,
|
|
session=self.session,
|
|
sender_uuid="",
|
|
text="#bp#",
|
|
ts=1000,
|
|
source_service="xmpp",
|
|
source_chat_id="policy-user@zm.is",
|
|
message_meta={},
|
|
)
|
|
results = async_to_sync(process_inbound_message)(
|
|
CommandContext(
|
|
service="xmpp",
|
|
channel_identifier="policy-user@zm.is",
|
|
message_id=str(msg.id),
|
|
user_id=self.user.id,
|
|
message_text="#bp#",
|
|
payload={},
|
|
)
|
|
)
|
|
self.assertEqual(1, len(results))
|
|
self.assertEqual("skipped", results[0].status)
|
|
self.assertTrue(
|
|
str(results[0].error).startswith("policy_denied:service_not_allowed")
|
|
)
|
|
|
|
def test_gateway_scope_can_require_trusted_omemo_key(self):
|
|
CommandSecurityPolicy.objects.create(
|
|
user=self.user,
|
|
scope_key="gateway.tasks",
|
|
enabled=True,
|
|
require_omemo=True,
|
|
require_trusted_omemo_fingerprint=True,
|
|
)
|
|
UserXmppOmemoState.objects.create(
|
|
user=self.user,
|
|
status="detected",
|
|
latest_client_key="sid:abc",
|
|
last_sender_jid="policy-user@zm.is/phone",
|
|
last_target_jid="jews.zm.is",
|
|
)
|
|
outputs: list[str] = []
|
|
|
|
async def _tasks_handler(_ctx, emit):
|
|
emit("ok")
|
|
return True
|
|
|
|
handled = async_to_sync(dispatch_gateway_command)(
|
|
context=GatewayCommandContext(
|
|
user=self.user,
|
|
source_message=None,
|
|
service="xmpp",
|
|
channel_identifier="policy-user@zm.is",
|
|
sender_identifier="policy-user@zm.is/phone",
|
|
message_text=".tasks list",
|
|
message_meta={
|
|
"xmpp": {"omemo_status": "detected", "omemo_client_key": "sid:abc"}
|
|
},
|
|
payload={},
|
|
),
|
|
routes=[
|
|
GatewayCommandRoute(
|
|
name="tasks",
|
|
scope_key="gateway.tasks",
|
|
matcher=lambda text: str(text).startswith(".tasks"),
|
|
handler=_tasks_handler,
|
|
)
|
|
],
|
|
emit=lambda value: outputs.append(str(value)),
|
|
)
|
|
self.assertTrue(handled)
|
|
self.assertEqual(["ok"], outputs)
|
|
event = GatewayCommandEvent.objects.order_by("-created_at").first()
|
|
self.assertIsNotNone(event)
|
|
self.assertEqual("ok", event.status if event else "")
|
|
|
|
def test_gateway_scope_blocks_when_omemo_required_but_missing(self):
|
|
CommandSecurityPolicy.objects.create(
|
|
user=self.user,
|
|
scope_key="gateway.tasks",
|
|
enabled=True,
|
|
require_omemo=True,
|
|
)
|
|
outputs: list[str] = []
|
|
|
|
async def _tasks_handler(_ctx, emit):
|
|
emit("unexpected")
|
|
return True
|
|
|
|
handled = async_to_sync(dispatch_gateway_command)(
|
|
context=GatewayCommandContext(
|
|
user=self.user,
|
|
source_message=None,
|
|
service="xmpp",
|
|
channel_identifier="policy-user@zm.is",
|
|
sender_identifier="policy-user@zm.is/phone",
|
|
message_text=".tasks list",
|
|
message_meta={"xmpp": {"omemo_status": "no_omemo"}},
|
|
payload={},
|
|
),
|
|
routes=[
|
|
GatewayCommandRoute(
|
|
name="tasks",
|
|
scope_key="gateway.tasks",
|
|
matcher=lambda text: str(text).startswith(".tasks"),
|
|
handler=_tasks_handler,
|
|
)
|
|
],
|
|
emit=lambda value: outputs.append(str(value)),
|
|
)
|
|
self.assertTrue(handled)
|
|
self.assertTrue(outputs)
|
|
self.assertIn("blocked by policy", outputs[0].lower())
|
|
event = GatewayCommandEvent.objects.order_by("-created_at").first()
|
|
self.assertIsNotNone(event)
|
|
self.assertEqual("blocked", event.status if event else "")
|
|
|
|
def test_global_scope_override_can_force_scope_disabled(self):
|
|
CommandSecurityPolicy.objects.create(
|
|
user=self.user,
|
|
scope_key="gateway.tasks",
|
|
enabled=True,
|
|
)
|
|
CommandSecurityPolicy.objects.create(
|
|
user=self.user,
|
|
scope_key="global.override",
|
|
settings={"scope_enabled": "off"},
|
|
)
|
|
decision = evaluate_command_policy(
|
|
user=self.user,
|
|
scope_key="gateway.tasks",
|
|
context=CommandSecurityContext(
|
|
service="xmpp",
|
|
channel_identifier="policy-user@zm.is",
|
|
message_meta={},
|
|
payload={},
|
|
),
|
|
)
|
|
self.assertFalse(decision.allowed)
|
|
self.assertEqual("policy_disabled", decision.code)
|
|
|
|
def test_global_scope_override_allowed_services_applies_to_all_scopes(self):
|
|
CommandSecurityPolicy.objects.create(
|
|
user=self.user,
|
|
scope_key="global.override",
|
|
allowed_services=["xmpp"],
|
|
)
|
|
decision = evaluate_command_policy(
|
|
user=self.user,
|
|
scope_key="tasks.commands",
|
|
context=CommandSecurityContext(
|
|
service="whatsapp",
|
|
channel_identifier="12035550123",
|
|
message_meta={},
|
|
payload={},
|
|
),
|
|
)
|
|
self.assertFalse(decision.allowed)
|
|
self.assertEqual("service_not_allowed", decision.code)
|