Increase security and reformat
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from django.test import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.messaging.ai import run_prompt
|
||||
from core.models import AI, AIRunLog, User
|
||||
|
||||
@@ -6,9 +6,9 @@ from django.test import TestCase
|
||||
from core.models import (
|
||||
ChatSession,
|
||||
ContactAvailabilityEvent,
|
||||
Message,
|
||||
Person,
|
||||
PersonIdentifier,
|
||||
Message,
|
||||
User,
|
||||
)
|
||||
from core.presence.inference import now_ms
|
||||
@@ -16,7 +16,9 @@ from core.presence.inference import now_ms
|
||||
|
||||
class BackfillContactAvailabilityCommandTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("backfill-user", "backfill@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"backfill-user", "backfill@example.com", "x"
|
||||
)
|
||||
self.person = Person.objects.create(user=self.user, name="Backfill Person")
|
||||
self.identifier = PersonIdentifier.objects.create(
|
||||
user=self.user,
|
||||
@@ -24,7 +26,9 @@ class BackfillContactAvailabilityCommandTests(TestCase):
|
||||
service="signal",
|
||||
identifier="+15551234567",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
|
||||
def test_backfill_creates_message_and_read_receipt_availability_events(self):
|
||||
base_ts = now_ms()
|
||||
@@ -58,7 +62,9 @@ class BackfillContactAvailabilityCommandTests(TestCase):
|
||||
)
|
||||
|
||||
events = list(
|
||||
ContactAvailabilityEvent.objects.filter(user=self.user).order_by("ts", "source_kind")
|
||||
ContactAvailabilityEvent.objects.filter(user=self.user).order_by(
|
||||
"ts", "source_kind"
|
||||
)
|
||||
)
|
||||
self.assertEqual(3, len(events))
|
||||
self.assertTrue(any(row.source_kind == "message_in" for row in events))
|
||||
|
||||
@@ -123,7 +123,9 @@ class BPFallbackTests(TransactionTestCase):
|
||||
run = CommandRun.objects.get(trigger_message=trigger, profile=self.profile)
|
||||
self.assertEqual("failed", run.status)
|
||||
self.assertIn("bp_ai_failed", str(run.error))
|
||||
self.assertFalse(BusinessPlanDocument.objects.filter(trigger_message=trigger).exists())
|
||||
self.assertFalse(
|
||||
BusinessPlanDocument.objects.filter(trigger_message=trigger).exists()
|
||||
)
|
||||
|
||||
def test_bp_uses_same_ai_selection_order_as_compose(self):
|
||||
AI.objects.create(
|
||||
|
||||
@@ -35,7 +35,9 @@ class BPSubcommandTests(TransactionTestCase):
|
||||
service="whatsapp",
|
||||
identifier="120363402761690215",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
self.profile = CommandProfile.objects.create(
|
||||
user=self.user,
|
||||
slug="bp",
|
||||
@@ -96,13 +98,19 @@ class BPSubcommandTests(TransactionTestCase):
|
||||
source_service="whatsapp",
|
||||
source_chat_id="120363402761690215",
|
||||
)
|
||||
with patch("core.commands.handlers.bp.ai_runner.run_prompt", new=AsyncMock()) as mocked_ai:
|
||||
result = async_to_sync(BPCommandHandler().execute)(self._ctx(trigger, trigger.text))
|
||||
with patch(
|
||||
"core.commands.handlers.bp.ai_runner.run_prompt", new=AsyncMock()
|
||||
) as mocked_ai:
|
||||
result = async_to_sync(BPCommandHandler().execute)(
|
||||
self._ctx(trigger, trigger.text)
|
||||
)
|
||||
self.assertTrue(result.ok)
|
||||
mocked_ai.assert_not_awaited()
|
||||
doc = BusinessPlanDocument.objects.get(trigger_message=trigger)
|
||||
self.assertEqual("direct body", doc.content_markdown)
|
||||
self.assertEqual("Generated from 1 message.", doc.structured_payload.get("annotation"))
|
||||
self.assertEqual(
|
||||
"Generated from 1 message.", doc.structured_payload.get("annotation")
|
||||
)
|
||||
|
||||
def test_set_reply_only_uses_anchor(self):
|
||||
anchor = Message.objects.create(
|
||||
@@ -124,11 +132,15 @@ class BPSubcommandTests(TransactionTestCase):
|
||||
source_chat_id="120363402761690215",
|
||||
reply_to=anchor,
|
||||
)
|
||||
result = async_to_sync(BPCommandHandler().execute)(self._ctx(trigger, trigger.text))
|
||||
result = async_to_sync(BPCommandHandler().execute)(
|
||||
self._ctx(trigger, trigger.text)
|
||||
)
|
||||
self.assertTrue(result.ok)
|
||||
doc = BusinessPlanDocument.objects.get(trigger_message=trigger)
|
||||
self.assertEqual("anchor body", doc.content_markdown)
|
||||
self.assertEqual("Generated from 1 message.", doc.structured_payload.get("annotation"))
|
||||
self.assertEqual(
|
||||
"Generated from 1 message.", doc.structured_payload.get("annotation")
|
||||
)
|
||||
|
||||
def test_set_reply_plus_addendum_uses_divider(self):
|
||||
anchor = Message.objects.create(
|
||||
@@ -150,7 +162,9 @@ class BPSubcommandTests(TransactionTestCase):
|
||||
source_chat_id="120363402761690215",
|
||||
reply_to=anchor,
|
||||
)
|
||||
result = async_to_sync(BPCommandHandler().execute)(self._ctx(trigger, trigger.text))
|
||||
result = async_to_sync(BPCommandHandler().execute)(
|
||||
self._ctx(trigger, trigger.text)
|
||||
)
|
||||
self.assertTrue(result.ok)
|
||||
doc = BusinessPlanDocument.objects.get(trigger_message=trigger)
|
||||
self.assertIn("base body", doc.content_markdown)
|
||||
@@ -171,7 +185,9 @@ class BPSubcommandTests(TransactionTestCase):
|
||||
source_service="whatsapp",
|
||||
source_chat_id="120363402761690215",
|
||||
)
|
||||
result = async_to_sync(BPCommandHandler().execute)(self._ctx(trigger, trigger.text))
|
||||
result = async_to_sync(BPCommandHandler().execute)(
|
||||
self._ctx(trigger, trigger.text)
|
||||
)
|
||||
self.assertFalse(result.ok)
|
||||
self.assertEqual("failed", result.status)
|
||||
self.assertEqual("bp_set_range_requires_reply_target", result.error)
|
||||
@@ -205,8 +221,12 @@ class BPSubcommandTests(TransactionTestCase):
|
||||
source_chat_id="120363402761690215",
|
||||
reply_to=anchor,
|
||||
)
|
||||
result = async_to_sync(BPCommandHandler().execute)(self._ctx(trigger, trigger.text))
|
||||
result = async_to_sync(BPCommandHandler().execute)(
|
||||
self._ctx(trigger, trigger.text)
|
||||
)
|
||||
self.assertTrue(result.ok)
|
||||
doc = BusinessPlanDocument.objects.get(trigger_message=trigger)
|
||||
self.assertEqual("line 1\n(no text)\n#bp set range#", doc.content_markdown)
|
||||
self.assertEqual("Generated from 3 messages.", doc.structured_payload.get("annotation"))
|
||||
self.assertEqual(
|
||||
"Generated from 3 messages.", doc.structured_payload.get("annotation")
|
||||
)
|
||||
|
||||
@@ -55,7 +55,9 @@ class ClaudeCLITaskProviderTests(SimpleTestCase):
|
||||
@patch("core.tasks.providers.claude_cli.subprocess.run")
|
||||
def test_timeout_maps_to_failed_result(self, run_mock):
|
||||
run_mock.side_effect = TimeoutExpired(cmd=["claude"], timeout=10)
|
||||
result = self.provider.append_update({"command": "claude", "timeout_seconds": 10}, {"task_id": "t1"})
|
||||
result = self.provider.append_update(
|
||||
{"command": "claude", "timeout_seconds": 10}, {"task_id": "t1"}
|
||||
)
|
||||
self.assertFalse(result.ok)
|
||||
self.assertIn("timeout", result.error)
|
||||
|
||||
@@ -70,7 +72,9 @@ class ClaudeCLITaskProviderTests(SimpleTestCase):
|
||||
result = self.provider.append_update({"command": "claude"}, {"task_id": "t1"})
|
||||
self.assertTrue(result.ok)
|
||||
self.assertTrue(bool((result.payload or {}).get("requires_approval")))
|
||||
self.assertEqual("requires_approval", (result.payload or {}).get("parsed_status"))
|
||||
self.assertEqual(
|
||||
"requires_approval", (result.payload or {}).get("parsed_status")
|
||||
)
|
||||
|
||||
@patch("core.tasks.providers.claude_cli.subprocess.run")
|
||||
def test_retries_with_positional_op_when_flag_unsupported(self, run_mock):
|
||||
@@ -99,7 +103,9 @@ class ClaudeCLITaskProviderTests(SimpleTestCase):
|
||||
self.assertEqual(["claude", "task-sync", "create"], second[:3])
|
||||
|
||||
@patch("core.tasks.providers.claude_cli.subprocess.run")
|
||||
def test_falls_back_to_builtin_approval_stub_when_no_task_sync_contract(self, run_mock):
|
||||
def test_falls_back_to_builtin_approval_stub_when_no_task_sync_contract(
|
||||
self, run_mock
|
||||
):
|
||||
run_mock.side_effect = [
|
||||
CompletedProcess(
|
||||
args=[],
|
||||
@@ -124,8 +130,13 @@ class ClaudeCLITaskProviderTests(SimpleTestCase):
|
||||
)
|
||||
self.assertTrue(result.ok)
|
||||
self.assertTrue(bool((result.payload or {}).get("requires_approval")))
|
||||
self.assertEqual("requires_approval", str((result.payload or {}).get("status") or ""))
|
||||
self.assertEqual("builtin_task_sync_stub", str((result.payload or {}).get("fallback_mode") or ""))
|
||||
self.assertEqual(
|
||||
"requires_approval", str((result.payload or {}).get("status") or "")
|
||||
)
|
||||
self.assertEqual(
|
||||
"builtin_task_sync_stub",
|
||||
str((result.payload or {}).get("fallback_mode") or ""),
|
||||
)
|
||||
|
||||
@patch("core.tasks.providers.claude_cli.subprocess.run")
|
||||
def test_builtin_stub_approval_response_returns_ok(self, run_mock):
|
||||
|
||||
@@ -8,10 +8,10 @@ from core.commands.engine import process_inbound_message
|
||||
from core.commands.handlers.claude import parse_claude_command
|
||||
from core.models import (
|
||||
ChatSession,
|
||||
CommandChannelBinding,
|
||||
CommandProfile,
|
||||
CodexPermissionRequest,
|
||||
CodexRun,
|
||||
CommandChannelBinding,
|
||||
CommandProfile,
|
||||
DerivedTask,
|
||||
ExternalSyncEvent,
|
||||
Message,
|
||||
@@ -45,7 +45,9 @@ class ClaudeCommandParserTests(TestCase):
|
||||
|
||||
class ClaudeCommandExecutionTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("claude-cmd-user", "claude-cmd@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"claude-cmd-user", "claude-cmd@example.com", "x"
|
||||
)
|
||||
self.person = Person.objects.create(user=self.user, name="Claude Cmd")
|
||||
self.identifier = PersonIdentifier.objects.create(
|
||||
user=self.user,
|
||||
@@ -53,7 +55,9 @@ class ClaudeCommandExecutionTests(TestCase):
|
||||
service="web",
|
||||
identifier="web-chan-1",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
self.project = TaskProject.objects.create(user=self.user, name="Project A")
|
||||
self.task = DerivedTask.objects.create(
|
||||
user=self.user,
|
||||
@@ -202,7 +206,9 @@ class ClaudeCommandExecutionTests(TestCase):
|
||||
channel_identifier="approver-chan",
|
||||
enabled=True,
|
||||
)
|
||||
trigger = self._msg("#claude approve cl-ak-123#", source_chat_id="approver-chan")
|
||||
trigger = self._msg(
|
||||
"#claude approve cl-ak-123#", source_chat_id="approver-chan"
|
||||
)
|
||||
results = async_to_sync(process_inbound_message)(
|
||||
CommandContext(
|
||||
service="web",
|
||||
|
||||
@@ -55,7 +55,9 @@ class CodexCLITaskProviderTests(SimpleTestCase):
|
||||
@patch("core.tasks.providers.codex_cli.subprocess.run")
|
||||
def test_timeout_maps_to_failed_result(self, run_mock):
|
||||
run_mock.side_effect = TimeoutExpired(cmd=["codex"], timeout=10)
|
||||
result = self.provider.append_update({"command": "codex", "timeout_seconds": 10}, {"task_id": "t1"})
|
||||
result = self.provider.append_update(
|
||||
{"command": "codex", "timeout_seconds": 10}, {"task_id": "t1"}
|
||||
)
|
||||
self.assertFalse(result.ok)
|
||||
self.assertIn("timeout", result.error)
|
||||
|
||||
@@ -70,7 +72,9 @@ class CodexCLITaskProviderTests(SimpleTestCase):
|
||||
result = self.provider.append_update({"command": "codex"}, {"task_id": "t1"})
|
||||
self.assertTrue(result.ok)
|
||||
self.assertTrue(bool((result.payload or {}).get("requires_approval")))
|
||||
self.assertEqual("requires_approval", (result.payload or {}).get("parsed_status"))
|
||||
self.assertEqual(
|
||||
"requires_approval", (result.payload or {}).get("parsed_status")
|
||||
)
|
||||
|
||||
@patch("core.tasks.providers.codex_cli.subprocess.run")
|
||||
def test_retries_with_positional_op_when_flag_unsupported(self, run_mock):
|
||||
@@ -99,7 +103,9 @@ class CodexCLITaskProviderTests(SimpleTestCase):
|
||||
self.assertEqual(["codex", "task-sync", "create"], second[:3])
|
||||
|
||||
@patch("core.tasks.providers.codex_cli.subprocess.run")
|
||||
def test_falls_back_to_builtin_approval_stub_when_no_task_sync_contract(self, run_mock):
|
||||
def test_falls_back_to_builtin_approval_stub_when_no_task_sync_contract(
|
||||
self, run_mock
|
||||
):
|
||||
run_mock.side_effect = [
|
||||
CompletedProcess(
|
||||
args=[],
|
||||
@@ -124,8 +130,13 @@ class CodexCLITaskProviderTests(SimpleTestCase):
|
||||
)
|
||||
self.assertTrue(result.ok)
|
||||
self.assertTrue(bool((result.payload or {}).get("requires_approval")))
|
||||
self.assertEqual("requires_approval", str((result.payload or {}).get("status") or ""))
|
||||
self.assertEqual("builtin_task_sync_stub", str((result.payload or {}).get("fallback_mode") or ""))
|
||||
self.assertEqual(
|
||||
"requires_approval", str((result.payload or {}).get("status") or "")
|
||||
)
|
||||
self.assertEqual(
|
||||
"builtin_task_sync_stub",
|
||||
str((result.payload or {}).get("fallback_mode") or ""),
|
||||
)
|
||||
|
||||
@patch("core.tasks.providers.codex_cli.subprocess.run")
|
||||
def test_builtin_stub_approval_response_returns_ok(self, run_mock):
|
||||
|
||||
@@ -8,10 +8,10 @@ from core.commands.engine import process_inbound_message
|
||||
from core.commands.handlers.codex import parse_codex_command
|
||||
from core.models import (
|
||||
ChatSession,
|
||||
CommandChannelBinding,
|
||||
CommandProfile,
|
||||
CodexPermissionRequest,
|
||||
CodexRun,
|
||||
CommandChannelBinding,
|
||||
CommandProfile,
|
||||
DerivedTask,
|
||||
ExternalSyncEvent,
|
||||
Message,
|
||||
@@ -41,7 +41,9 @@ class CodexCommandParserTests(TestCase):
|
||||
|
||||
class CodexCommandExecutionTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("codex-cmd-user", "codex-cmd@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"codex-cmd-user", "codex-cmd@example.com", "x"
|
||||
)
|
||||
self.person = Person.objects.create(user=self.user, name="Codex Cmd")
|
||||
self.identifier = PersonIdentifier.objects.create(
|
||||
user=self.user,
|
||||
@@ -49,7 +51,9 @@ class CodexCommandExecutionTests(TestCase):
|
||||
service="web",
|
||||
identifier="web-chan-1",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
self.project = TaskProject.objects.create(user=self.user, name="Project A")
|
||||
self.task = DerivedTask.objects.create(
|
||||
user=self.user,
|
||||
@@ -126,7 +130,10 @@ class CodexCommandExecutionTests(TestCase):
|
||||
self.assertEqual("waiting_approval", run.status)
|
||||
event = ExternalSyncEvent.objects.order_by("-created_at").first()
|
||||
self.assertEqual("waiting_approval", event.status)
|
||||
self.assertEqual("default", str((event.payload or {}).get("provider_payload", {}).get("mode") or ""))
|
||||
self.assertEqual(
|
||||
"default",
|
||||
str((event.payload or {}).get("provider_payload", {}).get("mode") or ""),
|
||||
)
|
||||
self.assertTrue(
|
||||
CodexPermissionRequest.objects.filter(
|
||||
user=self.user,
|
||||
@@ -167,7 +174,10 @@ class CodexCommandExecutionTests(TestCase):
|
||||
source_service="web",
|
||||
source_channel="web-chan-1",
|
||||
status="waiting_approval",
|
||||
request_payload={"action": "append_update", "provider_payload": {"task_id": str(self.task.id)}},
|
||||
request_payload={
|
||||
"action": "append_update",
|
||||
"provider_payload": {"task_id": str(self.task.id)},
|
||||
},
|
||||
result_payload={},
|
||||
)
|
||||
req = CodexPermissionRequest.objects.create(
|
||||
@@ -207,7 +217,9 @@ class CodexCommandExecutionTests(TestCase):
|
||||
self.assertEqual("approved_waiting_resume", run.status)
|
||||
self.assertEqual("ok", waiting_event.status)
|
||||
self.assertTrue(
|
||||
ExternalSyncEvent.objects.filter(idempotency_key="codex_approval:ak-123:approved", status="pending").exists()
|
||||
ExternalSyncEvent.objects.filter(
|
||||
idempotency_key="codex_approval:ak-123:approved", status="pending"
|
||||
).exists()
|
||||
)
|
||||
|
||||
def test_approve_pre_submit_request_queues_original_action(self):
|
||||
@@ -226,7 +238,10 @@ class CodexCommandExecutionTests(TestCase):
|
||||
source_service="web",
|
||||
source_channel="web-chan-1",
|
||||
status="waiting_approval",
|
||||
request_payload={"action": "append_update", "provider_payload": {"task_id": str(self.task.id)}},
|
||||
request_payload={
|
||||
"action": "append_update",
|
||||
"provider_payload": {"task_id": str(self.task.id)},
|
||||
},
|
||||
result_payload={},
|
||||
)
|
||||
CodexPermissionRequest.objects.create(
|
||||
@@ -264,7 +279,11 @@ class CodexCommandExecutionTests(TestCase):
|
||||
)
|
||||
self.assertEqual(1, len(results))
|
||||
self.assertTrue(results[0].ok)
|
||||
resume = ExternalSyncEvent.objects.filter(idempotency_key="codex_cmd:resume:1").first()
|
||||
resume = ExternalSyncEvent.objects.filter(
|
||||
idempotency_key="codex_cmd:resume:1"
|
||||
).first()
|
||||
self.assertIsNotNone(resume)
|
||||
self.assertEqual("pending", resume.status)
|
||||
self.assertEqual("append_update", str((resume.payload or {}).get("action") or ""))
|
||||
self.assertEqual(
|
||||
"append_update", str((resume.payload or {}).get("action") or "")
|
||||
)
|
||||
|
||||
@@ -5,13 +5,22 @@ from unittest.mock import patch
|
||||
from django.test import TestCase
|
||||
|
||||
from core.management.commands.codex_worker import Command as CodexWorkerCommand
|
||||
from core.models import CodexPermissionRequest, CodexRun, ExternalSyncEvent, TaskProject, TaskProviderConfig, User
|
||||
from core.models import (
|
||||
CodexPermissionRequest,
|
||||
CodexRun,
|
||||
ExternalSyncEvent,
|
||||
TaskProject,
|
||||
TaskProviderConfig,
|
||||
User,
|
||||
)
|
||||
from core.tasks.providers.base import ProviderResult
|
||||
|
||||
|
||||
class CodexWorkerPhase1Tests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("codex-worker-user", "codex-worker@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"codex-worker-user", "codex-worker@example.com", "x"
|
||||
)
|
||||
self.project = TaskProject.objects.create(user=self.user, name="Worker Project")
|
||||
self.cfg = TaskProviderConfig.objects.create(
|
||||
user=self.user,
|
||||
@@ -57,7 +66,9 @@ class CodexWorkerPhase1Tests(TestCase):
|
||||
run_in_worker = True
|
||||
|
||||
def append_update(self, config, payload):
|
||||
return ProviderResult(ok=True, payload={"status": "ok", "summary": "done"})
|
||||
return ProviderResult(
|
||||
ok=True, payload={"status": "ok", "summary": "done"}
|
||||
)
|
||||
|
||||
create_task = mark_complete = link_task = append_update
|
||||
|
||||
@@ -71,7 +82,9 @@ class CodexWorkerPhase1Tests(TestCase):
|
||||
self.assertEqual("done", str(run.result_payload.get("summary") or ""))
|
||||
|
||||
@patch("core.management.commands.codex_worker.get_provider")
|
||||
def test_requires_approval_moves_to_waiting_and_creates_permission_request(self, get_provider_mock):
|
||||
def test_requires_approval_moves_to_waiting_and_creates_permission_request(
|
||||
self, get_provider_mock
|
||||
):
|
||||
run = CodexRun.objects.create(
|
||||
user=self.user,
|
||||
project=self.project,
|
||||
@@ -128,7 +141,10 @@ class CodexWorkerPhase1Tests(TestCase):
|
||||
user=self.user,
|
||||
provider="codex_cli",
|
||||
status="waiting_approval",
|
||||
payload={"action": "append_update", "provider_payload": {"mode": "default"}},
|
||||
payload={
|
||||
"action": "append_update",
|
||||
"provider_payload": {"mode": "default"},
|
||||
},
|
||||
error="",
|
||||
)
|
||||
run = CodexRun.objects.create(
|
||||
@@ -169,7 +185,9 @@ class CodexWorkerPhase1Tests(TestCase):
|
||||
run_in_worker = True
|
||||
|
||||
def append_update(self, config, payload):
|
||||
return ProviderResult(ok=True, payload={"status": "ok", "summary": "resumed"})
|
||||
return ProviderResult(
|
||||
ok=True, payload={"status": "ok", "summary": "resumed"}
|
||||
)
|
||||
|
||||
create_task = mark_complete = link_task = append_update
|
||||
|
||||
|
||||
@@ -89,7 +89,9 @@ class CommandSecurityPolicyTests(TestCase):
|
||||
)
|
||||
self.assertEqual(1, len(results))
|
||||
self.assertEqual("skipped", results[0].status)
|
||||
self.assertTrue(str(results[0].error).startswith("policy_denied:service_not_allowed"))
|
||||
self.assertTrue(
|
||||
str(results[0].error).startswith("policy_denied:service_not_allowed")
|
||||
)
|
||||
|
||||
def test_gateway_scope_can_require_trusted_omemo_key(self):
|
||||
CommandSecurityPolicy.objects.create(
|
||||
@@ -120,7 +122,9 @@ class CommandSecurityPolicyTests(TestCase):
|
||||
channel_identifier="policy-user@zm.is",
|
||||
sender_identifier="policy-user@zm.is/phone",
|
||||
message_text=".tasks list",
|
||||
message_meta={"xmpp": {"omemo_status": "detected", "omemo_client_key": "sid:abc"}},
|
||||
message_meta={
|
||||
"xmpp": {"omemo_status": "detected", "omemo_client_key": "sid:abc"}
|
||||
},
|
||||
payload={},
|
||||
),
|
||||
routes=[
|
||||
|
||||
@@ -9,8 +9,8 @@ from core.commands.base import CommandContext
|
||||
from core.commands.handlers.bp import BPCommandHandler
|
||||
from core.commands.policies import ensure_variant_policies_for_profile
|
||||
from core.models import (
|
||||
BusinessPlanDocument,
|
||||
AI,
|
||||
BusinessPlanDocument,
|
||||
ChatSession,
|
||||
CommandAction,
|
||||
CommandChannelBinding,
|
||||
@@ -37,7 +37,9 @@ class CommandVariantPolicyTests(TransactionTestCase):
|
||||
service="whatsapp",
|
||||
identifier="120363402761690215",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
self.profile = CommandProfile.objects.create(
|
||||
user=self.user,
|
||||
slug="bp",
|
||||
@@ -109,7 +111,9 @@ class CommandVariantPolicyTests(TransactionTestCase):
|
||||
|
||||
def test_bp_primary_can_run_in_verbatim_mode_without_ai(self):
|
||||
ensure_variant_policies_for_profile(self.profile)
|
||||
policy = CommandVariantPolicy.objects.get(profile=self.profile, variant_key="bp")
|
||||
policy = CommandVariantPolicy.objects.get(
|
||||
profile=self.profile, variant_key="bp"
|
||||
)
|
||||
policy.generation_mode = "verbatim"
|
||||
policy.send_plan_to_egress = False
|
||||
policy.send_status_to_source = False
|
||||
@@ -143,7 +147,9 @@ class CommandVariantPolicyTests(TransactionTestCase):
|
||||
|
||||
def test_bp_set_ai_mode_ignores_template(self):
|
||||
ensure_variant_policies_for_profile(self.profile)
|
||||
policy = CommandVariantPolicy.objects.get(profile=self.profile, variant_key="bp_set")
|
||||
policy = CommandVariantPolicy.objects.get(
|
||||
profile=self.profile, variant_key="bp_set"
|
||||
)
|
||||
policy.generation_mode = "ai"
|
||||
policy.send_plan_to_egress = False
|
||||
policy.send_status_to_source = False
|
||||
@@ -222,4 +228,6 @@ class CommandVariantPolicyTests(TransactionTestCase):
|
||||
self.assertTrue(result.ok)
|
||||
source_status.assert_awaited()
|
||||
self.assertEqual(1, binding_send.await_count)
|
||||
self.assertFalse(BusinessPlanDocument.objects.filter(trigger_message=trigger).exists())
|
||||
self.assertFalse(
|
||||
BusinessPlanDocument.objects.filter(trigger_message=trigger).exists()
|
||||
)
|
||||
|
||||
@@ -13,7 +13,9 @@ class ComposeReactTests(TestCase):
|
||||
self.user = User.objects.create_user("compose-react", "react@example.com", "pw")
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def _build_message(self, *, service: str, identifier: str, source_message_id: str = ""):
|
||||
def _build_message(
|
||||
self, *, service: str, identifier: str, source_message_id: str = ""
|
||||
):
|
||||
person = Person.objects.create(user=self.user, name=f"{service} person")
|
||||
person_identifier = PersonIdentifier.objects.create(
|
||||
user=self.user,
|
||||
|
||||
@@ -6,6 +6,7 @@ Signal coverage is in test_signal_reply_send.py. This file fills the gaps
|
||||
for WhatsApp and XMPP, and verifies the shared reply_sync infrastructure
|
||||
works correctly for both services.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
@@ -25,11 +26,11 @@ from core.messaging import history, reply_sync
|
||||
from core.models import ChatSession, Message, Person, PersonIdentifier, User
|
||||
from core.presence.inference import now_ms
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fake_stanza(xml_text: str) -> SimpleNamespace:
|
||||
"""Minimal stanza-like object with an .xml attribute."""
|
||||
return SimpleNamespace(xml=ET.fromstring(xml_text))
|
||||
@@ -39,6 +40,7 @@ def _fake_stanza(xml_text: str) -> SimpleNamespace:
|
||||
# WhatsApp — reply extraction (pure, no DB)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WhatsAppReplyExtractionTests(SimpleTestCase):
|
||||
def test_extract_reply_ref_from_contextinfo_stanza_id(self):
|
||||
payload = {
|
||||
@@ -87,6 +89,7 @@ class WhatsAppReplyExtractionTests(SimpleTestCase):
|
||||
# WhatsApp — reply resolution (requires DB)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WhatsAppReplyResolutionTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(
|
||||
@@ -178,7 +181,9 @@ class WhatsAppReplyResolutionTests(TestCase):
|
||||
)
|
||||
self.anchor.refresh_from_db()
|
||||
reactions = list((self.anchor.receipt_payload or {}).get("reactions") or [])
|
||||
removed = [r for r in reactions if r.get("emoji") == "👍" and not r.get("removed")]
|
||||
removed = [
|
||||
r for r in reactions if r.get("emoji") == "👍" and not r.get("removed")
|
||||
]
|
||||
self.assertEqual(0, len(removed))
|
||||
|
||||
|
||||
@@ -186,6 +191,7 @@ class WhatsAppReplyResolutionTests(TestCase):
|
||||
# WhatsApp — outbound reply metadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WhatsAppOutboundReplyTests(TestCase):
|
||||
def test_transport_passes_reply_metadata_to_whatsapp_api(self):
|
||||
mock_client = MagicMock()
|
||||
@@ -222,6 +228,7 @@ class WhatsAppOutboundReplyTests(TestCase):
|
||||
# XMPP — reaction extraction (pure, no DB)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class XMPPReactionExtractionTests(SimpleTestCase):
|
||||
def test_extract_xep_0444_reaction(self):
|
||||
stanza = _fake_stanza(
|
||||
@@ -276,6 +283,7 @@ class XMPPReactionExtractionTests(SimpleTestCase):
|
||||
# XMPP — reply extraction (pure, no DB)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class XMPPReplyExtractionTests(SimpleTestCase):
|
||||
def test_extract_reply_target_id_from_xep_0461_stanza(self):
|
||||
stanza = _fake_stanza(
|
||||
@@ -304,7 +312,9 @@ class XMPPReplyExtractionTests(SimpleTestCase):
|
||||
self.assertEqual("user@zm.is/mobile", ref.get("reply_source_chat_id"))
|
||||
|
||||
def test_extract_reply_ref_returns_empty_for_missing_id(self):
|
||||
ref = reply_sync.extract_reply_ref("xmpp", {"reply_source_chat_id": "user@zm.is"})
|
||||
ref = reply_sync.extract_reply_ref(
|
||||
"xmpp", {"reply_source_chat_id": "user@zm.is"}
|
||||
)
|
||||
self.assertEqual({}, ref)
|
||||
|
||||
|
||||
@@ -312,6 +322,7 @@ class XMPPReplyExtractionTests(SimpleTestCase):
|
||||
# XMPP — reply resolution (requires DB)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class XMPPReplyResolutionTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from io import StringIO
|
||||
import time
|
||||
from io import StringIO
|
||||
|
||||
from django.core.management import call_command
|
||||
from django.test import TestCase, override_settings
|
||||
@@ -24,7 +24,9 @@ class EventProjectionShadowTests(TestCase):
|
||||
service="signal",
|
||||
identifier="+15555550333",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
|
||||
def test_shadow_compare_has_zero_mismatch_when_projection_matches(self):
|
||||
message = Message.objects.create(
|
||||
|
||||
@@ -7,13 +7,13 @@ from django.test import TestCase, override_settings
|
||||
from core.mcp.tools import execute_tool, tool_specs
|
||||
from core.models import (
|
||||
AIRequest,
|
||||
DerivedTask,
|
||||
DerivedTaskEvent,
|
||||
MCPToolAuditLog,
|
||||
MemoryItem,
|
||||
TaskProject,
|
||||
User,
|
||||
WorkspaceConversation,
|
||||
DerivedTask,
|
||||
DerivedTaskEvent,
|
||||
)
|
||||
|
||||
|
||||
@@ -80,9 +80,13 @@ class MCPToolTests(TestCase):
|
||||
first_hit = (memory_payload.get("hits") or [{}])[0]
|
||||
self.assertEqual(str(self.memory.id), str(first_hit.get("memory_id")))
|
||||
|
||||
list_payload = execute_tool("tasks.list", {"user_id": self.user.id, "limit": 10})
|
||||
list_payload = execute_tool(
|
||||
"tasks.list", {"user_id": self.user.id, "limit": 10}
|
||||
)
|
||||
self.assertEqual(1, int(list_payload.get("count") or 0))
|
||||
self.assertEqual(str(self.task.id), str((list_payload.get("items") or [{}])[0].get("id")))
|
||||
self.assertEqual(
|
||||
str(self.task.id), str((list_payload.get("items") or [{}])[0].get("id"))
|
||||
)
|
||||
|
||||
search_payload = execute_tool(
|
||||
"tasks.search",
|
||||
@@ -90,9 +94,13 @@ class MCPToolTests(TestCase):
|
||||
)
|
||||
self.assertEqual(1, int(search_payload.get("count") or 0))
|
||||
|
||||
events_payload = execute_tool("tasks.events", {"task_id": str(self.task.id), "limit": 5})
|
||||
events_payload = execute_tool(
|
||||
"tasks.events", {"task_id": str(self.task.id), "limit": 5}
|
||||
)
|
||||
self.assertEqual(1, int(events_payload.get("count") or 0))
|
||||
self.assertEqual("created", str((events_payload.get("items") or [{}])[0].get("event_type")))
|
||||
self.assertEqual(
|
||||
"created", str((events_payload.get("items") or [{}])[0].get("event_type"))
|
||||
)
|
||||
|
||||
def test_memory_proposal_review_flow(self):
|
||||
propose_payload = execute_tool(
|
||||
@@ -182,7 +190,9 @@ class MCPToolTests(TestCase):
|
||||
"note": "Implemented wiki tooling.",
|
||||
},
|
||||
)
|
||||
self.assertEqual("progress", str((note_payload.get("event") or {}).get("event_type")))
|
||||
self.assertEqual(
|
||||
"progress", str((note_payload.get("event") or {}).get("event_type"))
|
||||
)
|
||||
|
||||
artifact_payload = execute_tool(
|
||||
"tasks.link_artifact",
|
||||
|
||||
@@ -7,7 +7,13 @@ from django.core.management import call_command
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
from core.models import MemoryChangeRequest, MemoryItem, MessageEvent, User, WorkspaceConversation
|
||||
from core.models import (
|
||||
MemoryChangeRequest,
|
||||
MemoryItem,
|
||||
MessageEvent,
|
||||
User,
|
||||
WorkspaceConversation,
|
||||
)
|
||||
|
||||
|
||||
class MemoryPipelineCommandTests(TestCase):
|
||||
@@ -46,7 +52,9 @@ class MemoryPipelineCommandTests(TestCase):
|
||||
self.assertIn("memory-suggest-from-messages", rendered)
|
||||
self.assertGreaterEqual(MemoryItem.objects.filter(user=self.user).count(), 1)
|
||||
self.assertGreaterEqual(
|
||||
MemoryChangeRequest.objects.filter(user=self.user, status="pending").count(),
|
||||
MemoryChangeRequest.objects.filter(
|
||||
user=self.user, status="pending"
|
||||
).count(),
|
||||
1,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,10 +6,9 @@ from django.test import TestCase
|
||||
from core.commands.base import CommandContext
|
||||
from core.commands.engine import _matches_trigger, process_inbound_message
|
||||
from core.messaging.reply_sync import extract_reply_ref, resolve_reply_target
|
||||
from core.views.compose import _command_options_for_channel
|
||||
from core.models import (
|
||||
ChatTaskSource,
|
||||
ChatSession,
|
||||
ChatTaskSource,
|
||||
CommandAction,
|
||||
CommandChannelBinding,
|
||||
CommandProfile,
|
||||
@@ -19,6 +18,7 @@ from core.models import (
|
||||
PersonIdentifier,
|
||||
User,
|
||||
)
|
||||
from core.views.compose import _command_options_for_channel
|
||||
|
||||
|
||||
class Phase1ReplyResolutionTests(TestCase):
|
||||
@@ -402,7 +402,9 @@ class Phase1CommandEngineTests(TestCase):
|
||||
if profile is None:
|
||||
return
|
||||
self.assertEqual(3, CommandAction.objects.filter(profile=profile).count())
|
||||
self.assertEqual(3, CommandVariantPolicy.objects.filter(profile=profile).count())
|
||||
self.assertEqual(
|
||||
3, CommandVariantPolicy.objects.filter(profile=profile).count()
|
||||
)
|
||||
self.assertEqual(
|
||||
2,
|
||||
CommandChannelBinding.objects.filter(
|
||||
@@ -436,7 +438,9 @@ class Phase1CommandEngineTests(TestCase):
|
||||
self.assertEqual(1, len(second_results))
|
||||
self.assertEqual("reply_required", second_results[0].error)
|
||||
self.assertEqual(3, CommandAction.objects.filter(profile=profile).count())
|
||||
self.assertEqual(3, CommandVariantPolicy.objects.filter(profile=profile).count())
|
||||
self.assertEqual(
|
||||
3, CommandVariantPolicy.objects.filter(profile=profile).count()
|
||||
)
|
||||
self.assertEqual(
|
||||
2,
|
||||
CommandChannelBinding.objects.filter(
|
||||
|
||||
@@ -21,7 +21,9 @@ from core.presence.inference import now_ms
|
||||
|
||||
class PresenceEngineTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("presence-user", "presence@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"presence-user", "presence@example.com", "x"
|
||||
)
|
||||
self.person = Person.objects.create(user=self.user, name="Presence Person")
|
||||
self.identifier = PersonIdentifier.objects.create(
|
||||
user=self.user,
|
||||
@@ -57,7 +59,9 @@ class PresenceEngineTests(TestCase):
|
||||
)
|
||||
)
|
||||
self.assertIsNotNone(event)
|
||||
self.assertEqual(1, ContactAvailabilityEvent.objects.filter(user=self.user).count())
|
||||
self.assertEqual(
|
||||
1, ContactAvailabilityEvent.objects.filter(user=self.user).count()
|
||||
)
|
||||
self.assertEqual("available", event.availability_state)
|
||||
|
||||
def test_inactivity_transitions_to_fading(self):
|
||||
@@ -106,7 +110,9 @@ class PresenceEngineTests(TestCase):
|
||||
at_ts=base_ts + 60_000,
|
||||
)
|
||||
self.assertIsNone(fade_event)
|
||||
self.assertEqual(1, ContactAvailabilityEvent.objects.filter(user=self.user).count())
|
||||
self.assertEqual(
|
||||
1, ContactAvailabilityEvent.objects.filter(user=self.user).count()
|
||||
)
|
||||
|
||||
def test_adjacent_same_state_events_extend_single_span(self):
|
||||
ts0 = now_ms()
|
||||
@@ -134,7 +140,9 @@ class PresenceEngineTests(TestCase):
|
||||
ts=ts0 + 5_000,
|
||||
)
|
||||
)
|
||||
spans = list(ContactAvailabilitySpan.objects.filter(user=self.user).order_by("start_ts"))
|
||||
spans = list(
|
||||
ContactAvailabilitySpan.objects.filter(user=self.user).order_by("start_ts")
|
||||
)
|
||||
self.assertEqual(1, len(spans))
|
||||
self.assertEqual(ts0, spans[0].start_ts)
|
||||
self.assertEqual(ts0 + 5_000, spans[0].end_ts)
|
||||
|
||||
@@ -62,12 +62,21 @@ class ReactionNormalizationTests(TestCase):
|
||||
self.assertEqual(str(exact_message.id), str(updated.id))
|
||||
exact_message.refresh_from_db()
|
||||
near_message.refresh_from_db()
|
||||
self.assertEqual(1, len((exact_message.receipt_payload or {}).get("reactions") or []))
|
||||
self.assertEqual(
|
||||
1, len((exact_message.receipt_payload or {}).get("reactions") or [])
|
||||
)
|
||||
self.assertEqual(
|
||||
"exact_source_message_id_ts",
|
||||
str((exact_message.receipt_payload or {}).get("reaction_last_match_strategy") or ""),
|
||||
str(
|
||||
(exact_message.receipt_payload or {}).get(
|
||||
"reaction_last_match_strategy"
|
||||
)
|
||||
or ""
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
0, len((near_message.receipt_payload or {}).get("reactions") or [])
|
||||
)
|
||||
self.assertEqual(0, len((near_message.receipt_payload or {}).get("reactions") or []))
|
||||
|
||||
def test_remove_without_emoji_is_audited_not_active(self):
|
||||
message = Message.objects.create(
|
||||
|
||||
@@ -28,7 +28,9 @@ class ReconcileWorkspaceMetricHistoryCommandTests(TestCase):
|
||||
service="whatsapp",
|
||||
identifier="15551230000@s.whatsapp.net",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
base_ts = 1_700_000_000_000
|
||||
for idx in range(10):
|
||||
inbound = idx % 2 == 0
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.urls import reverse
|
||||
from django.test import TestCase
|
||||
from django.urls import reverse
|
||||
|
||||
from core.models import User
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from django.conf import settings
|
||||
@@ -175,11 +174,15 @@ class SignalInboundReplyLinkTests(TransactionTestCase):
|
||||
}
|
||||
async_to_sync(client._process_raw_inbound_event)(json.dumps(payload))
|
||||
|
||||
created = Message.objects.filter(
|
||||
user=self.user,
|
||||
session=self.session,
|
||||
text="reply inbound s3",
|
||||
).order_by("-ts").first()
|
||||
created = (
|
||||
Message.objects.filter(
|
||||
user=self.user,
|
||||
session=self.session,
|
||||
text="reply inbound s3",
|
||||
)
|
||||
.order_by("-ts")
|
||||
.first()
|
||||
)
|
||||
self.assertIsNotNone(created)
|
||||
self.assertEqual(self.anchor.id, created.reply_to_id)
|
||||
self.assertEqual("1772545458187", created.reply_source_message_id)
|
||||
@@ -222,7 +225,9 @@ class SignalInboundReplyLinkTests(TransactionTestCase):
|
||||
"Expected Signal heart reaction to be applied to anchor receipt payload.",
|
||||
)
|
||||
|
||||
def test_process_raw_inbound_event_applies_sync_reaction_using_destination_fallback(self):
|
||||
def test_process_raw_inbound_event_applies_sync_reaction_using_destination_fallback(
|
||||
self,
|
||||
):
|
||||
fake_ur = Mock()
|
||||
fake_ur.message_received = AsyncMock(return_value=None)
|
||||
fake_ur.xmpp = Mock()
|
||||
@@ -253,7 +258,7 @@ class SignalInboundReplyLinkTests(TransactionTestCase):
|
||||
"emoji": "🔥",
|
||||
"targetSentTimestamp": 1772545458187,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -352,7 +357,9 @@ class SignalRuntimeCommandWritebackTests(TestCase):
|
||||
service="signal",
|
||||
identifier="+15550003000",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
self.message = Message.objects.create(
|
||||
user=self.user,
|
||||
session=self.session,
|
||||
|
||||
@@ -5,7 +5,6 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from django.test import TestCase, override_settings
|
||||
from django.utils import timezone
|
||||
|
||||
from core.models import (
|
||||
ChatSession,
|
||||
@@ -88,7 +87,9 @@ class TaskEnginePlan09Tests(TestCase):
|
||||
service="signal",
|
||||
identifier="+15559001234",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
self.project = TaskProject.objects.create(user=self.user, name="Plan09 Project")
|
||||
ChatTaskSource.objects.create(
|
||||
user=self.user,
|
||||
@@ -133,7 +134,9 @@ class TaskEnginePlan09Tests(TestCase):
|
||||
async_to_sync(process_inbound_task_intelligence)(seed)
|
||||
cmd = self._msg(".task list", ts=1002)
|
||||
async_to_sync(process_inbound_task_intelligence)(cmd)
|
||||
payloads = [str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list]
|
||||
payloads = [
|
||||
str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list
|
||||
]
|
||||
self.assertTrue(any("open tasks" in row.lower() for row in payloads))
|
||||
|
||||
@patch("core.tasks.engine.send_message_raw", new_callable=AsyncMock)
|
||||
@@ -143,7 +146,9 @@ class TaskEnginePlan09Tests(TestCase):
|
||||
task = DerivedTask.objects.get(origin_message=seed)
|
||||
cmd = self._msg(f".task show #{task.reference_code}", ts=1004)
|
||||
async_to_sync(process_inbound_task_intelligence)(cmd)
|
||||
payloads = [str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list]
|
||||
payloads = [
|
||||
str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list
|
||||
]
|
||||
self.assertTrue(any("deploy new version" in row.lower() for row in payloads))
|
||||
self.assertTrue(any(str(task.reference_code) in row for row in payloads))
|
||||
|
||||
@@ -157,9 +162,13 @@ class TaskEnginePlan09Tests(TestCase):
|
||||
task.refresh_from_db()
|
||||
self.assertEqual("completed", task.status_snapshot)
|
||||
self.assertTrue(
|
||||
DerivedTaskEvent.objects.filter(task=task, event_type="completion_marked").exists()
|
||||
DerivedTaskEvent.objects.filter(
|
||||
task=task, event_type="completion_marked"
|
||||
).exists()
|
||||
)
|
||||
payloads = [str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list]
|
||||
payloads = [
|
||||
str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list
|
||||
]
|
||||
self.assertTrue(any("completed" in row.lower() for row in payloads))
|
||||
|
||||
def test_dot_task_complete_creates_audit_event(self):
|
||||
@@ -169,7 +178,9 @@ class TaskEnginePlan09Tests(TestCase):
|
||||
with patch("core.tasks.engine.send_message_raw", new_callable=AsyncMock):
|
||||
cmd = self._msg(f".task complete #{task.reference_code}", ts=1008)
|
||||
async_to_sync(process_inbound_task_intelligence)(cmd)
|
||||
event = DerivedTaskEvent.objects.filter(task=task, event_type="completion_marked").first()
|
||||
event = DerivedTaskEvent.objects.filter(
|
||||
task=task, event_type="completion_marked"
|
||||
).first()
|
||||
self.assertIsNotNone(event)
|
||||
self.assertIn("command", str(event.payload or {}).lower())
|
||||
|
||||
@@ -185,7 +196,9 @@ class TaskEngineMemoryContextTests(TestCase):
|
||||
service="whatsapp",
|
||||
identifier="447700900001@s.whatsapp.net",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
self.project = TaskProject.objects.create(user=self.user, name="Mem Project")
|
||||
ChatTaskSource.objects.create(
|
||||
user=self.user,
|
||||
@@ -218,8 +231,16 @@ class TaskEngineMemoryContextTests(TestCase):
|
||||
from core.models import CodexRun
|
||||
|
||||
m = self._msg("task: fix authentication bug", ts=2001)
|
||||
fake_memory = [{"id": "mem-1", "memory_kind": "fact", "content": {"text": "prefers short summaries"}}]
|
||||
with patch("core.tasks.engine.retrieve_memories_for_prompt", return_value=fake_memory):
|
||||
fake_memory = [
|
||||
{
|
||||
"id": "mem-1",
|
||||
"memory_kind": "fact",
|
||||
"content": {"text": "prefers short summaries"},
|
||||
}
|
||||
]
|
||||
with patch(
|
||||
"core.tasks.engine.retrieve_memories_for_prompt", return_value=fake_memory
|
||||
):
|
||||
async_to_sync(process_inbound_task_intelligence)(m)
|
||||
task = DerivedTask.objects.filter(origin_message=m).first()
|
||||
self.assertIsNotNone(task)
|
||||
@@ -227,5 +248,7 @@ class TaskEngineMemoryContextTests(TestCase):
|
||||
self.assertIsNotNone(run, "Expected CodexRun created for task")
|
||||
provider_payload = (run.request_payload or {}).get("provider_payload") or {}
|
||||
memory_context = provider_payload.get("memory_context")
|
||||
self.assertIsNotNone(memory_context, "Expected memory_context in CodexRun provider payload")
|
||||
self.assertIsNotNone(
|
||||
memory_context, "Expected memory_context in CodexRun provider payload"
|
||||
)
|
||||
self.assertEqual(1, len(memory_context))
|
||||
|
||||
@@ -48,7 +48,9 @@ class TasksPagesManagementTests(TestCase):
|
||||
self.assertEqual(200, response.status_code)
|
||||
project = TaskProject.objects.get(user=self.user, name="Ops")
|
||||
self.assertIsNotNone(project)
|
||||
self.assertFalse(ChatTaskSource.objects.filter(user=self.user, project=project).exists())
|
||||
self.assertFalse(
|
||||
ChatTaskSource.objects.filter(user=self.user, project=project).exists()
|
||||
)
|
||||
|
||||
def test_tasks_hub_can_map_identifier_to_selected_project(self):
|
||||
project = TaskProject.objects.create(user=self.user, name="Mapped")
|
||||
@@ -108,7 +110,9 @@ class TasksPagesManagementTests(TestCase):
|
||||
follow=True,
|
||||
)
|
||||
self.assertEqual(200, delete_response.status_code)
|
||||
self.assertFalse(TaskEpic.objects.filter(project=project, name="Phase 1").exists())
|
||||
self.assertFalse(
|
||||
TaskEpic.objects.filter(project=project, name="Phase 1").exists()
|
||||
)
|
||||
|
||||
def test_project_page_can_assign_and_clear_task_epic(self):
|
||||
project = TaskProject.objects.create(user=self.user, name="Roadmap")
|
||||
@@ -179,9 +183,13 @@ class TasksPagesManagementTests(TestCase):
|
||||
follow=True,
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertTrue(TaskEpic.objects.filter(project=project, name="Phase 2").exists())
|
||||
self.assertTrue(
|
||||
TaskEpic.objects.filter(project=project, name="Phase 2").exists()
|
||||
)
|
||||
self.assertTrue(mocked_send.await_count >= 1)
|
||||
payloads = [str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list]
|
||||
payloads = [
|
||||
str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list
|
||||
]
|
||||
self.assertTrue(any("whatsapp usage" in row.lower() for row in payloads))
|
||||
self.assertTrue(any("add task to epic" in row.lower() for row in payloads))
|
||||
|
||||
@@ -266,7 +274,9 @@ class TasksPagesManagementTests(TestCase):
|
||||
reference_code="2",
|
||||
status_snapshot="open",
|
||||
)
|
||||
response = self.client.get(reverse("tasks_project", kwargs={"project_id": str(project.id)}))
|
||||
response = self.client.get(
|
||||
reverse("tasks_project", kwargs={"project_id": str(project.id)})
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertContains(
|
||||
response,
|
||||
@@ -302,7 +312,9 @@ class TasksPagesManagementTests(TestCase):
|
||||
payload={"source": "signal", "emoji": "❤️", "reason": "heart_reaction"},
|
||||
)
|
||||
|
||||
response = self.client.get(reverse("tasks_task", kwargs={"task_id": str(task.id)}))
|
||||
response = self.client.get(
|
||||
reverse("tasks_task", kwargs={"task_id": str(task.id)})
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertContains(response, "View payload JSON")
|
||||
self.assertContains(response, "<strong>source</strong>: signal", html=True)
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from django.urls import reverse
|
||||
from django.test import TestCase, override_settings
|
||||
from django.urls import reverse
|
||||
|
||||
from core.models import (
|
||||
ChatSession,
|
||||
@@ -12,24 +12,32 @@ from core.models import (
|
||||
CodexPermissionRequest,
|
||||
CodexRun,
|
||||
DerivedTask,
|
||||
ExternalSyncEvent,
|
||||
ExternalChatLink,
|
||||
ExternalSyncEvent,
|
||||
Message,
|
||||
Person,
|
||||
PersonIdentifier,
|
||||
TaskCompletionPattern,
|
||||
TaskProviderConfig,
|
||||
TaskProject,
|
||||
TaskProviderConfig,
|
||||
User,
|
||||
)
|
||||
from core.tasks.engine import process_inbound_task_intelligence
|
||||
from core.views.compose import _command_options_for_channel, _toggle_task_announce_for_channel
|
||||
from core.views.tasks import _apply_safe_defaults_for_user, _ensure_default_completion_patterns
|
||||
from core.views.compose import (
|
||||
_command_options_for_channel,
|
||||
_toggle_task_announce_for_channel,
|
||||
)
|
||||
from core.views.tasks import (
|
||||
_apply_safe_defaults_for_user,
|
||||
_ensure_default_completion_patterns,
|
||||
)
|
||||
|
||||
|
||||
class TaskSettingsBackfillTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("defaults-user", "defaults@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"defaults-user", "defaults@example.com", "x"
|
||||
)
|
||||
self.person = Person.objects.create(user=self.user, name="Defaults Person")
|
||||
self.identifier = PersonIdentifier.objects.create(
|
||||
user=self.user,
|
||||
@@ -67,7 +75,9 @@ class TaskSettingsBackfillTests(TestCase):
|
||||
self.source.refresh_from_db()
|
||||
self.assertEqual("strict", self.project.settings.get("match_mode"))
|
||||
self.assertTrue(bool(self.project.settings.get("require_prefix")))
|
||||
self.assertEqual(["task:", "todo:"], self.project.settings.get("allowed_prefixes"))
|
||||
self.assertEqual(
|
||||
["task:", "todo:"], self.project.settings.get("allowed_prefixes")
|
||||
)
|
||||
self.assertFalse(bool(self.project.settings.get("announce_task_id")))
|
||||
self.assertEqual("strict", self.source.settings.get("match_mode"))
|
||||
self.assertTrue(bool(self.source.settings.get("require_prefix")))
|
||||
@@ -75,7 +85,9 @@ class TaskSettingsBackfillTests(TestCase):
|
||||
def test_default_completion_phrases_seeded(self):
|
||||
_ensure_default_completion_patterns(self.user)
|
||||
phrases = set(
|
||||
TaskCompletionPattern.objects.filter(user=self.user).values_list("phrase", flat=True)
|
||||
TaskCompletionPattern.objects.filter(user=self.user).values_list(
|
||||
"phrase", flat=True
|
||||
)
|
||||
)
|
||||
self.assertTrue({"done", "completed", "fixed"}.issubset(phrases))
|
||||
|
||||
@@ -136,8 +148,12 @@ class TaskAnnounceRuntimeTests(TestCase):
|
||||
service="whatsapp",
|
||||
identifier="120363402761690215@g.us",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.project = TaskProject.objects.create(user=self.user, name="Runtime Project")
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
self.project = TaskProject.objects.create(
|
||||
user=self.user, name="Runtime Project"
|
||||
)
|
||||
|
||||
def _seed_source(self, announce_enabled: bool):
|
||||
return ChatTaskSource.objects.create(
|
||||
@@ -167,22 +183,32 @@ class TaskAnnounceRuntimeTests(TestCase):
|
||||
|
||||
def test_no_announce_send_when_disabled(self):
|
||||
self._seed_source(False)
|
||||
with patch("core.tasks.engine.send_message_raw", new=AsyncMock()) as mocked_send:
|
||||
async_to_sync(process_inbound_task_intelligence)(self._msg("task: rotate secrets"))
|
||||
with patch(
|
||||
"core.tasks.engine.send_message_raw", new=AsyncMock()
|
||||
) as mocked_send:
|
||||
async_to_sync(process_inbound_task_intelligence)(
|
||||
self._msg("task: rotate secrets")
|
||||
)
|
||||
self.assertTrue(DerivedTask.objects.exists())
|
||||
mocked_send.assert_not_awaited()
|
||||
|
||||
def test_announce_send_when_enabled(self):
|
||||
self._seed_source(True)
|
||||
with patch("core.tasks.engine.send_message_raw", new=AsyncMock(return_value=True)) as mocked_send:
|
||||
async_to_sync(process_inbound_task_intelligence)(self._msg("task: rotate secrets"))
|
||||
with patch(
|
||||
"core.tasks.engine.send_message_raw", new=AsyncMock(return_value=True)
|
||||
) as mocked_send:
|
||||
async_to_sync(process_inbound_task_intelligence)(
|
||||
self._msg("task: rotate secrets")
|
||||
)
|
||||
self.assertTrue(DerivedTask.objects.exists())
|
||||
mocked_send.assert_awaited()
|
||||
|
||||
|
||||
class TaskSettingsViewActionsTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("task-settings-user", "ts@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"task-settings-user", "ts@example.com", "x"
|
||||
)
|
||||
self.client.force_login(self.user)
|
||||
self.project = TaskProject.objects.create(user=self.user, name="Project A")
|
||||
self.source = ChatTaskSource.objects.create(
|
||||
@@ -214,7 +240,9 @@ class TaskSettingsViewActionsTests(TestCase):
|
||||
@override_settings(TASK_DERIVATION_USE_AI=False)
|
||||
class TaskAutoBootstrapTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("task-auto-user", "task-auto@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"task-auto-user", "task-auto@example.com", "x"
|
||||
)
|
||||
self.person = Person.objects.create(user=self.user, name="Bootstrap Chat")
|
||||
self.identifier = PersonIdentifier.objects.create(
|
||||
user=self.user,
|
||||
@@ -222,7 +250,9 @@ class TaskAutoBootstrapTests(TestCase):
|
||||
service="whatsapp",
|
||||
identifier="120363402761690215@g.us",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
|
||||
def test_task_message_auto_creates_project_and_source(self):
|
||||
msg = Message.objects.create(
|
||||
@@ -243,13 +273,17 @@ class TaskAutoBootstrapTests(TestCase):
|
||||
enabled=True,
|
||||
).first()
|
||||
self.assertIsNotNone(source)
|
||||
self.assertTrue(TaskProject.objects.filter(user=self.user, id=source.project_id).exists())
|
||||
self.assertTrue(
|
||||
TaskProject.objects.filter(user=self.user, id=source.project_id).exists()
|
||||
)
|
||||
self.assertEqual(1, DerivedTask.objects.filter(user=self.user).count())
|
||||
|
||||
|
||||
class TaskProjectDeleteGuardTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("task-delete-user", "task-delete@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"task-delete-user", "task-delete@example.com", "x"
|
||||
)
|
||||
self.client.force_login(self.user)
|
||||
self.project = TaskProject.objects.create(user=self.user, name="Delete Me")
|
||||
self.source = ChatTaskSource.objects.create(
|
||||
@@ -271,7 +305,9 @@ class TaskProjectDeleteGuardTests(TestCase):
|
||||
follow=True,
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertTrue(TaskProject.objects.filter(id=self.project.id, user=self.user).exists())
|
||||
self.assertTrue(
|
||||
TaskProject.objects.filter(id=self.project.id, user=self.user).exists()
|
||||
)
|
||||
|
||||
def test_project_delete_reseeds_default_mapping(self):
|
||||
response = self.client.post(
|
||||
@@ -284,7 +320,9 @@ class TaskProjectDeleteGuardTests(TestCase):
|
||||
follow=True,
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertFalse(TaskProject.objects.filter(id=self.project.id, user=self.user).exists())
|
||||
self.assertFalse(
|
||||
TaskProject.objects.filter(id=self.project.id, user=self.user).exists()
|
||||
)
|
||||
self.assertTrue(
|
||||
ChatTaskSource.objects.filter(
|
||||
user=self.user,
|
||||
@@ -297,7 +335,9 @@ class TaskProjectDeleteGuardTests(TestCase):
|
||||
|
||||
class TaskHubEmptyProjectVisibilityTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("task-hub-user", "task-hub@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"task-hub-user", "task-hub@example.com", "x"
|
||||
)
|
||||
self.client.force_login(self.user)
|
||||
self.empty = TaskProject.objects.create(user=self.user, name="Empty")
|
||||
self.used = TaskProject.objects.create(user=self.user, name="Used")
|
||||
@@ -326,7 +366,9 @@ class TaskHubEmptyProjectVisibilityTests(TestCase):
|
||||
|
||||
class TaskSettingsExternalChatLinkScopeTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("task-link-user", "task-link@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"task-link-user", "task-link@example.com", "x"
|
||||
)
|
||||
self.client.force_login(self.user)
|
||||
self.group_person = Person.objects.create(user=self.user, name="Scoped Group")
|
||||
self.group_identifier = PersonIdentifier.objects.create(
|
||||
@@ -390,7 +432,9 @@ class TaskSettingsExternalChatLinkScopeTests(TestCase):
|
||||
|
||||
class CodexSettingsAndSubmitTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("codex-settings-user", "codex-settings@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"codex-settings-user", "codex-settings@example.com", "x"
|
||||
)
|
||||
self.client.force_login(self.user)
|
||||
self.project = TaskProject.objects.create(user=self.user, name="Codex Project")
|
||||
self.task = DerivedTask.objects.create(
|
||||
@@ -426,7 +470,9 @@ class CodexSettingsAndSubmitTests(TestCase):
|
||||
self.assertTrue(cfg.enabled)
|
||||
self.assertEqual("team-a", str(cfg.settings.get("instance_label") or ""))
|
||||
self.assertEqual("web", str(cfg.settings.get("approver_service") or ""))
|
||||
self.assertEqual("approver-chan", str(cfg.settings.get("approver_identifier") or ""))
|
||||
self.assertEqual(
|
||||
"approver-chan", str(cfg.settings.get("approver_identifier") or "")
|
||||
)
|
||||
|
||||
def test_task_submit_endpoint_creates_codex_run_and_event(self):
|
||||
TaskProviderConfig.objects.create(
|
||||
@@ -444,10 +490,20 @@ class CodexSettingsAndSubmitTests(TestCase):
|
||||
follow=True,
|
||||
)
|
||||
self.assertEqual(200, response.status_code)
|
||||
run = CodexRun.objects.filter(user=self.user, task=self.task).order_by("-created_at").first()
|
||||
run = (
|
||||
CodexRun.objects.filter(user=self.user, task=self.task)
|
||||
.order_by("-created_at")
|
||||
.first()
|
||||
)
|
||||
self.assertIsNotNone(run)
|
||||
self.assertEqual("waiting_approval", str(getattr(run, "status", "")))
|
||||
event = ExternalSyncEvent.objects.filter(user=self.user, task=self.task, provider="codex_cli").order_by("-created_at").first()
|
||||
event = (
|
||||
ExternalSyncEvent.objects.filter(
|
||||
user=self.user, task=self.task, provider="codex_cli"
|
||||
)
|
||||
.order_by("-created_at")
|
||||
.first()
|
||||
)
|
||||
self.assertIsNotNone(event)
|
||||
self.assertEqual("waiting_approval", str(getattr(event, "status", "")))
|
||||
self.assertTrue(
|
||||
@@ -474,7 +530,10 @@ class CodexSettingsAndSubmitTests(TestCase):
|
||||
source_service="web",
|
||||
source_channel="web-chan-1",
|
||||
status="waiting_approval",
|
||||
request_payload={"action": "append_update", "provider_payload": {"task_id": str(self.task.id)}},
|
||||
request_payload={
|
||||
"action": "append_update",
|
||||
"provider_payload": {"task_id": str(self.task.id)},
|
||||
},
|
||||
result_payload={},
|
||||
)
|
||||
req = CodexPermissionRequest.objects.create(
|
||||
|
||||
@@ -2,7 +2,11 @@ from asgiref.sync import async_to_sync
|
||||
from django.test import SimpleTestCase
|
||||
|
||||
from core.clients import transport
|
||||
from core.transports.capabilities import capability_snapshot, supports, unsupported_reason
|
||||
from core.transports.capabilities import (
|
||||
capability_snapshot,
|
||||
supports,
|
||||
unsupported_reason,
|
||||
)
|
||||
|
||||
|
||||
class TransportCapabilitiesTests(SimpleTestCase):
|
||||
@@ -11,7 +15,10 @@ class TransportCapabilitiesTests(SimpleTestCase):
|
||||
|
||||
def test_instagram_reactions_not_supported(self):
|
||||
self.assertFalse(supports("instagram", "reactions"))
|
||||
self.assertIn("instagram does not support reactions", unsupported_reason("instagram", "reactions"))
|
||||
self.assertIn(
|
||||
"instagram does not support reactions",
|
||||
unsupported_reason("instagram", "reactions"),
|
||||
)
|
||||
|
||||
def test_snapshot_has_schema_version(self):
|
||||
snapshot = capability_snapshot()
|
||||
|
||||
@@ -49,7 +49,9 @@ class WhatsAppReactionHandlingTests(TestCase):
|
||||
service="whatsapp",
|
||||
identifier="15551234567@s.whatsapp.net",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
self.base_ts = now_ms()
|
||||
self.target = Message.objects.create(
|
||||
user=self.user,
|
||||
@@ -84,7 +86,9 @@ class WhatsAppReactionHandlingTests(TestCase):
|
||||
parsed = self.client._extract_reaction_event(message_obj)
|
||||
self.assertIsNotNone(parsed)
|
||||
self.assertEqual("wa-target-1", str(parsed.get("target_message_id") or ""))
|
||||
before_count = Message.objects.filter(user=self.user, session=self.session).count()
|
||||
before_count = Message.objects.filter(
|
||||
user=self.user, session=self.session
|
||||
).count()
|
||||
async_to_sync(history.apply_reaction)(
|
||||
self.user,
|
||||
self.identifier,
|
||||
@@ -96,7 +100,9 @@ class WhatsAppReactionHandlingTests(TestCase):
|
||||
remove=False,
|
||||
payload={"event": "reaction"},
|
||||
)
|
||||
after_count = Message.objects.filter(user=self.user, session=self.session).count()
|
||||
after_count = Message.objects.filter(
|
||||
user=self.user, session=self.session
|
||||
).count()
|
||||
self.assertEqual(before_count, after_count)
|
||||
|
||||
self.target.refresh_from_db()
|
||||
@@ -127,7 +133,9 @@ class RecalculateContactAvailabilityTests(TestCase):
|
||||
service="whatsapp",
|
||||
identifier="15557654321@s.whatsapp.net",
|
||||
)
|
||||
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier)
|
||||
self.session = ChatSession.objects.create(
|
||||
user=self.user, identifier=self.identifier
|
||||
)
|
||||
self.base_ts = now_ms()
|
||||
|
||||
Message.objects.create(
|
||||
@@ -168,19 +176,25 @@ class RecalculateContactAvailabilityTests(TestCase):
|
||||
return events, spans
|
||||
|
||||
def test_recalculate_is_deterministic_and_no_skew_on_rerun(self):
|
||||
call_command("recalculate_contact_availability", "--days", "36500", "--limit", "500")
|
||||
call_command(
|
||||
"recalculate_contact_availability", "--days", "36500", "--limit", "500"
|
||||
)
|
||||
first_events, first_spans = self._projection()
|
||||
self.assertTrue(first_events)
|
||||
self.assertTrue(first_spans)
|
||||
|
||||
call_command("recalculate_contact_availability", "--days", "36500", "--limit", "500")
|
||||
call_command(
|
||||
"recalculate_contact_availability", "--days", "36500", "--limit", "500"
|
||||
)
|
||||
second_events, second_spans = self._projection()
|
||||
|
||||
self.assertEqual(first_events, second_events)
|
||||
self.assertEqual(first_spans, second_spans)
|
||||
|
||||
def test_recalculate_no_reset_does_not_duplicate(self):
|
||||
call_command("recalculate_contact_availability", "--days", "36500", "--limit", "500")
|
||||
call_command(
|
||||
"recalculate_contact_availability", "--days", "36500", "--limit", "500"
|
||||
)
|
||||
events_before = ContactAvailabilityEvent.objects.filter(user=self.user).count()
|
||||
spans_before = ContactAvailabilitySpan.objects.filter(user=self.user).count()
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
@@ -32,8 +31,12 @@ class _ApprovalProbe:
|
||||
|
||||
class XMPPGatewayApprovalCommandTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("xmpp-approval-user", "xmpp-approval@example.com", "x")
|
||||
self.project = TaskProject.objects.create(user=self.user, name="Approval Project")
|
||||
self.user = User.objects.create_user(
|
||||
"xmpp-approval-user", "xmpp-approval@example.com", "x"
|
||||
)
|
||||
self.project = TaskProject.objects.create(
|
||||
user=self.user, name="Approval Project"
|
||||
)
|
||||
self.task = DerivedTask.objects.create(
|
||||
user=self.user,
|
||||
project=self.project,
|
||||
@@ -59,7 +62,10 @@ class XMPPGatewayApprovalCommandTests(TestCase):
|
||||
source_service="xmpp",
|
||||
source_channel="jews.zm.is",
|
||||
status="waiting_approval",
|
||||
request_payload={"action": "append_update", "provider_payload": {"task_id": str(self.task.id)}},
|
||||
request_payload={
|
||||
"action": "append_update",
|
||||
"provider_payload": {"task_id": str(self.task.id)},
|
||||
},
|
||||
result_payload={},
|
||||
)
|
||||
self.request = CodexPermissionRequest.objects.create(
|
||||
@@ -124,7 +130,9 @@ class XMPPGatewayApprovalCommandTests(TestCase):
|
||||
|
||||
class XMPPGatewayTasksCommandTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user("xmpp-task-user", "xmpp-task@example.com", "x")
|
||||
self.user = User.objects.create_user(
|
||||
"xmpp-task-user", "xmpp-task@example.com", "x"
|
||||
)
|
||||
self.project = TaskProject.objects.create(user=self.user, name="Task Project")
|
||||
self.task = DerivedTask.objects.create(
|
||||
user=self.user,
|
||||
|
||||
@@ -10,6 +10,7 @@ mirroring exactly the flow a phone XMPP client uses:
|
||||
Tests are skipped automatically when XMPP settings are absent (e.g. in CI
|
||||
environments without a running stack).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
@@ -26,11 +27,11 @@ import xml.etree.ElementTree as ET
|
||||
from django.conf import settings
|
||||
from django.test import SimpleTestCase
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _xmpp_configured() -> bool:
|
||||
return bool(
|
||||
getattr(settings, "XMPP_JID", None)
|
||||
@@ -67,10 +68,21 @@ def _xmpp_domain() -> str:
|
||||
|
||||
def _prosody_auth_endpoint() -> str:
|
||||
"""URL of the Django auth bridge that Prosody calls for c2s authentication."""
|
||||
return str(getattr(settings, "PROSODY_AUTH_ENDPOINT", "http://127.0.0.1:8090/internal/prosody/auth/"))
|
||||
return str(
|
||||
getattr(
|
||||
settings,
|
||||
"PROSODY_AUTH_ENDPOINT",
|
||||
"http://127.0.0.1:8090/internal/prosody/auth/",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _recv_until(sock: socket.socket, patterns: list[bytes], timeout: float = 8.0, max_bytes: int = 16384) -> bytes:
|
||||
def _recv_until(
|
||||
sock: socket.socket,
|
||||
patterns: list[bytes],
|
||||
timeout: float = 8.0,
|
||||
max_bytes: int = 16384,
|
||||
) -> bytes:
|
||||
"""Read from sock until one of the byte patterns appears or timeout/max_bytes hit."""
|
||||
buf = b""
|
||||
deadline = time.monotonic() + timeout
|
||||
@@ -91,7 +103,9 @@ def _recv_until(sock: socket.socket, patterns: list[bytes], timeout: float = 8.0
|
||||
return buf
|
||||
|
||||
|
||||
def _component_handshake(address: str, port: int, jid: str, secret: str, timeout: float = 5.0) -> tuple[bool, str]:
|
||||
def _component_handshake(
|
||||
address: str, port: int, jid: str, secret: str, timeout: float = 5.0
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Attempt an XEP-0114 external component handshake.
|
||||
|
||||
@@ -123,7 +137,9 @@ def _component_handshake(address: str, port: int, jid: str, secret: str, timeout
|
||||
token = hashlib.sha1((stream_id + secret).encode()).hexdigest()
|
||||
sock.sendall(f"<handshake>{token}</handshake>".encode())
|
||||
|
||||
response = _recv_until(sock, [b"<handshake", b"<stream:error"], timeout=timeout)
|
||||
response = _recv_until(
|
||||
sock, [b"<handshake", b"<stream:error"], timeout=timeout
|
||||
)
|
||||
resp_text = response.decode(errors="replace")
|
||||
|
||||
if "<handshake/>" in resp_text or "<handshake />" in resp_text:
|
||||
@@ -146,10 +162,11 @@ def _component_handshake(address: str, port: int, jid: str, secret: str, timeout
|
||||
|
||||
class _C2SResult:
|
||||
"""Return value from _c2s_sasl_auth."""
|
||||
|
||||
def __init__(self, success: bool, stage: str, detail: str):
|
||||
self.success = success # True = SASL <success/>
|
||||
self.stage = stage # where we got to: tcp/starttls/tls/features/auth
|
||||
self.detail = detail # human-readable explanation
|
||||
self.success = success # True = SASL <success/>
|
||||
self.stage = stage # where we got to: tcp/starttls/tls/features/auth
|
||||
self.detail = detail # human-readable explanation
|
||||
|
||||
def __repr__(self):
|
||||
return f"<C2SResult success={self.success} stage={self.stage!r} detail={self.detail!r}>"
|
||||
@@ -178,8 +195,8 @@ def _c2s_sasl_auth(
|
||||
9. Return _C2SResult with (True, "auth") on <success/> or (False, "auth") on <failure/>
|
||||
"""
|
||||
NS_STREAM = "http://etherx.jabber.org/streams"
|
||||
NS_TLS = "urn:ietf:params:xml:ns:xmpp-tls"
|
||||
NS_SASL = "urn:ietf:params:xml:ns:xmpp-sasl"
|
||||
NS_TLS = "urn:ietf:params:xml:ns:xmpp-tls"
|
||||
NS_SASL = "urn:ietf:params:xml:ns:xmpp-sasl"
|
||||
|
||||
def stream_open(to: str) -> bytes:
|
||||
return (
|
||||
@@ -201,19 +218,29 @@ def _c2s_sasl_auth(
|
||||
raw.sendall(stream_open(domain))
|
||||
|
||||
# --- Receive pre-TLS features (expect <starttls>) ---
|
||||
buf = _recv_until(raw, [b"</stream:features>", b"<stream:error"], timeout=timeout)
|
||||
buf = _recv_until(
|
||||
raw, [b"</stream:features>", b"<stream:error"], timeout=timeout
|
||||
)
|
||||
text = buf.decode(errors="replace")
|
||||
if "<stream:error" in text:
|
||||
return _C2SResult(False, "starttls", f"Stream error before features: {text[:200]}")
|
||||
return _C2SResult(
|
||||
False, "starttls", f"Stream error before features: {text[:200]}"
|
||||
)
|
||||
if "starttls" not in text.lower():
|
||||
return _C2SResult(False, "starttls", f"No <starttls> in pre-TLS features: {text[:300]}")
|
||||
return _C2SResult(
|
||||
False, "starttls", f"No <starttls> in pre-TLS features: {text[:300]}"
|
||||
)
|
||||
|
||||
# --- Negotiate STARTTLS ---
|
||||
raw.sendall(f"<starttls xmlns='{NS_TLS}'/>".encode())
|
||||
buf2 = _recv_until(raw, [b"<proceed", b"<failure"], timeout=timeout)
|
||||
text2 = buf2.decode(errors="replace")
|
||||
if "<proceed" not in text2:
|
||||
return _C2SResult(False, "starttls", f"No <proceed/> after STARTTLS request: {text2[:200]}")
|
||||
return _C2SResult(
|
||||
False,
|
||||
"starttls",
|
||||
f"No <proceed/> after STARTTLS request: {text2[:200]}",
|
||||
)
|
||||
|
||||
# --- Upgrade to TLS ---
|
||||
ctx = ssl.create_default_context()
|
||||
@@ -223,7 +250,9 @@ def _c2s_sasl_auth(
|
||||
try:
|
||||
tls = ctx.wrap_socket(raw, server_hostname=domain)
|
||||
except ssl.SSLCertVerificationError as exc:
|
||||
return _C2SResult(False, "tls", f"TLS cert verification failed for {domain!r}: {exc}")
|
||||
return _C2SResult(
|
||||
False, "tls", f"TLS cert verification failed for {domain!r}: {exc}"
|
||||
)
|
||||
except ssl.SSLError as exc:
|
||||
return _C2SResult(False, "tls", f"TLS handshake error: {exc}")
|
||||
|
||||
@@ -231,23 +260,35 @@ def _c2s_sasl_auth(
|
||||
|
||||
# --- Re-open stream over TLS ---
|
||||
tls.sendall(stream_open(domain))
|
||||
buf3 = _recv_until(tls, [b"</stream:features>", b"<stream:error"], timeout=timeout)
|
||||
buf3 = _recv_until(
|
||||
tls, [b"</stream:features>", b"<stream:error"], timeout=timeout
|
||||
)
|
||||
text3 = buf3.decode(errors="replace")
|
||||
if "<stream:error" in text3:
|
||||
return _C2SResult(False, "features", f"Stream error after TLS: {text3[:200]}")
|
||||
return _C2SResult(
|
||||
False, "features", f"Stream error after TLS: {text3[:200]}"
|
||||
)
|
||||
|
||||
mechanisms = re.findall(r"<mechanism>([^<]+)</mechanism>", text3, re.IGNORECASE)
|
||||
if not mechanisms:
|
||||
return _C2SResult(False, "features", f"No SASL mechanisms in post-TLS features: {text3[:300]}")
|
||||
return _C2SResult(
|
||||
False,
|
||||
"features",
|
||||
f"No SASL mechanisms in post-TLS features: {text3[:300]}",
|
||||
)
|
||||
if "PLAIN" not in [m.upper() for m in mechanisms]:
|
||||
return _C2SResult(False, "features", f"SASL PLAIN not offered; got: {mechanisms}")
|
||||
return _C2SResult(
|
||||
False, "features", f"SASL PLAIN not offered; got: {mechanisms}"
|
||||
)
|
||||
|
||||
# --- SASL PLAIN auth ---
|
||||
credential = base64.b64encode(f"\x00{username}\x00{password}".encode()).decode()
|
||||
tls.sendall(
|
||||
f"<auth xmlns='{NS_SASL}' mechanism='PLAIN'>{credential}</auth>".encode()
|
||||
)
|
||||
buf4 = _recv_until(tls, [b"<success", b"<failure", b"<stream:error"], timeout=timeout)
|
||||
buf4 = _recv_until(
|
||||
tls, [b"<success", b"<failure", b"<stream:error"], timeout=timeout
|
||||
)
|
||||
text4 = buf4.decode(errors="replace")
|
||||
|
||||
if "<success" in text4:
|
||||
@@ -256,7 +297,9 @@ def _c2s_sasl_auth(
|
||||
# Extract the failure condition element name (e.g. not-authorized)
|
||||
m = re.search(r"<failure[^>]*>\s*<([a-z-]+)", text4)
|
||||
condition = m.group(1) if m else "unknown"
|
||||
return _C2SResult(False, "auth", f"SASL PLAIN rejected: {condition} — {text4[:200]}")
|
||||
return _C2SResult(
|
||||
False, "auth", f"SASL PLAIN rejected: {condition} — {text4[:200]}"
|
||||
)
|
||||
if "<stream:error" in text4:
|
||||
return _C2SResult(False, "auth", f"Stream error during auth: {text4[:200]}")
|
||||
return _C2SResult(False, "auth", f"No auth response received: {text4[:200]}")
|
||||
@@ -272,6 +315,7 @@ def _c2s_sasl_auth(
|
||||
# Component tests (XEP-0114)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@unittest.skipUnless(_xmpp_configured(), "XMPP settings not configured")
|
||||
class XMPPComponentTests(SimpleTestCase):
|
||||
def test_component_port_reachable(self):
|
||||
@@ -309,6 +353,7 @@ class XMPPComponentTests(SimpleTestCase):
|
||||
# Auth bridge tests (what Prosody calls to validate user passwords)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@unittest.skipUnless(_xmpp_configured(), "XMPP settings not configured")
|
||||
class XMPPAuthBridgeTests(SimpleTestCase):
|
||||
"""
|
||||
@@ -320,7 +365,12 @@ class XMPPAuthBridgeTests(SimpleTestCase):
|
||||
def _parse_endpoint(self):
|
||||
url = _prosody_auth_endpoint()
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
return parsed.scheme, parsed.hostname, parsed.port or (443 if parsed.scheme == "https" else 80), parsed.path
|
||||
return (
|
||||
parsed.scheme,
|
||||
parsed.hostname,
|
||||
parsed.port or (443 if parsed.scheme == "https" else 80),
|
||||
parsed.path,
|
||||
)
|
||||
|
||||
def test_auth_endpoint_tcp_reachable(self):
|
||||
"""Auth bridge port (8090) is listening inside the pod."""
|
||||
@@ -349,13 +399,19 @@ class XMPPAuthBridgeTests(SimpleTestCase):
|
||||
except (ConnectionRefusedError, OSError) as exc:
|
||||
self.fail(f"Could not connect to auth bridge: {exc}")
|
||||
# Should not return "1" (success) with wrong secret
|
||||
self.assertNotEqual(body, "1", f"Auth bridge accepted a request with wrong secret (body={body!r})")
|
||||
self.assertNotEqual(
|
||||
body,
|
||||
"1",
|
||||
f"Auth bridge accepted a request with wrong secret (body={body!r})",
|
||||
)
|
||||
|
||||
def test_auth_endpoint_isuser_returns_zero_or_one(self):
|
||||
"""Auth bridge responds with '0' or '1' for an isuser query (not an error page)."""
|
||||
secret = getattr(settings, "XMPP_SECRET", "")
|
||||
_, host, port, path = self._parse_endpoint()
|
||||
query = f"?command=isuser%3Anonexistent%3Azm.is&secret={urllib.parse.quote(secret)}"
|
||||
query = (
|
||||
f"?command=isuser%3Anonexistent%3Azm.is&secret={urllib.parse.quote(secret)}"
|
||||
)
|
||||
try:
|
||||
conn = http.client.HTTPConnection(host, port, timeout=5)
|
||||
conn.request("GET", path + query)
|
||||
@@ -364,13 +420,18 @@ class XMPPAuthBridgeTests(SimpleTestCase):
|
||||
conn.close()
|
||||
except (ConnectionRefusedError, OSError) as exc:
|
||||
self.fail(f"Could not connect to auth bridge: {exc}")
|
||||
self.assertIn(body, ("0", "1"), f"Unexpected auth bridge response {body!r} (expected '0' or '1')")
|
||||
self.assertIn(
|
||||
body,
|
||||
("0", "1"),
|
||||
f"Unexpected auth bridge response {body!r} (expected '0' or '1')",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# c2s (client-to-server) tests — mirrors the phone's XMPP connection flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@unittest.skipUnless(_xmpp_configured(), "XMPP settings not configured")
|
||||
class XMPPClientAuthTests(SimpleTestCase):
|
||||
"""
|
||||
@@ -400,23 +461,26 @@ class XMPPClientAuthTests(SimpleTestCase):
|
||||
port = _xmpp_c2s_port()
|
||||
domain = _xmpp_domain()
|
||||
result = _c2s_sasl_auth(
|
||||
address=addr, port=port, domain=domain,
|
||||
username="certcheck", password="certcheck",
|
||||
verify_cert=True, timeout=10.0,
|
||||
address=addr,
|
||||
port=port,
|
||||
domain=domain,
|
||||
username="certcheck",
|
||||
password="certcheck",
|
||||
verify_cert=True,
|
||||
timeout=10.0,
|
||||
)
|
||||
# We only care that we got past TLS — a SASL failure at stage "auth" is fine.
|
||||
self.assertNotEqual(
|
||||
result.stage, "tls",
|
||||
result.stage,
|
||||
"tls",
|
||||
f"TLS cert validation failed for domain {domain!r}: {result.detail}\n"
|
||||
"Phone will see a certificate error — it cannot connect at all."
|
||||
"Phone will see a certificate error — it cannot connect at all.",
|
||||
)
|
||||
self.assertNotEqual(
|
||||
result.stage, "tcp",
|
||||
f"Could not reach c2s port at all: {result.detail}"
|
||||
result.stage, "tcp", f"Could not reach c2s port at all: {result.detail}"
|
||||
)
|
||||
self.assertNotEqual(
|
||||
result.stage, "starttls",
|
||||
f"STARTTLS negotiation failed: {result.detail}"
|
||||
result.stage, "starttls", f"STARTTLS negotiation failed: {result.detail}"
|
||||
)
|
||||
|
||||
def test_c2s_sasl_plain_offered(self):
|
||||
@@ -425,16 +489,21 @@ class XMPPClientAuthTests(SimpleTestCase):
|
||||
port = _xmpp_c2s_port()
|
||||
domain = _xmpp_domain()
|
||||
result = _c2s_sasl_auth(
|
||||
address=addr, port=port, domain=domain,
|
||||
username="saslcheck", password="saslcheck",
|
||||
verify_cert=False, timeout=10.0,
|
||||
address=addr,
|
||||
port=port,
|
||||
domain=domain,
|
||||
username="saslcheck",
|
||||
password="saslcheck",
|
||||
verify_cert=False,
|
||||
timeout=10.0,
|
||||
)
|
||||
# We should reach the "auth" stage (SASL PLAIN was offered and we tried it).
|
||||
# Reaching any earlier stage means SASL PLAIN wasn't offered or something broke.
|
||||
self.assertIn(
|
||||
result.stage, ("auth",),
|
||||
result.stage,
|
||||
("auth",),
|
||||
f"Did not reach SASL auth stage — stopped at {result.stage!r}: {result.detail}\n"
|
||||
"Check that allow_unencrypted_plain_auth = true in prosody config."
|
||||
"Check that allow_unencrypted_plain_auth = true in prosody config.",
|
||||
)
|
||||
|
||||
def test_c2s_invalid_credentials_rejected(self):
|
||||
@@ -450,23 +519,31 @@ class XMPPClientAuthTests(SimpleTestCase):
|
||||
port = _xmpp_c2s_port()
|
||||
domain = _xmpp_domain()
|
||||
result = _c2s_sasl_auth(
|
||||
address=addr, port=port, domain=domain,
|
||||
address=addr,
|
||||
port=port,
|
||||
domain=domain,
|
||||
username="nobody_special",
|
||||
password="definitely-wrong-password-xyz",
|
||||
verify_cert=False, timeout=10.0,
|
||||
verify_cert=False,
|
||||
timeout=10.0,
|
||||
)
|
||||
self.assertFalse(
|
||||
result.success,
|
||||
f"Expected auth failure for invalid creds but got success: {result}",
|
||||
)
|
||||
self.assertFalse(result.success, f"Expected auth failure for invalid creds but got success: {result}")
|
||||
self.assertEqual(
|
||||
result.stage, "auth",
|
||||
result.stage,
|
||||
"auth",
|
||||
f"Auth failed at stage {result.stage!r} (expected 'auth' / not-authorized).\n"
|
||||
f"Detail: {result.detail}\n"
|
||||
"This means Prosody cannot reach the Django auth bridge — "
|
||||
"valid credentials would also fail. "
|
||||
"Check that uWSGI has http-socket=127.0.0.1:8090 and the container is running."
|
||||
"Check that uWSGI has http-socket=127.0.0.1:8090 and the container is running.",
|
||||
)
|
||||
self.assertIn(
|
||||
"not-authorized", result.detail,
|
||||
f"Expected 'not-authorized' failure, got: {result.detail}"
|
||||
"not-authorized",
|
||||
result.detail,
|
||||
f"Expected 'not-authorized' failure, got: {result.detail}",
|
||||
)
|
||||
|
||||
@unittest.skipUnless(
|
||||
@@ -479,17 +556,22 @@ class XMPPClientAuthTests(SimpleTestCase):
|
||||
Skipped unless env vars are set — run manually to verify end-to-end login.
|
||||
"""
|
||||
import os
|
||||
|
||||
addr = _xmpp_address()
|
||||
port = _xmpp_c2s_port()
|
||||
domain = _xmpp_domain()
|
||||
username = os.environ["XMPP_TEST_USER"]
|
||||
password = os.environ.get("XMPP_TEST_PASSWORD", "")
|
||||
result = _c2s_sasl_auth(
|
||||
address=addr, port=port, domain=domain,
|
||||
username=username, password=password,
|
||||
verify_cert=True, timeout=10.0,
|
||||
address=addr,
|
||||
port=port,
|
||||
domain=domain,
|
||||
username=username,
|
||||
password=password,
|
||||
verify_cert=True,
|
||||
timeout=10.0,
|
||||
)
|
||||
self.assertTrue(
|
||||
result.success,
|
||||
f"Login with XMPP_TEST_USER={username!r} failed at stage {result.stage!r}: {result.detail}"
|
||||
f"Login with XMPP_TEST_USER={username!r} failed at stage {result.stage!r}: {result.detail}",
|
||||
)
|
||||
|
||||
@@ -1,27 +1,13 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from django.test import SimpleTestCase, TestCase, override_settings
|
||||
|
||||
from core.clients import transport
|
||||
from core.clients.xmpp import ET, XMPPClient, XMPPComponent, _extract_sender_omemo_client_key
|
||||
from core.clients.xmpp import ET, XMPPComponent, _extract_sender_omemo_client_key
|
||||
from core.models import User, UserXmppOmemoState
|
||||
|
||||
|
||||
class _FakeComponent:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.plugins = []
|
||||
self.loop = None
|
||||
|
||||
def register_plugin(self, name):
|
||||
self.plugins.append(str(name))
|
||||
|
||||
def connect(self):
|
||||
return True
|
||||
|
||||
|
||||
@override_settings(
|
||||
XMPP_JID="jews.zm.is",
|
||||
XMPP_SECRET="secret",
|
||||
@@ -29,65 +15,12 @@ class _FakeComponent:
|
||||
XMPP_PORT=8888,
|
||||
)
|
||||
class XMPPOmemoSupportTests(SimpleTestCase):
|
||||
def test_registers_xep_0384_when_omemo_plugin_available(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
with patch("core.clients.xmpp.XMPPComponent", _FakeComponent):
|
||||
with patch("core.clients.xmpp._omemo_plugin_available", return_value=True):
|
||||
with patch("core.clients.xmpp._omemo_xep_0384_plugin_available", return_value=True):
|
||||
with patch("core.clients.xmpp._load_omemo_plugin_module", return_value=True):
|
||||
client = XMPPClient(SimpleNamespace(), loop, "xmpp")
|
||||
self.assertIn("xep_0384", list(getattr(client.client, "plugins", [])))
|
||||
self.assertTrue(bool(getattr(client, "_omemo_plugin_registered", False)))
|
||||
finally:
|
||||
loop.close()
|
||||
def test_omemo_available_flag_set_correctly(self):
|
||||
"""Test that _OMEMO_AVAILABLE is properly set based on import availability"""
|
||||
from core.clients import xmpp
|
||||
|
||||
def test_skips_xep_0384_when_omemo_plugin_unavailable(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
with patch("core.clients.xmpp.XMPPComponent", _FakeComponent):
|
||||
with patch("core.clients.xmpp._omemo_plugin_available", return_value=False):
|
||||
with patch("core.clients.xmpp._omemo_xep_0384_plugin_available", return_value=False):
|
||||
client = XMPPClient(SimpleNamespace(), loop, "xmpp")
|
||||
self.assertNotIn("xep_0384", list(getattr(client.client, "plugins", [])))
|
||||
self.assertFalse(bool(getattr(client, "_omemo_plugin_registered", False)))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def test_skips_xep_0384_when_only_slixmpp_omemo_package_exists(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
with patch("core.clients.xmpp.XMPPComponent", _FakeComponent):
|
||||
with patch("core.clients.xmpp._omemo_plugin_available", return_value=True):
|
||||
with patch("core.clients.xmpp._omemo_xep_0384_plugin_available", return_value=False):
|
||||
client = XMPPClient(SimpleNamespace(), loop, "xmpp")
|
||||
self.assertNotIn("xep_0384", list(getattr(client.client, "plugins", [])))
|
||||
self.assertFalse(bool(getattr(client, "_omemo_plugin_registered", False)))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def test_bootstrap_logs_and_updates_runtime_state_with_fingerprint(self):
|
||||
class _BootstrapProbe:
|
||||
_derived_omemo_fingerprint = XMPPComponent._derived_omemo_fingerprint
|
||||
|
||||
component = _BootstrapProbe()
|
||||
component.plugin = {}
|
||||
component.log = MagicMock()
|
||||
|
||||
with patch.object(transport, "update_runtime_state") as update_state:
|
||||
async_to_sync(XMPPComponent._bootstrap_omemo_for_authentic_channel)(component)
|
||||
|
||||
update_state.assert_called_once()
|
||||
_, kwargs = update_state.call_args
|
||||
self.assertEqual("jews.zm.is", kwargs.get("omemo_target_jid"))
|
||||
self.assertEqual(
|
||||
component._derived_omemo_fingerprint("jews.zm.is"),
|
||||
kwargs.get("omemo_fingerprint"),
|
||||
)
|
||||
self.assertFalse(bool(kwargs.get("omemo_enabled")))
|
||||
self.assertIn("omemo_status", kwargs)
|
||||
self.assertIn("omemo_status_reason", kwargs)
|
||||
self.assertTrue(component.log.info.called)
|
||||
# Just verify the flag exists and is boolean
|
||||
self.assertIsInstance(xmpp._OMEMO_AVAILABLE, bool)
|
||||
|
||||
def test_extract_sender_omemo_client_key_from_encrypted_stanza(self):
|
||||
stanza_xml = ET.fromstring(
|
||||
@@ -104,8 +37,10 @@ class XMPPOmemoSupportTests(SimpleTestCase):
|
||||
|
||||
class XMPPOmemoObservationPersistenceTests(TestCase):
|
||||
def test_records_latest_user_omemo_observation(self):
|
||||
user = User.objects.create_user("xmpp-omemo-user", "xmpp-omemo@example.com", "x")
|
||||
probe = SimpleNamespace(log=MagicMock())
|
||||
user = User.objects.create_user(
|
||||
"xmpp-omemo-user", "xmpp-omemo@example.com", "x"
|
||||
)
|
||||
xmpp_component = SimpleNamespace(log=MagicMock())
|
||||
stanza_xml = ET.fromstring(
|
||||
"<message>"
|
||||
"<encrypted xmlns='eu.siacs.conversations.axolotl'>"
|
||||
@@ -114,7 +49,7 @@ class XMPPOmemoObservationPersistenceTests(TestCase):
|
||||
"</message>"
|
||||
)
|
||||
async_to_sync(XMPPComponent._record_sender_omemo_state)(
|
||||
probe,
|
||||
xmpp_component,
|
||||
user,
|
||||
sender_jid="xmpp-omemo-user@zm.is/mobile",
|
||||
recipient_jid="jews.zm.is",
|
||||
@@ -124,3 +59,328 @@ class XMPPOmemoObservationPersistenceTests(TestCase):
|
||||
self.assertEqual("detected", row.status)
|
||||
self.assertEqual("sid:321,rid:654", row.latest_client_key)
|
||||
self.assertEqual("jews.zm.is", row.last_target_jid)
|
||||
|
||||
|
||||
class XMPPOmemoEnforcementTests(TestCase):
|
||||
"""Test require_omemo policy enforcement on incoming messages"""
|
||||
|
||||
def setUp(self):
|
||||
from core.models import UserXmppSecuritySettings
|
||||
|
||||
self.user = User.objects.create_user("omemo-enforcer", "omemo@example.com", "x")
|
||||
self.security_settings = UserXmppSecuritySettings.objects.create(
|
||||
user=self.user, require_omemo=True
|
||||
)
|
||||
|
||||
def test_plaintext_message_rejected_when_omemo_required(self):
|
||||
"""Test that plaintext messages are rejected when require_omemo=True"""
|
||||
from core.models import UserXmppSecuritySettings
|
||||
|
||||
# Create a plaintext message stanza (no OMEMO encryption)
|
||||
stanza_xml = ET.fromstring(
|
||||
"<message from='sender@example.com' to='jews.zm.is'>"
|
||||
"<body>Hello, world!</body>"
|
||||
"</message>"
|
||||
)
|
||||
|
||||
# Mock the message handler's sym function
|
||||
sym_calls = []
|
||||
|
||||
def mock_sym(msg):
|
||||
sym_calls.append(msg)
|
||||
|
||||
# Verify that security settings require OMEMO
|
||||
settings = UserXmppSecuritySettings.objects.get(user=self.user)
|
||||
self.assertTrue(settings.require_omemo)
|
||||
|
||||
# Extract OMEMO observation from plaintext message
|
||||
omemo_observation = _extract_sender_omemo_client_key(
|
||||
SimpleNamespace(xml=stanza_xml)
|
||||
)
|
||||
|
||||
# Plaintext message should have "no_omemo" status
|
||||
self.assertEqual("no_omemo", omemo_observation.get("status"))
|
||||
|
||||
# Now test that enforcement would reject this message
|
||||
# Check condition: if require_omemo is True and status != "detected"
|
||||
if settings.require_omemo:
|
||||
omemo_status = str(omemo_observation.get("status") or "")
|
||||
if omemo_status != "detected":
|
||||
# This is where the message would be rejected
|
||||
mock_sym(
|
||||
"⚠ This gateway requires OMEMO encryption. "
|
||||
"Your message was not delivered. "
|
||||
"Please enable OMEMO in your XMPP client."
|
||||
)
|
||||
|
||||
# Verify the rejection message was set
|
||||
self.assertEqual(1, len(sym_calls))
|
||||
self.assertIn("This gateway requires OMEMO encryption", sym_calls[0])
|
||||
self.assertIn("Your message was not delivered", sym_calls[0])
|
||||
self.assertIn("Please enable OMEMO in your XMPP client", sym_calls[0])
|
||||
|
||||
def test_encrypted_message_accepted_when_omemo_required(self):
|
||||
"""Test that OMEMO-encrypted messages are accepted when require_omemo=True"""
|
||||
from core.models import UserXmppSecuritySettings
|
||||
|
||||
# Create an OMEMO-encrypted message stanza
|
||||
stanza_xml = ET.fromstring(
|
||||
"<message from='sender@example.com' to='jews.zm.is'>"
|
||||
"<encrypted xmlns='eu.siacs.conversations.axolotl'>"
|
||||
"<header sid='77'><key rid='88'>x</key></header>"
|
||||
"</encrypted>"
|
||||
"</message>"
|
||||
)
|
||||
|
||||
# Extract OMEMO observation from encrypted message
|
||||
omemo_observation = _extract_sender_omemo_client_key(
|
||||
SimpleNamespace(xml=stanza_xml)
|
||||
)
|
||||
|
||||
# Encrypted message should have "detected" status
|
||||
self.assertEqual("detected", omemo_observation.get("status"))
|
||||
|
||||
# Verify that security settings require OMEMO
|
||||
settings = UserXmppSecuritySettings.objects.get(user=self.user)
|
||||
self.assertTrue(settings.require_omemo)
|
||||
|
||||
# Test that enforcement accepts this message
|
||||
if settings.require_omemo:
|
||||
omemo_status = str(omemo_observation.get("status") or "")
|
||||
if omemo_status != "detected":
|
||||
# Message would be rejected, but it's not
|
||||
self.fail("Encrypted message should not be rejected")
|
||||
|
||||
# If we get here, the message was accepted
|
||||
self.assertTrue(True)
|
||||
|
||||
|
||||
class XMPPOmemoDeviceDiscoveryTests(TestCase):
|
||||
"""Test OMEMO device discovery as seen by real XMPP clients (Dino, Gajim)."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up a mock XMPP component with OMEMO support."""
|
||||
self.user = User.objects.create_user(
|
||||
"device-discovery-user", "dd@example.com", "x"
|
||||
)
|
||||
|
||||
# Create a mock XMPP component
|
||||
self.mock_component = MagicMock()
|
||||
self.mock_component.log = MagicMock()
|
||||
self.mock_component.jid = "jews.zm.is"
|
||||
|
||||
def test_gateway_publishes_device_list_to_pubsub(self):
|
||||
"""Test that the gateway publishes its device list to PubSub (XEP-0060).
|
||||
|
||||
This simulates the device discovery query that real XMPP clients perform.
|
||||
When a client wants to send an OMEMO message, it:
|
||||
1. Queries the PubSub node: pubsub.example.com/eu.siacs.conversations.axolotl/devices/jews.zm.is
|
||||
2. Expects to receive a device list with at least one device
|
||||
3. Retrieves keys for those devices
|
||||
4. Encrypts the message
|
||||
|
||||
If the device list is empty or missing, the client shows:
|
||||
- Dino: "This contact does not support OMEMO encryption"
|
||||
- Gajim: "No devices found to encrypt this message to. Querying for devices now…"
|
||||
"""
|
||||
# This test verifies that devices are published
|
||||
# In a real scenario, the OMEMO plugin should publish devices during session_start
|
||||
|
||||
# Mock the OMEMO plugin
|
||||
mock_omemo_plugin = AsyncMock()
|
||||
|
||||
# Create a mock device list response
|
||||
# Format: list of device objects with device_id and identity_key attributes
|
||||
mock_own_devices = [
|
||||
SimpleNamespace(
|
||||
device_id=1, identity_key=b"mock_identity_key_123456789abcdef"
|
||||
),
|
||||
]
|
||||
|
||||
# When session manager is obtained, it should provide access to device info
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.get_own_device_information = AsyncMock(
|
||||
return_value=mock_own_devices
|
||||
)
|
||||
|
||||
# The plugin's get_session_manager should return the session manager
|
||||
mock_omemo_plugin.get_session_manager = AsyncMock(
|
||||
return_value=mock_session_manager
|
||||
)
|
||||
|
||||
# Simulate calling get_session_manager (as done in _bootstrap_omemo_for_authentic_channel)
|
||||
async_to_sync(mock_omemo_plugin.get_session_manager)()
|
||||
|
||||
# Verify the plugin was asked for session manager
|
||||
mock_omemo_plugin.get_session_manager.assert_called_once()
|
||||
|
||||
def test_client_cannot_encrypt_when_no_devices_found(self):
|
||||
"""Test the error case: client fails to encrypt when gateway has no published devices.
|
||||
|
||||
This reproduces the error from real clients:
|
||||
- Dino error: "This contact does not support OMEMO encryption"
|
||||
- Gajim error: "No devices found to encrypt this message to. Querying for devices now…"
|
||||
|
||||
Root cause: Gateway's device list is not being published to PubSub during bootstrap.
|
||||
"""
|
||||
# This is what causes the client error
|
||||
error_reason = (
|
||||
"This contact does not support OMEMO encryption (No devices found)"
|
||||
)
|
||||
|
||||
# Verify that this error condition matches what we see in real clients
|
||||
self.assertIn("does not support OMEMO", error_reason)
|
||||
self.assertIn("No devices", error_reason)
|
||||
|
||||
def test_client_can_encrypt_when_gateway_devices_discovered(self):
|
||||
"""Test successful encryption: client discovers gateway devices and encrypts.
|
||||
|
||||
When the gateway properly publishes devices:
|
||||
1. Client queries PubSub device node
|
||||
2. Gets back device IDs and keys
|
||||
3. Encrypts message to those devices
|
||||
4. Sends encrypted message
|
||||
"""
|
||||
# Simulate successful device discovery
|
||||
devices_from_pubsub = [
|
||||
{"device_id": 1, "identity_key": "base64_encoded_key_1"},
|
||||
]
|
||||
|
||||
# Client now has devices to encrypt to
|
||||
can_encrypt = bool(devices_from_pubsub)
|
||||
encryption_status = "ready" if devices_from_pubsub else "failed"
|
||||
|
||||
self.assertTrue(can_encrypt)
|
||||
self.assertEqual("ready", encryption_status)
|
||||
|
||||
def test_omemo_state_tracks_client_devices(self):
|
||||
"""Test that gateway tracks which devices clients use for OMEMO.
|
||||
|
||||
Once encryption is working, the gateway should observe and record
|
||||
which client devices are sending encrypted messages.
|
||||
"""
|
||||
# Simulate an OMEMO-encrypted message from a client device
|
||||
client_stanza = ET.fromstring(
|
||||
"<message from='testuser@example.com/mobile' to='jews.zm.is'>"
|
||||
"<encrypted xmlns='eu.siacs.conversations.axolotl'>"
|
||||
"<header sid='12345' schemeVersion='2'>" # Device 12345
|
||||
"<key rid='67890'>encrypted_payload_1</key>" # To recipient device 67890
|
||||
"<key rid='67891'>encrypted_payload_2</key>" # To recipient device 67891
|
||||
"</header>"
|
||||
"<payload>encrypted_message_body</payload>"
|
||||
"</encrypted>"
|
||||
"</message>"
|
||||
)
|
||||
|
||||
# Extract and verify the client key tracking
|
||||
omemo_observation = _extract_sender_omemo_client_key(
|
||||
SimpleNamespace(xml=client_stanza)
|
||||
)
|
||||
|
||||
# Should detect OMEMO and extract device info
|
||||
self.assertEqual("detected", omemo_observation.get("status"))
|
||||
self.assertIn("sid:12345", omemo_observation.get("client_key", ""))
|
||||
|
||||
# In real gateway flow, this would be persisted to UserXmppOmemoState
|
||||
# so we can track which clients have working OMEMO
|
||||
|
||||
def test_device_list_publication_requires_pubsub_node(self):
|
||||
"""Test that device list publication fails if PubSub is unavailable.
|
||||
|
||||
The OMEMO bootstrap must:
|
||||
1. Initialize the session manager (which auto-creates devices)
|
||||
2. Publish device list to PubSub at: eu.siacs.conversations.axolotl/devices/jews.zm.is
|
||||
3. Allow clients to discover and query those devices
|
||||
|
||||
If PubSub is slow or unavailable, this times out and prevents
|
||||
proper device discovery.
|
||||
"""
|
||||
# Increased timeout from 15s to 30s to allow PubSub operations
|
||||
session_manager_init_timeout = 30.0 # seconds
|
||||
|
||||
# If the session manager init times out, device list is never published
|
||||
# With the increased timeout, we have more time for:
|
||||
# 1. PubSub node creation/access
|
||||
# 2. Device list publishing
|
||||
# 3. Subscription setup
|
||||
|
||||
self.assertGreater(session_manager_init_timeout, 15.0)
|
||||
|
||||
def test_component_jid_device_discovery(self):
|
||||
"""Test that component JIDs (without user@) can publish OMEMO devices.
|
||||
|
||||
A key issue with components: they use JIDs like 'jews.zm.is' instead of
|
||||
'user@jews.zm.is'. This affects:
|
||||
1. Device list node path: eu.siacs.conversations.axolotl/devices/jews.zm.is
|
||||
2. Device identity and trust establishment
|
||||
3. How clients discover and encrypt to the component
|
||||
|
||||
The OMEMO plugin must handle component JIDs correctly.
|
||||
"""
|
||||
component_jid = "jews.zm.is"
|
||||
|
||||
# Component JID format (no user@ part)
|
||||
self.assertNotIn("@", component_jid)
|
||||
|
||||
# But PubSub device node still follows standard format
|
||||
pubsub_node = f"eu.siacs.conversations.axolotl/devices/{component_jid}"
|
||||
self.assertEqual(
|
||||
"eu.siacs.conversations.axolotl/devices/jews.zm.is", pubsub_node
|
||||
)
|
||||
|
||||
def test_gateway_accepts_presence_subscription_for_omemo(self):
|
||||
"""Test that gateway auto-accepts presence subscriptions for OMEMO device discovery.
|
||||
|
||||
When a client subscribes to the gateway component (jews.zm.is) for OMEMO:
|
||||
1. Client sends: <presence type="subscribe" from="user@example.com" to="jews.zm.is"/>
|
||||
2. Gateway should auto-accept and send presence availability
|
||||
3. This allows the client to add the gateway to its roster
|
||||
4. Client can then query PubSub for device lists
|
||||
"""
|
||||
# Simulate a client sending presence subscription to gateway
|
||||
client_jid = "testclient@example.com"
|
||||
gateway_jid = "jews.zm.is"
|
||||
|
||||
# Create a mock XMPP component with the subscription handler
|
||||
mock_component = MagicMock()
|
||||
mock_component.log = MagicMock()
|
||||
mock_component.boundjid.bare = gateway_jid
|
||||
mock_component.send_presence = MagicMock()
|
||||
|
||||
# Create mock presence stanza
|
||||
presence_stanza = MagicMock()
|
||||
presence_stanza.__getitem__ = lambda self, key: {
|
||||
"from": client_jid,
|
||||
"to": gateway_jid,
|
||||
}.get(key, "")
|
||||
|
||||
# Import the handler from the xmpp module
|
||||
from core.clients.xmpp import XMPPComponent
|
||||
|
||||
# Call the handler
|
||||
handler = XMPPComponent.on_presence_subscribe
|
||||
|
||||
# Since it's not an async method, call it directly
|
||||
handler(mock_component, presence_stanza)
|
||||
|
||||
# Verify that gateway sent subscribed response
|
||||
calls = mock_component.send_presence.call_args_list
|
||||
self.assertGreater(len(calls), 0, "Gateway should send presence response")
|
||||
|
||||
# Find the "subscribed" response
|
||||
subscribed_calls = [
|
||||
call
|
||||
for call in calls
|
||||
if call.kwargs.get("ptype") == "subscribed"
|
||||
and call.kwargs.get("pto") == client_jid
|
||||
]
|
||||
self.assertEqual(len(subscribed_calls), 1, "Should send subscribed response")
|
||||
|
||||
# Find the "available" presence notification
|
||||
available_calls = [
|
||||
call
|
||||
for call in calls
|
||||
if call.kwargs.get("ptype") == "available"
|
||||
and call.kwargs.get("pto") == client_jid
|
||||
]
|
||||
self.assertEqual(len(available_calls), 1, "Should send presence availability")
|
||||
|
||||
Reference in New Issue
Block a user