Improve security
This commit is contained in:
220
core/security/command_policy.py
Normal file
220
core/security/command_policy.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.models import CommandSecurityPolicy, UserXmppOmemoState
|
||||
|
||||
GLOBAL_SCOPE_KEY = "global.override"
|
||||
OVERRIDE_OPTIONS = {"per_scope", "on", "off"}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CommandSecurityContext:
|
||||
service: str
|
||||
channel_identifier: str
|
||||
message_meta: dict
|
||||
payload: dict
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CommandPolicyDecision:
|
||||
allowed: bool
|
||||
code: str = "allowed"
|
||||
reason: str = ""
|
||||
|
||||
|
||||
def _normalize_service(value: str) -> str:
|
||||
return str(value or "").strip().lower()
|
||||
|
||||
|
||||
def _normalize_channel(value: str) -> str:
|
||||
return str(value or "").strip()
|
||||
|
||||
|
||||
def _normalize_list(values) -> list[str]:
|
||||
rows: list[str] = []
|
||||
if not isinstance(values, list):
|
||||
return rows
|
||||
for row in values:
|
||||
item = str(row or "").strip()
|
||||
if item and item not in rows:
|
||||
rows.append(item)
|
||||
return rows
|
||||
|
||||
|
||||
def _parse_override_value(value: str) -> str:
|
||||
option = str(value or "").strip().lower()
|
||||
if option == "inherit":
|
||||
# Backward compatibility for previously saved values.
|
||||
option = "per_scope"
|
||||
if option in OVERRIDE_OPTIONS:
|
||||
return option
|
||||
return "per_scope"
|
||||
|
||||
|
||||
def _match_channel(rule: str, channel: str) -> bool:
|
||||
value = str(rule or "").strip()
|
||||
current = str(channel or "").strip()
|
||||
if not value:
|
||||
return False
|
||||
if value == "*":
|
||||
return True
|
||||
if value.endswith("*"):
|
||||
return current.startswith(value[:-1])
|
||||
return current == value
|
||||
|
||||
|
||||
def _omemo_facts(ctx: CommandSecurityContext) -> tuple[str, str]:
|
||||
message_meta = dict(ctx.message_meta or {})
|
||||
payload = dict(ctx.payload or {})
|
||||
xmpp_meta = dict(message_meta.get("xmpp") or {})
|
||||
status = str(
|
||||
xmpp_meta.get("omemo_status")
|
||||
or payload.get("omemo_status")
|
||||
or ""
|
||||
).strip().lower()
|
||||
client_key = str(
|
||||
xmpp_meta.get("omemo_client_key")
|
||||
or payload.get("omemo_client_key")
|
||||
or ""
|
||||
).strip()
|
||||
return status, client_key
|
||||
|
||||
|
||||
def _channel_allowed_for_rules(rules: dict, service: str, channel: str) -> bool:
|
||||
service_rules = _normalize_list(rules.get(service))
|
||||
if not service_rules:
|
||||
service_rules = _normalize_list(rules.get("*"))
|
||||
if not service_rules:
|
||||
return True
|
||||
return any(_match_channel(rule, channel) for rule in service_rules)
|
||||
|
||||
|
||||
def _service_allowed(allowed_services: list[str], service: str) -> bool:
|
||||
if not allowed_services:
|
||||
return True
|
||||
return service in allowed_services
|
||||
|
||||
|
||||
def _effective_bool(local_value: bool, global_override: str) -> bool:
|
||||
option = _parse_override_value(global_override)
|
||||
if option == "on":
|
||||
return True
|
||||
if option == "off":
|
||||
return False
|
||||
return bool(local_value)
|
||||
|
||||
|
||||
def evaluate_command_policy(
|
||||
*,
|
||||
user,
|
||||
scope_key: str,
|
||||
context: CommandSecurityContext,
|
||||
) -> CommandPolicyDecision:
|
||||
scope = str(scope_key or "").strip().lower()
|
||||
if not scope:
|
||||
return CommandPolicyDecision(allowed=True)
|
||||
|
||||
policy = (
|
||||
CommandSecurityPolicy.objects.filter(
|
||||
user=user,
|
||||
scope_key=scope,
|
||||
)
|
||||
.order_by("-updated_at")
|
||||
.first()
|
||||
)
|
||||
global_policy = (
|
||||
CommandSecurityPolicy.objects.filter(
|
||||
user=user,
|
||||
scope_key=GLOBAL_SCOPE_KEY,
|
||||
)
|
||||
.order_by("-updated_at")
|
||||
.first()
|
||||
)
|
||||
|
||||
if policy is None and global_policy is None:
|
||||
return CommandPolicyDecision(allowed=True)
|
||||
|
||||
global_settings = dict(getattr(global_policy, "settings", {}) or {})
|
||||
local_enabled = bool(getattr(policy, "enabled", True))
|
||||
local_require_omemo = bool(getattr(policy, "require_omemo", False))
|
||||
local_require_trusted = bool(
|
||||
getattr(policy, "require_trusted_omemo_fingerprint", False)
|
||||
)
|
||||
enabled = _effective_bool(local_enabled, global_settings.get("scope_enabled"))
|
||||
require_omemo = _effective_bool(
|
||||
local_require_omemo, global_settings.get("require_omemo")
|
||||
)
|
||||
require_trusted_omemo_fingerprint = _effective_bool(
|
||||
local_require_trusted,
|
||||
global_settings.get("require_trusted_fingerprint"),
|
||||
)
|
||||
|
||||
if not enabled:
|
||||
return CommandPolicyDecision(
|
||||
allowed=False,
|
||||
code="policy_disabled",
|
||||
reason=f"{scope} is disabled by command policy",
|
||||
)
|
||||
|
||||
service = _normalize_service(context.service)
|
||||
channel = _normalize_channel(context.channel_identifier)
|
||||
allowed_services = [
|
||||
item.lower() for item in _normalize_list(getattr(policy, "allowed_services", []))
|
||||
]
|
||||
global_allowed_services = [
|
||||
item.lower()
|
||||
for item in _normalize_list(getattr(global_policy, "allowed_services", []))
|
||||
]
|
||||
if not _service_allowed(allowed_services, service):
|
||||
return CommandPolicyDecision(
|
||||
allowed=False,
|
||||
code="service_not_allowed",
|
||||
reason=f"service={service or '-'} not allowed for scope={scope}",
|
||||
)
|
||||
if not _service_allowed(global_allowed_services, service):
|
||||
return CommandPolicyDecision(
|
||||
allowed=False,
|
||||
code="service_not_allowed",
|
||||
reason=f"service={service or '-'} not allowed by global override",
|
||||
)
|
||||
local_channel_rules = dict(getattr(policy, "allowed_channels", {}) or {})
|
||||
if not _channel_allowed_for_rules(local_channel_rules, service, channel):
|
||||
return CommandPolicyDecision(
|
||||
allowed=False,
|
||||
code="channel_not_allowed",
|
||||
reason=f"channel={channel or '-'} not allowed for scope={scope}",
|
||||
)
|
||||
global_channel_rules = dict(getattr(global_policy, "allowed_channels", {}) or {})
|
||||
if not _channel_allowed_for_rules(global_channel_rules, service, channel):
|
||||
return CommandPolicyDecision(
|
||||
allowed=False,
|
||||
code="channel_not_allowed",
|
||||
reason=f"channel={channel or '-'} not allowed by global override",
|
||||
)
|
||||
|
||||
omemo_status, omemo_client_key = _omemo_facts(context)
|
||||
if require_omemo and omemo_status != "detected":
|
||||
return CommandPolicyDecision(
|
||||
allowed=False,
|
||||
code="omemo_required",
|
||||
reason=f"scope={scope} requires OMEMO",
|
||||
)
|
||||
|
||||
if require_trusted_omemo_fingerprint:
|
||||
if omemo_status != "detected" or not omemo_client_key:
|
||||
return CommandPolicyDecision(
|
||||
allowed=False,
|
||||
code="trusted_fingerprint_required",
|
||||
reason=f"scope={scope} requires trusted OMEMO fingerprint",
|
||||
)
|
||||
state = UserXmppOmemoState.objects.filter(user=user).first()
|
||||
expected_key = str(getattr(state, "latest_client_key", "") or "").strip()
|
||||
if not expected_key or expected_key != omemo_client_key:
|
||||
return CommandPolicyDecision(
|
||||
allowed=False,
|
||||
code="fingerprint_mismatch",
|
||||
reason=f"scope={scope} OMEMO fingerprint does not match enrolled key",
|
||||
)
|
||||
|
||||
return CommandPolicyDecision(allowed=True)
|
||||
Reference in New Issue
Block a user