Files
GIA/core/commands/engine.py

342 lines
12 KiB
Python

from __future__ import annotations
from asgiref.sync import sync_to_async
from core.commands.base import CommandContext, CommandResult
from core.commands.handlers.bp import (
BPCommandHandler,
bp_reply_is_optional_for_trigger,
bp_subcommands_enabled,
bp_trigger_matches,
)
from core.commands.handlers.codex import CodexCommandHandler, codex_trigger_matches
from core.commands.policies import ensure_variant_policies_for_profile
from core.commands.registry import get as get_handler
from core.commands.registry import register
from core.messaging.reply_sync import is_mirrored_origin
from core.models import CommandAction, CommandChannelBinding, CommandProfile, Message
from core.tasks.chat_defaults import ensure_default_source_for_chat
from core.util import logs
log = logs.get_logger("command_engine")
_REGISTERED = False
def _channel_variants(service: str, channel_identifier: str) -> list[str]:
value = str(channel_identifier or "").strip()
if not value:
return []
variants = [value]
svc = str(service or "").strip().lower()
if svc == "whatsapp":
bare = value.split("@", 1)[0].strip()
if bare and bare not in variants:
variants.append(bare)
group = f"{bare}@g.us" if bare else ""
if group and group not in variants:
variants.append(group)
return variants
def _canonical_channel_identifier(service: str, channel_identifier: str) -> str:
value = str(channel_identifier or "").strip()
if not value:
return ""
if str(service or "").strip().lower() == "whatsapp":
return value.split("@", 1)[0].strip()
return value
def _effective_bootstrap_scope(
ctx: CommandContext,
trigger_message: Message,
) -> tuple[str, str]:
service = str(ctx.service or "").strip().lower()
identifier = str(ctx.channel_identifier or "").strip()
if service != "web":
return service, identifier
session_identifier = getattr(getattr(trigger_message, "session", None), "identifier", None)
fallback_service = str(getattr(session_identifier, "service", "") or "").strip().lower()
fallback_identifier = str(getattr(session_identifier, "identifier", "") or "").strip()
if fallback_service and fallback_identifier and fallback_service != "web":
return fallback_service, fallback_identifier
return service, identifier
def _ensure_bp_profile(user_id: int) -> CommandProfile:
profile, _ = CommandProfile.objects.get_or_create(
user_id=user_id,
slug="bp",
defaults={
"name": "Business Plan",
"enabled": True,
"trigger_token": ".bp",
"reply_required": True,
"exact_match_only": True,
"window_scope": "conversation",
"visibility_mode": "status_in_source",
},
)
updated = False
if not profile.enabled:
profile.enabled = True
updated = True
if updated:
profile.save(update_fields=["enabled", "updated_at"])
if str(profile.trigger_token or "").strip() != ".bp":
profile.trigger_token = ".bp"
profile.save(update_fields=["trigger_token", "updated_at"])
for action_type, position in (("extract_bp", 0), ("save_document", 1), ("post_result", 2)):
action, created = CommandAction.objects.get_or_create(
profile=profile,
action_type=action_type,
defaults={"enabled": True, "position": position},
)
if (not created) and (not action.enabled):
action.enabled = True
action.save(update_fields=["enabled", "updated_at"])
ensure_variant_policies_for_profile(profile)
return profile
def _ensure_codex_profile(user_id: int) -> CommandProfile:
profile, _ = CommandProfile.objects.get_or_create(
user_id=user_id,
slug="codex",
defaults={
"name": "Codex",
"enabled": True,
"trigger_token": ".codex",
"reply_required": False,
"exact_match_only": False,
"window_scope": "conversation",
"visibility_mode": "status_in_source",
},
)
if not profile.enabled:
profile.enabled = True
profile.save(update_fields=["enabled", "updated_at"])
if str(profile.trigger_token or "").strip() != ".codex":
profile.trigger_token = ".codex"
profile.save(update_fields=["trigger_token", "updated_at"])
return profile
def _ensure_profile_for_slug(user_id: int, slug: str) -> CommandProfile | None:
if slug == "bp":
return _ensure_bp_profile(user_id)
if slug == "codex":
return _ensure_codex_profile(user_id)
return None
def _detected_bootstrap_slugs(message_text: str) -> list[str]:
slugs: list[str] = []
if bp_trigger_matches(message_text, ".bp", False):
slugs.append("bp")
if codex_trigger_matches(message_text, ".codex", False):
slugs.append("codex")
return slugs
def _auto_setup_profile_bindings_for_first_command(
ctx: CommandContext,
trigger_message: Message,
) -> None:
author = str(getattr(trigger_message, "custom_author", "") or "").strip().upper()
if author != "USER":
return
slugs = _detected_bootstrap_slugs(ctx.message_text)
if not slugs:
return
service, identifier = _effective_bootstrap_scope(ctx, trigger_message)
service = str(service or "").strip().lower()
canonical = _canonical_channel_identifier(service, identifier)
variants = _channel_variants(service, canonical)
if not service or not variants:
return
for slug in slugs:
profile = _ensure_profile_for_slug(ctx.user_id, slug)
if profile is None:
continue
already_enabled = CommandChannelBinding.objects.filter(
profile=profile,
enabled=True,
direction="ingress",
service=service,
channel_identifier__in=variants,
).exists()
if already_enabled:
continue
for direction in ("ingress", "egress"):
binding, _ = CommandChannelBinding.objects.get_or_create(
profile=profile,
direction=direction,
service=service,
channel_identifier=canonical,
defaults={"enabled": True},
)
if not binding.enabled:
binding.enabled = True
binding.save(update_fields=["enabled", "updated_at"])
alternate_variants = [value for value in variants if value != canonical]
if alternate_variants:
CommandChannelBinding.objects.filter(
profile=profile,
direction=direction,
service=service,
channel_identifier__in=alternate_variants,
).update(enabled=False)
ensure_default_source_for_chat(
user=trigger_message.user,
service=service,
channel_identifier=canonical,
message=trigger_message,
)
def ensure_handlers_registered():
global _REGISTERED
if _REGISTERED:
return
register(BPCommandHandler())
register(CodexCommandHandler())
_REGISTERED = True
async def _eligible_profiles(ctx: CommandContext) -> list[CommandProfile]:
def _load():
trigger = (
Message.objects.select_related("session", "session__identifier")
.filter(id=ctx.message_id, user_id=ctx.user_id)
.first()
)
direct_variants = _channel_variants(ctx.service, ctx.channel_identifier)
source_channel = str(getattr(trigger, "source_chat_id", "") or "").strip()
for expanded in _channel_variants(ctx.service, source_channel):
if expanded and expanded not in direct_variants:
direct_variants.append(expanded)
if not direct_variants:
return []
direct = list(
CommandProfile.objects.filter(
user_id=ctx.user_id,
enabled=True,
channel_bindings__enabled=True,
channel_bindings__direction="ingress",
channel_bindings__service=ctx.service,
channel_bindings__channel_identifier__in=direct_variants,
).distinct()
)
if direct:
return direct
# Compose-originated messages use `web` service even when the
# underlying conversation is mapped to a platform identifier.
if str(ctx.service or "").strip().lower() != "web":
return []
identifier = getattr(getattr(trigger, "session", None), "identifier", None)
fallback_service = str(getattr(identifier, "service", "") or "").strip().lower()
fallback_identifier = str(getattr(identifier, "identifier", "") or "").strip()
fallback_variants = _channel_variants(fallback_service, fallback_identifier)
for expanded in _channel_variants(fallback_service, source_channel):
if expanded and expanded not in fallback_variants:
fallback_variants.append(expanded)
if not fallback_service or not fallback_variants:
return []
return list(
CommandProfile.objects.filter(
user_id=ctx.user_id,
enabled=True,
channel_bindings__enabled=True,
channel_bindings__direction="ingress",
channel_bindings__service=fallback_service,
channel_bindings__channel_identifier__in=fallback_variants,
).distinct()
)
return await sync_to_async(_load)()
def _matches_trigger(profile: CommandProfile, text: str) -> bool:
if profile.slug == "bp" and bp_subcommands_enabled():
return bp_trigger_matches(
message_text=text,
trigger_token=profile.trigger_token,
exact_match_only=profile.exact_match_only,
)
if profile.slug == "codex":
return codex_trigger_matches(
message_text=text,
trigger_token=profile.trigger_token,
exact_match_only=profile.exact_match_only,
)
body = str(text or "").strip()
trigger = str(profile.trigger_token or "").strip()
if not trigger:
return False
if profile.exact_match_only:
return body == trigger
return trigger in body
async def process_inbound_message(ctx: CommandContext) -> list[CommandResult]:
ensure_handlers_registered()
trigger_message = await sync_to_async(
lambda: Message.objects.filter(id=ctx.message_id).first()
)()
if trigger_message is None:
return []
if is_mirrored_origin(trigger_message.message_meta):
return []
await sync_to_async(_auto_setup_profile_bindings_for_first_command)(
ctx,
trigger_message,
)
profiles = await _eligible_profiles(ctx)
results: list[CommandResult] = []
for profile in profiles:
if not _matches_trigger(profile, ctx.message_text):
continue
if profile.reply_required and trigger_message.reply_to_id is None:
if (
profile.slug == "bp"
and bp_subcommands_enabled()
and bp_reply_is_optional_for_trigger(ctx.message_text)
):
pass
else:
results.append(
CommandResult(
ok=False,
status="skipped",
error="reply_required",
payload={"profile": profile.slug},
)
)
continue
handler = get_handler(profile.slug)
if handler is None:
results.append(
CommandResult(
ok=False,
status="failed",
error=f"missing_handler:{profile.slug}",
)
)
continue
try:
result = await handler.execute(ctx)
results.append(result)
except Exception as exc:
log.exception("command execution failed for profile=%s: %s", profile.slug, exc)
results.append(
CommandResult(
ok=False,
status="failed",
error=f"handler_exception:{exc}",
)
)
return results