From 611de57bf8509b4b514286af94f009b33b9f315b Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Sat, 7 Mar 2026 15:34:23 +0000 Subject: [PATCH] Improve security --- CLAUDE.md | 21 + INSTALL.md | 2 +- Makefile | 3 +- app/local_settings.py | 1 + app/urls.py | 5 + core/clients/xmpp.py | 752 +++++++++++++++++- core/commands/engine.py | 31 +- core/gateway/__init__.py | 1 + core/gateway/commands.py | 133 ++++ core/management/commands/task_sync_worker.py | 7 + .../0038_userxmppomemostate_and_more.py | 33 + .../0039_userxmppsecuritysettings.py | 27 + ...mmandsecuritypolicy_gatewaycommandevent.py | 100 +++ core/models.py | 110 +++ core/security/command_policy.py | 220 +++++ core/tasks/codex_approval.py | 18 +- core/tasks/engine.py | 37 + core/templates/base.html | 5 +- core/templates/pages/security.html | 695 ++++++++++++++++ core/templates/pages/tasks-hub.html | 37 +- core/templates/pages/tasks-settings.html | 57 ++ core/tests/test_attachment_security.py | 1 + core/tests/test_command_security_policy.py | 224 ++++++ core/tests/test_cross_platform_messaging.py | 382 +++++++++ core/tests/test_task_sync_worker_command.py | 9 + core/tests/test_xmpp_approval_commands.py | 175 ++++ core/tests/test_xmpp_omemo_support.py | 126 +++ core/views/system.py | 361 ++++++++- core/views/tasks.py | 89 ++- requirements.txt | 11 +- stack.env.example | 2 + 31 files changed, 3617 insertions(+), 58 deletions(-) create mode 100644 core/gateway/__init__.py create mode 100644 core/gateway/commands.py create mode 100644 core/management/commands/task_sync_worker.py create mode 100644 core/migrations/0038_userxmppomemostate_and_more.py create mode 100644 core/migrations/0039_userxmppsecuritysettings.py create mode 100644 core/migrations/0040_commandsecuritypolicy_gatewaycommandevent.py create mode 100644 core/security/command_policy.py create mode 100644 core/templates/pages/security.html create mode 100644 core/tests/test_command_security_policy.py create mode 100644 core/tests/test_cross_platform_messaging.py create mode 100644 core/tests/test_task_sync_worker_command.py create mode 100644 core/tests/test_xmpp_approval_commands.py create mode 100644 core/tests/test_xmpp_omemo_support.py diff --git a/CLAUDE.md b/CLAUDE.md index d4dfbc3..dae378b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -25,3 +25,24 @@ AI coding tools (Copilot, Claude) will reuse any values they see in context. A r Before committing test files, verify no identifier matches a real person: - No number outside the reserved fictitious ranges above - No name that corresponds to a real contact used as a literal identifier + +## Naming: Avoid Ambiguous Role Labels + +**Never use "User", "Bot", "Us", or "Them" as role labels without qualification — these terms are context-dependent and misleading in this codebase.** + +GIA acts in multiple roles simultaneously: +- It is a Django **User** (account holder) from the perspective of external services (XMPP, WhatsApp, Signal). +- It is a **component** (gateway/bot) from the perspective of contacts. +- The human who owns and operates the GIA instance is the **account holder** or **operator** (not "user", which collides with `User` model). +- Remote people the system communicates with are **contacts**. + +Preferred terms: + +| Avoid | Prefer | +| ------------------ | --------------------------------------------------------------- | +| "User" (ambiguous) | "account holder" or "operator" (for the Django `User`) | +| "Bot" | "component" or "gateway" (for the XMPP/transport layer) | +| "Us" | name the specific actor: "GIA", "the component", "the operator" | +| "Them" | "contact" or "remote party" | + +Apply this in: comments, template labels, log messages, and variable names. diff --git a/INSTALL.md b/INSTALL.md index aab81b0..e94d0ee 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -85,7 +85,7 @@ make auth Optional static token helper: ```bash -make token +make token TOKEN_USER= ``` ## 6) Logs and health checks diff --git a/Makefile b/Makefile index f01b7ed..cd03ff0 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,6 @@ QUADLET_MGR := ./scripts/quadlet/manage.sh MODULES ?= core.tests +TOKEN_USER ?= m STACK_ID_CLEAN := $(shell sid="$${GIA_STACK_ID:-$${STACK_ID:-}}"; sid=$$(printf "%s" "$$sid" | tr -cs 'a-zA-Z0-9._-' '-' | sed 's/^-*//; s/-*$$//'); printf "%s" "$$sid") STACK_SUFFIX := $(if $(STACK_ID_CLEAN),_$(STACK_ID_CLEAN),) APP_CONTAINER := gia$(STACK_SUFFIX) @@ -56,7 +57,7 @@ auth: token: @if podman ps --format '{{.Names}}' | grep -qx "$(APP_CONTAINER)"; then \ - podman exec "$(APP_CONTAINER)" sh -lc "cd /code && . /venv/bin/activate && python manage.py addstatictoken m"; \ + podman exec "$(APP_CONTAINER)" sh -lc "cd /code && . /venv/bin/activate && python manage.py addstatictoken $(TOKEN_USER)"; \ else \ echo "Container '$(APP_CONTAINER)' is not running. Start the stack first with 'make run'." >&2; \ exit 125; \ diff --git a/app/local_settings.py b/app/local_settings.py index e07639f..ab2548e 100644 --- a/app/local_settings.py +++ b/app/local_settings.py @@ -90,6 +90,7 @@ XMPP_JID = getenv("XMPP_JID") XMPP_USER_DOMAIN = getenv("XMPP_USER_DOMAIN", "") XMPP_PORT = int(getenv("XMPP_PORT", "8888") or 8888) XMPP_SECRET = getenv("XMPP_SECRET") +XMPP_OMEMO_DATA_DIR = getenv("XMPP_OMEMO_DATA_DIR", "") EVENT_LEDGER_DUAL_WRITE = getenv("EVENT_LEDGER_DUAL_WRITE", "false").lower() in trues CAPABILITY_ENFORCEMENT_ENABLED = ( diff --git a/app/urls.py b/app/urls.py index f26942c..9c0cb05 100644 --- a/app/urls.py +++ b/app/urls.py @@ -59,6 +59,11 @@ urlpatterns = [ notifications.NotificationsUpdate.as_view(), name="notifications_update", ), + path( + "settings/security/", + system.SecurityPage.as_view(), + name="security_settings", + ), path( "settings/system/", system.SystemSettings.as_view(), diff --git a/core/clients/xmpp.py b/core/clients/xmpp.py index f7cd76b..9a87176 100644 --- a/core/clients/xmpp.py +++ b/core/clients/xmpp.py @@ -1,9 +1,13 @@ import asyncio +import base64 +import json import mimetypes +import os import re import time import uuid -from urllib.parse import urlsplit +from pathlib import Path +from urllib.parse import parse_qs, urlparse, urlsplit import aiohttp from asgiref.sync import sync_to_async @@ -16,9 +20,18 @@ from slixmpp.xmlstream import register_stanza_plugin from slixmpp.xmlstream.stanzabase import ET from core.clients import ClientBase, transport +from core.gateway.commands import ( + GatewayCommandContext, + GatewayCommandRoute, + dispatch_gateway_command, +) from core.messaging import ai, history, replies, reply_sync, utils from core.models import ( ChatSession, + CodexPermissionRequest, + CodexRun, + DerivedTask, + ExternalSyncEvent, Manipulation, PatternMitigationAutoSettings, PatternMitigationCorrection, @@ -28,6 +41,7 @@ from core.models import ( Person, PersonIdentifier, User, + UserXmppOmemoState, WorkspaceConversation, ) from core.security.attachments import ( @@ -40,6 +54,7 @@ URL_PATTERN = re.compile(r"https?://[^\s<>'\"\\]+") EMOJI_ONLY_PATTERN = re.compile( r"^[\U0001F300-\U0001FAFF\u2600-\u27BF\uFE0F\u200D\u2640-\u2642\u2764]+$" ) +TOTP_BASE32_SECRET_RE = re.compile(r"^[A-Z2-7]{16,}$") def _clean_url(value): @@ -129,6 +144,135 @@ def _parse_greentext_reaction(body_text): return {"quoted_text": quoted, "emoji": emoji} +def _omemo_plugin_available() -> bool: + try: + import importlib + return importlib.util.find_spec("slixmpp_omemo") is not None + except Exception: + return False + + +def _extract_sender_omemo_client_key(stanza) -> dict: + """Extract OMEMO client key info from an encrypted stanza.""" + ns = "eu.siacs.conversations.axolotl" + header = stanza.xml.find(f".//{{{ns}}}header") + if header is None: + return {"status": "no_omemo"} + sid = str(header.attrib.get("sid") or "").strip() + key_el = header.find(f"{{{ns}}}key") + rid = str(key_el.attrib.get("rid") or "").strip() if key_el is not None else "" + if sid or rid: + return {"status": "detected", "client_key": f"sid:{sid},rid:{rid}"} + return {"status": "no_omemo"} + + +# --------------------------------------------------------------------------- +# OMEMO storage + plugin implementation +# --------------------------------------------------------------------------- + +try: + from omemo.storage import Just, Maybe, Nothing, Storage as _OmemoStorageBase + from slixmpp_omemo import XEP_0384 as _XEP_0384Base + from slixmpp_omemo.base_session_manager import TrustLevel as _OmemoTrustLevel + from slixmpp.plugins.base import register_plugin as _slixmpp_register_plugin + _OMEMO_AVAILABLE = True +except ImportError: + _OMEMO_AVAILABLE = False + _OmemoStorageBase = object + _XEP_0384Base = object + _OmemoTrustLevel = None + _slixmpp_register_plugin = None + + +if _OMEMO_AVAILABLE: + class _OmemoStorage(_OmemoStorageBase): + """JSON-file-backed OMEMO key storage.""" + + def __init__(self, path: str) -> None: + super().__init__() + self._path = path + try: + with open(path) as f: + self._data: dict = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + self._data = {} + + def _save(self) -> None: + os.makedirs(os.path.dirname(self._path), exist_ok=True) + with open(self._path, "w") as f: + json.dump(self._data, f) + + async def _load(self, key: str) -> Maybe: + if key in self._data: + return Just(self._data[key]) + return Nothing() + + async def _store(self, key: str, value) -> None: + self._data[key] = value + self._save() + + async def _delete(self, key: str) -> None: + self._data.pop(key, None) + self._save() + + class _GiaOmemoPlugin(_XEP_0384Base): + """Concrete XEP-0384 OMEMO plugin for the GIA XMPP gateway component. + + Uses BTBV (blind trust before verification) – appropriate for a + server-side bridge that processes messages on behalf of users. + """ + + name = "xep_0384" + description = "OMEMO Encryption (GIA gateway)" + dependencies = {"xep_0004", "xep_0030", "xep_0060", "xep_0163", "xep_0280", "xep_0334"} + default_config = { + "fallback_message": "This message is OMEMO encrypted.", + "data_dir": "", + } + + def plugin_init(self) -> None: + data_dir = str(self.config.get("data_dir") or "").strip() + if not data_dir: + data_dir = str(Path(settings.BASE_DIR) / "xmpp_omemo_data") + os.makedirs(data_dir, exist_ok=True) + self._storage_impl = _OmemoStorage(os.path.join(data_dir, "omemo.json")) + super().plugin_init() + + @property + def storage(self) -> _OmemoStorageBase: + return self._storage_impl + + @property + def _btbv_enabled(self) -> bool: + return True + + async def _devices_blindly_trusted(self, blindly_trusted, identifier): + import logging + logging.getLogger(__name__).info( + "OMEMO: blindly trusted %d new device(s)", len(blindly_trusted) + ) + + async def _prompt_manual_trust(self, manually_trusted, identifier): + """Auto-trust all undecided devices (gateway mode).""" + import logging + log = logging.getLogger(__name__) + log.info( + "OMEMO: auto-trusting %d undecided device(s) (gateway mode)", + len(manually_trusted), + ) + session_manager = await self.get_session_manager() + for device in manually_trusted: + try: + await session_manager.set_trust( + device.bare_jid, + device.device_id, + device.identity_key, + _OmemoTrustLevel.BLINDLY_TRUSTED.value, + ) + except Exception as exc: + log.warning("OMEMO set_trust failed for %s: %s", device.bare_jid, exc) + + class XMPPComponent(ComponentXMPP): """ @@ -147,6 +291,8 @@ class XMPPComponent(ComponentXMPP): self.log = logs.get_logger("XMPP") super().__init__(jid, secret, server, port) + # Enable message IDs so the OMEMO plugin can associate encrypted stanzas. + self.use_message_ids = True # Use one reconnect strategy (our backoff loop) to avoid reconnect churn. self.auto_reconnect = False # Register chat state plugins @@ -297,6 +443,470 @@ class XMPPComponent(ComponentXMPP): ) return plan + def _derived_omemo_fingerprint(self, jid: str) -> str: + import hashlib + return hashlib.sha256(f"xmpp-omemo-key:{jid}".encode()).hexdigest()[:32] + + def _get_omemo_plugin(self): + """Return the active XEP-0384 plugin instance, or None if not loaded.""" + try: + return self["xep_0384"] + except Exception: + return None + + async def _bootstrap_omemo_for_authentic_channel(self): + jid = str(getattr(settings, "XMPP_JID", "") or "").strip() + omemo_plugin = self._get_omemo_plugin() + omemo_enabled = omemo_plugin is not None + status = "active" if omemo_enabled else "not_available" + reason = "OMEMO plugin active" if omemo_enabled else "xep_0384 plugin not loaded" + fingerprint = self._derived_omemo_fingerprint(jid) + if omemo_enabled: + try: + import asyncio as _asyncio + session_manager = await _asyncio.wait_for( + omemo_plugin.get_session_manager(), timeout=15.0 + ) + own_devices = await session_manager.get_own_device_information() + if own_devices: + key_bytes = own_devices[0].identity_key + fingerprint = ":".join(f"{b:02X}" for b in key_bytes) + except Exception as exc: + self.log.warning("OMEMO: could not read own device fingerprint: %s", exc) + self.log.info( + "OMEMO bootstrap: jid=%s enabled=%s status=%s fingerprint=%s", + jid, omemo_enabled, status, fingerprint, + ) + transport.update_runtime_state( + "xmpp", + omemo_target_jid=jid, + omemo_fingerprint=fingerprint, + omemo_enabled=omemo_enabled, + omemo_status=status, + omemo_status_reason=reason, + ) + + async def _record_sender_omemo_state(self, user, *, sender_jid, recipient_jid, message_stanza): + parsed = _extract_sender_omemo_client_key(message_stanza) + status = str(parsed.get("status") or "no_omemo") + client_key = str(parsed.get("client_key") or "") + await sync_to_async(UserXmppOmemoState.objects.update_or_create)( + user=user, + defaults={ + "status": status, + "latest_client_key": client_key, + "last_sender_jid": str(sender_jid or ""), + "last_target_jid": str(recipient_jid or ""), + }, + ) + + _approval_event_prefix = "codex_approval" + + _APPROVAL_PROVIDER_COMMANDS = { + ".claude": "claude", + ".codex": "codex_cli", + } + + def _resolve_request_provider(self, request): + event = getattr(request, "external_sync_event", None) + if event is None: + return "" + return str(getattr(event, "provider", "") or "").strip() + + _ACTION_TO_STATUS = {"approve": "approved", "reject": "denied"} + + async def _apply_approval_decision(self, request, decision, sym): + status = self._ACTION_TO_STATUS.get(decision, decision) + request.status = status + await sync_to_async(request.save)(update_fields=["status"]) + run = None + if request.codex_run_id: + run = await sync_to_async(CodexRun.objects.get)(pk=request.codex_run_id) + run.status = "approved_waiting_resume" if status == "approved" else status + await sync_to_async(run.save)(update_fields=["status"]) + if request.external_sync_event_id: + evt = await sync_to_async(ExternalSyncEvent.objects.get)(pk=request.external_sync_event_id) + evt.status = "ok" + await sync_to_async(evt.save)(update_fields=["status"]) + user = await sync_to_async(User.objects.get)(pk=request.user_id) + task = None + if run is not None and run.task_id: + task = await sync_to_async(DerivedTask.objects.get)(pk=run.task_id) + ikey = f"{self._approval_event_prefix}:{request.approval_key}:{status}" + await sync_to_async(ExternalSyncEvent.objects.get_or_create)( + idempotency_key=ikey, + defaults={ + "user": user, + "task": task, + "provider": "codex_cli", + "status": "pending", + "payload": {}, + "error": "", + }, + ) + + async def _approval_list_pending(self, user, scope, sym): + requests = await sync_to_async(list)( + CodexPermissionRequest.objects.filter( + user=user, status="pending" + ).order_by("-requested_at")[:20] + ) + sym(f"pending={len(requests)}") + for req in requests: + sym(f" {req.approval_key}: {req.summary}") + + async def _approval_status(self, user, approval_key, sym): + try: + req = await sync_to_async( + CodexPermissionRequest.objects.get + )(user=user, approval_key=approval_key) + sym(f"status={req.status} key={req.approval_key}") + except CodexPermissionRequest.DoesNotExist: + sym(f"approval_key_not_found:{approval_key}") + + async def _handle_approval_command(self, user, body, sender_jid, sym): + command = body.strip() + for prefix, expected_provider in self._APPROVAL_PROVIDER_COMMANDS.items(): + if command.startswith(prefix + " ") or command == prefix: + sub = command[len(prefix):].strip() + parts = sub.split() + if len(parts) >= 2 and parts[0] in ("approve", "reject"): + action, approval_key = parts[0], parts[1] + try: + req = await sync_to_async( + CodexPermissionRequest.objects.select_related( + "external_sync_event" + ).get + )(user=user, approval_key=approval_key) + except CodexPermissionRequest.DoesNotExist: + sym(f"approval_key_not_found:{approval_key}") + return True + provider = self._resolve_request_provider(req) + if not provider.startswith(expected_provider): + sym(f"approval_key_not_for_provider:{approval_key} provider={provider}") + return True + await self._apply_approval_decision(req, action, sym) + sym(f"{action}d: {approval_key}") + return True + sym(f"usage: {prefix} approve|reject ") + return True + + if not command.startswith(".approval"): + return False + + rest = command[len(".approval"):].strip() + + if rest.split() and rest.split()[0] in ("approve", "reject"): + parts = rest.split() + action = parts[0] + approval_key = parts[1] if len(parts) > 1 else "" + if not approval_key: + sym("usage: .approval approve|reject ") + return True + try: + req = await sync_to_async( + CodexPermissionRequest.objects.select_related( + "external_sync_event" + ).get + )(user=user, approval_key=approval_key) + except CodexPermissionRequest.DoesNotExist: + sym(f"approval_key_not_found:{approval_key}") + return True + await self._apply_approval_decision(req, action, sym) + sym(f"{action}d: {approval_key}") + return True + + if rest.startswith("list-pending"): + scope = rest[len("list-pending"):].strip() or "mine" + await self._approval_list_pending(user, scope, sym) + return True + + if rest.startswith("status "): + approval_key = rest[len("status "):].strip() + await self._approval_status(user, approval_key, sym) + return True + + sym( + "approval: .approval approve|reject | " + ".approval list-pending [all] | " + ".approval status " + ) + return True + + async def _handle_tasks_command(self, user, body, sym): + command = body.strip() + if not command.startswith(".tasks"): + return False + rest = command[len(".tasks"):].strip() + + if rest.startswith("list"): + parts = rest.split() + status_filter = parts[1] if len(parts) > 1 else "open" + limit = int(parts[2]) if len(parts) > 2 and parts[2].isdigit() else 10 + tasks = await sync_to_async(list)( + DerivedTask.objects.filter( + user=user, status_snapshot=status_filter + ).order_by("-id")[:limit] + ) + if not tasks: + sym(f"no {status_filter} tasks") + else: + for t in tasks: + sym(f"#{t.reference_code} [{t.status_snapshot}] {t.title}") + return True + + if rest.startswith("show "): + ref = rest[len("show "):].strip().lstrip("#") + try: + task = await sync_to_async(DerivedTask.objects.get)( + user=user, reference_code=ref + ) + sym(f"#{task.reference_code} {task.title}") + sym(f"status: {task.status_snapshot}") + except DerivedTask.DoesNotExist: + sym(f"task_not_found:#{ref}") + return True + + if rest.startswith("complete "): + ref = rest[len("complete "):].strip().lstrip("#") + try: + task = await sync_to_async(DerivedTask.objects.get)( + user=user, reference_code=ref + ) + task.status_snapshot = "completed" + await sync_to_async(task.save)(update_fields=["status_snapshot"]) + sym(f"completed #{ref}") + except DerivedTask.DoesNotExist: + sym(f"task_not_found:#{ref}") + return True + + if rest.startswith("undo "): + ref = rest[len("undo "):].strip().lstrip("#") + try: + task = await sync_to_async(DerivedTask.objects.get)( + user=user, reference_code=ref + ) + await sync_to_async(task.delete)() + sym(f"removed #{ref}") + except DerivedTask.DoesNotExist: + sym(f"task_not_found:#{ref}") + return True + + sym( + "tasks: .tasks list [status] [limit] | " + ".tasks show # | " + ".tasks complete # | " + ".tasks undo #" + ) + return True + + def _extract_totp_secret_candidate(self, command_text: str) -> str: + text = str(command_text or "").strip() + if not text: + return "" + lowered = text.lower() + if lowered.startswith("otpauth://"): + parsed = urlparse(text) + query = parse_qs(parsed.query or "") + return str((query.get("secret") or [""])[0] or "").strip() + if lowered.startswith(".totp"): + rest = text[len(".totp"):].strip() + if not rest: + return "" + parts = rest.split(maxsplit=1) + action = str(parts[0] or "").strip().lower() + if action in {"enroll", "set"} and len(parts) > 1: + return str(parts[1] or "").strip() + if action in {"status", "help"}: + return "" + return rest + compact = text.replace(" ", "").strip().upper() + if TOTP_BASE32_SECRET_RE.match(compact): + return compact + return "" + + async def _handle_totp_command(self, user, body, sym): + command = str(body or "").strip() + lowered = command.lower() + if lowered.startswith(".totp status"): + exists = await sync_to_async( + lambda: __import__( + "django_otp.plugins.otp_totp.models", + fromlist=["TOTPDevice"], + ) + .TOTPDevice.objects.filter(user=user, confirmed=True) + .exists() + )() + sym("totp: configured" if exists else "totp: not configured") + return True + if lowered == ".totp help": + sym("totp: .totp enroll | .totp status") + return True + + secret_candidate = self._extract_totp_secret_candidate(command) + if not secret_candidate: + if lowered.startswith(".totp"): + sym("usage: .totp enroll ") + return True + return False + + normalized = str(secret_candidate).replace(" ", "").strip().upper() + try: + key_bytes = base64.b32decode(normalized, casefold=True) + except Exception: + sym("totp: invalid secret format") + return True + if len(key_bytes) < 10: + sym("totp: secret too short") + return True + + def _save_device(): + from django_otp.plugins.otp_totp.models import TOTPDevice + + device = ( + TOTPDevice.objects.filter(user=user) + .order_by("-id") + .first() + ) + if device is None: + device = TOTPDevice(user=user, name="gateway") + device.key = key_bytes.hex() + device.confirmed = True + device.step = 30 + device.t0 = 0 + device.digits = 6 + device.tolerance = 1 + device.drift = 0 + device.save() + return device.name + + device_name = await sync_to_async(_save_device)() + sym(f"totp: enrolled for user={user.username} device={device_name}") + return True + + async def _route_gateway_command( + self, + *, + sender_user, + body, + sender_jid, + recipient_jid, + local_message, + message_meta, + sym, + ): + command_text = str(body or "").strip() + + async def _contacts_handler(_ctx, emit): + persons = await sync_to_async(list)(Person.objects.filter(user=sender_user).order_by("name")) + if not persons: + emit("No contacts found.") + return True + emit("Contacts: " + ", ".join([p.name for p in persons])) + return True + + async def _help_handler(_ctx, emit): + for line in self._gateway_help_lines(): + emit(line) + return True + + async def _whoami_handler(_ctx, emit): + emit(str(sender_user.__dict__)) + return True + + async def _approval_handler(_ctx, emit): + return await self._handle_approval_command(sender_user, command_text, sender_jid, emit) + + async def _tasks_handler(_ctx, emit): + return await self._handle_tasks_command(sender_user, command_text, emit) + + async def _totp_handler(_ctx, emit): + return await self._handle_totp_command(sender_user, command_text, emit) + + routes = [ + GatewayCommandRoute( + name="contacts", + scope_key="gateway.contacts", + matcher=lambda text: str(text or "").strip().lower() == ".contacts", + handler=_contacts_handler, + ), + GatewayCommandRoute( + name="help", + scope_key="gateway.help", + matcher=lambda text: str(text or "").strip().lower() == ".help", + handler=_help_handler, + ), + GatewayCommandRoute( + name="whoami", + scope_key="gateway.whoami", + matcher=lambda text: str(text or "").strip().lower() == ".whoami", + handler=_whoami_handler, + ), + GatewayCommandRoute( + name="approval", + scope_key="gateway.approval", + matcher=lambda text: str(text or "").strip().lower().startswith(".approval") + or any( + str(text or "").strip().lower().startswith(prefix + " ") + or str(text or "").strip().lower() == prefix + for prefix in self._APPROVAL_PROVIDER_COMMANDS + ), + handler=_approval_handler, + ), + GatewayCommandRoute( + name="tasks", + scope_key="gateway.tasks", + matcher=lambda text: str(text or "").strip().lower().startswith(".tasks"), + handler=_tasks_handler, + ), + GatewayCommandRoute( + name="totp", + scope_key="gateway.totp", + matcher=lambda text: bool(self._extract_totp_secret_candidate(text)), + handler=_totp_handler, + ), + ] + handled = await dispatch_gateway_command( + context=GatewayCommandContext( + user=sender_user, + source_message=local_message, + service="xmpp", + channel_identifier=str(sender_jid or ""), + sender_identifier=str(sender_jid or ""), + message_text=command_text, + message_meta=dict(message_meta or {}), + payload={ + "sender_jid": str(sender_jid or ""), + "recipient_jid": str(recipient_jid or ""), + }, + ), + routes=routes, + emit=sym, + ) + if not handled and command_text.startswith("."): + sym("No such command") + return handled + + def _gateway_help_lines(self): + return [ + "Gateway commands:", + " .contacts — list contacts", + " .whoami — show current user", + " .help — show this help", + " .totp enroll — enroll TOTP for this user", + " .totp status — show whether TOTP is configured", + "Approval commands:", + " .approval list-pending [all] — list pending approval requests", + " .approval approve — approve a request", + " .approval reject — reject a request", + " .approval status — check request status", + "Task commands:", + " .tasks list [status] [limit] — list tasks", + " .tasks show # — show task details", + " .tasks complete # — mark task complete", + " .tasks undo # — remove task", + ] + async def _handle_mitigation_command(self, sender_user, body, sym): def parse_parts(raw): return [part.strip() for part in raw.split("|")] @@ -855,6 +1465,7 @@ class XMPPComponent(ComponentXMPP): # This client connects as an external component, not a user client; # XEP-0280 (carbons) is client-scoped and not valid here. self.log.debug("Skipping carbons enable for component session") + await self._bootstrap_omemo_for_authentic_channel() async def _reconnect_loop(self): try: @@ -1031,6 +1642,18 @@ class XMPPComponent(ComponentXMPP): recipient_username = recipient_jid recipient_domain = recipient_jid + # Attempt to decrypt OMEMO-encrypted messages before body extraction. + original_msg = msg + omemo_plugin = self._get_omemo_plugin() + if omemo_plugin: + try: + if omemo_plugin.is_encrypted(msg): + decrypted, _ = await omemo_plugin.decrypt_message(msg) + msg = decrypted + self.log.debug("OMEMO: decrypted message from %s", sender_jid) + except Exception as exc: + self.log.warning("OMEMO: decryption failed from %s: %s", sender_jid, exc) + # Extract message body body = msg["body"] if msg["body"] else "" parsed_reaction = _extract_xmpp_reaction(msg) @@ -1157,36 +1780,55 @@ class XMPPComponent(ComponentXMPP): self.log.warning(f"Unknown sender: {sender_username}") return + # Record the sender's OMEMO state (uses the original, pre-decryption stanza). + try: + await self._record_sender_omemo_state( + sender_user, + sender_jid=sender_jid, + recipient_jid=recipient_jid, + message_stanza=original_msg, + ) + except Exception as exc: + self.log.warning("OMEMO: failed to record sender state: %s", exc) + omemo_observation = _extract_sender_omemo_client_key(original_msg) + + # Enforce mandatory encryption policy. + try: + from core.models import UserXmppSecuritySettings + sec_settings = await sync_to_async( + lambda: UserXmppSecuritySettings.objects.filter(user=sender_user).first() + )() + if sec_settings and sec_settings.require_omemo: + omemo_status = str(omemo_observation.get("status") or "") + if omemo_status != "detected": + sym( + "⚠ This gateway requires OMEMO encryption. " + "Your message was not delivered. " + "Please enable OMEMO in your XMPP client." + ) + return + except Exception as exc: + self.log.warning("OMEMO policy check failed: %s", exc) + if recipient_jid == settings.XMPP_JID: self.log.debug("Handling command message sent to gateway JID") - if body.startswith("."): - # Messaging the gateway directly - if body == ".contacts": - # Lookup Person objects linked to sender - persons = Person.objects.filter(user=sender_user) - if not persons.exists(): - self.log.debug("No contacts found for %s", sender_username) - sym("No contacts found.") - return - - # Construct contact list response - contact_names = [person.name for person in persons] - response_text = "Contacts: " + ", ".join(contact_names) - sym(response_text) - elif body == ".help": - sym("Commands: .contacts, .whoami, .mitigation help") - elif body.startswith(".mitigation"): - handled = await self._handle_mitigation_command( - sender_user, - body, - sym, - ) - if not handled: - sym("Unknown mitigation command. Try .mitigation help") - elif body == ".whoami": - sym(str(sender_user.__dict__)) - else: - sym("No such command") + if body.startswith(".") or self._extract_totp_secret_candidate(body): + await self._route_gateway_command( + sender_user=sender_user, + body=body, + sender_jid=sender_jid, + recipient_jid=recipient_jid, + local_message=None, + message_meta={ + "xmpp": { + "sender_jid": str(sender_jid or ""), + "recipient_jid": str(recipient_jid or ""), + "omemo_status": str(omemo_observation.get("status") or ""), + "omemo_client_key": str(omemo_observation.get("client_key") or ""), + } + }, + sym=sym, + ) else: self.log.debug("Handling routed message to contact") if "|" in recipient_username: @@ -1357,7 +1999,14 @@ class XMPPComponent(ComponentXMPP): reply_source_message_id=str( reply_ref.get("reply_source_message_id") or "" ), - message_meta={}, + message_meta={ + "xmpp": { + "sender_jid": str(sender_jid or ""), + "recipient_jid": str(recipient_jid or ""), + "omemo_status": str(omemo_observation.get("status") or ""), + "omemo_client_key": str(omemo_observation.get("client_key") or ""), + } + }, ) self.log.debug("Stored outbound XMPP message in history") await self.ur.message_received( @@ -1513,6 +2162,32 @@ class XMPPComponent(ComponentXMPP): msg.xml.append(oob_element) self.log.debug("Sending XMPP message: %s", msg.xml) + + # Attempt OMEMO encryption for text-only messages (not attachments). + if not attachment_url: + omemo_plugin = self._get_omemo_plugin() + if omemo_plugin: + try: + from slixmpp.jid import JID as _JID + encrypted_msgs, enc_errors = await omemo_plugin.encrypt_message( + msg, _JID(recipient_jid) + ) + if enc_errors: + self.log.debug( + "OMEMO: non-critical encryption errors for %s: %s", + recipient_jid, enc_errors, + ) + if encrypted_msgs: + for enc_msg in encrypted_msgs.values(): + enc_msg.send() + self.log.debug("OMEMO: sent encrypted message to %s", recipient_jid) + return msg_id + except Exception as exc: + self.log.debug( + "OMEMO: encryption not available for %s, sending plaintext: %s", + recipient_jid, exc, + ) + msg.send() return msg_id @@ -1834,6 +2509,23 @@ class XMPPClient(ClientBase): self.client.register_plugin("xep_0085") # Chat State Notifications self.client.register_plugin("xep_0363") # HTTP File Upload + self._omemo_plugin_registered = False + if _OMEMO_AVAILABLE: + try: + data_dir = str(getattr(settings, "XMPP_OMEMO_DATA_DIR", "") or "").strip() + if not data_dir: + data_dir = str(Path(settings.BASE_DIR) / "xmpp_omemo_data") + # Register our concrete plugin class under the "xep_0384" name so + # that slixmpp's dependency resolver finds it. + _slixmpp_register_plugin(_GiaOmemoPlugin) + self.client.register_plugin("xep_0384", pconfig={"data_dir": data_dir}) + self._omemo_plugin_registered = True + self.log.info("OMEMO: xep_0384 plugin registered, data_dir=%s", data_dir) + except Exception as exc: + self.log.warning("OMEMO: failed to register xep_0384 plugin: %s", exc) + else: + self.log.warning("OMEMO: slixmpp_omemo not available, OMEMO disabled") + def start(self): if not self._enabled or self.client is None: return diff --git a/core/commands/engine.py b/core/commands/engine.py index cfa54c5..8d9f655 100644 --- a/core/commands/engine.py +++ b/core/commands/engine.py @@ -16,6 +16,7 @@ from core.commands.registry import get as get_handler from core.commands.registry import register from core.messaging.reply_sync import is_mirrored_origin from core.models import CommandAction, CommandChannelBinding, CommandProfile, Message +from core.security.command_policy import CommandSecurityContext, evaluate_command_policy from core.tasks.chat_defaults import ensure_default_source_for_chat from core.util import logs @@ -318,12 +319,21 @@ def _matches_trigger(profile: CommandProfile, text: str) -> bool: async def process_inbound_message(ctx: CommandContext) -> list[CommandResult]: ensure_handlers_registered() trigger_message = await sync_to_async( - lambda: Message.objects.filter(id=ctx.message_id).first() + lambda: Message.objects.select_related("user", "session", "session__identifier") + .filter(id=ctx.message_id) + .first() )() if trigger_message is None: return [] if is_mirrored_origin(trigger_message.message_meta): return [] + effective_service, effective_channel = _effective_bootstrap_scope(ctx, trigger_message) + security_context = CommandSecurityContext( + service=effective_service, + channel_identifier=effective_channel, + message_meta=dict(getattr(trigger_message, "message_meta", {}) or {}), + payload=dict(ctx.payload or {}), + ) await sync_to_async(_auto_setup_profile_bindings_for_first_command)( ctx, trigger_message, @@ -334,6 +344,25 @@ async def process_inbound_message(ctx: CommandContext) -> list[CommandResult]: for profile in profiles: if not _matches_trigger(profile, ctx.message_text): continue + decision = await sync_to_async(evaluate_command_policy)( + user=trigger_message.user, + scope_key=f"command.{profile.slug}", + context=security_context, + ) + if not decision.allowed: + results.append( + CommandResult( + ok=False, + status="skipped", + error=f"policy_denied:{decision.code}", + payload={ + "profile": profile.slug, + "scope": f"command.{profile.slug}", + "reason": decision.reason, + }, + ) + ) + continue if profile.reply_required and trigger_message.reply_to_id is None: if ( profile.slug == "bp" diff --git a/core/gateway/__init__.py b/core/gateway/__init__.py new file mode 100644 index 0000000..07190b8 --- /dev/null +++ b/core/gateway/__init__.py @@ -0,0 +1 @@ +"""Gateway command routing utilities.""" diff --git a/core/gateway/commands.py b/core/gateway/commands.py new file mode 100644 index 0000000..56d69da --- /dev/null +++ b/core/gateway/commands.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Awaitable, Callable + +from asgiref.sync import sync_to_async + +from core.models import GatewayCommandEvent +from core.security.command_policy import CommandSecurityContext, evaluate_command_policy + + +GatewayEmit = Callable[[str], None] +GatewayHandler = Callable[["GatewayCommandContext", GatewayEmit], Awaitable[bool]] +GatewayMatcher = Callable[[str], bool] + + +@dataclass(slots=True) +class GatewayCommandContext: + user: object + source_message: object + service: str + channel_identifier: str + sender_identifier: str + message_text: str + message_meta: dict + payload: dict + + +@dataclass(slots=True) +class GatewayCommandRoute: + name: str + scope_key: str + matcher: GatewayMatcher + handler: GatewayHandler + + +def _first_token(text: str) -> str: + body = str(text or "").strip() + if not body: + return "" + return str(body.split()[0] or "").strip().lower() + + +def _derive_unknown_scope(text: str) -> str: + token = _first_token(text).lstrip(".") + if not token: + token = "message" + return f"gateway.{token}" + + +async def dispatch_gateway_command( + *, + context: GatewayCommandContext, + routes: list[GatewayCommandRoute], + emit: GatewayEmit, +) -> bool: + text = str(context.message_text or "").strip() + if not text: + return False + + route = next((row for row in routes if row.matcher(text)), None) + scope_key = route.scope_key if route is not None else _derive_unknown_scope(text) + command_name = route.name if route is not None else _first_token(text).lstrip(".") + + event = await sync_to_async(GatewayCommandEvent.objects.create)( + user=context.user, + source_message=context.source_message, + service=str(context.service or "").strip().lower() or "xmpp", + channel_identifier=str(context.channel_identifier or "").strip(), + sender_identifier=str(context.sender_identifier or "").strip(), + scope_key=scope_key, + command_name=command_name, + command_text=text, + status="pending", + request_meta={ + "payload": dict(context.payload or {}), + "message_meta": dict(context.message_meta or {}), + }, + ) + + if route is None: + event.status = "ignored" + event.error = "unmatched_gateway_command" + await sync_to_async(event.save)(update_fields=["status", "error", "updated_at"]) + return False + + decision = await sync_to_async(evaluate_command_policy)( + user=context.user, + scope_key=scope_key, + context=CommandSecurityContext( + service=context.service, + channel_identifier=context.channel_identifier, + message_meta=dict(context.message_meta or {}), + payload=dict(context.payload or {}), + ), + ) + if not decision.allowed: + message = ( + f"blocked by policy: {decision.code}" + if not decision.reason + else f"blocked by policy: {decision.reason}" + ) + emit(message) + event.status = "blocked" + event.error = f"{decision.code}:{decision.reason}" + event.response_meta = {"policy_code": decision.code, "policy_reason": decision.reason} + await sync_to_async(event.save)( + update_fields=["status", "error", "response_meta", "updated_at"] + ) + return True + + responses: list[str] = [] + + def _captured_emit(value: str) -> None: + row = str(value or "") + responses.append(row) + emit(row) + + try: + handled = await route.handler(context, _captured_emit) + except Exception as exc: + event.status = "failed" + event.error = f"handler_exception:{exc}" + event.response_meta = {"responses": responses} + await sync_to_async(event.save)( + update_fields=["status", "error", "response_meta", "updated_at"] + ) + return True + + event.status = "ok" if handled else "ignored" + event.response_meta = {"responses": responses} + await sync_to_async(event.save)(update_fields=["status", "response_meta", "updated_at"]) + return bool(handled) diff --git a/core/management/commands/task_sync_worker.py b/core/management/commands/task_sync_worker.py new file mode 100644 index 0000000..df48011 --- /dev/null +++ b/core/management/commands/task_sync_worker.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from core.management.commands.codex_worker import Command as LegacyCodexWorkerCommand + + +class Command(LegacyCodexWorkerCommand): + help = "Process queued task-sync events for worker-backed providers (Codex + Claude)." diff --git a/core/migrations/0038_userxmppomemostate_and_more.py b/core/migrations/0038_userxmppomemostate_and_more.py new file mode 100644 index 0000000..930b0f2 --- /dev/null +++ b/core/migrations/0038_userxmppomemostate_and_more.py @@ -0,0 +1,33 @@ +# Generated by Django 5.2.11 on 2026-03-06 20:42 + +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0037_derivedtask_due_date_assignee_identifier'), + ] + + operations = [ + migrations.CreateModel( + name='UserXmppOmemoState', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('status', models.CharField(choices=[('pending', 'Pending'), ('detected', 'Detected'), ('no_omemo', 'No OMEMO'), ('error', 'Error')], default='pending', max_length=32)), + ('latest_client_key', models.CharField(blank=True, default='', max_length=255)), + ('last_sender_jid', models.CharField(blank=True, default='', max_length=255)), + ('last_target_jid', models.CharField(blank=True, default='', max_length=255)), + ('status_reason', models.TextField(blank=True, default='')), + ('details', models.JSONField(blank=True, default=dict)), + ('last_seen_at', models.DateTimeField(blank=True, null=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('user', models.OneToOneField(on_delete=models.deletion.CASCADE, related_name='xmpp_omemo_state', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'indexes': [models.Index(fields=['status', 'updated_at'], name='core_userxm_status_133ead_idx')], + }, + ), + ] diff --git a/core/migrations/0039_userxmppsecuritysettings.py b/core/migrations/0039_userxmppsecuritysettings.py new file mode 100644 index 0000000..48f29fe --- /dev/null +++ b/core/migrations/0039_userxmppsecuritysettings.py @@ -0,0 +1,27 @@ +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0038_userxmppomemostate_and_more'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name='UserXmppSecuritySettings', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('require_omemo', models.BooleanField(default=False)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('user', models.OneToOneField( + on_delete=models.deletion.CASCADE, + related_name='xmpp_security_settings', + to=settings.AUTH_USER_MODEL, + )), + ], + ), + ] diff --git a/core/migrations/0040_commandsecuritypolicy_gatewaycommandevent.py b/core/migrations/0040_commandsecuritypolicy_gatewaycommandevent.py new file mode 100644 index 0000000..a0b273f --- /dev/null +++ b/core/migrations/0040_commandsecuritypolicy_gatewaycommandevent.py @@ -0,0 +1,100 @@ +# Generated by Django 4.2.19 on 2026-03-07 00:00 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("core", "0039_userxmppsecuritysettings"), + ] + + operations = [ + migrations.CreateModel( + name="CommandSecurityPolicy", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("scope_key", models.CharField(default="gateway.tasks", max_length=64)), + ("enabled", models.BooleanField(default=True)), + ("require_omemo", models.BooleanField(default=False)), + ("require_trusted_omemo_fingerprint", models.BooleanField(default=False)), + ("allowed_services", models.JSONField(blank=True, default=list)), + ("allowed_channels", models.JSONField(blank=True, default=dict)), + ("settings", models.JSONField(blank=True, default=dict)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="command_security_policies", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "indexes": [ + models.Index(fields=["user", "scope_key"], name="core_comman_user_id_701379_idx"), + models.Index( + fields=["user", "enabled", "updated_at"], + name="core_comman_user_id_82e21d_idx", + ), + ], + "constraints": [ + models.UniqueConstraint( + fields=("user", "scope_key"), + name="unique_command_security_policy_per_scope", + ) + ], + }, + ), + migrations.CreateModel( + name="GatewayCommandEvent", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("service", models.CharField(choices=[("signal", "Signal"), ("whatsapp", "WhatsApp"), ("xmpp", "XMPP"), ("instagram", "Instagram"), ("web", "Web")], max_length=255)), + ("channel_identifier", models.CharField(blank=True, default="", max_length=255)), + ("sender_identifier", models.CharField(blank=True, default="", max_length=255)), + ("scope_key", models.CharField(blank=True, default="", max_length=64)), + ("command_name", models.CharField(blank=True, default="", max_length=64)), + ("command_text", models.TextField(blank=True, default="")), + ("status", models.CharField(choices=[("pending", "Pending"), ("blocked", "Blocked"), ("ok", "OK"), ("failed", "Failed"), ("ignored", "Ignored")], default="pending", max_length=32)), + ("error", models.TextField(blank=True, default="")), + ("request_meta", models.JSONField(blank=True, default=dict)), + ("response_meta", models.JSONField(blank=True, default=dict)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "source_message", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="gateway_command_events", + to="core.message", + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="gateway_command_events", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "indexes": [ + models.Index( + fields=["user", "scope_key", "created_at"], + name="core_gatewa_user_id_d997cf_idx", + ), + models.Index( + fields=["user", "status", "created_at"], + name="core_gatewa_user_id_639afe_idx", + ), + ], + }, + ), + ] diff --git a/core/models.py b/core/models.py index 54dbb03..7dbd1b0 100644 --- a/core/models.py +++ b/core/models.py @@ -2100,6 +2100,76 @@ class CommandRun(models.Model): indexes = [models.Index(fields=["user", "status", "updated_at"])] +class CommandSecurityPolicy(models.Model): + user = models.ForeignKey( + User, + on_delete=models.CASCADE, + related_name="command_security_policies", + ) + scope_key = models.CharField(max_length=64, default="gateway.tasks") + enabled = models.BooleanField(default=True) + require_omemo = models.BooleanField(default=False) + require_trusted_omemo_fingerprint = models.BooleanField(default=False) + allowed_services = models.JSONField(default=list, blank=True) + allowed_channels = models.JSONField(default=dict, blank=True) + settings = models.JSONField(default=dict, blank=True) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["user", "scope_key"], + name="unique_command_security_policy_per_scope", + ) + ] + indexes = [ + models.Index(fields=["user", "scope_key"]), + models.Index(fields=["user", "enabled", "updated_at"]), + ] + + +class GatewayCommandEvent(models.Model): + STATUS_CHOICES = ( + ("pending", "Pending"), + ("blocked", "Blocked"), + ("ok", "OK"), + ("failed", "Failed"), + ("ignored", "Ignored"), + ) + + user = models.ForeignKey( + User, + on_delete=models.CASCADE, + related_name="gateway_command_events", + ) + source_message = models.ForeignKey( + Message, + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="gateway_command_events", + ) + service = models.CharField(max_length=255, choices=CHANNEL_SERVICE_CHOICES) + channel_identifier = models.CharField(max_length=255, blank=True, default="") + sender_identifier = models.CharField(max_length=255, blank=True, default="") + scope_key = models.CharField(max_length=64, blank=True, default="") + command_name = models.CharField(max_length=64, blank=True, default="") + command_text = models.TextField(blank=True, default="") + status = models.CharField(max_length=32, choices=STATUS_CHOICES, default="pending") + error = models.TextField(blank=True, default="") + request_meta = models.JSONField(default=dict, blank=True) + response_meta = models.JSONField(default=dict, blank=True) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + indexes = [ + models.Index(fields=["user", "scope_key", "created_at"]), + models.Index(fields=["user", "status", "created_at"]), + ] + + class TranslationBridge(models.Model): DIRECTION_CHOICES = ( ("a_to_b", "A To B"), @@ -2815,6 +2885,46 @@ class ExternalChatLink(models.Model): ] +class UserXmppOmemoState(models.Model): + STATUS_CHOICES = ( + ("pending", "Pending"), + ("detected", "Detected"), + ("no_omemo", "No OMEMO"), + ("error", "Error"), + ) + + user = models.OneToOneField( + User, + on_delete=models.CASCADE, + related_name="xmpp_omemo_state", + ) + status = models.CharField(max_length=32, choices=STATUS_CHOICES, default="pending") + latest_client_key = models.CharField(max_length=255, blank=True, default="") + last_sender_jid = models.CharField(max_length=255, blank=True, default="") + last_target_jid = models.CharField(max_length=255, blank=True, default="") + status_reason = models.TextField(blank=True, default="") + details = models.JSONField(blank=True, default=dict) + last_seen_at = models.DateTimeField(blank=True, null=True) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + indexes = [ + models.Index(fields=["status", "updated_at"], name="core_userxm_status_133ead_idx"), + ] + + +class UserXmppSecuritySettings(models.Model): + user = models.OneToOneField( + User, + on_delete=models.CASCADE, + related_name="xmpp_security_settings", + ) + require_omemo = models.BooleanField(default=False) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class TaskCompletionPattern(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="task_completion_patterns") diff --git a/core/security/command_policy.py b/core/security/command_policy.py new file mode 100644 index 0000000..cdccf12 --- /dev/null +++ b/core/security/command_policy.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from core.models import CommandSecurityPolicy, UserXmppOmemoState + +GLOBAL_SCOPE_KEY = "global.override" +OVERRIDE_OPTIONS = {"per_scope", "on", "off"} + + +@dataclass(slots=True) +class CommandSecurityContext: + service: str + channel_identifier: str + message_meta: dict + payload: dict + + +@dataclass(slots=True) +class CommandPolicyDecision: + allowed: bool + code: str = "allowed" + reason: str = "" + + +def _normalize_service(value: str) -> str: + return str(value or "").strip().lower() + + +def _normalize_channel(value: str) -> str: + return str(value or "").strip() + + +def _normalize_list(values) -> list[str]: + rows: list[str] = [] + if not isinstance(values, list): + return rows + for row in values: + item = str(row or "").strip() + if item and item not in rows: + rows.append(item) + return rows + + +def _parse_override_value(value: str) -> str: + option = str(value or "").strip().lower() + if option == "inherit": + # Backward compatibility for previously saved values. + option = "per_scope" + if option in OVERRIDE_OPTIONS: + return option + return "per_scope" + + +def _match_channel(rule: str, channel: str) -> bool: + value = str(rule or "").strip() + current = str(channel or "").strip() + if not value: + return False + if value == "*": + return True + if value.endswith("*"): + return current.startswith(value[:-1]) + return current == value + + +def _omemo_facts(ctx: CommandSecurityContext) -> tuple[str, str]: + message_meta = dict(ctx.message_meta or {}) + payload = dict(ctx.payload or {}) + xmpp_meta = dict(message_meta.get("xmpp") or {}) + status = str( + xmpp_meta.get("omemo_status") + or payload.get("omemo_status") + or "" + ).strip().lower() + client_key = str( + xmpp_meta.get("omemo_client_key") + or payload.get("omemo_client_key") + or "" + ).strip() + return status, client_key + + +def _channel_allowed_for_rules(rules: dict, service: str, channel: str) -> bool: + service_rules = _normalize_list(rules.get(service)) + if not service_rules: + service_rules = _normalize_list(rules.get("*")) + if not service_rules: + return True + return any(_match_channel(rule, channel) for rule in service_rules) + + +def _service_allowed(allowed_services: list[str], service: str) -> bool: + if not allowed_services: + return True + return service in allowed_services + + +def _effective_bool(local_value: bool, global_override: str) -> bool: + option = _parse_override_value(global_override) + if option == "on": + return True + if option == "off": + return False + return bool(local_value) + + +def evaluate_command_policy( + *, + user, + scope_key: str, + context: CommandSecurityContext, +) -> CommandPolicyDecision: + scope = str(scope_key or "").strip().lower() + if not scope: + return CommandPolicyDecision(allowed=True) + + policy = ( + CommandSecurityPolicy.objects.filter( + user=user, + scope_key=scope, + ) + .order_by("-updated_at") + .first() + ) + global_policy = ( + CommandSecurityPolicy.objects.filter( + user=user, + scope_key=GLOBAL_SCOPE_KEY, + ) + .order_by("-updated_at") + .first() + ) + + if policy is None and global_policy is None: + return CommandPolicyDecision(allowed=True) + + global_settings = dict(getattr(global_policy, "settings", {}) or {}) + local_enabled = bool(getattr(policy, "enabled", True)) + local_require_omemo = bool(getattr(policy, "require_omemo", False)) + local_require_trusted = bool( + getattr(policy, "require_trusted_omemo_fingerprint", False) + ) + enabled = _effective_bool(local_enabled, global_settings.get("scope_enabled")) + require_omemo = _effective_bool( + local_require_omemo, global_settings.get("require_omemo") + ) + require_trusted_omemo_fingerprint = _effective_bool( + local_require_trusted, + global_settings.get("require_trusted_fingerprint"), + ) + + if not enabled: + return CommandPolicyDecision( + allowed=False, + code="policy_disabled", + reason=f"{scope} is disabled by command policy", + ) + + service = _normalize_service(context.service) + channel = _normalize_channel(context.channel_identifier) + allowed_services = [ + item.lower() for item in _normalize_list(getattr(policy, "allowed_services", [])) + ] + global_allowed_services = [ + item.lower() + for item in _normalize_list(getattr(global_policy, "allowed_services", [])) + ] + if not _service_allowed(allowed_services, service): + return CommandPolicyDecision( + allowed=False, + code="service_not_allowed", + reason=f"service={service or '-'} not allowed for scope={scope}", + ) + if not _service_allowed(global_allowed_services, service): + return CommandPolicyDecision( + allowed=False, + code="service_not_allowed", + reason=f"service={service or '-'} not allowed by global override", + ) + local_channel_rules = dict(getattr(policy, "allowed_channels", {}) or {}) + if not _channel_allowed_for_rules(local_channel_rules, service, channel): + return CommandPolicyDecision( + allowed=False, + code="channel_not_allowed", + reason=f"channel={channel or '-'} not allowed for scope={scope}", + ) + global_channel_rules = dict(getattr(global_policy, "allowed_channels", {}) or {}) + if not _channel_allowed_for_rules(global_channel_rules, service, channel): + return CommandPolicyDecision( + allowed=False, + code="channel_not_allowed", + reason=f"channel={channel or '-'} not allowed by global override", + ) + + omemo_status, omemo_client_key = _omemo_facts(context) + if require_omemo and omemo_status != "detected": + return CommandPolicyDecision( + allowed=False, + code="omemo_required", + reason=f"scope={scope} requires OMEMO", + ) + + if require_trusted_omemo_fingerprint: + if omemo_status != "detected" or not omemo_client_key: + return CommandPolicyDecision( + allowed=False, + code="trusted_fingerprint_required", + reason=f"scope={scope} requires trusted OMEMO fingerprint", + ) + state = UserXmppOmemoState.objects.filter(user=user).first() + expected_key = str(getattr(state, "latest_client_key", "") or "").strip() + if not expected_key or expected_key != omemo_client_key: + return CommandPolicyDecision( + allowed=False, + code="fingerprint_mismatch", + reason=f"scope={scope} OMEMO fingerprint does not match enrolled key", + ) + + return CommandPolicyDecision(allowed=True) diff --git a/core/tasks/codex_approval.py b/core/tasks/codex_approval.py index 84aec49..fb142f8 100644 --- a/core/tasks/codex_approval.py +++ b/core/tasks/codex_approval.py @@ -22,7 +22,9 @@ def queue_codex_event_with_pre_approval( action: str, provider_payload: dict, idempotency_key: str, + provider: str = "codex_cli", ) -> tuple[ExternalSyncEvent, CodexPermissionRequest]: + provider = str(provider or "codex_cli").strip() or "codex_cli" approval_key = _deterministic_approval_key(idempotency_key) waiting_event, _ = ExternalSyncEvent.objects.update_or_create( idempotency_key=f"codex_waiting:{idempotency_key}", @@ -30,7 +32,7 @@ def queue_codex_event_with_pre_approval( "user": user, "task": task, "task_event": task_event, - "provider": "codex_cli", + "provider": provider, "status": "waiting_approval", "payload": { "action": str(action or "append_update"), @@ -43,16 +45,18 @@ def queue_codex_event_with_pre_approval( run.error = "" run.save(update_fields=["status", "error", "updated_at"]) + provider_label = "Claude" if provider == "claude_cli" else "Codex" + xmpp_cmd = ".claude" if provider == "claude_cli" else ".codex" request, _ = CodexPermissionRequest.objects.update_or_create( approval_key=approval_key, defaults={ "user": user, "codex_run": run, "external_sync_event": waiting_event, - "summary": "Pre-submit approval required before sending to Codex", + "summary": f"Pre-submit approval required before sending to {provider_label}", "requested_permissions": { "type": "pre_submit", - "provider": "codex_cli", + "provider": provider, "action": str(action or "append_update"), }, "resume_payload": { @@ -68,7 +72,7 @@ def queue_codex_event_with_pre_approval( }, ) - cfg = TaskProviderConfig.objects.filter(user=user, provider="codex_cli", enabled=True).first() + cfg = TaskProviderConfig.objects.filter(user=user, provider=provider, enabled=True).first() settings_payload = dict(getattr(cfg, "settings", {}) or {}) approver_service = str(settings_payload.get("approver_service") or "").strip().lower() approver_identifier = str(settings_payload.get("approver_identifier") or "").strip() @@ -78,10 +82,10 @@ def queue_codex_event_with_pre_approval( approver_service, approver_identifier, text=( - f"[codex approval] key={approval_key}\n" - "summary=Pre-submit approval required before sending to Codex\n" + f"[{provider} approval] key={approval_key}\n" + f"summary=Pre-submit approval required before sending to {provider_label}\n" "requested=pre_submit\n" - f"use: .codex approve {approval_key} or .codex deny {approval_key}" + f"use: {xmpp_cmd} approve {approval_key} or {xmpp_cmd} deny {approval_key}" ), attachments=[], metadata={"origin_tag": f"codex-pre-approval:{approval_key}"}, diff --git a/core/tasks/engine.py b/core/tasks/engine.py index 05e7bfb..200ad58 100644 --- a/core/tasks/engine.py +++ b/core/tasks/engine.py @@ -26,6 +26,7 @@ from core.tasks.chat_defaults import ensure_default_source_for_chat, resolve_mes from core.tasks.codex_approval import queue_codex_event_with_pre_approval from core.tasks.providers import get_provider from core.tasks.codex_support import resolve_external_chat_id +from core.security.command_policy import CommandSecurityContext, evaluate_command_policy _TASK_HINT_RE = re.compile(r"\b(todo|task|action|need to|please)\b", re.IGNORECASE) _COMPLETION_RE = re.compile(r"\b(done|completed|fixed)\s*#([A-Za-z0-9_-]+)\b", re.IGNORECASE) @@ -699,6 +700,20 @@ def _is_task_command_candidate(text: str) -> bool: return _has_task_prefix(body.lower(), ["task:", "todo:"]) +def _is_explicit_task_command(text: str) -> bool: + body = str(text or "").strip() + if not body: + return False + return bool( + _LIST_TASKS_RE.match(body) + or _LIST_TASKS_CMD_RE.match(body) + or _TASK_SHOW_RE.match(body) + or _TASK_COMPLETE_CMD_RE.match(body) + or _UNDO_TASK_RE.match(body) + or _EPIC_CREATE_RE.match(body) + ) + + async def process_inbound_task_intelligence(message: Message) -> None: if message is None: return @@ -707,6 +722,20 @@ async def process_inbound_task_intelligence(message: Message) -> None: text = str(message.text or "").strip() if not text: return + security_context = CommandSecurityContext( + service=str(message.source_service or "").strip().lower(), + channel_identifier=str(message.source_chat_id or "").strip(), + message_meta=dict(message.message_meta or {}), + payload={}, + ) + if _is_explicit_task_command(text): + command_decision = await sync_to_async(evaluate_command_policy)( + user=message.user, + scope_key="tasks.commands", + context=security_context, + ) + if not command_decision.allowed: + return sources = await _resolve_source_mappings(message) if not sources: @@ -729,6 +758,14 @@ async def process_inbound_task_intelligence(message: Message) -> None: if await _handle_epic_create_command(message, sources, text): return + submit_decision = await sync_to_async(evaluate_command_policy)( + user=message.user, + scope_key="tasks.submit", + context=security_context, + ) + if not submit_decision.allowed: + return + completion_allowed = any(bool(_effective_flags(source).get("completion_enabled")) for source in sources) completion_rx = await _completion_regex(message) if completion_allowed else None marker_match = (completion_rx.search(text) if completion_rx else None) or (_COMPLETION_RE.search(text) if completion_allowed else None) diff --git a/core/templates/base.html b/core/templates/base.html index 3ce17ac..ea8a441 100644 --- a/core/templates/base.html +++ b/core/templates/base.html @@ -400,9 +400,12 @@ diff --git a/core/tests/test_attachment_security.py b/core/tests/test_attachment_security.py index 60fa358..4cefe7b 100644 --- a/core/tests/test_attachment_security.py +++ b/core/tests/test_attachment_security.py @@ -24,6 +24,7 @@ class AttachmentSecurityTests(SimpleTestCase): size=32, ) + @override_settings(ATTACHMENT_ALLOW_PRIVATE_URLS=False) def test_blocks_private_url_by_default(self): with self.assertRaises(ValueError): validate_attachment_url("http://localhost/internal") diff --git a/core/tests/test_command_security_policy.py b/core/tests/test_command_security_policy.py new file mode 100644 index 0000000..d1cfcc4 --- /dev/null +++ b/core/tests/test_command_security_policy.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +from asgiref.sync import async_to_sync +from django.test import TestCase + +from core.commands.base import CommandContext +from core.commands.engine import process_inbound_message +from core.gateway.commands import ( + GatewayCommandContext, + GatewayCommandRoute, + dispatch_gateway_command, +) +from core.models import ( + ChatSession, + CommandChannelBinding, + CommandProfile, + CommandSecurityPolicy, + GatewayCommandEvent, + Message, + Person, + PersonIdentifier, + User, + UserXmppOmemoState, +) +from core.security.command_policy import CommandSecurityContext, evaluate_command_policy + + +class CommandSecurityPolicyTests(TestCase): + def setUp(self): + self.user = User.objects.create_user( + username="policy-user", + email="policy-user@example.com", + password="x", + ) + self.person = Person.objects.create(user=self.user, name="Policy Person") + self.identifier = PersonIdentifier.objects.create( + user=self.user, + person=self.person, + service="xmpp", + identifier="policy-user@zm.is", + ) + self.session = ChatSession.objects.create( + user=self.user, + identifier=self.identifier, + ) + + def test_command_profile_scope_denies_disallowed_service(self): + profile = CommandProfile.objects.create( + user=self.user, + slug="bp", + name="Business Plan", + enabled=True, + trigger_token="#bp#", + reply_required=False, + exact_match_only=True, + ) + CommandChannelBinding.objects.create( + profile=profile, + direction="ingress", + service="xmpp", + channel_identifier="policy-user@zm.is", + enabled=True, + ) + CommandSecurityPolicy.objects.create( + user=self.user, + scope_key="command.bp", + enabled=True, + allowed_services=["whatsapp"], + ) + msg = Message.objects.create( + user=self.user, + session=self.session, + sender_uuid="", + text="#bp#", + ts=1000, + source_service="xmpp", + source_chat_id="policy-user@zm.is", + message_meta={}, + ) + results = async_to_sync(process_inbound_message)( + CommandContext( + service="xmpp", + channel_identifier="policy-user@zm.is", + message_id=str(msg.id), + user_id=self.user.id, + message_text="#bp#", + payload={}, + ) + ) + self.assertEqual(1, len(results)) + self.assertEqual("skipped", results[0].status) + self.assertTrue(str(results[0].error).startswith("policy_denied:service_not_allowed")) + + def test_gateway_scope_can_require_trusted_omemo_key(self): + CommandSecurityPolicy.objects.create( + user=self.user, + scope_key="gateway.tasks", + enabled=True, + require_omemo=True, + require_trusted_omemo_fingerprint=True, + ) + UserXmppOmemoState.objects.create( + user=self.user, + status="detected", + latest_client_key="sid:abc", + last_sender_jid="policy-user@zm.is/phone", + last_target_jid="jews.zm.is", + ) + outputs: list[str] = [] + + async def _tasks_handler(_ctx, emit): + emit("ok") + return True + + handled = async_to_sync(dispatch_gateway_command)( + context=GatewayCommandContext( + user=self.user, + source_message=None, + service="xmpp", + channel_identifier="policy-user@zm.is", + sender_identifier="policy-user@zm.is/phone", + message_text=".tasks list", + message_meta={"xmpp": {"omemo_status": "detected", "omemo_client_key": "sid:abc"}}, + payload={}, + ), + routes=[ + GatewayCommandRoute( + name="tasks", + scope_key="gateway.tasks", + matcher=lambda text: str(text).startswith(".tasks"), + handler=_tasks_handler, + ) + ], + emit=lambda value: outputs.append(str(value)), + ) + self.assertTrue(handled) + self.assertEqual(["ok"], outputs) + event = GatewayCommandEvent.objects.order_by("-created_at").first() + self.assertIsNotNone(event) + self.assertEqual("ok", event.status if event else "") + + def test_gateway_scope_blocks_when_omemo_required_but_missing(self): + CommandSecurityPolicy.objects.create( + user=self.user, + scope_key="gateway.tasks", + enabled=True, + require_omemo=True, + ) + outputs: list[str] = [] + + async def _tasks_handler(_ctx, emit): + emit("unexpected") + return True + + handled = async_to_sync(dispatch_gateway_command)( + context=GatewayCommandContext( + user=self.user, + source_message=None, + service="xmpp", + channel_identifier="policy-user@zm.is", + sender_identifier="policy-user@zm.is/phone", + message_text=".tasks list", + message_meta={"xmpp": {"omemo_status": "no_omemo"}}, + payload={}, + ), + routes=[ + GatewayCommandRoute( + name="tasks", + scope_key="gateway.tasks", + matcher=lambda text: str(text).startswith(".tasks"), + handler=_tasks_handler, + ) + ], + emit=lambda value: outputs.append(str(value)), + ) + self.assertTrue(handled) + self.assertTrue(outputs) + self.assertIn("blocked by policy", outputs[0].lower()) + event = GatewayCommandEvent.objects.order_by("-created_at").first() + self.assertIsNotNone(event) + self.assertEqual("blocked", event.status if event else "") + + def test_global_scope_override_can_force_scope_disabled(self): + CommandSecurityPolicy.objects.create( + user=self.user, + scope_key="gateway.tasks", + enabled=True, + ) + CommandSecurityPolicy.objects.create( + user=self.user, + scope_key="global.override", + settings={"scope_enabled": "off"}, + ) + decision = evaluate_command_policy( + user=self.user, + scope_key="gateway.tasks", + context=CommandSecurityContext( + service="xmpp", + channel_identifier="policy-user@zm.is", + message_meta={}, + payload={}, + ), + ) + self.assertFalse(decision.allowed) + self.assertEqual("policy_disabled", decision.code) + + def test_global_scope_override_allowed_services_applies_to_all_scopes(self): + CommandSecurityPolicy.objects.create( + user=self.user, + scope_key="global.override", + allowed_services=["xmpp"], + ) + decision = evaluate_command_policy( + user=self.user, + scope_key="tasks.commands", + context=CommandSecurityContext( + service="whatsapp", + channel_identifier="12035550123", + message_meta={}, + payload={}, + ), + ) + self.assertFalse(decision.allowed) + self.assertEqual("service_not_allowed", decision.code) diff --git a/core/tests/test_cross_platform_messaging.py b/core/tests/test_cross_platform_messaging.py new file mode 100644 index 0000000..87beab6 --- /dev/null +++ b/core/tests/test_cross_platform_messaging.py @@ -0,0 +1,382 @@ +""" +Cross-platform messaging tests: replies, reactions, and messages across +Signal, WhatsApp, and XMPP. + +Signal coverage is in test_signal_reply_send.py. This file fills the gaps +for WhatsApp and XMPP, and verifies the shared reply_sync infrastructure +works correctly for both services. +""" +from __future__ import annotations + +import xml.etree.ElementTree as ET +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +from asgiref.sync import async_to_sync +from django.test import SimpleTestCase, TestCase + +from core.clients import transport +from core.clients.xmpp import ( + _extract_xmpp_reaction, + _extract_xmpp_reply_target_id, + _parse_greentext_reaction, +) +from core.messaging import history, reply_sync +from core.models import ChatSession, Message, Person, PersonIdentifier, User +from core.presence.inference import now_ms + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _fake_stanza(xml_text: str) -> SimpleNamespace: + """Minimal stanza-like object with an .xml attribute.""" + return SimpleNamespace(xml=ET.fromstring(xml_text)) + + +# --------------------------------------------------------------------------- +# WhatsApp — reply extraction (pure, no DB) +# --------------------------------------------------------------------------- + +class WhatsAppReplyExtractionTests(SimpleTestCase): + def test_extract_reply_ref_from_contextinfo_stanza_id(self): + payload = { + "contextInfo": { + "stanzaId": "wa-anchor-001", + "participant": "447700900001@s.whatsapp.net", + } + } + ref = reply_sync.extract_reply_ref("whatsapp", payload) + self.assertEqual("wa-anchor-001", ref.get("reply_source_message_id")) + self.assertEqual("whatsapp", ref.get("reply_source_service")) + self.assertEqual("447700900001@s.whatsapp.net", ref.get("reply_source_chat_id")) + + def test_extract_reply_ref_from_extended_text_message(self): + payload = { + "extendedTextMessage": { + "text": "quoting you", + "contextInfo": { + "stanzaId": "wa-anchor-002", + "participant": "447700900002@s.whatsapp.net", + }, + } + } + ref = reply_sync.extract_reply_ref("whatsapp", payload) + self.assertEqual("wa-anchor-002", ref.get("reply_source_message_id")) + + def test_extract_reply_ref_returns_empty_when_no_context(self): + ref = reply_sync.extract_reply_ref("whatsapp", {"conversation": "plain text"}) + self.assertEqual({}, ref) + + def test_extract_reply_ref_from_image_message_contextinfo(self): + payload = { + "imageMessage": { + "caption": "look at this", + "contextInfo": { + "stanzaId": "wa-anchor-003", + "participant": "447700900003@s.whatsapp.net", + }, + } + } + ref = reply_sync.extract_reply_ref("whatsapp", payload) + self.assertEqual("wa-anchor-003", ref.get("reply_source_message_id")) + + +# --------------------------------------------------------------------------- +# WhatsApp — reply resolution (requires DB) +# --------------------------------------------------------------------------- + +class WhatsAppReplyResolutionTests(TestCase): + def setUp(self): + self.user = User.objects.create_user( + "wa-resolve-user", "wa-resolve@example.com", "x" + ) + self.person = Person.objects.create(user=self.user, name="WA Resolve") + self.identifier = PersonIdentifier.objects.create( + user=self.user, + person=self.person, + service="whatsapp", + identifier="447700900001@s.whatsapp.net", + ) + self.session = ChatSession.objects.create( + user=self.user, identifier=self.identifier + ) + self.anchor = Message.objects.create( + user=self.user, + session=self.session, + ts=now_ms(), + text="anchor message", + source_service="whatsapp", + source_message_id="wa-anchor-001", + source_chat_id="447700900001@s.whatsapp.net", + sender_uuid="447700900001@s.whatsapp.net", + ) + + def test_resolve_reply_target_by_source_message_id(self): + ref = { + "reply_source_message_id": "wa-anchor-001", + "reply_source_service": "whatsapp", + "reply_source_chat_id": "447700900001@s.whatsapp.net", + } + target = async_to_sync(reply_sync.resolve_reply_target)( + self.user, self.session, ref + ) + self.assertIsNotNone(target) + self.assertEqual(self.anchor.id, target.id) + + def test_resolve_returns_none_for_unknown_id(self): + ref = { + "reply_source_message_id": "wa-nonexistent-999", + "reply_source_service": "whatsapp", + "reply_source_chat_id": "447700900001@s.whatsapp.net", + } + target = async_to_sync(reply_sync.resolve_reply_target)( + self.user, self.session, ref + ) + self.assertIsNone(target) + + def test_reaction_applied_to_whatsapp_anchor(self): + async_to_sync(history.apply_reaction)( + self.user, + self.identifier, + target_message_id="wa-anchor-001", + target_ts=int(self.anchor.ts), + emoji="👍", + source_service="whatsapp", + actor="447700900001@s.whatsapp.net", + remove=False, + payload={"event": "reaction"}, + ) + self.anchor.refresh_from_db() + reactions = list((self.anchor.receipt_payload or {}).get("reactions") or []) + self.assertEqual(1, len(reactions)) + self.assertEqual("👍", reactions[0].get("emoji")) + + def test_reaction_removal_clears_flag(self): + async_to_sync(history.apply_reaction)( + self.user, + self.identifier, + target_message_id="wa-anchor-001", + target_ts=int(self.anchor.ts), + emoji="👍", + source_service="whatsapp", + actor="447700900001@s.whatsapp.net", + remove=False, + payload={}, + ) + async_to_sync(history.apply_reaction)( + self.user, + self.identifier, + target_message_id="wa-anchor-001", + target_ts=int(self.anchor.ts), + emoji="👍", + source_service="whatsapp", + actor="447700900001@s.whatsapp.net", + remove=True, + payload={}, + ) + self.anchor.refresh_from_db() + reactions = list((self.anchor.receipt_payload or {}).get("reactions") or []) + removed = [r for r in reactions if r.get("emoji") == "👍" and not r.get("removed")] + self.assertEqual(0, len(removed)) + + +# --------------------------------------------------------------------------- +# WhatsApp — outbound reply metadata +# --------------------------------------------------------------------------- + +class WhatsAppOutboundReplyTests(TestCase): + def test_transport_passes_reply_metadata_to_whatsapp_api(self): + mock_client = MagicMock() + mock_client.send_message_raw = AsyncMock(return_value="wa-sent-001") + with patch( + "core.clients.transport.get_runtime_client", + return_value=mock_client, + ), patch( + "core.clients.transport.prepare_outbound_attachments", + new=AsyncMock(return_value=[]), + ), patch( + "core.clients.transport._capability_checks_enabled", + return_value=False, + ): + result = async_to_sync(transport.send_message_raw)( + "whatsapp", + "447700900001@s.whatsapp.net", + text="reply text", + attachments=[], + metadata={ + "quote_id": "wa-anchor-001", + "quote_author": "447700900001@s.whatsapp.net", + "quote_text": "anchor message", + }, + ) + self.assertEqual("wa-sent-001", result) + mock_client.send_message_raw.assert_awaited_once() + _, call_kwargs = mock_client.send_message_raw.call_args + meta = call_kwargs.get("metadata") or {} + self.assertEqual("wa-anchor-001", meta.get("quote_id")) + + +# --------------------------------------------------------------------------- +# XMPP — reaction extraction (pure, no DB) +# --------------------------------------------------------------------------- + +class XMPPReactionExtractionTests(SimpleTestCase): + def test_extract_xep_0444_reaction(self): + stanza = _fake_stanza( + "" + "" + "👍" + "" + "" + ) + result = _extract_xmpp_reaction(stanza) + self.assertIsNotNone(result) + self.assertEqual("xmpp-anchor-001", result.get("target_id")) + self.assertEqual("👍", result.get("emoji")) + self.assertFalse(result.get("remove")) + + def test_extract_xep_0444_reaction_removal(self): + stanza = _fake_stanza( + "" + "" + "" + "" + ) + result = _extract_xmpp_reaction(stanza) + self.assertIsNotNone(result) + self.assertEqual("xmpp-anchor-002", result.get("target_id")) + self.assertTrue(result.get("remove")) + + def test_extract_returns_none_for_plain_message(self): + stanza = _fake_stanza("hello") + self.assertIsNone(_extract_xmpp_reaction(stanza)) + + def test_parse_greentext_reaction_valid(self): + result = _parse_greentext_reaction(">anchor message\n😊") + self.assertIsNotNone(result) + self.assertEqual("anchor message", result.get("quoted_text")) + self.assertEqual("😊", result.get("emoji")) + + def test_parse_greentext_reaction_rejects_non_emoji_second_line(self): + result = _parse_greentext_reaction(">anchor message\nnot an emoji") + self.assertIsNone(result) + + def test_parse_greentext_reaction_rejects_single_line(self): + result = _parse_greentext_reaction(">anchor message") + self.assertIsNone(result) + + def test_parse_greentext_reaction_rejects_no_quote_prefix(self): + result = _parse_greentext_reaction("anchor message\n😊") + self.assertIsNone(result) + + +# --------------------------------------------------------------------------- +# XMPP — reply extraction (pure, no DB) +# --------------------------------------------------------------------------- + +class XMPPReplyExtractionTests(SimpleTestCase): + def test_extract_reply_target_id_from_xep_0461_stanza(self): + stanza = _fake_stanza( + "" + "" + "quoted reply" + "" + ) + target_id = _extract_xmpp_reply_target_id(stanza) + self.assertEqual("xmpp-anchor-001", target_id) + + def test_extract_reply_target_id_returns_empty_for_plain(self): + stanza = _fake_stanza("hello") + self.assertEqual("", _extract_xmpp_reply_target_id(stanza)) + + def test_extract_reply_ref_for_xmpp_service(self): + ref = reply_sync.extract_reply_ref( + "xmpp", + { + "reply_source_message_id": "xmpp-anchor-001", + "reply_source_chat_id": "user@zm.is/mobile", + }, + ) + self.assertEqual("xmpp-anchor-001", ref.get("reply_source_message_id")) + self.assertEqual("xmpp", ref.get("reply_source_service")) + self.assertEqual("user@zm.is/mobile", ref.get("reply_source_chat_id")) + + def test_extract_reply_ref_returns_empty_for_missing_id(self): + ref = reply_sync.extract_reply_ref("xmpp", {"reply_source_chat_id": "user@zm.is"}) + self.assertEqual({}, ref) + + +# --------------------------------------------------------------------------- +# XMPP — reply resolution (requires DB) +# --------------------------------------------------------------------------- + +class XMPPReplyResolutionTests(TestCase): + def setUp(self): + self.user = User.objects.create_user( + "xmpp-resolve-user", "xmpp-resolve@example.com", "x" + ) + self.person = Person.objects.create(user=self.user, name="XMPP Resolve") + self.identifier = PersonIdentifier.objects.create( + user=self.user, + person=self.person, + service="xmpp", + identifier="contact@zm.is", + ) + self.session = ChatSession.objects.create( + user=self.user, identifier=self.identifier + ) + self.anchor = Message.objects.create( + user=self.user, + session=self.session, + ts=now_ms(), + text="xmpp anchor", + source_service="xmpp", + source_message_id="xmpp-anchor-001", + source_chat_id="contact@zm.is/mobile", + sender_uuid="contact@zm.is", + ) + + def test_resolve_reply_target_by_source_message_id(self): + ref = reply_sync.extract_reply_ref( + "xmpp", + { + "reply_source_message_id": "xmpp-anchor-001", + "reply_source_chat_id": "contact@zm.is/mobile", + }, + ) + target = async_to_sync(reply_sync.resolve_reply_target)( + self.user, self.session, ref + ) + self.assertIsNotNone(target) + self.assertEqual(self.anchor.id, target.id) + + def test_xmpp_reaction_applied_to_anchor_via_history(self): + async_to_sync(history.apply_reaction)( + self.user, + self.identifier, + target_message_id="xmpp-anchor-001", + target_ts=int(self.anchor.ts), + emoji="🔥", + source_service="xmpp", + actor="contact@zm.is", + remove=False, + payload={"target_xmpp_id": "xmpp-anchor-001"}, + ) + self.anchor.refresh_from_db() + reactions = list((self.anchor.receipt_payload or {}).get("reactions") or []) + self.assertTrue( + any(r.get("emoji") == "🔥" for r in reactions), + "Expected 🔥 reaction to be stored on the anchor.", + ) + + def test_xmpp_reply_ref_resolved_to_none_for_unknown_id(self): + ref = reply_sync.extract_reply_ref( + "xmpp", + {"reply_source_message_id": "xmpp-nonexistent-999"}, + ) + target = async_to_sync(reply_sync.resolve_reply_target)( + self.user, self.session, ref + ) + self.assertIsNone(target) diff --git a/core/tests/test_task_sync_worker_command.py b/core/tests/test_task_sync_worker_command.py new file mode 100644 index 0000000..4c6d031 --- /dev/null +++ b/core/tests/test_task_sync_worker_command.py @@ -0,0 +1,9 @@ +from django.test import SimpleTestCase + +from core.management.commands.codex_worker import Command as LegacyWorkerCommand +from core.management.commands.task_sync_worker import Command as TaskSyncWorkerCommand + + +class TaskSyncWorkerCommandAliasTests(SimpleTestCase): + def test_task_sync_worker_is_legacy_worker_alias(self): + self.assertTrue(issubclass(TaskSyncWorkerCommand, LegacyWorkerCommand)) diff --git a/core/tests/test_xmpp_approval_commands.py b/core/tests/test_xmpp_approval_commands.py new file mode 100644 index 0000000..e493ba0 --- /dev/null +++ b/core/tests/test_xmpp_approval_commands.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +from asgiref.sync import async_to_sync +from django.test import TestCase + +from core.clients.xmpp import XMPPComponent +from core.models import ( + CodexPermissionRequest, + CodexRun, + DerivedTask, + ExternalSyncEvent, + TaskProject, + User, +) + + +class _ApprovalProbe: + _resolve_request_provider = XMPPComponent._resolve_request_provider + _approval_event_prefix = XMPPComponent._approval_event_prefix + _APPROVAL_PROVIDER_COMMANDS = XMPPComponent._APPROVAL_PROVIDER_COMMANDS + _ACTION_TO_STATUS = XMPPComponent._ACTION_TO_STATUS + _apply_approval_decision = XMPPComponent._apply_approval_decision + _approval_list_pending = XMPPComponent._approval_list_pending + _approval_status = XMPPComponent._approval_status + _handle_approval_command = XMPPComponent._handle_approval_command + _gateway_help_lines = XMPPComponent._gateway_help_lines + _handle_tasks_command = XMPPComponent._handle_tasks_command + + +class XMPPGatewayApprovalCommandTests(TestCase): + def setUp(self): + self.user = User.objects.create_user("xmpp-approval-user", "xmpp-approval@example.com", "x") + self.project = TaskProject.objects.create(user=self.user, name="Approval Project") + self.task = DerivedTask.objects.create( + user=self.user, + project=self.project, + epic=None, + title="Approve me", + source_service="xmpp", + source_channel="jews.zm.is", + reference_code="77", + status_snapshot="open", + ) + self.waiting_event = ExternalSyncEvent.objects.create( + user=self.user, + task=self.task, + provider="codex_cli", + status="waiting_approval", + payload={}, + error="", + ) + self.run = CodexRun.objects.create( + user=self.user, + task=self.task, + project=self.project, + source_service="xmpp", + source_channel="jews.zm.is", + status="waiting_approval", + request_payload={"action": "append_update", "provider_payload": {"task_id": str(self.task.id)}}, + result_payload={}, + ) + self.request = CodexPermissionRequest.objects.create( + user=self.user, + codex_run=self.run, + external_sync_event=self.waiting_event, + approval_key="ak-xmpp-1", + summary="Need auth approval", + requested_permissions={"items": ["workspace_write"]}, + resume_payload={}, + status="pending", + ) + self.probe = _ApprovalProbe() + self.probe.log = MagicMock() + + def _run_command(self, text: str) -> list[str]: + messages = [] + + def _sym(value): + messages.append(str(value)) + + handled = async_to_sync(XMPPComponent._handle_approval_command)( + self.probe, + self.user, + text, + "xmpp-approval-user@zm.is/mobile", + _sym, + ) + self.assertTrue(handled) + self.assertTrue(messages) + return messages + + def test_approval_approve_command_resolves_request_and_queues_resume(self): + rows = self._run_command(".approval approve ak-xmpp-1") + self.assertIn("approved", "\n".join(rows).lower()) + self.request.refresh_from_db() + self.run.refresh_from_db() + self.waiting_event.refresh_from_db() + self.assertEqual("approved", self.request.status) + self.assertEqual("approved_waiting_resume", self.run.status) + self.assertEqual("ok", self.waiting_event.status) + resume = ExternalSyncEvent.objects.filter( + idempotency_key="codex_approval:ak-xmpp-1:approved" + ).first() + self.assertIsNotNone(resume) + self.assertEqual("pending", resume.status) + + def test_approval_list_pending_and_status(self): + rows = self._run_command(".approval list-pending all") + text = "\n".join(rows) + self.assertIn("pending=1", text) + self.assertIn("ak-xmpp-1", text) + status_rows = self._run_command(".approval status ak-xmpp-1") + self.assertIn("status=pending", "\n".join(status_rows)) + + def test_provider_specific_command_rejects_mismatched_key(self): + rows = self._run_command(".claude approve ak-xmpp-1") + self.assertIn("approval_key_not_for_provider", "\n".join(rows)) + self.request.refresh_from_db() + self.assertEqual("pending", self.request.status) + + +class XMPPGatewayTasksCommandTests(TestCase): + def setUp(self): + self.user = User.objects.create_user("xmpp-task-user", "xmpp-task@example.com", "x") + self.project = TaskProject.objects.create(user=self.user, name="Task Project") + self.task = DerivedTask.objects.create( + user=self.user, + project=self.project, + epic=None, + title="Ship CLI", + source_service="xmpp", + source_channel="jews.zm.is", + reference_code="12", + status_snapshot="open", + ) + self.probe = _ApprovalProbe() + self.probe.log = MagicMock() + + def _run_tasks(self, text: str) -> list[str]: + messages = [] + + def _sym(value): + messages.append(str(value)) + + handled = async_to_sync(XMPPComponent._handle_tasks_command)( + self.probe, + self.user, + text, + _sym, + ) + self.assertTrue(handled) + self.assertTrue(messages) + return messages + + def test_help_contains_approval_and_tasks_sections(self): + lines = self.probe._gateway_help_lines() + text = "\n".join(lines) + self.assertIn(".approval list-pending", text) + self.assertIn(".tasks list", text) + + def test_tasks_list_show_complete_and_undo(self): + rows = self._run_tasks(".tasks list open 10") + self.assertIn("#12", "\n".join(rows)) + rows = self._run_tasks(".tasks show #12") + self.assertIn("status: open", "\n".join(rows)) + rows = self._run_tasks(".tasks complete #12") + self.assertIn("completed #12", "\n".join(rows)) + self.task.refresh_from_db() + self.assertEqual("completed", self.task.status_snapshot) + rows = self._run_tasks(".tasks undo #12") + self.assertIn("removed #12", "\n".join(rows)) + self.assertFalse(DerivedTask.objects.filter(id=self.task.id).exists()) diff --git a/core/tests/test_xmpp_omemo_support.py b/core/tests/test_xmpp_omemo_support.py new file mode 100644 index 0000000..bfc28a7 --- /dev/null +++ b/core/tests/test_xmpp_omemo_support.py @@ -0,0 +1,126 @@ +import asyncio +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from asgiref.sync import async_to_sync +from django.test import SimpleTestCase, TestCase, override_settings + +from core.clients import transport +from core.clients.xmpp import ET, XMPPClient, XMPPComponent, _extract_sender_omemo_client_key +from core.models import User, UserXmppOmemoState + + +class _FakeComponent: + def __init__(self, *args, **kwargs): + self.plugins = [] + self.loop = None + + def register_plugin(self, name): + self.plugins.append(str(name)) + + def connect(self): + return True + + +@override_settings( + XMPP_JID="jews.zm.is", + XMPP_SECRET="secret", + XMPP_ADDRESS="127.0.0.1", + XMPP_PORT=8888, +) +class XMPPOmemoSupportTests(SimpleTestCase): + def test_registers_xep_0384_when_omemo_plugin_available(self): + loop = asyncio.new_event_loop() + try: + with patch("core.clients.xmpp.XMPPComponent", _FakeComponent): + with patch("core.clients.xmpp._omemo_plugin_available", return_value=True): + with patch("core.clients.xmpp._omemo_xep_0384_plugin_available", return_value=True): + with patch("core.clients.xmpp._load_omemo_plugin_module", return_value=True): + client = XMPPClient(SimpleNamespace(), loop, "xmpp") + self.assertIn("xep_0384", list(getattr(client.client, "plugins", []))) + self.assertTrue(bool(getattr(client, "_omemo_plugin_registered", False))) + finally: + loop.close() + + def test_skips_xep_0384_when_omemo_plugin_unavailable(self): + loop = asyncio.new_event_loop() + try: + with patch("core.clients.xmpp.XMPPComponent", _FakeComponent): + with patch("core.clients.xmpp._omemo_plugin_available", return_value=False): + with patch("core.clients.xmpp._omemo_xep_0384_plugin_available", return_value=False): + client = XMPPClient(SimpleNamespace(), loop, "xmpp") + self.assertNotIn("xep_0384", list(getattr(client.client, "plugins", []))) + self.assertFalse(bool(getattr(client, "_omemo_plugin_registered", False))) + finally: + loop.close() + + def test_skips_xep_0384_when_only_slixmpp_omemo_package_exists(self): + loop = asyncio.new_event_loop() + try: + with patch("core.clients.xmpp.XMPPComponent", _FakeComponent): + with patch("core.clients.xmpp._omemo_plugin_available", return_value=True): + with patch("core.clients.xmpp._omemo_xep_0384_plugin_available", return_value=False): + client = XMPPClient(SimpleNamespace(), loop, "xmpp") + self.assertNotIn("xep_0384", list(getattr(client.client, "plugins", []))) + self.assertFalse(bool(getattr(client, "_omemo_plugin_registered", False))) + finally: + loop.close() + + def test_bootstrap_logs_and_updates_runtime_state_with_fingerprint(self): + class _BootstrapProbe: + _derived_omemo_fingerprint = XMPPComponent._derived_omemo_fingerprint + + component = _BootstrapProbe() + component.plugin = {} + component.log = MagicMock() + + with patch.object(transport, "update_runtime_state") as update_state: + async_to_sync(XMPPComponent._bootstrap_omemo_for_authentic_channel)(component) + + update_state.assert_called_once() + _, kwargs = update_state.call_args + self.assertEqual("jews.zm.is", kwargs.get("omemo_target_jid")) + self.assertEqual( + component._derived_omemo_fingerprint("jews.zm.is"), + kwargs.get("omemo_fingerprint"), + ) + self.assertFalse(bool(kwargs.get("omemo_enabled"))) + self.assertIn("omemo_status", kwargs) + self.assertIn("omemo_status_reason", kwargs) + self.assertTrue(component.log.info.called) + + def test_extract_sender_omemo_client_key_from_encrypted_stanza(self): + stanza_xml = ET.fromstring( + "" + "" + "
x
" + "
" + "
" + ) + parsed = _extract_sender_omemo_client_key(SimpleNamespace(xml=stanza_xml)) + self.assertEqual("detected", parsed.get("status")) + self.assertEqual("sid:77,rid:88", parsed.get("client_key")) + + +class XMPPOmemoObservationPersistenceTests(TestCase): + def test_records_latest_user_omemo_observation(self): + user = User.objects.create_user("xmpp-omemo-user", "xmpp-omemo@example.com", "x") + probe = SimpleNamespace(log=MagicMock()) + stanza_xml = ET.fromstring( + "" + "" + "
x
" + "
" + "
" + ) + async_to_sync(XMPPComponent._record_sender_omemo_state)( + probe, + user, + sender_jid="xmpp-omemo-user@zm.is/mobile", + recipient_jid="jews.zm.is", + message_stanza=SimpleNamespace(xml=stanza_xml), + ) + row = UserXmppOmemoState.objects.get(user=user) + self.assertEqual("detected", row.status) + self.assertEqual("sid:321,rid:654", row.latest_client_key) + self.assertEqual("jews.zm.is", row.last_target_jid) diff --git a/core/views/system.py b/core/views/system.py index 24863d3..2176341 100644 --- a/core/views/system.py +++ b/core/views/system.py @@ -1,10 +1,12 @@ import time -from django.http import JsonResponse +from django.contrib.auth.mixins import LoginRequiredMixin +from django.http import HttpResponseRedirect, JsonResponse from django.shortcuts import render from django.urls import reverse from django.views import View +from core.clients import transport from core.models import ( AdapterHealthEvent, AIRequest, @@ -28,6 +30,9 @@ from core.models import ( Persona, PersonIdentifier, QueuedMessage, + CommandSecurityPolicy, + UserXmppOmemoState, + UserXmppSecuritySettings, WorkspaceConversation, WorkspaceMetricSnapshot, ) @@ -459,3 +464,357 @@ class MemorySearchQueryAPI(SuperUserRequiredMixin, View): ], } ) + + +def _parse_xmpp_jid(jid_str: str) -> dict: + """Split a full JID (localpart@domain/resource) into components.""" + raw = str(jid_str or "").strip() + bare, _, resource = raw.partition("/") + localpart, _, domain = bare.partition("@") + return {"full": raw, "bare": bare, "localpart": localpart, "domain": domain, "resource": resource} + + +def _to_bool(value, default=False): + if value is None: + return bool(default) + text = str(value).strip().lower() + if text in {"1", "true", "yes", "on", "y"}: + return True + if text in {"0", "false", "no", "off", "n"}: + return False + return bool(default) + + +class SecurityPage(LoginRequiredMixin, View): + """Security settings page for OMEMO and command-scope policy controls.""" + + template_name = "pages/security.html" + GLOBAL_SCOPE_KEY = "global.override" + # Allowed Services list used by both Global Scope Override and local scopes. + # Keep this in sync with the UI text on the Security page. + POLICY_SERVICES = ["xmpp", "whatsapp", "signal", "instagram", "web"] + # Override mode names as shown in the interface: + # - per_scope: local scope controls remain editable + # - on/off: global override forces each local scope value + OVERRIDE_OPTIONS = ("per_scope", "on", "off") + GLOBAL_OVERRIDE_FIELDS = ( + "scope_enabled", + "require_omemo", + "require_trusted_fingerprint", + ) + POLICY_SCOPES = [ + ("gateway.tasks", "Gateway .tasks commands", "Handles .tasks list/show/complete/undo over gateway channels."), + ("gateway.approval", "Gateway approval commands", "Handles .approval/.codex/.claude approve/deny over gateway channels."), + ("gateway.totp", "Gateway TOTP enrollment", "Controls TOTP enrollment/status commands over gateway channels."), + ("tasks.submit", "Task submissions from chat", "Controls automatic task creation from inbound messages."), + ("tasks.commands", "Task command verbs (.task/.undo/.epic)", "Controls explicit task command verbs."), + ("command.bp", "Business plan command", "Controls Business Plan command execution."), + ("command.codex", "Codex command", "Controls Codex command execution."), + ("command.claude", "Claude command", "Controls Claude command execution."), + ] + POLICY_GROUP_LABELS = { + "gateway": "Gateway", + "tasks": "Tasks", + "command": "Commands", + "agentic": "Agentic", + "other": "Other", + } + + def _security_settings(self, request): + row, _ = UserXmppSecuritySettings.objects.get_or_create(user=request.user) + return row + + def _parse_override_value(self, value): + option = str(value or "").strip().lower() + if option == "inherit": + # Backward-compat for existing persisted values. + option = "per_scope" + if option in self.OVERRIDE_OPTIONS: + return option + return "per_scope" + + def _global_override_payload(self, request): + row, _ = CommandSecurityPolicy.objects.get_or_create( + user=request.user, + scope_key=self.GLOBAL_SCOPE_KEY, + defaults={ + "enabled": True, + "allowed_services": [], + "allowed_channels": {}, + "settings": {}, + }, + ) + settings_payload = dict(row.settings or {}) + values = { + "scope_enabled": self._parse_override_value( + settings_payload.get("scope_enabled") + ), + "require_omemo": self._parse_override_value( + settings_payload.get("require_omemo") + ), + "require_trusted_fingerprint": self._parse_override_value( + settings_payload.get("require_trusted_fingerprint") + ), + } + allowed_services = [ + str(value or "").strip().lower() + for value in (row.allowed_services or []) + if str(value or "").strip() + ] + channel_rules = self._channel_rules_from_map(dict(row.allowed_channels or {})) + if not channel_rules: + channel_rules = [{"service": "xmpp", "pattern": ""}] + return { + "row": row, + "values": values, + "allowed_services": allowed_services, + "channel_rules": channel_rules, + } + + def _apply_global_override(self, current_value: bool, option: str) -> bool: + normalized = self._parse_override_value(option) + if normalized == "on": + return True + if normalized == "off": + return False + return bool(current_value) + + def _channel_rules_from_map(self, source_map): + rows = [] + raw = dict(source_map or {}) + for service_key, patterns in raw.items(): + service_name = str(service_key or "").strip().lower() + if not service_name: + continue + if isinstance(patterns, list): + for pattern in patterns: + pattern_text = str(pattern or "").strip() + if pattern_text: + rows.append({ + "service": service_name, + "pattern": pattern_text, + }) + return rows + + def _channels_map_from_post(self, request): + channel_services = request.POST.getlist("allowed_channel_service") + channel_patterns = request.POST.getlist("allowed_channel_pattern") + allowed_channels: dict[str, list[str]] = {} + for idx, raw_pattern in enumerate(channel_patterns): + pattern = str(raw_pattern or "").strip() + if not pattern: + continue + service_name = str( + channel_services[idx] if idx < len(channel_services) else "" + ).strip().lower() + if not service_name: + service_name = "*" + allowed_channels.setdefault(service_name, []) + if pattern not in allowed_channels[service_name]: + allowed_channels[service_name].append(pattern) + return allowed_channels + + def _scope_rows(self, request): + global_overrides = self._global_override_payload(request)["values"] + rows = { + str(item.scope_key or "").strip().lower(): item + for item in CommandSecurityPolicy.objects.filter(user=request.user).exclude( + scope_key=self.GLOBAL_SCOPE_KEY + ) + } + payload = [] + for scope_key, label, description in self.POLICY_SCOPES: + key = str(scope_key or "").strip().lower() + item = rows.get(key) + raw_allowed_services = [ + str(value or "").strip().lower() + for value in (getattr(item, "allowed_services", []) or []) + if str(value or "").strip() + ] + channel_rules = self._channel_rules_from_map( + dict(getattr(item, "allowed_channels", {}) or {}) + ) + if not channel_rules: + channel_rules = [{"service": "xmpp", "pattern": ""}] + enabled_locked = global_overrides["scope_enabled"] != "per_scope" + require_omemo_locked = global_overrides["require_omemo"] != "per_scope" + require_trusted_locked = ( + global_overrides["require_trusted_fingerprint"] != "per_scope" + ) + payload.append({ + "scope_key": key, + "label": label, + "description": description, + "enabled": self._apply_global_override( + bool(getattr(item, "enabled", True)), + global_overrides["scope_enabled"], + ), + "require_omemo": self._apply_global_override( + bool(getattr(item, "require_omemo", False)), + global_overrides["require_omemo"], + ), + "require_trusted_fingerprint": self._apply_global_override( + bool(getattr(item, "require_trusted_omemo_fingerprint", False)), + global_overrides["require_trusted_fingerprint"], + ), + "enabled_locked": enabled_locked, + "require_omemo_locked": require_omemo_locked, + "require_trusted_fingerprint_locked": require_trusted_locked, + "lock_help": "Set this field to 'Per Scope' in Global Scope Override to edit it here.", + "allowed_services": raw_allowed_services, + "channel_rules": channel_rules, + }) + return payload + + def _scope_group_key(self, scope_key: str) -> str: + key = str(scope_key or "").strip().lower() + if key in {"command.codex", "command.claude"}: + return "agentic" + if key.startswith("gateway."): + return "command" + if key.startswith("tasks."): + if key == "tasks.submit": + return "tasks" + return "command" + if key.startswith("command."): + return "command" + if ".commands" in key: + return "command" + if ".approval" in key: + return "command" + if ".totp" in key: + return "command" + if ".task" in key: + return "tasks" + return "other" + + def _grouped_scope_rows(self, request): + rows = self._scope_rows(request) + grouped: dict[str, list[dict]] = {key: [] for key in self.POLICY_GROUP_LABELS} + for row in rows: + group_key = self._scope_group_key(row.get("scope_key")) + grouped.setdefault(group_key, []) + grouped[group_key].append(row) + payload = [] + for group_key in ("tasks", "command", "agentic", "other"): + items = grouped.get(group_key) or [] + if not items: + continue + payload.append({ + "key": group_key, + "label": self.POLICY_GROUP_LABELS.get(group_key, group_key.title()), + "rows": items, + }) + return payload + + def post(self, request): + row = self._security_settings(request) + if "require_omemo" in request.POST: + row.require_omemo = _to_bool(request.POST.get("require_omemo"), False) + row.save(update_fields=["require_omemo", "updated_at"]) + redirect_to = HttpResponseRedirect(reverse("security_settings")) + scope_key = str(request.POST.get("scope_key") or "").strip().lower() + if scope_key == self.GLOBAL_SCOPE_KEY: + global_row = self._global_override_payload(request)["row"] + settings_payload = dict(global_row.settings or {}) + for field in self.GLOBAL_OVERRIDE_FIELDS: + settings_payload[field] = self._parse_override_value( + request.POST.get(f"global_{field}") + ) + global_row.allowed_services = [ + str(item or "").strip().lower() + for item in request.POST.getlist("allowed_services") + if str(item or "").strip() + ] + global_row.allowed_channels = self._channels_map_from_post(request) + global_row.settings = settings_payload + global_row.save( + update_fields=[ + "settings", + "allowed_services", + "allowed_channels", + "updated_at", + ] + ) + return redirect_to + + if scope_key: + if str(request.POST.get("scope_change_mode") or "").strip() != "1": + return redirect_to + global_overrides = self._global_override_payload(request)["values"] + allowed_services = [ + str(item or "").strip().lower() + for item in request.POST.getlist("allowed_services") + if str(item or "").strip() + ] + allowed_channels = self._channels_map_from_post(request) + policy, _ = CommandSecurityPolicy.objects.get_or_create( + user=request.user, + scope_key=scope_key, + ) + policy.allowed_services = allowed_services + policy.allowed_channels = allowed_channels + if global_overrides["scope_enabled"] == "per_scope": + policy.enabled = _to_bool(request.POST.get("policy_enabled"), True) + if global_overrides["require_omemo"] == "per_scope": + policy.require_omemo = _to_bool( + request.POST.get("policy_require_omemo"), False + ) + if global_overrides["require_trusted_fingerprint"] == "per_scope": + policy.require_trusted_omemo_fingerprint = _to_bool( + request.POST.get("policy_require_trusted_fingerprint"), + False, + ) + policy.save( + update_fields=[ + "enabled", + "require_omemo", + "require_trusted_omemo_fingerprint", + "allowed_services", + "allowed_channels", + "updated_at", + ] + ) + return redirect_to + + def get(self, request): + xmpp_state = transport.get_runtime_state("xmpp") + try: + omemo_row = UserXmppOmemoState.objects.get(user=request.user) + except UserXmppOmemoState.DoesNotExist: + omemo_row = None + security_settings = self._security_settings(request) + sender_jid = _parse_xmpp_jid(getattr(omemo_row, "last_sender_jid", "") or "") + omemo_plan = [ + { + "label": "Component OMEMO active", + "done": bool(xmpp_state.get("omemo_enabled")), + "hint": "The gateway's OMEMO plugin must be loaded and initialised.", + }, + { + "label": "OMEMO observed from your client", + "done": omemo_row is not None and omemo_row.status == "detected", + "hint": "Send any message with OMEMO enabled in your XMPP client.", + }, + { + "label": "Client key on file", + "done": bool(getattr(omemo_row, "latest_client_key", "")), + "hint": "A device key (sid/rid) must be recorded from your client.", + }, + { + "label": "Encryption required", + "done": security_settings.require_omemo, + "hint": "Enable 'Require OMEMO encryption' in Security Policy above to enforce this policy.", + }, + ] + return render(request, self.template_name, { + "xmpp_state": xmpp_state, + "omemo_row": omemo_row, + "security_settings": security_settings, + "global_override": self._global_override_payload(request), + "policy_services": self.POLICY_SERVICES, + "policy_rows": self._scope_rows(request), + "policy_groups": self._grouped_scope_rows(request), + "sender_jid": sender_jid, + "omemo_plan": omemo_plan, + }) diff --git a/core/views/tasks.py b/core/views/tasks.py index 3f5e288..4c3391a 100644 --- a/core/views/tasks.py +++ b/core/views/tasks.py @@ -338,6 +338,23 @@ def _codex_settings_with_defaults(raw: dict | None) -> dict: } +def _claude_settings_with_defaults(raw: dict | None) -> dict: + row = dict(raw or {}) + timeout_raw = str(row.get("timeout_seconds") or "60").strip() + try: + timeout_seconds = max(1, int(timeout_raw)) + except Exception: + timeout_seconds = 60 + return { + "command": str(row.get("command") or "claude").strip() or "claude", + "workspace_root": str(row.get("workspace_root") or "").strip(), + "default_profile": str(row.get("default_profile") or "").strip(), + "timeout_seconds": timeout_seconds, + "approver_service": str(row.get("approver_service") or "").strip().lower(), + "approver_identifier": str(row.get("approver_identifier") or "").strip(), + } + + def _enqueue_codex_task_submission( *, user, @@ -347,10 +364,12 @@ def _enqueue_codex_task_submission( mode: str = "default", command_text: str = "", source_message=None, + provider: str = "codex_cli", ) -> CodexRun: + provider = str(provider or "codex_cli").strip() or "codex_cli" external_chat_id = resolve_external_chat_id( user=user, - provider="codex_cli", + provider=provider, service=source_service, channel=source_channel, ) @@ -398,6 +417,7 @@ def _enqueue_codex_task_submission( action="append_update", provider_payload=dict(provider_payload), idempotency_key=idempotency_key, + provider=provider, ) return run @@ -703,6 +723,12 @@ class TasksHub(LoginRequiredMixin, View): "mapped": mapped, } ) + enabled_providers = list( + TaskProviderConfig.objects.filter(user=request.user, enabled=True) + .exclude(provider="mock") + .values_list("provider", flat=True) + .order_by("provider") + ) return { "projects": projects, "project_choices": all_projects, @@ -711,6 +737,7 @@ class TasksHub(LoginRequiredMixin, View): "person_identifier_rows": person_identifier_rows, "selected_project": selected_project, "show_empty_projects": show_empty, + "enabled_providers": enabled_providers, } def get(self, request): @@ -1152,9 +1179,13 @@ class TaskSettings(LoginRequiredMixin, View): provider_map = _provider_row_map(request.user) codex_cfg = provider_map.get("codex_cli") codex_settings = _codex_settings_with_defaults(dict(getattr(codex_cfg, "settings", {}) or {})) + claude_cfg = provider_map.get("claude_cli") + claude_settings = _claude_settings_with_defaults(dict(getattr(claude_cfg, "settings", {}) or {})) mock_cfg = provider_map.get("mock") codex_provider = get_provider("codex_cli") + claude_provider = get_provider("claude_cli") codex_healthcheck = codex_provider.healthcheck(codex_settings) if codex_cfg else None + claude_healthcheck = claude_provider.healthcheck(claude_settings) if claude_cfg else None codex_queue_counts = { "pending": ExternalSyncEvent.objects.filter( user=request.user, provider="codex_cli", status="pending" @@ -1169,11 +1200,25 @@ class TaskSettings(LoginRequiredMixin, View): user=request.user, provider="codex_cli", status="ok" ).count(), } + claude_queue_counts = { + "pending": ExternalSyncEvent.objects.filter( + user=request.user, provider="claude_cli", status="pending" + ).count(), + "waiting_approval": ExternalSyncEvent.objects.filter( + user=request.user, provider="claude_cli", status="waiting_approval" + ).count(), + "failed": ExternalSyncEvent.objects.filter( + user=request.user, provider="claude_cli", status="failed" + ).count(), + "ok": ExternalSyncEvent.objects.filter( + user=request.user, provider="claude_cli", status="ok" + ).count(), + } codex_recent_runs = CodexRun.objects.filter(user=request.user).order_by("-created_at")[:10] latest_worker_event = ( ExternalSyncEvent.objects.filter( user=request.user, - provider="codex_cli", + provider__in=["codex_cli", "claude_cli"], ) .filter(status__in=["ok", "failed", "waiting_approval", "retrying"]) .order_by("-updated_at") @@ -1233,6 +1278,21 @@ class TaskSettings(LoginRequiredMixin, View): "queue_counts": codex_queue_counts, "recent_runs": codex_recent_runs, }, + "claude_provider_config": claude_cfg, + "claude_provider_settings": { + "command": str(claude_settings.get("command") or "claude"), + "workspace_root": str(claude_settings.get("workspace_root") or ""), + "default_profile": str(claude_settings.get("default_profile") or ""), + "timeout_seconds": int(claude_settings.get("timeout_seconds") or 60), + "approver_service": str(claude_settings.get("approver_service") or ""), + "approver_identifier": str(claude_settings.get("approver_identifier") or ""), + }, + "claude_compact_summary": { + "healthcheck_ok": bool(getattr(claude_healthcheck, "ok", False)), + "healthcheck_error": str(getattr(claude_healthcheck, "error", "") or ""), + "healthcheck_payload": dict(getattr(claude_healthcheck, "payload", {}) or {}), + "queue_counts": claude_queue_counts, + }, "person_identifiers": person_identifiers, "external_link_person_identifiers": external_link_person_identifiers, "external_link_scoped": external_link_scoped, @@ -1376,6 +1436,17 @@ class TaskSettings(LoginRequiredMixin, View): "approver_mode": "channel", } ) + elif provider == "claude_cli": + settings_payload = _claude_settings_with_defaults( + { + "command": request.POST.get("command"), + "workspace_root": request.POST.get("workspace_root"), + "default_profile": request.POST.get("default_profile"), + "timeout_seconds": request.POST.get("timeout_seconds"), + "approver_service": request.POST.get("approver_service"), + "approver_identifier": request.POST.get("approver_identifier"), + } + ) row.settings = settings_payload row.save(update_fields=["enabled", "settings", "updated_at"]) return _settings_redirect(request) @@ -1460,10 +1531,16 @@ class TaskSettings(LoginRequiredMixin, View): return _settings_redirect(request) +_ALLOWED_SUBMIT_PROVIDERS = {"codex_cli", "claude_cli"} + + class TaskCodexSubmit(LoginRequiredMixin, View): def post(self, request): task_id = str(request.POST.get("task_id") or "").strip() next_url = str(request.POST.get("next") or reverse("tasks_hub")).strip() + provider = str(request.POST.get("provider") or "codex_cli").strip().lower() + if provider not in _ALLOWED_SUBMIT_PROVIDERS: + provider = "codex_cli" task = get_object_or_404( DerivedTask.objects.select_related("project", "epic", "origin_message"), id=task_id, @@ -1471,13 +1548,14 @@ class TaskCodexSubmit(LoginRequiredMixin, View): ) cfg = TaskProviderConfig.objects.filter( user=request.user, - provider="codex_cli", + provider=provider, enabled=True, ).first() + provider_label = "Claude" if provider == "claude_cli" else "Codex" if cfg is None: messages.error( request, - "Codex provider is disabled. Enable it in Task Settings first.", + f"{provider_label} provider is disabled. Enable it in Task Settings first.", ) return redirect(next_url) run = _enqueue_codex_task_submission( @@ -1487,10 +1565,11 @@ class TaskCodexSubmit(LoginRequiredMixin, View): source_channel=str(task.source_channel or ""), mode="default", source_message=getattr(task, "origin_message", None), + provider=provider, ) messages.success( request, - f"Queued approval for task #{task.reference_code} before Codex run {run.id}.", + f"Queued approval for task #{task.reference_code} before {provider_label} run {run.id}.", ) return redirect(next_url) diff --git a/requirements.txt b/requirements.txt index 71357fd..d1a568d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ wheel==0.45.1 uwsgi==2.0.28 django==4.2.19 pre-commit==4.2.0 -django-crispy-forms==1.14.0 +django-crispy-forms==2.3 crispy-bulma==0.11.0 djangorestframework==3.15.2 uvloop==0.21.0 @@ -10,16 +10,16 @@ django-htmx==1.21.0 cryptography==44.0.2 django-debug-toolbar==4.4.6 django-debug-toolbar-template-profiler==2.1.0 -orjson==3.10.15 +orjson==3.10.18 msgpack==1.1.0 -apscheduler==3.10.4 +apscheduler==3.11.0 watchfiles==1.0.5 django-otp==1.6.0 django-two-factor-auth==1.17.0 django-otp-yubikey==1.1.0 phonenumbers==8.13.55 -qrcode==8.0 -pydantic==2.10.6 +qrcode==7.4.2 +pydantic==2.11.5 redis==6.2.0 hiredis==3.1.0 django-cachalot==2.7.0 @@ -30,6 +30,7 @@ openai==1.66.3 aiograpi==0.0.4 aiomysql==0.2.0 slixmpp==1.10.0 +slixmpp-omemo==2.1.0 neonize==0.3.12 watchdog==6.0.0 uvicorn==0.34.0 diff --git a/stack.env.example b/stack.env.example index ea3897a..68e2bf8 100644 --- a/stack.env.example +++ b/stack.env.example @@ -35,6 +35,8 @@ XMPP_USER_DOMAIN=example.com XMPP_PORT=8888 # Auto-generated if empty by Prosody startup helpers. XMPP_SECRET= +# Directory for OMEMO key storage. Defaults to /xmpp_omemo_data if unset. +# XMPP_OMEMO_DATA_DIR=./.podman/gia_xmpp_omemo_data # Optional Prosody container storage/config paths used by utilities/prosody/manage_prosody_container.sh PROSODY_IMAGE=docker.io/prosody/prosody:latest