Implement plans
This commit is contained in:
@@ -1,18 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.conf import settings
|
||||
|
||||
from core.commands.base import CommandContext, CommandResult
|
||||
from core.commands.handlers.bp import (
|
||||
BPCommandHandler,
|
||||
bp_reply_is_optional_for_trigger,
|
||||
bp_subcommands_enabled,
|
||||
bp_trigger_matches,
|
||||
)
|
||||
from core.commands.handlers.codex import CodexCommandHandler, codex_trigger_matches
|
||||
from core.commands.policies import ensure_variant_policies_for_profile
|
||||
from core.commands.registry import get as get_handler
|
||||
from core.commands.registry import register
|
||||
from core.messaging.reply_sync import is_mirrored_origin
|
||||
from core.models import CommandChannelBinding, CommandProfile, Message
|
||||
from core.models import CommandAction, CommandChannelBinding, CommandProfile, Message
|
||||
from core.util import logs
|
||||
|
||||
log = logs.get_logger("command_engine")
|
||||
@@ -36,17 +38,178 @@ def _channel_variants(service: str, channel_identifier: str) -> list[str]:
|
||||
return variants
|
||||
|
||||
|
||||
def _canonical_channel_identifier(service: str, channel_identifier: str) -> str:
|
||||
value = str(channel_identifier or "").strip()
|
||||
if not value:
|
||||
return ""
|
||||
if str(service or "").strip().lower() == "whatsapp":
|
||||
return value.split("@", 1)[0].strip()
|
||||
return value
|
||||
|
||||
|
||||
def _effective_bootstrap_scope(
|
||||
ctx: CommandContext,
|
||||
trigger_message: Message,
|
||||
) -> tuple[str, str]:
|
||||
service = str(ctx.service or "").strip().lower()
|
||||
identifier = str(ctx.channel_identifier or "").strip()
|
||||
if service != "web":
|
||||
return service, identifier
|
||||
session_identifier = getattr(getattr(trigger_message, "session", None), "identifier", None)
|
||||
fallback_service = str(getattr(session_identifier, "service", "") or "").strip().lower()
|
||||
fallback_identifier = str(getattr(session_identifier, "identifier", "") or "").strip()
|
||||
if fallback_service and fallback_identifier and fallback_service != "web":
|
||||
return fallback_service, fallback_identifier
|
||||
return service, identifier
|
||||
|
||||
|
||||
def _ensure_bp_profile(user_id: int) -> CommandProfile:
|
||||
profile, _ = CommandProfile.objects.get_or_create(
|
||||
user_id=user_id,
|
||||
slug="bp",
|
||||
defaults={
|
||||
"name": "Business Plan",
|
||||
"enabled": True,
|
||||
"trigger_token": ".bp",
|
||||
"reply_required": True,
|
||||
"exact_match_only": True,
|
||||
"window_scope": "conversation",
|
||||
"visibility_mode": "status_in_source",
|
||||
},
|
||||
)
|
||||
updated = False
|
||||
if not profile.enabled:
|
||||
profile.enabled = True
|
||||
updated = True
|
||||
if updated:
|
||||
profile.save(update_fields=["enabled", "updated_at"])
|
||||
if str(profile.trigger_token or "").strip() != ".bp":
|
||||
profile.trigger_token = ".bp"
|
||||
profile.save(update_fields=["trigger_token", "updated_at"])
|
||||
for action_type, position in (("extract_bp", 0), ("save_document", 1), ("post_result", 2)):
|
||||
action, created = CommandAction.objects.get_or_create(
|
||||
profile=profile,
|
||||
action_type=action_type,
|
||||
defaults={"enabled": True, "position": position},
|
||||
)
|
||||
if (not created) and (not action.enabled):
|
||||
action.enabled = True
|
||||
action.save(update_fields=["enabled", "updated_at"])
|
||||
ensure_variant_policies_for_profile(profile)
|
||||
return profile
|
||||
|
||||
|
||||
def _ensure_codex_profile(user_id: int) -> CommandProfile:
|
||||
profile, _ = CommandProfile.objects.get_or_create(
|
||||
user_id=user_id,
|
||||
slug="codex",
|
||||
defaults={
|
||||
"name": "Codex",
|
||||
"enabled": True,
|
||||
"trigger_token": ".codex",
|
||||
"reply_required": False,
|
||||
"exact_match_only": False,
|
||||
"window_scope": "conversation",
|
||||
"visibility_mode": "status_in_source",
|
||||
},
|
||||
)
|
||||
if not profile.enabled:
|
||||
profile.enabled = True
|
||||
profile.save(update_fields=["enabled", "updated_at"])
|
||||
if str(profile.trigger_token or "").strip() != ".codex":
|
||||
profile.trigger_token = ".codex"
|
||||
profile.save(update_fields=["trigger_token", "updated_at"])
|
||||
return profile
|
||||
|
||||
|
||||
def _ensure_profile_for_slug(user_id: int, slug: str) -> CommandProfile | None:
|
||||
if slug == "bp":
|
||||
return _ensure_bp_profile(user_id)
|
||||
if slug == "codex":
|
||||
return _ensure_codex_profile(user_id)
|
||||
return None
|
||||
|
||||
|
||||
def _detected_bootstrap_slugs(message_text: str) -> list[str]:
|
||||
slugs: list[str] = []
|
||||
if bp_trigger_matches(message_text, ".bp", False):
|
||||
slugs.append("bp")
|
||||
if codex_trigger_matches(message_text, ".codex", False):
|
||||
slugs.append("codex")
|
||||
return slugs
|
||||
|
||||
|
||||
def _auto_setup_profile_bindings_for_first_command(
|
||||
ctx: CommandContext,
|
||||
trigger_message: Message,
|
||||
) -> None:
|
||||
author = str(getattr(trigger_message, "custom_author", "") or "").strip().upper()
|
||||
if author != "USER":
|
||||
return
|
||||
slugs = _detected_bootstrap_slugs(ctx.message_text)
|
||||
if not slugs:
|
||||
return
|
||||
service, identifier = _effective_bootstrap_scope(ctx, trigger_message)
|
||||
service = str(service or "").strip().lower()
|
||||
canonical = _canonical_channel_identifier(service, identifier)
|
||||
variants = _channel_variants(service, canonical)
|
||||
if not service or not variants:
|
||||
return
|
||||
for slug in slugs:
|
||||
profile = _ensure_profile_for_slug(ctx.user_id, slug)
|
||||
if profile is None:
|
||||
continue
|
||||
already_enabled = CommandChannelBinding.objects.filter(
|
||||
profile=profile,
|
||||
enabled=True,
|
||||
direction="ingress",
|
||||
service=service,
|
||||
channel_identifier__in=variants,
|
||||
).exists()
|
||||
if already_enabled:
|
||||
continue
|
||||
for direction in ("ingress", "egress"):
|
||||
binding, _ = CommandChannelBinding.objects.get_or_create(
|
||||
profile=profile,
|
||||
direction=direction,
|
||||
service=service,
|
||||
channel_identifier=canonical,
|
||||
defaults={"enabled": True},
|
||||
)
|
||||
if not binding.enabled:
|
||||
binding.enabled = True
|
||||
binding.save(update_fields=["enabled", "updated_at"])
|
||||
alternate_variants = [value for value in variants if value != canonical]
|
||||
if alternate_variants:
|
||||
CommandChannelBinding.objects.filter(
|
||||
profile=profile,
|
||||
direction=direction,
|
||||
service=service,
|
||||
channel_identifier__in=alternate_variants,
|
||||
).update(enabled=False)
|
||||
|
||||
|
||||
def ensure_handlers_registered():
|
||||
global _REGISTERED
|
||||
if _REGISTERED:
|
||||
return
|
||||
register(BPCommandHandler())
|
||||
register(CodexCommandHandler())
|
||||
_REGISTERED = True
|
||||
|
||||
|
||||
async def _eligible_profiles(ctx: CommandContext) -> list[CommandProfile]:
|
||||
def _load():
|
||||
trigger = (
|
||||
Message.objects.select_related("session", "session__identifier")
|
||||
.filter(id=ctx.message_id, user_id=ctx.user_id)
|
||||
.first()
|
||||
)
|
||||
direct_variants = _channel_variants(ctx.service, ctx.channel_identifier)
|
||||
source_channel = str(getattr(trigger, "source_chat_id", "") or "").strip()
|
||||
for expanded in _channel_variants(ctx.service, source_channel):
|
||||
if expanded and expanded not in direct_variants:
|
||||
direct_variants.append(expanded)
|
||||
if not direct_variants:
|
||||
return []
|
||||
direct = list(
|
||||
@@ -65,15 +228,13 @@ async def _eligible_profiles(ctx: CommandContext) -> list[CommandProfile]:
|
||||
# underlying conversation is mapped to a platform identifier.
|
||||
if str(ctx.service or "").strip().lower() != "web":
|
||||
return []
|
||||
trigger = (
|
||||
Message.objects.select_related("session", "session__identifier")
|
||||
.filter(id=ctx.message_id, user_id=ctx.user_id)
|
||||
.first()
|
||||
)
|
||||
identifier = getattr(getattr(trigger, "session", None), "identifier", None)
|
||||
fallback_service = str(getattr(identifier, "service", "") or "").strip().lower()
|
||||
fallback_identifier = str(getattr(identifier, "identifier", "") or "").strip()
|
||||
fallback_variants = _channel_variants(fallback_service, fallback_identifier)
|
||||
for expanded in _channel_variants(fallback_service, source_channel):
|
||||
if expanded and expanded not in fallback_variants:
|
||||
fallback_variants.append(expanded)
|
||||
if not fallback_service or not fallback_variants:
|
||||
return []
|
||||
return list(
|
||||
@@ -91,12 +252,18 @@ async def _eligible_profiles(ctx: CommandContext) -> list[CommandProfile]:
|
||||
|
||||
|
||||
def _matches_trigger(profile: CommandProfile, text: str) -> bool:
|
||||
if profile.slug == "bp" and bool(getattr(settings, "BP_SUBCOMMANDS_V1", True)):
|
||||
if profile.slug == "bp" and bp_subcommands_enabled():
|
||||
return bp_trigger_matches(
|
||||
message_text=text,
|
||||
trigger_token=profile.trigger_token,
|
||||
exact_match_only=profile.exact_match_only,
|
||||
)
|
||||
if profile.slug == "codex":
|
||||
return codex_trigger_matches(
|
||||
message_text=text,
|
||||
trigger_token=profile.trigger_token,
|
||||
exact_match_only=profile.exact_match_only,
|
||||
)
|
||||
body = str(text or "").strip()
|
||||
trigger = str(profile.trigger_token or "").strip()
|
||||
if not trigger:
|
||||
@@ -115,6 +282,10 @@ async def process_inbound_message(ctx: CommandContext) -> list[CommandResult]:
|
||||
return []
|
||||
if is_mirrored_origin(trigger_message.message_meta):
|
||||
return []
|
||||
await sync_to_async(_auto_setup_profile_bindings_for_first_command)(
|
||||
ctx,
|
||||
trigger_message,
|
||||
)
|
||||
|
||||
profiles = await _eligible_profiles(ctx)
|
||||
results: list[CommandResult] = []
|
||||
@@ -124,7 +295,7 @@ async def process_inbound_message(ctx: CommandContext) -> list[CommandResult]:
|
||||
if profile.reply_required and trigger_message.reply_to_id is None:
|
||||
if (
|
||||
profile.slug == "bp"
|
||||
and bool(getattr(settings, "BP_SUBCOMMANDS_V1", True))
|
||||
and bp_subcommands_enabled()
|
||||
and bp_reply_is_optional_for_trigger(ctx.message_text)
|
||||
):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user