from __future__ import annotations import time import uuid from asgiref.sync import async_to_sync from django.core.management.base import BaseCommand from core.clients.transport import send_message_raw from core.models import CodexPermissionRequest, CodexRun, ExternalSyncEvent, TaskProviderConfig from core.tasks.providers import get_provider from core.util import logs log = logs.get_logger("codex_worker") class Command(BaseCommand): help = "Process queued external sync events for worker-backed providers (codex_cli)." def add_arguments(self, parser): parser.add_argument("--once", action="store_true", default=False) parser.add_argument("--sleep-seconds", type=float, default=2.0) parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--provider", default="codex_cli") def _claim_batch(self, provider: str, batch_size: int) -> list[str]: ids: list[str] = [] rows = list( ExternalSyncEvent.objects.filter( provider=provider, status__in=["pending", "retrying"], ) .order_by("updated_at")[: max(1, batch_size)] .values_list("id", flat=True) ) for row_id in rows: updated = ExternalSyncEvent.objects.filter( id=row_id, provider=provider, status__in=["pending", "retrying"], ).update(status="retrying") if updated: ids.append(str(row_id)) return ids def _run_event(self, event: ExternalSyncEvent) -> None: provider = get_provider(event.provider) if not bool(getattr(provider, "run_in_worker", False)): return cfg = ( TaskProviderConfig.objects.filter( user=event.user, provider=event.provider, enabled=True, ) .order_by("-updated_at") .first() ) if cfg is None: event.status = "failed" event.error = "provider_disabled_or_missing" event.save(update_fields=["status", "error", "updated_at"]) provider_payload = dict((event.payload or {}).get("provider_payload") or {}) run_id = str(provider_payload.get("codex_run_id") or "").strip() if run_id: CodexRun.objects.filter(id=run_id, user=event.user).update( status="failed", error="provider_disabled_or_missing", ) return payload = dict(event.payload or {}) action = str(payload.get("action") or "append_update").strip().lower() provider_payload = dict(payload.get("provider_payload") or payload) run_id = str(provider_payload.get("codex_run_id") or payload.get("codex_run_id") or "").strip() codex_run = None if run_id: codex_run = CodexRun.objects.filter(id=run_id, user=event.user).first() if codex_run is None and event.task_id: codex_run = ( CodexRun.objects.filter( user=event.user, task_id=event.task_id, status__in=["queued", "running", "approved_waiting_resume"], ) .order_by("-updated_at") .first() ) if codex_run is not None: codex_run.status = "running" codex_run.error = "" codex_run.save(update_fields=["status", "error", "updated_at"]) if action == "create": result = provider.create_task(dict(cfg.settings or {}), provider_payload) elif action == "complete": result = provider.mark_complete(dict(cfg.settings or {}), provider_payload) elif action == "link_task": result = provider.link_task(dict(cfg.settings or {}), provider_payload) else: result = provider.append_update(dict(cfg.settings or {}), provider_payload) result_payload = dict(result.payload or {}) requires_approval = bool(result_payload.get("requires_approval")) if requires_approval: approval_key = str(result_payload.get("approval_key") or uuid.uuid4().hex[:12]).strip() permission_request = dict(result_payload.get("permission_request") or {}) summary = str(result_payload.get("summary") or permission_request.get("summary") or "").strip() requested_permissions = permission_request.get("requested_permissions") if not isinstance(requested_permissions, (list, dict)): requested_permissions = permission_request or {} resume_payload = result_payload.get("resume_payload") if not isinstance(resume_payload, dict): resume_payload = {} event.status = "waiting_approval" event.error = "" event.payload = dict(payload, worker_processed=True, result=result_payload) event.save(update_fields=["status", "error", "payload", "updated_at"]) if codex_run is not None: codex_run.status = "waiting_approval" codex_run.result_payload = dict(result_payload) codex_run.error = "" codex_run.save(update_fields=["status", "result_payload", "error", "updated_at"]) CodexPermissionRequest.objects.update_or_create( approval_key=approval_key, defaults={ "user": event.user, "codex_run": codex_run if codex_run is not None else CodexRun.objects.create( user=event.user, task=event.task, derived_task_event=event.task_event, source_service=str(provider_payload.get("source_service") or ""), source_channel=str(provider_payload.get("source_channel") or ""), external_chat_id=str(provider_payload.get("external_chat_id") or ""), status="waiting_approval", request_payload=dict(payload or {}), result_payload=dict(result_payload), error="", ), "external_sync_event": event, "summary": summary, "requested_permissions": requested_permissions if isinstance(requested_permissions, dict) else { "items": list(requested_permissions or []) }, "resume_payload": dict(resume_payload or {}), "status": "pending", "resolved_at": None, "resolved_by_identifier": "", "resolution_note": "", }, ) approver_service = str((cfg.settings or {}).get("approver_service") or "").strip().lower() approver_identifier = str((cfg.settings or {}).get("approver_identifier") or "").strip() requested_text = result_payload.get("permission_request") or result_payload.get("requested_permissions") or {} if approver_service and approver_identifier: try: async_to_sync(send_message_raw)( approver_service, approver_identifier, text=( f"[codex approval] key={approval_key}\\n" f"summary={summary or 'Codex run requires approval'}\\n" f"requested={requested_text}\\n" f"use: .codex approve {approval_key} or .codex deny {approval_key}" ), attachments=[], metadata={"origin_tag": f"codex-approval:{approval_key}"}, ) except Exception: log.exception("failed to notify approver channel for approval_key=%s", approval_key) else: source_service = str(provider_payload.get("source_service") or "").strip().lower() source_channel = str(provider_payload.get("source_channel") or "").strip() if source_service and source_channel: try: async_to_sync(send_message_raw)( source_service, source_channel, text=( "[codex approval] approval is pending but no approver channel is configured. " "Set approver_service and approver_identifier in Codex settings." ), attachments=[], metadata={"origin_tag": "codex-approval-missing-target"}, ) except Exception: log.exception("failed to notify source channel for missing approver target") return event.status = "ok" if result.ok else "failed" event.error = str(result.error or "") event.payload = dict( payload, worker_processed=True, result=result_payload, ) event.save(update_fields=["status", "error", "payload", "updated_at"]) if codex_run is not None: codex_run.status = "ok" if result.ok else "failed" codex_run.error = str(result.error or "") codex_run.result_payload = result_payload codex_run.save(update_fields=["status", "error", "result_payload", "updated_at"]) if result.ok and result.external_key and event.task_id and not str(event.task.external_key or "").strip(): event.task.external_key = str(result.external_key) event.task.save(update_fields=["external_key"]) def handle(self, *args, **options): once = bool(options.get("once")) sleep_seconds = max(0.2, float(options.get("sleep_seconds") or 2.0)) batch_size = max(1, int(options.get("batch_size") or 20)) provider_name = str(options.get("provider") or "codex_cli").strip().lower() log.info( "codex_worker started provider=%s once=%s sleep=%s batch_size=%s", provider_name, once, sleep_seconds, batch_size, ) while True: claimed_ids = self._claim_batch(provider_name, batch_size) if not claimed_ids: if once: log.info("codex_worker exiting: no pending events") return time.sleep(sleep_seconds) continue for row_id in claimed_ids: event = ExternalSyncEvent.objects.filter(id=row_id).select_related("task", "user").first() if event is None: continue try: self._run_event(event) except Exception as exc: log.exception("codex_worker failed processing id=%s", row_id) ExternalSyncEvent.objects.filter(id=row_id).update( status="failed", error=f"worker_exception:{exc}", ) if once: log.info("codex_worker processed %s event(s)", len(claimed_ids)) return