Improve security

This commit is contained in:
2026-03-07 15:34:23 +00:00
parent add685a326
commit 611de57bf8
31 changed files with 3617 additions and 58 deletions

View File

@@ -0,0 +1,220 @@
from __future__ import annotations
from dataclasses import dataclass
from core.models import CommandSecurityPolicy, UserXmppOmemoState
GLOBAL_SCOPE_KEY = "global.override"
OVERRIDE_OPTIONS = {"per_scope", "on", "off"}
@dataclass(slots=True)
class CommandSecurityContext:
service: str
channel_identifier: str
message_meta: dict
payload: dict
@dataclass(slots=True)
class CommandPolicyDecision:
allowed: bool
code: str = "allowed"
reason: str = ""
def _normalize_service(value: str) -> str:
return str(value or "").strip().lower()
def _normalize_channel(value: str) -> str:
return str(value or "").strip()
def _normalize_list(values) -> list[str]:
rows: list[str] = []
if not isinstance(values, list):
return rows
for row in values:
item = str(row or "").strip()
if item and item not in rows:
rows.append(item)
return rows
def _parse_override_value(value: str) -> str:
option = str(value or "").strip().lower()
if option == "inherit":
# Backward compatibility for previously saved values.
option = "per_scope"
if option in OVERRIDE_OPTIONS:
return option
return "per_scope"
def _match_channel(rule: str, channel: str) -> bool:
value = str(rule or "").strip()
current = str(channel or "").strip()
if not value:
return False
if value == "*":
return True
if value.endswith("*"):
return current.startswith(value[:-1])
return current == value
def _omemo_facts(ctx: CommandSecurityContext) -> tuple[str, str]:
message_meta = dict(ctx.message_meta or {})
payload = dict(ctx.payload or {})
xmpp_meta = dict(message_meta.get("xmpp") or {})
status = str(
xmpp_meta.get("omemo_status")
or payload.get("omemo_status")
or ""
).strip().lower()
client_key = str(
xmpp_meta.get("omemo_client_key")
or payload.get("omemo_client_key")
or ""
).strip()
return status, client_key
def _channel_allowed_for_rules(rules: dict, service: str, channel: str) -> bool:
service_rules = _normalize_list(rules.get(service))
if not service_rules:
service_rules = _normalize_list(rules.get("*"))
if not service_rules:
return True
return any(_match_channel(rule, channel) for rule in service_rules)
def _service_allowed(allowed_services: list[str], service: str) -> bool:
if not allowed_services:
return True
return service in allowed_services
def _effective_bool(local_value: bool, global_override: str) -> bool:
option = _parse_override_value(global_override)
if option == "on":
return True
if option == "off":
return False
return bool(local_value)
def evaluate_command_policy(
*,
user,
scope_key: str,
context: CommandSecurityContext,
) -> CommandPolicyDecision:
scope = str(scope_key or "").strip().lower()
if not scope:
return CommandPolicyDecision(allowed=True)
policy = (
CommandSecurityPolicy.objects.filter(
user=user,
scope_key=scope,
)
.order_by("-updated_at")
.first()
)
global_policy = (
CommandSecurityPolicy.objects.filter(
user=user,
scope_key=GLOBAL_SCOPE_KEY,
)
.order_by("-updated_at")
.first()
)
if policy is None and global_policy is None:
return CommandPolicyDecision(allowed=True)
global_settings = dict(getattr(global_policy, "settings", {}) or {})
local_enabled = bool(getattr(policy, "enabled", True))
local_require_omemo = bool(getattr(policy, "require_omemo", False))
local_require_trusted = bool(
getattr(policy, "require_trusted_omemo_fingerprint", False)
)
enabled = _effective_bool(local_enabled, global_settings.get("scope_enabled"))
require_omemo = _effective_bool(
local_require_omemo, global_settings.get("require_omemo")
)
require_trusted_omemo_fingerprint = _effective_bool(
local_require_trusted,
global_settings.get("require_trusted_fingerprint"),
)
if not enabled:
return CommandPolicyDecision(
allowed=False,
code="policy_disabled",
reason=f"{scope} is disabled by command policy",
)
service = _normalize_service(context.service)
channel = _normalize_channel(context.channel_identifier)
allowed_services = [
item.lower() for item in _normalize_list(getattr(policy, "allowed_services", []))
]
global_allowed_services = [
item.lower()
for item in _normalize_list(getattr(global_policy, "allowed_services", []))
]
if not _service_allowed(allowed_services, service):
return CommandPolicyDecision(
allowed=False,
code="service_not_allowed",
reason=f"service={service or '-'} not allowed for scope={scope}",
)
if not _service_allowed(global_allowed_services, service):
return CommandPolicyDecision(
allowed=False,
code="service_not_allowed",
reason=f"service={service or '-'} not allowed by global override",
)
local_channel_rules = dict(getattr(policy, "allowed_channels", {}) or {})
if not _channel_allowed_for_rules(local_channel_rules, service, channel):
return CommandPolicyDecision(
allowed=False,
code="channel_not_allowed",
reason=f"channel={channel or '-'} not allowed for scope={scope}",
)
global_channel_rules = dict(getattr(global_policy, "allowed_channels", {}) or {})
if not _channel_allowed_for_rules(global_channel_rules, service, channel):
return CommandPolicyDecision(
allowed=False,
code="channel_not_allowed",
reason=f"channel={channel or '-'} not allowed by global override",
)
omemo_status, omemo_client_key = _omemo_facts(context)
if require_omemo and omemo_status != "detected":
return CommandPolicyDecision(
allowed=False,
code="omemo_required",
reason=f"scope={scope} requires OMEMO",
)
if require_trusted_omemo_fingerprint:
if omemo_status != "detected" or not omemo_client_key:
return CommandPolicyDecision(
allowed=False,
code="trusted_fingerprint_required",
reason=f"scope={scope} requires trusted OMEMO fingerprint",
)
state = UserXmppOmemoState.objects.filter(user=user).first()
expected_key = str(getattr(state, "latest_client_key", "") or "").strip()
if not expected_key or expected_key != omemo_client_key:
return CommandPolicyDecision(
allowed=False,
code="fingerprint_mismatch",
reason=f"scope={scope} OMEMO fingerprint does not match enrolled key",
)
return CommandPolicyDecision(allowed=True)