from __future__ import annotations from dataclasses import dataclass from core.models import CommandSecurityPolicy, UserXmppOmemoState GLOBAL_SCOPE_KEY = "global.override" OVERRIDE_OPTIONS = {"per_scope", "on", "off"} @dataclass(slots=True) class CommandSecurityContext: service: str channel_identifier: str message_meta: dict payload: dict @dataclass(slots=True) class CommandPolicyDecision: allowed: bool code: str = "allowed" reason: str = "" def _normalize_service(value: str) -> str: return str(value or "").strip().lower() def _normalize_channel(value: str) -> str: return str(value or "").strip() def _normalize_list(values) -> list[str]: rows: list[str] = [] if not isinstance(values, list): return rows for row in values: item = str(row or "").strip() if item and item not in rows: rows.append(item) return rows def _parse_override_value(value: str) -> str: option = str(value or "").strip().lower() if option == "inherit": # Backward compatibility for previously saved values. option = "per_scope" if option in OVERRIDE_OPTIONS: return option return "per_scope" def _match_channel(rule: str, channel: str) -> bool: value = str(rule or "").strip() current = str(channel or "").strip() if not value: return False if value == "*": return True if value.endswith("*"): return current.startswith(value[:-1]) return current == value def _omemo_facts(ctx: CommandSecurityContext) -> tuple[str, str]: message_meta = dict(ctx.message_meta or {}) payload = dict(ctx.payload or {}) xmpp_meta = dict(message_meta.get("xmpp") or {}) status = str( xmpp_meta.get("omemo_status") or payload.get("omemo_status") or "" ).strip().lower() client_key = str( xmpp_meta.get("omemo_client_key") or payload.get("omemo_client_key") or "" ).strip() return status, client_key def _channel_allowed_for_rules(rules: dict, service: str, channel: str) -> bool: service_rules = _normalize_list(rules.get(service)) if not service_rules: service_rules = _normalize_list(rules.get("*")) if not service_rules: return True return any(_match_channel(rule, channel) for rule in service_rules) def _service_allowed(allowed_services: list[str], service: str) -> bool: if not allowed_services: return True return service in allowed_services def _effective_bool(local_value: bool, global_override: str) -> bool: option = _parse_override_value(global_override) if option == "on": return True if option == "off": return False return bool(local_value) def evaluate_command_policy( *, user, scope_key: str, context: CommandSecurityContext, ) -> CommandPolicyDecision: scope = str(scope_key or "").strip().lower() if not scope: return CommandPolicyDecision(allowed=True) policy = ( CommandSecurityPolicy.objects.filter( user=user, scope_key=scope, ) .order_by("-updated_at") .first() ) global_policy = ( CommandSecurityPolicy.objects.filter( user=user, scope_key=GLOBAL_SCOPE_KEY, ) .order_by("-updated_at") .first() ) if policy is None and global_policy is None: return CommandPolicyDecision(allowed=True) global_settings = dict(getattr(global_policy, "settings", {}) or {}) local_enabled = bool(getattr(policy, "enabled", True)) local_require_omemo = bool(getattr(policy, "require_omemo", False)) local_require_trusted = bool( getattr(policy, "require_trusted_omemo_fingerprint", False) ) enabled = _effective_bool(local_enabled, global_settings.get("scope_enabled")) require_omemo = _effective_bool( local_require_omemo, global_settings.get("require_omemo") ) require_trusted_omemo_fingerprint = _effective_bool( local_require_trusted, global_settings.get("require_trusted_fingerprint"), ) if not enabled: return CommandPolicyDecision( allowed=False, code="policy_disabled", reason=f"{scope} is disabled by command policy", ) service = _normalize_service(context.service) channel = _normalize_channel(context.channel_identifier) allowed_services = [ item.lower() for item in _normalize_list(getattr(policy, "allowed_services", [])) ] global_allowed_services = [ item.lower() for item in _normalize_list(getattr(global_policy, "allowed_services", [])) ] if not _service_allowed(allowed_services, service): return CommandPolicyDecision( allowed=False, code="service_not_allowed", reason=f"service={service or '-'} not allowed for scope={scope}", ) if not _service_allowed(global_allowed_services, service): return CommandPolicyDecision( allowed=False, code="service_not_allowed", reason=f"service={service or '-'} not allowed by global override", ) local_channel_rules = dict(getattr(policy, "allowed_channels", {}) or {}) if not _channel_allowed_for_rules(local_channel_rules, service, channel): return CommandPolicyDecision( allowed=False, code="channel_not_allowed", reason=f"channel={channel or '-'} not allowed for scope={scope}", ) global_channel_rules = dict(getattr(global_policy, "allowed_channels", {}) or {}) if not _channel_allowed_for_rules(global_channel_rules, service, channel): return CommandPolicyDecision( allowed=False, code="channel_not_allowed", reason=f"channel={channel or '-'} not allowed by global override", ) omemo_status, omemo_client_key = _omemo_facts(context) if require_omemo and omemo_status != "detected": return CommandPolicyDecision( allowed=False, code="omemo_required", reason=f"scope={scope} requires OMEMO", ) if require_trusted_omemo_fingerprint: if omemo_status != "detected" or not omemo_client_key: return CommandPolicyDecision( allowed=False, code="trusted_fingerprint_required", reason=f"scope={scope} requires trusted OMEMO fingerprint", ) state = UserXmppOmemoState.objects.filter(user=user).first() expected_key = str(getattr(state, "latest_client_key", "") or "").strip() if not expected_key or expected_key != omemo_client_key: return CommandPolicyDecision( allowed=False, code="fingerprint_mismatch", reason=f"scope={scope} OMEMO fingerprint does not match enrolled key", ) return CommandPolicyDecision(allowed=True)