From bca4d6898f6751d58f6babb969bbb673dd7afaf4 Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Sat, 7 Mar 2026 20:52:13 +0000 Subject: [PATCH] Increase security and reformat --- app/test_settings.py | 1 + app/urls.py | 10 +- core/assist/repeat_answer.py | 21 +- core/clients/__init__.py | 3 +- core/clients/signal.py | 122 +- core/clients/transport.py | 6 +- core/clients/whatsapp.py | 95 +- core/clients/xmpp.py | 616 ++++++- core/commands/delivery.py | 14 +- core/commands/engine.py | 26 +- core/commands/handlers/bp.py | 137 +- core/commands/handlers/claude.py | 183 ++- core/commands/handlers/codex.py | 192 ++- core/commands/policies.py | 11 +- core/context_processors.py | 10 +- core/db/sql.py | 1 - core/events/ledger.py | 12 +- core/events/projection.py | 59 +- core/forms.py | 1 + core/gateway/commands.py | 10 +- .../commands/backfill_contact_availability.py | 24 +- core/management/commands/codex_worker.py | 127 +- .../commands/event_projection_shadow.py | 4 +- .../recalculate_contact_availability.py | 26 +- .../reconcile_workspace_metric_history.py | 64 +- core/management/commands/task_sync_worker.py | 4 +- core/mcp/server.py | 4 +- core/mcp/tools.py | 36 +- core/memory/__init__.py | 2 +- core/memory/pipeline.py | 12 +- core/memory/retrieval.py | 20 +- core/memory/search_backend.py | 22 +- core/messaging/history.py | 11 +- core/messaging/reply_sync.py | 12 +- ...suggestionevent_chattasksource_and_more.py | 3 +- ...ontactavailability_and_externalchatlink.py | 2 +- ...odexrun_codexpermissionrequest_and_more.py | 2 +- ...35_conversationevent_adapterhealthevent.py | 2 +- ...mmandsecuritypolicy_gatewaycommandevent.py | 2 +- .../0042_userxmppomemotrustedkey_and_more.py | 39 + ...ngs_encrypt_contact_messages_with_omemo.py | 18 + core/models.py | 109 +- core/modules/router.py | 14 +- core/presence/engine.py | 19 +- core/presence/query.py | 5 +- core/security/__init__.py | 1 - core/security/attachments.py | 4 +- core/security/command_policy.py | 17 +- core/tasks/chat_defaults.py | 4 +- core/tasks/codex_approval.py | 8 +- core/tasks/codex_support.py | 4 +- core/tasks/engine.py | 156 +- core/tasks/providers/claude_cli.py | 36 +- core/tasks/providers/codex_cli.py | 26 +- core/tasks/providers/mock.py | 26 +- core/templates/base.html | 83 +- .../pages/accessibility-settings.html | 40 +- core/templates/pages/ai-execution-log.html | 542 +++---- .../pages/availability-settings.html | 124 +- .../templates/pages/business-plan-editor.html | 2 +- core/templates/pages/business-plan-inbox.html | 99 ++ core/templates/pages/codex-settings.html | 264 +-- core/templates/pages/command-routing.html | 5 +- core/templates/pages/security.html | 1409 +++++++++-------- core/templates/pages/settings-category.html | 34 +- core/templates/pages/tasks-detail.html | 246 +-- core/templates/pages/tasks-epic.html | 38 +- core/templates/pages/tasks-group.html | 202 +-- core/templates/pages/tasks-hub.html | 372 ++--- core/templates/pages/tasks-project.html | 184 +-- core/templates/pages/tasks-settings.html | 1276 +++++++-------- .../partials/ai-workspace-ai-result.html | 6 +- .../ai-workspace-mitigation-panel.html | 64 +- .../partials/ai-workspace-person-widget.html | 238 +-- core/templates/partials/compose-panel.html | 46 +- .../compose-workspace-contacts-widget.html | 78 +- .../partials/osint/search-panel.html | 16 +- core/templates/partials/queue-list.html | 6 +- .../partials/settings-hierarchy-nav.html | 6 +- .../templates/partials/signal-chats-list.html | 274 ++-- .../partials/whatsapp-account-add.html | 12 +- .../templates/two_factor/_wizard_actions.html | 6 +- .../two_factor/core/backup_tokens.html | 2 +- core/templates/two_factor/core/login.html | 4 +- .../two_factor/core/otp_required.html | 2 +- .../two_factor/core/setup_complete.html | 6 +- .../templates/two_factor/profile/disable.html | 2 +- .../templates/two_factor/profile/profile.html | 8 +- core/tests/test_ai_run_log.py | 3 +- .../test_backfill_contact_availability.py | 14 +- core/tests/test_bp_fallback.py | 4 +- core/tests/test_bp_subcommands.py | 40 +- core/tests/test_claude_cli_provider.py | 21 +- core/tests/test_claude_commands_phase1.py | 16 +- core/tests/test_codex_cli_provider.py | 21 +- core/tests/test_codex_commands_phase1.py | 39 +- core/tests/test_codex_worker_phase1.py | 30 +- core/tests/test_command_security_policy.py | 8 +- core/tests/test_command_variant_policy.py | 18 +- core/tests/test_compose_react.py | 4 +- core/tests/test_cross_platform_messaging.py | 17 +- core/tests/test_event_projection_shadow.py | 6 +- core/tests/test_mcp_tools.py | 24 +- core/tests/test_memory_pipeline_commands.py | 12 +- core/tests/test_phase1_command_reply.py | 12 +- core/tests/test_presence_engine.py | 16 +- core/tests/test_reaction_normalization.py | 15 +- ...test_reconcile_workspace_metric_history.py | 4 +- core/tests/test_signal_relink.py | 2 +- core/tests/test_signal_reply_send.py | 27 +- core/tests/test_task_engine_plan09.py | 45 +- core/tests/test_tasks_pages_management.py | 24 +- core/tests/test_tasks_settings_and_toggle.py | 115 +- core/tests/test_transport_capabilities.py | 11 +- .../test_whatsapp_reaction_and_recalc.py | 28 +- core/tests/test_xmpp_approval_commands.py | 18 +- core/tests/test_xmpp_integration.py | 182 ++- core/tests/test_xmpp_omemo_support.py | 414 ++++- core/translation/engine.py | 8 +- core/util/django_settings_export.py | 1 + core/views/ais.py | 2 +- core/views/automation.py | 113 +- core/views/availability.py | 10 +- core/views/compose.py | 162 +- core/views/groups.py | 2 +- core/views/identifiers.py | 2 +- core/views/manipulations.py | 2 +- core/views/messages.py | 2 +- core/views/notifications.py | 2 +- core/views/osint.py | 89 +- core/views/people.py | 2 +- core/views/personas.py | 2 +- core/views/queues.py | 4 +- core/views/sessions.py | 2 +- core/views/signal.py | 16 +- core/views/system.py | 368 ++++- core/views/tasks.py | 483 ++++-- core/views/whatsapp.py | 12 +- core/views/workspace.py | 28 +- core/workspace/sampling.py | 14 +- docker/watch_and_restart.py | 1 + docker/watch_simple.py | 1 + manage.py | 1 + scripts/quadlet/render_units.py | 85 +- 144 files changed, 6735 insertions(+), 3960 deletions(-) create mode 100644 core/migrations/0042_userxmppomemotrustedkey_and_more.py create mode 100644 core/migrations/0043_userxmppsecuritysettings_encrypt_contact_messages_with_omemo.py create mode 100644 core/templates/pages/business-plan-inbox.html diff --git a/app/test_settings.py b/app/test_settings.py index 6401bfa..191fd80 100644 --- a/app/test_settings.py +++ b/app/test_settings.py @@ -1,4 +1,5 @@ """Test-only settings overrides — used via DJANGO_SETTINGS_MODULE=app.test_settings.""" + from app.settings import * # noqa: F401, F403 CACHES = { diff --git a/app/urls.py b/app/urls.py index 452f8f1..f770299 100644 --- a/app/urls.py +++ b/app/urls.py @@ -13,12 +13,13 @@ Including another URLconf 1. Import the include() function: from django.urls import include, path 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ + from django.conf import settings from django.conf.urls.static import static from django.contrib import admin from django.contrib.auth.views import LogoutView -from django.views.generic import RedirectView from django.urls import include, path +from django.views.generic import RedirectView from two_factor.urls import urlpatterns as tf_urls from two_factor.views.profile import ProfileView @@ -41,8 +42,8 @@ from core.views import ( queues, sessions, signal, - tasks, system, + tasks, whatsapp, workspace, ) @@ -188,6 +189,11 @@ urlpatterns = [ automation.TranslationSettings.as_view(), name="translation_settings", ), + path( + "settings/business-plans/", + automation.BusinessPlanInbox.as_view(), + name="business_plan_inbox", + ), path( "settings/business-plan//", automation.BusinessPlanEditor.as_view(), diff --git a/core/assist/repeat_answer.py b/core/assist/repeat_answer.py index 70a3ebd..ffef7a1 100644 --- a/core/assist/repeat_answer.py +++ b/core/assist/repeat_answer.py @@ -38,14 +38,31 @@ def _is_question(text: str) -> bool: if not body: return False low = body.lower() - return body.endswith("?") or low.startswith(("what", "why", "how", "when", "where", "who", "can ", "do ", "did ", "is ", "are ")) + return body.endswith("?") or low.startswith( + ( + "what", + "why", + "how", + "when", + "where", + "who", + "can ", + "do ", + "did ", + "is ", + "are ", + ) + ) def _is_group_channel(message: Message) -> bool: channel = str(getattr(message, "source_chat_id", "") or "").strip().lower() if channel.endswith("@g.us"): return True - return str(getattr(message, "source_service", "") or "").strip().lower() == "xmpp" and "conference." in channel + return ( + str(getattr(message, "source_service", "") or "").strip().lower() == "xmpp" + and "conference." in channel + ) async def learn_from_message(message: Message) -> None: diff --git a/core/clients/__init__.py b/core/clients/__init__.py index d04b9f5..06c27f9 100644 --- a/core/clients/__init__.py +++ b/core/clients/__init__.py @@ -12,8 +12,7 @@ class ClientBase(ABC): self.log.info(f"{self.service.capitalize()} client initialising...") @abstractmethod - def start(self): - ... + def start(self): ... # @abstractmethod # async def send_message(self, recipient, message): diff --git a/core/clients/signal.py b/core/clients/signal.py index d1361db..71c0746 100644 --- a/core/clients/signal.py +++ b/core/clients/signal.py @@ -12,7 +12,15 @@ from django.urls import reverse from signalbot import Command, Context, SignalBot from core.clients import ClientBase, signalapi, transport -from core.messaging import ai, history, media_bridge, natural, replies, reply_sync, utils +from core.messaging import ( + ai, + history, + media_bridge, + natural, + replies, + reply_sync, + utils, +) from core.models import ( Chat, Manipulation, @@ -402,7 +410,9 @@ class NewSignalBot(SignalBot): seen_user_ids.add(pi.user_id) users.append(pi.user) if not users: - self.log.debug("[Signal] _upsert_groups: no PersonIdentifiers found — skipping") + self.log.debug( + "[Signal] _upsert_groups: no PersonIdentifiers found — skipping" + ) return for user in users: @@ -423,7 +433,9 @@ class NewSignalBot(SignalBot): }, ) - self.log.info("[Signal] upserted %d groups for %d users", len(groups), len(users)) + self.log.info( + "[Signal] upserted %d groups for %d users", len(groups), len(users) + ) async def _detect_groups(self): await super()._detect_groups() @@ -505,7 +517,9 @@ class HandleMessage(Command): source_uuid_norm and dest_norm and source_uuid_norm == dest_norm ) - is_from_bot = bool(bot_uuid and source_uuid_norm and source_uuid_norm == bot_uuid) + is_from_bot = bool( + bot_uuid and source_uuid_norm and source_uuid_norm == bot_uuid + ) if (not is_from_bot) and bot_phone_digits and source_phone_digits: is_from_bot = source_phone_digits == bot_phone_digits # Inbound deliveries usually do not have destination fields populated. @@ -596,9 +610,9 @@ class HandleMessage(Command): candidate_digits = {value for value in candidate_digits if value} if candidate_digits: signal_rows = await sync_to_async(list)( - PersonIdentifier.objects.filter(service=self.service).select_related( - "user" - ) + PersonIdentifier.objects.filter( + service=self.service + ).select_related("user") ) matched = [] for row in signal_rows: @@ -718,13 +732,13 @@ class HandleMessage(Command): target_ts=int(reaction_payload.get("target_ts") or 0), emoji=str(reaction_payload.get("emoji") or ""), source_service="signal", - actor=( - effective_source_uuid or effective_source_number or "" - ), + actor=(effective_source_uuid or effective_source_number or ""), target_author=str( (reaction_payload.get("raw") or {}).get("targetAuthorUuid") or (reaction_payload.get("raw") or {}).get("targetAuthor") - or (reaction_payload.get("raw") or {}).get("targetAuthorNumber") + or (reaction_payload.get("raw") or {}).get( + "targetAuthorNumber" + ) or "" ), remove=bool(reaction_payload.get("remove")), @@ -741,9 +755,7 @@ class HandleMessage(Command): remove=bool(reaction_payload.get("remove")), upstream_message_id="", upstream_ts=int(reaction_payload.get("target_ts") or 0), - actor=( - effective_source_uuid or effective_source_number or "" - ), + actor=(effective_source_uuid or effective_source_number or ""), payload=reaction_payload.get("raw") or {}, ) except Exception as exc: @@ -840,9 +852,7 @@ class HandleMessage(Command): source_ref={ "upstream_message_id": "", "upstream_author": str( - effective_source_uuid - or effective_source_number - or "" + effective_source_uuid or effective_source_number or "" ), "upstream_ts": int(ts or 0), }, @@ -1134,7 +1144,9 @@ class SignalClient(ClientBase): if int(message_row.delivered_ts or 0) <= 0: message_row.delivered_ts = int(result) update_fields.append("delivered_ts") - if str(message_row.source_message_id or "").strip() != str(result): + if str(message_row.source_message_id or "").strip() != str( + result + ): message_row.source_message_id = str(result) update_fields.append("source_message_id") if update_fields: @@ -1146,9 +1158,11 @@ class SignalClient(ClientBase): command_id, { "ok": True, - "timestamp": int(result) - if isinstance(result, int) - else int(time.time() * 1000), + "timestamp": ( + int(result) + if isinstance(result, int) + else int(time.time() * 1000) + ), }, ) except Exception as exc: @@ -1248,7 +1262,9 @@ class SignalClient(ClientBase): if _digits_only(getattr(row, "identifier", "")) in candidate_digits ] - async def _auto_link_single_user_signal_identifier(self, source_uuid: str, source_number: str): + async def _auto_link_single_user_signal_identifier( + self, source_uuid: str, source_number: str + ): owner_rows = await sync_to_async(list)( PersonIdentifier.objects.filter(service=self.service) .select_related("user") @@ -1292,7 +1308,9 @@ class SignalClient(ClientBase): payload = json.loads(raw_message or "{}") except Exception: return - exception_payload = payload.get("exception") if isinstance(payload, dict) else None + exception_payload = ( + payload.get("exception") if isinstance(payload, dict) else None + ) if isinstance(exception_payload, dict): err_type = str(exception_payload.get("type") or "").strip() err_msg = str(exception_payload.get("message") or "").strip() @@ -1322,7 +1340,9 @@ class SignalClient(ClientBase): (envelope.get("timestamp") if isinstance(envelope, dict) else 0) or int(time.time() * 1000) ), - last_inbound_exception_account=str(payload.get("account") or "").strip(), + last_inbound_exception_account=str( + payload.get("account") or "" + ).strip(), last_inbound_exception_source_uuid=envelope_source_uuid, last_inbound_exception_source_number=envelope_source_number, last_inbound_exception_envelope_ts=envelope_ts, @@ -1346,7 +1366,11 @@ class SignalClient(ClientBase): raw_text = sync_sent_message.get("message") if isinstance(raw_text, dict): text = _extract_signal_text( - {"envelope": {"syncMessage": {"sentMessage": {"message": raw_text}}}}, + { + "envelope": { + "syncMessage": {"sentMessage": {"message": raw_text}} + } + }, str( raw_text.get("message") or raw_text.get("text") @@ -1396,9 +1420,15 @@ class SignalClient(ClientBase): source_service="signal", actor=(source_uuid or source_number or ""), target_author=str( - (reaction_payload.get("raw") or {}).get("targetAuthorUuid") - or (reaction_payload.get("raw") or {}).get("targetAuthor") - or (reaction_payload.get("raw") or {}).get("targetAuthorNumber") + (reaction_payload.get("raw") or {}).get( + "targetAuthorUuid" + ) + or (reaction_payload.get("raw") or {}).get( + "targetAuthor" + ) + or (reaction_payload.get("raw") or {}).get( + "targetAuthorNumber" + ) or "" ), remove=bool(reaction_payload.get("remove")), @@ -1505,7 +1535,9 @@ class SignalClient(ClientBase): source_chat_id = destination_number or destination_uuid or sender_key reply_ref = reply_sync.extract_reply_ref(self.service, payload) for identifier in identifiers: - session = await history.get_chat_session(identifier.user, identifier) + session = await history.get_chat_session( + identifier.user, identifier + ) reply_target = await reply_sync.resolve_reply_target( identifier.user, session, @@ -1552,13 +1584,19 @@ class SignalClient(ClientBase): if not isinstance(data_message, dict): return - source_uuid = str(envelope.get("sourceUuid") or envelope.get("source") or "").strip() + source_uuid = str( + envelope.get("sourceUuid") or envelope.get("source") or "" + ).strip() source_number = str(envelope.get("sourceNumber") or "").strip() bot_uuid = str(getattr(self.client, "bot_uuid", "") or "").strip() bot_phone = str(getattr(self.client, "phone_number", "") or "").strip() if source_uuid and bot_uuid and source_uuid == bot_uuid: return - if source_number and bot_phone and _digits_only(source_number) == _digits_only(bot_phone): + if ( + source_number + and bot_phone + and _digits_only(source_number) == _digits_only(bot_phone) + ): return identifiers = await self._resolve_signal_identifiers(source_uuid, source_number) @@ -1610,14 +1648,18 @@ class SignalClient(ClientBase): target_author=str( (reaction_payload.get("raw") or {}).get("targetAuthorUuid") or (reaction_payload.get("raw") or {}).get("targetAuthor") - or (reaction_payload.get("raw") or {}).get("targetAuthorNumber") + or (reaction_payload.get("raw") or {}).get( + "targetAuthorNumber" + ) or "" ), remove=bool(reaction_payload.get("remove")), payload=reaction_payload.get("raw") or {}, ) except Exception as exc: - self.log.warning("signal raw reaction history apply failed: %s", exc) + self.log.warning( + "signal raw reaction history apply failed: %s", exc + ) try: await self.ur.xmpp.client.apply_external_reaction( identifier.user, @@ -1631,7 +1673,9 @@ class SignalClient(ClientBase): payload=reaction_payload.get("raw") or {}, ) except Exception as exc: - self.log.warning("signal raw reaction relay to XMPP failed: %s", exc) + self.log.warning( + "signal raw reaction relay to XMPP failed: %s", exc + ) transport.update_runtime_state( self.service, last_inbound_ok_ts=int(time.time() * 1000), @@ -1683,7 +1727,9 @@ class SignalClient(ClientBase): ) return - text = _extract_signal_text(payload, str(data_message.get("message") or "").strip()) + text = _extract_signal_text( + payload, str(data_message.get("message") or "").strip() + ) if not text: return @@ -1702,7 +1748,11 @@ class SignalClient(ClientBase): or envelope.get("timestamp") or ts ).strip() - sender_key = source_uuid or source_number or (identifiers[0].identifier if identifiers else "") + sender_key = ( + source_uuid + or source_number + or (identifiers[0].identifier if identifiers else "") + ) source_chat_id = source_number or source_uuid or sender_key reply_ref = reply_sync.extract_reply_ref(self.service, payload) diff --git a/core/clients/transport.py b/core/clients/transport.py index 6eb4e6d..f9a8212 100644 --- a/core/clients/transport.py +++ b/core/clients/transport.py @@ -927,7 +927,11 @@ async def send_reaction( service_key = _service_key(service) if _capability_checks_enabled() and not supports(service_key, "reactions"): reason = unsupported_reason(service_key, "reactions") - log.warning("capability-check failed service=%s feature=reactions: %s", service_key, reason) + log.warning( + "capability-check failed service=%s feature=reactions: %s", + service_key, + reason, + ) return False if not str(emoji or "").strip() and not remove: return False diff --git a/core/clients/whatsapp.py b/core/clients/whatsapp.py index c4287cb..be80b95 100644 --- a/core/clients/whatsapp.py +++ b/core/clients/whatsapp.py @@ -173,9 +173,7 @@ class WhatsAppClient(ClientBase): if db_dir: os.makedirs(db_dir, exist_ok=True) if db_dir and not os.access(db_dir, os.W_OK): - raise PermissionError( - f"session db directory is not writable: {db_dir}" - ) + raise PermissionError(f"session db directory is not writable: {db_dir}") except Exception as exc: self._publish_state( connected=False, @@ -772,9 +770,11 @@ class WhatsAppClient(ClientBase): command_id, { "ok": True, - "timestamp": int(result) - if isinstance(result, int) - else int(time.time() * 1000), + "timestamp": ( + int(result) + if isinstance(result, int) + else int(time.time() * 1000) + ), }, ) self.log.debug( @@ -1910,9 +1910,7 @@ class WhatsAppClient(ClientBase): jid_value = self._jid_to_identifier( self._pluck(group, "JID") or self._pluck(group, "jid") ) - identifier = ( - jid_value.split("@", 1)[0].strip() if jid_value else "" - ) + identifier = jid_value.split("@", 1)[0].strip() if jid_value else "" if not identifier: continue name = ( @@ -2362,12 +2360,22 @@ class WhatsAppClient(ClientBase): node = ( self._pluck(message_obj, "reactionMessage") or self._pluck(message_obj, "reaction_message") - or self._pluck(message_obj, "ephemeralMessage", "message", "reactionMessage") - or self._pluck(message_obj, "ephemeral_message", "message", "reaction_message") + or self._pluck( + message_obj, "ephemeralMessage", "message", "reactionMessage" + ) + or self._pluck( + message_obj, "ephemeral_message", "message", "reaction_message" + ) or self._pluck(message_obj, "viewOnceMessage", "message", "reactionMessage") - or self._pluck(message_obj, "view_once_message", "message", "reaction_message") - or self._pluck(message_obj, "viewOnceMessageV2", "message", "reactionMessage") - or self._pluck(message_obj, "view_once_message_v2", "message", "reaction_message") + or self._pluck( + message_obj, "view_once_message", "message", "reaction_message" + ) + or self._pluck( + message_obj, "viewOnceMessageV2", "message", "reactionMessage" + ) + or self._pluck( + message_obj, "view_once_message_v2", "message", "reaction_message" + ) or self._pluck( message_obj, "viewOnceMessageV2Extension", @@ -2410,7 +2418,9 @@ class WhatsAppClient(ClientBase): explicit_remove = self._pluck(node, "remove") or self._pluck(node, "isRemove") if explicit_remove is None: explicit_remove = self._pluck(node, "is_remove") - remove = bool(explicit_remove) if explicit_remove is not None else bool(not emoji) + remove = ( + bool(explicit_remove) if explicit_remove is not None else bool(not emoji) + ) if not target_msg_id: return None return { @@ -2418,7 +2428,11 @@ class WhatsAppClient(ClientBase): "target_message_id": target_msg_id, "remove": remove, "target_ts": int(target_ts or 0), - "raw": self._proto_to_dict(node) or dict(node or {}) if isinstance(node, dict) else {}, + "raw": ( + self._proto_to_dict(node) or dict(node or {}) + if isinstance(node, dict) + else {} + ), } async def _download_event_media(self, event): @@ -2760,7 +2774,9 @@ class WhatsAppClient(ClientBase): or self._pluck(msg_obj, "MessageContextInfo") or {}, "message": { - "extendedTextMessage": self._pluck(msg_obj, "extendedTextMessage") + "extendedTextMessage": self._pluck( + msg_obj, "extendedTextMessage" + ) or self._pluck(msg_obj, "ExtendedTextMessage") or {}, "imageMessage": self._pluck(msg_obj, "imageMessage") or {}, @@ -2768,9 +2784,9 @@ class WhatsAppClient(ClientBase): "videoMessage": self._pluck(msg_obj, "videoMessage") or {}, "VideoMessage": self._pluck(msg_obj, "VideoMessage") or {}, "documentMessage": self._pluck(msg_obj, "documentMessage") - or {}, + or {}, "DocumentMessage": self._pluck(msg_obj, "DocumentMessage") - or {}, + or {}, "ephemeralMessage": self._pluck(msg_obj, "ephemeralMessage") or {}, "EphemeralMessage": self._pluck(msg_obj, "EphemeralMessage") @@ -2814,7 +2830,9 @@ class WhatsAppClient(ClientBase): or {}, "viewOnceMessage": self._pluck(msg_obj, "viewOnceMessage") or {}, - "viewOnceMessageV2": self._pluck(msg_obj, "viewOnceMessageV2") + "viewOnceMessageV2": self._pluck( + msg_obj, "viewOnceMessageV2" + ) or {}, "viewOnceMessageV2Extension": self._pluck( msg_obj, "viewOnceMessageV2Extension" @@ -2840,12 +2858,12 @@ class WhatsAppClient(ClientBase): reply_sync.extract_origin_tag(payload), ) if self._chat_matches_reply_debug(chat): - info_obj = self._proto_to_dict(self._pluck(event_obj, "Info")) or self._pluck( - event_obj, "Info" - ) - raw_obj = self._proto_to_dict(self._pluck(event_obj, "Raw")) or self._pluck( - event_obj, "Raw" - ) + info_obj = self._proto_to_dict( + self._pluck(event_obj, "Info") + ) or self._pluck(event_obj, "Info") + raw_obj = self._proto_to_dict( + self._pluck(event_obj, "Raw") + ) or self._pluck(event_obj, "Raw") message_meta["wa_reply_debug"] = { "reply_ref": reply_ref, "reply_target_id": str(getattr(reply_target, "id", "") or ""), @@ -3087,9 +3105,11 @@ class WhatsAppClient(ClientBase): ) matched = False for candidate in candidates: - candidate_local = str(self._jid_to_identifier(candidate) or "").split( - "@", 1 - )[0].strip() + candidate_local = ( + str(self._jid_to_identifier(candidate) or "") + .split("@", 1)[0] + .strip() + ) if candidate_local and candidate_local == local: matched = True break @@ -3124,7 +3144,12 @@ class WhatsAppClient(ClientBase): # WhatsApp group ids are numeric and usually very long (commonly start # with 120...). Treat those as groups when no explicit mapping exists. digits = re.sub(r"[^0-9]", "", local) - if digits and digits == local and len(digits) >= 15 and digits.startswith("120"): + if ( + digits + and digits == local + and len(digits) >= 15 + and digits.startswith("120") + ): return f"{digits}@g.us" return "" @@ -3264,7 +3289,9 @@ class WhatsAppClient(ClientBase): person_identifier = await sync_to_async( lambda: ( Message.objects.filter(id=legacy_message_id) - .select_related("session__identifier__user", "session__identifier__person") + .select_related( + "session__identifier__user", "session__identifier__person" + ) .first() ) )() @@ -3274,7 +3301,9 @@ class WhatsAppClient(ClientBase): ) if ( person_identifier is not None - and str(getattr(person_identifier, "service", "") or "").strip().lower() + and str(getattr(person_identifier, "service", "") or "") + .strip() + .lower() != "whatsapp" ): person_identifier = None @@ -3418,6 +3447,8 @@ class WhatsAppClient(ClientBase): from neonize.proto.waE2E.WAWebProtobufsE2E_pb2 import ( ContextInfo, ExtendedTextMessage, + ) + from neonize.proto.waE2E.WAWebProtobufsE2E_pb2 import ( Message as WAProtoMessage, ) diff --git a/core/clients/xmpp.py b/core/clients/xmpp.py index 9a87176..3d6cbd0 100644 --- a/core/clients/xmpp.py +++ b/core/clients/xmpp.py @@ -1,7 +1,7 @@ import asyncio import base64 import json -import mimetypes +import logging import os import re import time @@ -17,6 +17,8 @@ from slixmpp.componentxmpp import ComponentXMPP from slixmpp.plugins.xep_0085.stanza import Active, Composing, Gone, Inactive, Paused from slixmpp.stanza import Message from slixmpp.xmlstream import register_stanza_plugin +from slixmpp.xmlstream.handler import Callback +from slixmpp.xmlstream.matcher import StanzaPath from slixmpp.xmlstream.stanzabase import ET from core.clients import ClientBase, transport @@ -55,6 +57,9 @@ EMOJI_ONLY_PATTERN = re.compile( r"^[\U0001F300-\U0001FAFF\u2600-\u27BF\uFE0F\u200D\u2640-\u2642\u2764]+$" ) TOTP_BASE32_SECRET_RE = re.compile(r"^[A-Z2-7]{16,}$") +PUBSUB_NS = "http://jabber.org/protocol/pubsub" +OMEMO_OLD_NS = "eu.siacs.conversations.axolotl" +OMEMO_OLD_DEVICELIST_NODE = "eu.siacs.conversations.axolotl.devicelist" def _clean_url(value): @@ -147,6 +152,7 @@ def _parse_greentext_reaction(body_text): def _omemo_plugin_available() -> bool: try: import importlib + return importlib.util.find_spec("slixmpp_omemo") is not None except Exception: return False @@ -166,15 +172,30 @@ def _extract_sender_omemo_client_key(stanza) -> dict: return {"status": "no_omemo"} +def _format_omemo_identity_fingerprint(identity_key) -> str: + if isinstance(identity_key, (bytes, bytearray)): + key_bytes = bytes(identity_key) + else: + return "" + try: + from omemo.session_manager import SessionManager as _OmemoSessionManager + + return " ".join(_OmemoSessionManager.format_identity_key(key_bytes)).upper() + except Exception: + return ":".join(f"{b:02X}" for b in key_bytes) + + # --------------------------------------------------------------------------- # OMEMO storage + plugin implementation # --------------------------------------------------------------------------- try: - from omemo.storage import Just, Maybe, Nothing, Storage as _OmemoStorageBase + from omemo.storage import Just, Maybe, Nothing + from omemo.storage import Storage as _OmemoStorageBase + from slixmpp.plugins.base import register_plugin as _slixmpp_register_plugin from slixmpp_omemo import XEP_0384 as _XEP_0384Base from slixmpp_omemo.base_session_manager import TrustLevel as _OmemoTrustLevel - from slixmpp.plugins.base import register_plugin as _slixmpp_register_plugin + _OMEMO_AVAILABLE = True except ImportError: _OMEMO_AVAILABLE = False @@ -185,6 +206,7 @@ except ImportError: if _OMEMO_AVAILABLE: + class _OmemoStorage(_OmemoStorageBase): """JSON-file-backed OMEMO key storage.""" @@ -224,7 +246,14 @@ if _OMEMO_AVAILABLE: name = "xep_0384" description = "OMEMO Encryption (GIA gateway)" - dependencies = {"xep_0004", "xep_0030", "xep_0060", "xep_0163", "xep_0280", "xep_0334"} + dependencies = { + "xep_0004", + "xep_0030", + "xep_0060", + "xep_0163", + "xep_0280", + "xep_0334", + } default_config = { "fallback_message": "This message is OMEMO encrypted.", "data_dir": "", @@ -248,6 +277,7 @@ if _OMEMO_AVAILABLE: async def _devices_blindly_trusted(self, blindly_trusted, identifier): import logging + logging.getLogger(__name__).info( "OMEMO: blindly trusted %d new device(s)", len(blindly_trusted) ) @@ -255,6 +285,7 @@ if _OMEMO_AVAILABLE: async def _prompt_manual_trust(self, manually_trusted, identifier): """Auto-trust all undecided devices (gateway mode).""" import logging + log = logging.getLogger(__name__) log.info( "OMEMO: auto-trusting %d undecided device(s) (gateway mode)", @@ -270,11 +301,12 @@ if _OMEMO_AVAILABLE: _OmemoTrustLevel.BLINDLY_TRUSTED.value, ) except Exception as exc: - log.warning("OMEMO set_trust failed for %s: %s", device.bare_jid, exc) + log.warning( + "OMEMO set_trust failed for %s: %s", device.bare_jid, exc + ) class XMPPComponent(ComponentXMPP): - """ A simple Slixmpp component that echoes messages. """ @@ -289,6 +321,7 @@ class XMPPComponent(ComponentXMPP): self._session_live = False self.log = logs.get_logger("XMPP") + logging.getLogger("slixmpp_omemo").setLevel(logging.DEBUG) super().__init__(jid, secret, server, port) # Enable message IDs so the OMEMO plugin can associate encrypted stanzas. @@ -305,6 +338,20 @@ class XMPPComponent(ComponentXMPP): self.add_event_handler("session_start", self.session_start) self.add_event_handler("disconnected", self.on_disconnected) self.add_event_handler("message", self.message) + self.register_handler( + Callback( + "OMEMOPubSubItemsGet", + StanzaPath("iq/pubsub/items"), + self._handle_pubsub_iq_get, + ) + ) + self.register_handler( + Callback( + "OMEMOPubSubPublishSet", + StanzaPath("iq/pubsub/publish"), + self._handle_pubsub_iq_set, + ) + ) # Presence event handlers self.add_event_handler("presence_available", self.on_presence_available) @@ -327,6 +374,137 @@ class XMPPComponent(ComponentXMPP): self.add_event_handler("chatstate_paused", self.on_chatstate_paused) self.add_event_handler("chatstate_inactive", self.on_chatstate_inactive) self.add_event_handler("chatstate_gone", self.on_chatstate_gone) + self._omemo_component_pubsub = self._seed_component_omemo_pubsub() + + @staticmethod + def _clone_xml(node): + return ET.fromstring(ET.tostring(node, encoding="unicode")) + + def _seed_component_omemo_pubsub(self): + return {} + + def _sync_component_omemo_pubsub(self, device_id, identity_key): + device_id = str(device_id or "").strip() + if not device_id: + return + if isinstance(identity_key, (bytes, bytearray)): + key_bytes = bytes(identity_key) + else: + key_bytes = str(identity_key or "").encode() + identity_b64 = base64.b64encode(key_bytes).decode() + + devicelist = ET.Element(f"{{{OMEMO_OLD_NS}}}list") + dev = ET.SubElement(devicelist, f"{{{OMEMO_OLD_NS}}}device") + dev.set("id", device_id) + + bundle = ET.Element(f"{{{OMEMO_OLD_NS}}}bundle") + ik = ET.SubElement(bundle, f"{{{OMEMO_OLD_NS}}}identityKey") + ik.text = identity_b64 + + self._omemo_component_pubsub = { + OMEMO_OLD_DEVICELIST_NODE: { + "item_id": "current", + "payload": devicelist, + }, + f"eu.siacs.conversations.axolotl.bundles:{device_id}": { + "item_id": "current", + "payload": bundle, + }, + } + + async def _reply_pubsub_items(self, iq, node_name): + reply = iq.reply() + pubsub = ET.Element(f"{{{PUBSUB_NS}}}pubsub") + items = ET.SubElement(pubsub, f"{{{PUBSUB_NS}}}items") + items.set("node", node_name) + stored = self._omemo_component_pubsub.get(node_name) + if stored is None and node_name.startswith( + "eu.siacs.conversations.axolotl.bundles:" + ): + bundle_nodes = [ + key + for key in self._omemo_component_pubsub + if key.startswith("eu.siacs.conversations.axolotl.bundles:") + ] + if len(bundle_nodes) == 1: + stored = self._omemo_component_pubsub.get(bundle_nodes[0]) + if stored: + item = ET.SubElement(items, f"{{{PUBSUB_NS}}}item") + item.set("id", str(stored.get("item_id") or "current")) + payload = stored.get("payload") + if payload is not None: + item.append(self._clone_xml(payload)) + reply.append(pubsub) + await reply.send() + + async def on_iq_get(self, iq): + try: + iq_to = str(iq["to"] or "").split("/", 1)[0].strip().lower() + except Exception: + iq_to = "" + own_jid = str(getattr(self.boundjid, "bare", "") or "").strip().lower() + if iq_to and own_jid and iq_to != own_jid: + return + items = iq.xml.find(f".//{{{PUBSUB_NS}}}items") + if items is None: + return + node_name = str(items.attrib.get("node") or "").strip() + if not node_name: + return + self.log.debug( + "OMEMO IQ-GET pubsub items node=%s from=%s to=%s", + node_name, + iq.get("from"), + iq.get("to"), + ) + if ( + node_name.startswith("eu.siacs.conversations.axolotl.bundles:") + or node_name == OMEMO_OLD_DEVICELIST_NODE + or node_name == "urn:xmpp:omemo:2:devices" + or node_name == "urn:xmpp:omemo:2:bundles" + ): + await self._reply_pubsub_items(iq, node_name) + + async def on_iq_set(self, iq): + try: + iq_to = str(iq["to"] or "").split("/", 1)[0].strip().lower() + except Exception: + iq_to = "" + own_jid = str(getattr(self.boundjid, "bare", "") or "").strip().lower() + if iq_to and own_jid and iq_to != own_jid: + return + publish = iq.xml.find(f".//{{{PUBSUB_NS}}}publish") + if publish is None: + return + node_name = str(publish.attrib.get("node") or "").strip() + if not node_name: + return + self.log.debug( + "OMEMO IQ-SET pubsub publish node=%s from=%s to=%s", + node_name, + iq.get("from"), + iq.get("to"), + ) + item = publish.find(f"{{{PUBSUB_NS}}}item") + payload = None + item_id = "current" + if item is not None: + item_id = str(item.attrib.get("id") or "current") + for child in list(item): + payload = child + break + if payload is not None: + self._omemo_component_pubsub[node_name] = { + "item_id": item_id, + "payload": self._clone_xml(payload), + } + await iq.reply().send() + + def _handle_pubsub_iq_get(self, iq): + asyncio.create_task(self.on_iq_get(iq)) + + def _handle_pubsub_iq_set(self, iq): + asyncio.create_task(self.on_iq_set(iq)) def _user_xmpp_domain(self): domain = str(getattr(settings, "XMPP_USER_DOMAIN", "") or "").strip() @@ -354,8 +532,6 @@ class XMPPComponent(ComponentXMPP): self.log.error(f"Failed to enable Carbons: {e}") def get_identifier(self, msg): - xmpp_message_id = str(msg.get("id") or "").strip() - # Extract sender JID (full format: user@domain/resource) sender_jid = str(msg["from"]) @@ -445,6 +621,7 @@ class XMPPComponent(ComponentXMPP): def _derived_omemo_fingerprint(self, jid: str) -> str: import hashlib + return hashlib.sha256(f"xmpp-omemo-key:{jid}".encode()).hexdigest()[:32] def _get_omemo_plugin(self): @@ -456,26 +633,198 @@ class XMPPComponent(ComponentXMPP): async def _bootstrap_omemo_for_authentic_channel(self): jid = str(getattr(settings, "XMPP_JID", "") or "").strip() + omemo_plugin = self._get_omemo_plugin() omemo_enabled = omemo_plugin is not None status = "active" if omemo_enabled else "not_available" - reason = "OMEMO plugin active" if omemo_enabled else "xep_0384 plugin not loaded" + reason = ( + "OMEMO plugin active" if omemo_enabled else "xep_0384 plugin not loaded" + ) fingerprint = self._derived_omemo_fingerprint(jid) if omemo_enabled: try: import asyncio as _asyncio + + # OMEMO session-manager bootstrap can take longer in component mode because + # it performs consistency checks and initial pubsub interactions before + # device publication is complete. session_manager = await _asyncio.wait_for( - omemo_plugin.get_session_manager(), timeout=15.0 + omemo_plugin.get_session_manager(), timeout=180.0 ) own_devices = await session_manager.get_own_device_information() + device_id = None + if own_devices: - key_bytes = own_devices[0].identity_key - fingerprint = ":".join(f"{b:02X}" for b in key_bytes) + # own_devices is a tuple: (own_device, other_devices) + own_device = ( + own_devices[0] + if isinstance(own_devices, (tuple, list)) + else own_devices + ) + key_bytes = own_device.identity_key + try: + from omemo.session_manager import ( + SessionManager as _OmemoSessionManager, + ) + + fingerprint = " ".join( + _OmemoSessionManager.format_identity_key(key_bytes) + ).upper() + except Exception: + fingerprint = ":".join(f"{b:02X}" for b in key_bytes) + device_id = own_device.device_id + self._sync_component_omemo_pubsub(device_id, key_bytes) + self.log.info( + "OMEMO: own device created, device_id=%s fingerprint=%s", + device_id, + fingerprint, + ) + else: + # Fallback: session manager may not auto-create devices for component JIDs + # Manually generate a device ID and use it + import random + + device_id = random.randint(1, 2**31 - 1) + self.log.warning( + "OMEMO: session manager did not create device (component JID limitation), " + "using fallback device_id=%s", + device_id, + ) + + # CRITICAL FIX: Publish BOTH device list AND device bundle + # Clients need both: + # 1. Device list node: eu.siacs.conversations.axolotl/devices/{jid} (lists device IDs) + # 2. Device bundle nodes: eu.siacs.conversations.axolotl/bundles:{device_id} (contains keys) + # Without bundles, clients can't encrypt (they hang in "pending...") + if device_id: + try: + namespace = "eu.siacs.conversations.axolotl" + pubsub_service = ( + str(getattr(settings, "XMPP_PUBSUB_SERVICE", "pubsub")) + or "pubsub" + ) + + # Step 1: Publish device list + try: + device_list_dict = {device_id: None} + await session_manager._upload_device_list( + namespace, device_list_dict + ) + self.log.info( + "OMEMO: device list uploaded via session manager for %s", + jid, + ) + except (AttributeError, TypeError): + # Fallback: Manual publish via XEP-0060 + try: + device_item = ET.Element("list") + device_item.set("xmlns", namespace) + for dev_id in device_list_dict: + device_elem = ET.SubElement(device_item, "device") + device_elem.set("id", str(dev_id)) + + node_name = f"{namespace}/devices/{jid}" + await self["xep_0060"].publish( + pubsub_service, node_name, payload=device_item + ) + self.log.info( + "OMEMO: device list published via XEP-0060 for %s", + jid, + ) + except Exception as e: + self.log.warning( + "OMEMO: device list publish failed: %s", e + ) + + # Step 2: Publish device bundle (CRITICAL - contains the keys!) + # This is what was missing - clients couldn't find the keys + try: + if own_devices: + # Get actual identity key from device + own_device = ( + own_devices[0] + if isinstance(own_devices, (tuple, list)) + else own_devices + ) + identity_key = own_device.identity_key + signed_prekey = getattr( + own_device, "signed_prekey", None + ) + + # Build bundle item with actual keys + bundle_item = ET.Element("bundle") + bundle_item.set("xmlns", namespace) + + # Add signed prekey + if signed_prekey: + spk_elem = ET.SubElement( + bundle_item, "signedPreKeyPublic" + ) + spk_elem.text = signed_prekey + else: + # Fallback: use hash of identity key + import base64 + + spk_elem = ET.SubElement( + bundle_item, "signedPreKeyPublic" + ) + spk_elem.text = base64.b64encode( + identity_key + ).decode() + + # Add identity key + ik_elem = ET.SubElement(bundle_item, "identityKey") + import base64 + + ik_elem.text = base64.b64encode(identity_key).decode() + + # Publish bundle + bundle_node = f"{namespace}/bundles:{device_id}" + await self["xep_0060"].publish( + pubsub_service, bundle_node, payload=bundle_item + ) + self.log.info( + "OMEMO: device bundle published for %s (device_id=%s)", + jid, + device_id, + ) + except Exception as bundle_exc: + self.log.warning( + "OMEMO: device bundle publish failed: %s", bundle_exc + ) + + except Exception as upload_exc: + self.log.warning( + "OMEMO: device/bundle upload error: %s", upload_exc + ) + + # Try to refresh device list to ensure all devices are properly registered. + try: + await session_manager.refresh_device_lists(jid) + self.log.info("OMEMO: device list refreshed for %s", jid) + except Exception as refresh_exc: + self.log.debug("OMEMO: device list refresh: %s", refresh_exc) + except _asyncio.TimeoutError: + self.log.error( + "OMEMO: session manager initialization timeout after 180s. " + "Device list may not be published to PubSub. " + "Clients will not be able to discover gateway devices. " + "Check PubSub server connectivity and latency." + ) + status = "timeout" + reason = ( + "Session manager initialization timeout - device list not published" + ) except Exception as exc: - self.log.warning("OMEMO: could not read own device fingerprint: %s", exc) + self.log.warning( + "OMEMO: could not initialize device information: %s", exc + ) self.log.info( "OMEMO bootstrap: jid=%s enabled=%s status=%s fingerprint=%s", - jid, omemo_enabled, status, fingerprint, + jid, + omemo_enabled, + status, + fingerprint, ) transport.update_runtime_state( "xmpp", @@ -486,19 +835,41 @@ class XMPPComponent(ComponentXMPP): omemo_status_reason=reason, ) - async def _record_sender_omemo_state(self, user, *, sender_jid, recipient_jid, message_stanza): + async def _record_sender_omemo_state( + self, + user, + *, + sender_jid, + recipient_jid, + message_stanza, + sender_fingerprint="", + ): parsed = _extract_sender_omemo_client_key(message_stanza) status = str(parsed.get("status") or "no_omemo") client_key = str(parsed.get("client_key") or "") - await sync_to_async(UserXmppOmemoState.objects.update_or_create)( - user=user, - defaults={ - "status": status, - "latest_client_key": client_key, - "last_sender_jid": str(sender_jid or ""), - "last_target_jid": str(recipient_jid or ""), - }, - ) + + def _save_row(): + row, _ = UserXmppOmemoState.objects.get_or_create(user=user) + details = dict(row.details or {}) + if sender_fingerprint: + details["latest_client_fingerprint"] = str(sender_fingerprint) + row.status = status + row.latest_client_key = client_key + row.last_sender_jid = str(sender_jid or "") + row.last_target_jid = str(recipient_jid or "") + row.details = details + row.save( + update_fields=[ + "status", + "latest_client_key", + "last_sender_jid", + "last_target_jid", + "details", + "updated_at", + ] + ) + + await sync_to_async(_save_row)() _approval_event_prefix = "codex_approval" @@ -525,7 +896,9 @@ class XMPPComponent(ComponentXMPP): run.status = "approved_waiting_resume" if status == "approved" else status await sync_to_async(run.save)(update_fields=["status"]) if request.external_sync_event_id: - evt = await sync_to_async(ExternalSyncEvent.objects.get)(pk=request.external_sync_event_id) + evt = await sync_to_async(ExternalSyncEvent.objects.get)( + pk=request.external_sync_event_id + ) evt.status = "ok" await sync_to_async(evt.save)(update_fields=["status"]) user = await sync_to_async(User.objects.get)(pk=request.user_id) @@ -547,9 +920,9 @@ class XMPPComponent(ComponentXMPP): async def _approval_list_pending(self, user, scope, sym): requests = await sync_to_async(list)( - CodexPermissionRequest.objects.filter( - user=user, status="pending" - ).order_by("-requested_at")[:20] + CodexPermissionRequest.objects.filter(user=user, status="pending").order_by( + "-requested_at" + )[:20] ) sym(f"pending={len(requests)}") for req in requests: @@ -557,9 +930,9 @@ class XMPPComponent(ComponentXMPP): async def _approval_status(self, user, approval_key, sym): try: - req = await sync_to_async( - CodexPermissionRequest.objects.get - )(user=user, approval_key=approval_key) + req = await sync_to_async(CodexPermissionRequest.objects.get)( + user=user, approval_key=approval_key + ) sym(f"status={req.status} key={req.approval_key}") except CodexPermissionRequest.DoesNotExist: sym(f"approval_key_not_found:{approval_key}") @@ -568,7 +941,7 @@ class XMPPComponent(ComponentXMPP): command = body.strip() for prefix, expected_provider in self._APPROVAL_PROVIDER_COMMANDS.items(): if command.startswith(prefix + " ") or command == prefix: - sub = command[len(prefix):].strip() + sub = command[len(prefix) :].strip() parts = sub.split() if len(parts) >= 2 and parts[0] in ("approve", "reject"): action, approval_key = parts[0], parts[1] @@ -583,7 +956,9 @@ class XMPPComponent(ComponentXMPP): return True provider = self._resolve_request_provider(req) if not provider.startswith(expected_provider): - sym(f"approval_key_not_for_provider:{approval_key} provider={provider}") + sym( + f"approval_key_not_for_provider:{approval_key} provider={provider}" + ) return True await self._apply_approval_decision(req, action, sym) sym(f"{action}d: {approval_key}") @@ -594,7 +969,7 @@ class XMPPComponent(ComponentXMPP): if not command.startswith(".approval"): return False - rest = command[len(".approval"):].strip() + rest = command[len(".approval") :].strip() if rest.split() and rest.split()[0] in ("approve", "reject"): parts = rest.split() @@ -617,12 +992,12 @@ class XMPPComponent(ComponentXMPP): return True if rest.startswith("list-pending"): - scope = rest[len("list-pending"):].strip() or "mine" + scope = rest[len("list-pending") :].strip() or "mine" await self._approval_list_pending(user, scope, sym) return True if rest.startswith("status "): - approval_key = rest[len("status "):].strip() + approval_key = rest[len("status ") :].strip() await self._approval_status(user, approval_key, sym) return True @@ -637,7 +1012,7 @@ class XMPPComponent(ComponentXMPP): command = body.strip() if not command.startswith(".tasks"): return False - rest = command[len(".tasks"):].strip() + rest = command[len(".tasks") :].strip() if rest.startswith("list"): parts = rest.split() @@ -656,7 +1031,7 @@ class XMPPComponent(ComponentXMPP): return True if rest.startswith("show "): - ref = rest[len("show "):].strip().lstrip("#") + ref = rest[len("show ") :].strip().lstrip("#") try: task = await sync_to_async(DerivedTask.objects.get)( user=user, reference_code=ref @@ -668,7 +1043,7 @@ class XMPPComponent(ComponentXMPP): return True if rest.startswith("complete "): - ref = rest[len("complete "):].strip().lstrip("#") + ref = rest[len("complete ") :].strip().lstrip("#") try: task = await sync_to_async(DerivedTask.objects.get)( user=user, reference_code=ref @@ -681,7 +1056,7 @@ class XMPPComponent(ComponentXMPP): return True if rest.startswith("undo "): - ref = rest[len("undo "):].strip().lstrip("#") + ref = rest[len("undo ") :].strip().lstrip("#") try: task = await sync_to_async(DerivedTask.objects.get)( user=user, reference_code=ref @@ -710,7 +1085,7 @@ class XMPPComponent(ComponentXMPP): query = parse_qs(parsed.query or "") return str((query.get("secret") or [""])[0] or "").strip() if lowered.startswith(".totp"): - rest = text[len(".totp"):].strip() + rest = text[len(".totp") :].strip() if not rest: return "" parts = rest.split(maxsplit=1) @@ -720,9 +1095,6 @@ class XMPPComponent(ComponentXMPP): if action in {"status", "help"}: return "" return rest - compact = text.replace(" ", "").strip().upper() - if TOTP_BASE32_SECRET_RE.match(compact): - return compact return "" async def _handle_totp_command(self, user, body, sym): @@ -763,11 +1135,7 @@ class XMPPComponent(ComponentXMPP): def _save_device(): from django_otp.plugins.otp_totp.models import TOTPDevice - device = ( - TOTPDevice.objects.filter(user=user) - .order_by("-id") - .first() - ) + device = TOTPDevice.objects.filter(user=user).order_by("-id").first() if device is None: device = TOTPDevice(user=user, name="gateway") device.key = key_bytes.hex() @@ -798,7 +1166,9 @@ class XMPPComponent(ComponentXMPP): command_text = str(body or "").strip() async def _contacts_handler(_ctx, emit): - persons = await sync_to_async(list)(Person.objects.filter(user=sender_user).order_by("name")) + persons = await sync_to_async(list)( + Person.objects.filter(user=sender_user).order_by("name") + ) if not persons: emit("No contacts found.") return True @@ -815,7 +1185,9 @@ class XMPPComponent(ComponentXMPP): return True async def _approval_handler(_ctx, emit): - return await self._handle_approval_command(sender_user, command_text, sender_jid, emit) + return await self._handle_approval_command( + sender_user, command_text, sender_jid, emit + ) async def _tasks_handler(_ctx, emit): return await self._handle_tasks_command(sender_user, command_text, emit) @@ -845,7 +1217,10 @@ class XMPPComponent(ComponentXMPP): GatewayCommandRoute( name="approval", scope_key="gateway.approval", - matcher=lambda text: str(text or "").strip().lower().startswith(".approval") + matcher=lambda text: str(text or "") + .strip() + .lower() + .startswith(".approval") or any( str(text or "").strip().lower().startswith(prefix + " ") or str(text or "").strip().lower() == prefix @@ -856,7 +1231,10 @@ class XMPPComponent(ComponentXMPP): GatewayCommandRoute( name="tasks", scope_key="gateway.tasks", - matcher=lambda text: str(text or "").strip().lower().startswith(".tasks"), + matcher=lambda text: str(text or "") + .strip() + .lower() + .startswith(".tasks"), handler=_tasks_handler, ), GatewayCommandRoute( @@ -1365,16 +1743,36 @@ class XMPPComponent(ComponentXMPP): def on_presence_subscribe(self, pres): """ Handle incoming presence subscription requests. - Accept only if the recipient has a contact matching the sender. + Accept subscriptions to: + 1. The gateway component itself (jews.zm.is) - for OMEMO device discovery + 2. User contacts at the gateway (user|service@jews.zm.is) - for user-to-user messaging """ sender_jid = str(pres["from"]).split("/")[0] # Bare JID (user@domain) recipient_jid = str(pres["to"]).split("/")[0] + component_jid = str(self.boundjid.bare) self.log.debug( f"Received subscription request from {sender_jid} to {recipient_jid}" ) + # Check if subscription is to the gateway component itself + if recipient_jid == component_jid: + # Auto-accept subscriptions to the gateway component (for OMEMO) + self.log.info( + "Auto-accepting subscription to gateway component from %s", sender_jid + ) + self.send_presence(ptype="subscribed", pto=sender_jid, pfrom=component_jid) + + # Send presence availability to enable device discovery + self.send_presence(ptype="available", pto=sender_jid, pfrom=component_jid) + self.log.info( + "Gateway component is available to %s (OMEMO device discovery enabled)", + sender_jid, + ) + return + + # Otherwise, handle user-to-user subscription (existing logic) try: # Extract sender and recipient usernames user_username, _ = sender_jid.split("@", 1) @@ -1400,24 +1798,24 @@ class XMPPComponent(ComponentXMPP): self.log.debug("Resolving subscription identifier service=%s", service) PersonIdentifier.objects.get(user=user, person=person, service=service) - component_jid = f"{person_name.lower()}|{service}@{self.boundjid.bare}" + contact_jid = f"{person_name.lower()}|{service}@{self.boundjid.bare}" # Accept the subscription - self.send_presence(ptype="subscribed", pto=sender_jid, pfrom=component_jid) + self.send_presence(ptype="subscribed", pto=sender_jid, pfrom=contact_jid) self.log.debug( - f"Accepted subscription from {sender_jid}, sent from {component_jid}" + f"Accepted subscription from {sender_jid}, sent from {contact_jid}" ) # Send a presence request **from the recipient to the sender** (ASKS THEM TO ACCEPT BACK) - # self.send_presence(ptype="subscribe", pto=sender_jid, pfrom=component_jid) + # self.send_presence(ptype="subscribe", pto=sender_jid, pfrom=contact_jid) # Add sender to roster # self.update_roster(sender_jid, name=sender_jid.split("@")[0]) # Send presence update to sender **from the correct JID** - self.send_presence(ptype="available", pto=sender_jid, pfrom=component_jid) + self.send_presence(ptype="available", pto=sender_jid, pfrom=contact_jid) self.log.debug( - "Sent presence update from %s to %s", component_jid, sender_jid + "Sent presence update from %s to %s", contact_jid, sender_jid ) except (User.DoesNotExist, Person.DoesNotExist, PersonIdentifier.DoesNotExist): @@ -1645,14 +2043,20 @@ class XMPPComponent(ComponentXMPP): # Attempt to decrypt OMEMO-encrypted messages before body extraction. original_msg = msg omemo_plugin = self._get_omemo_plugin() + sender_omemo_fingerprint = "" if omemo_plugin: try: if omemo_plugin.is_encrypted(msg): - decrypted, _ = await omemo_plugin.decrypt_message(msg) + decrypted, sender_device = await omemo_plugin.decrypt_message(msg) msg = decrypted + sender_omemo_fingerprint = _format_omemo_identity_fingerprint( + getattr(sender_device, "identity_key", b"") + ) self.log.debug("OMEMO: decrypted message from %s", sender_jid) except Exception as exc: - self.log.warning("OMEMO: decryption failed from %s: %s", sender_jid, exc) + self.log.warning( + "OMEMO: decryption failed from %s: %s", sender_jid, exc + ) # Extract message body body = msg["body"] if msg["body"] else "" @@ -1678,7 +2082,9 @@ class XMPPComponent(ComponentXMPP): or "application/octet-stream", ) except Exception as exc: - self.log.warning("xmpp dropped unsafe attachment url=%s: %s", url_value, exc) + self.log.warning( + "xmpp dropped unsafe attachment url=%s: %s", url_value, exc + ) continue attachments.append( { @@ -1723,7 +2129,9 @@ class XMPPComponent(ComponentXMPP): content_type=_content_type_from_filename_or_url(safe_url), ) except Exception as exc: - self.log.warning("xmpp dropped extracted unsafe url=%s: %s", url_value, exc) + self.log.warning( + "xmpp dropped extracted unsafe url=%s: %s", url_value, exc + ) continue attachments.append( { @@ -1787,6 +2195,7 @@ class XMPPComponent(ComponentXMPP): sender_jid=sender_jid, recipient_jid=recipient_jid, message_stanza=original_msg, + sender_fingerprint=sender_omemo_fingerprint, ) except Exception as exc: self.log.warning("OMEMO: failed to record sender state: %s", exc) @@ -1795,8 +2204,11 @@ class XMPPComponent(ComponentXMPP): # Enforce mandatory encryption policy. try: from core.models import UserXmppSecuritySettings + sec_settings = await sync_to_async( - lambda: UserXmppSecuritySettings.objects.filter(user=sender_user).first() + lambda: UserXmppSecuritySettings.objects.filter( + user=sender_user + ).first() )() if sec_settings and sec_settings.require_omemo: omemo_status = str(omemo_observation.get("status") or "") @@ -1812,7 +2224,7 @@ class XMPPComponent(ComponentXMPP): if recipient_jid == settings.XMPP_JID: self.log.debug("Handling command message sent to gateway JID") - if body.startswith(".") or self._extract_totp_secret_candidate(body): + if body.startswith("."): await self._route_gateway_command( sender_user=sender_user, body=body, @@ -1824,7 +2236,9 @@ class XMPPComponent(ComponentXMPP): "sender_jid": str(sender_jid or ""), "recipient_jid": str(recipient_jid or ""), "omemo_status": str(omemo_observation.get("status") or ""), - "omemo_client_key": str(omemo_observation.get("client_key") or ""), + "omemo_client_key": str( + omemo_observation.get("client_key") or "" + ), } }, sym=sym, @@ -2004,7 +2418,9 @@ class XMPPComponent(ComponentXMPP): "sender_jid": str(sender_jid or ""), "recipient_jid": str(recipient_jid or ""), "omemo_status": str(omemo_observation.get("status") or ""), - "omemo_client_key": str(omemo_observation.get("client_key") or ""), + "omemo_client_key": str( + omemo_observation.get("client_key") or "" + ), } }, ) @@ -2127,9 +2543,7 @@ class XMPPComponent(ComponentXMPP): f"Upload failed: {response.status} {await response.text()}" ) return None - self.log.debug( - "Successfully uploaded %s to %s", filename, upload_url - ) + self.log.debug("Successfully uploaded %s to %s", filename, upload_url) # Send XMPP message immediately after successful upload xmpp_msg_id = await self.send_xmpp_message( @@ -2145,7 +2559,13 @@ class XMPPComponent(ComponentXMPP): return None async def send_xmpp_message( - self, recipient_jid, sender_jid, body_text, attachment_url=None + self, + recipient_jid, + sender_jid, + body_text, + attachment_url=None, + *, + use_omemo_encryption=True, ): """Sends an XMPP message with either text or an attachment URL.""" msg = self.make_message(mto=recipient_jid, mfrom=sender_jid, mtype="chat") @@ -2163,29 +2583,35 @@ class XMPPComponent(ComponentXMPP): self.log.debug("Sending XMPP message: %s", msg.xml) - # Attempt OMEMO encryption for text-only messages (not attachments). - if not attachment_url: + # Attempt OMEMO encryption for text-only messages (not attachments) + # when outbound policy allows it. + if not attachment_url and use_omemo_encryption: omemo_plugin = self._get_omemo_plugin() if omemo_plugin: try: from slixmpp.jid import JID as _JID + encrypted_msgs, enc_errors = await omemo_plugin.encrypt_message( msg, _JID(recipient_jid) ) if enc_errors: self.log.debug( "OMEMO: non-critical encryption errors for %s: %s", - recipient_jid, enc_errors, + recipient_jid, + enc_errors, ) if encrypted_msgs: for enc_msg in encrypted_msgs.values(): enc_msg.send() - self.log.debug("OMEMO: sent encrypted message to %s", recipient_jid) + self.log.debug( + "OMEMO: sent encrypted message to %s", recipient_jid + ) return msg_id except Exception as exc: self.log.debug( "OMEMO: encryption not available for %s, sending plaintext: %s", - recipient_jid, exc, + recipient_jid, + exc, ) msg.send() @@ -2347,6 +2773,19 @@ class XMPPComponent(ComponentXMPP): sender_jid = f"{person_identifier.person.name.lower()}|{person_identifier.service}@{settings.XMPP_JID}" recipient_jid = self._user_jid(person_identifier.user.username) + relay_encrypt_with_omemo = True + try: + from core.models import UserXmppSecuritySettings + + sec_settings = await sync_to_async( + lambda: UserXmppSecuritySettings.objects.filter(user=user).first() + )() + if sec_settings is not None: + relay_encrypt_with_omemo = bool( + getattr(sec_settings, "encrypt_contact_messages_with_omemo", True) + ) + except Exception as exc: + self.log.warning("XMPP relay OMEMO settings lookup failed: %s", exc) if is_outgoing_message: xmpp_id = await self.send_xmpp_message( recipient_jid, @@ -2385,7 +2824,12 @@ class XMPPComponent(ComponentXMPP): # Step 1: Send text message separately elif text: - xmpp_id = await self.send_xmpp_message(recipient_jid, sender_jid, text) + xmpp_id = await self.send_xmpp_message( + recipient_jid, + sender_jid, + text, + use_omemo_encryption=relay_encrypt_with_omemo, + ) transport.record_bridge_mapping( user_id=user.id, person_id=person_identifier.person_id, @@ -2512,17 +2956,25 @@ class XMPPClient(ClientBase): self._omemo_plugin_registered = False if _OMEMO_AVAILABLE: try: - data_dir = str(getattr(settings, "XMPP_OMEMO_DATA_DIR", "") or "").strip() + data_dir = str( + getattr(settings, "XMPP_OMEMO_DATA_DIR", "") or "" + ).strip() if not data_dir: data_dir = str(Path(settings.BASE_DIR) / "xmpp_omemo_data") # Register our concrete plugin class under the "xep_0384" name so # that slixmpp's dependency resolver finds it. _slixmpp_register_plugin(_GiaOmemoPlugin) - self.client.register_plugin("xep_0384", pconfig={"data_dir": data_dir}) + self.client.register_plugin( + "xep_0384", pconfig={"data_dir": data_dir} + ) self._omemo_plugin_registered = True - self.log.info("OMEMO: xep_0384 plugin registered, data_dir=%s", data_dir) + self.log.info( + "OMEMO: xep_0384 plugin registered, data_dir=%s", data_dir + ) except Exception as exc: - self.log.warning("OMEMO: failed to register xep_0384 plugin: %s", exc) + self.log.warning( + "OMEMO: failed to register xep_0384 plugin: %s", exc + ) else: self.log.warning("OMEMO: slixmpp_omemo not available, OMEMO disabled") diff --git a/core/commands/delivery.py b/core/commands/delivery.py index eb37a88..a42002a 100644 --- a/core/commands/delivery.py +++ b/core/commands/delivery.py @@ -31,7 +31,9 @@ def chunk_for_transport(text: str, limit: int = 3000) -> list[str]: return [part for part in parts if part] -async def post_status_in_source(trigger_message: Message, text: str, origin_tag: str) -> bool: +async def post_status_in_source( + trigger_message: Message, text: str, origin_tag: str +) -> bool: service = str(trigger_message.source_service or "").strip().lower() if service not in STATUS_VISIBLE_SOURCE_SERVICES: return False @@ -76,9 +78,10 @@ async def post_to_channel_binding( channel_identifier = str(binding_channel_identifier or "").strip() if service == "web": session = None - if channel_identifier and channel_identifier == str( - trigger_message.source_chat_id or "" - ).strip(): + if ( + channel_identifier + and channel_identifier == str(trigger_message.source_chat_id or "").strip() + ): session = trigger_message.session if session is None and channel_identifier: session = await sync_to_async( @@ -99,7 +102,8 @@ async def post_to_channel_binding( ts=int(time.time() * 1000), custom_author="BOT", source_service="web", - source_chat_id=channel_identifier or str(trigger_message.source_chat_id or ""), + source_chat_id=channel_identifier + or str(trigger_message.source_chat_id or ""), message_meta={"origin_tag": origin_tag}, ) return True diff --git a/core/commands/engine.py b/core/commands/engine.py index 8d9f655..8643808 100644 --- a/core/commands/engine.py +++ b/core/commands/engine.py @@ -58,9 +58,15 @@ def _effective_bootstrap_scope( identifier = str(ctx.channel_identifier or "").strip() if service != "web": return service, identifier - session_identifier = getattr(getattr(trigger_message, "session", None), "identifier", None) - fallback_service = str(getattr(session_identifier, "service", "") or "").strip().lower() - fallback_identifier = str(getattr(session_identifier, "identifier", "") or "").strip() + session_identifier = getattr( + getattr(trigger_message, "session", None), "identifier", None + ) + fallback_service = ( + str(getattr(session_identifier, "service", "") or "").strip().lower() + ) + fallback_identifier = str( + getattr(session_identifier, "identifier", "") or "" + ).strip() if fallback_service and fallback_identifier and fallback_service != "web": return fallback_service, fallback_identifier return service, identifier @@ -89,7 +95,11 @@ def _ensure_bp_profile(user_id: int) -> CommandProfile: if str(profile.trigger_token or "").strip() != ".bp": profile.trigger_token = ".bp" profile.save(update_fields=["trigger_token", "updated_at"]) - for action_type, position in (("extract_bp", 0), ("save_document", 1), ("post_result", 2)): + for action_type, position in ( + ("extract_bp", 0), + ("save_document", 1), + ("post_result", 2), + ): action, created = CommandAction.objects.get_or_create( profile=profile, action_type=action_type, @@ -327,7 +337,9 @@ async def process_inbound_message(ctx: CommandContext) -> list[CommandResult]: return [] if is_mirrored_origin(trigger_message.message_meta): return [] - effective_service, effective_channel = _effective_bootstrap_scope(ctx, trigger_message) + effective_service, effective_channel = _effective_bootstrap_scope( + ctx, trigger_message + ) security_context = CommandSecurityContext( service=effective_service, channel_identifier=effective_channel, @@ -394,7 +406,9 @@ async def process_inbound_message(ctx: CommandContext) -> list[CommandResult]: result = await handler.execute(ctx) results.append(result) except Exception as exc: - log.exception("command execution failed for profile=%s: %s", profile.slug, exc) + log.exception( + "command execution failed for profile=%s: %s", profile.slug, exc + ) results.append( CommandResult( ok=False, diff --git a/core/commands/handlers/bp.py b/core/commands/handlers/bp.py index a238a27..2107143 100644 --- a/core/commands/handlers/bp.py +++ b/core/commands/handlers/bp.py @@ -45,14 +45,15 @@ class BPParsedCommand(dict): return str(self.get("remainder_text") or "") - def parse_bp_subcommand(text: str) -> BPParsedCommand: body = str(text or "") if _BP_SET_RANGE_RE.match(body): return BPParsedCommand(command="set_range", remainder_text="") match = _BP_SET_RE.match(body) if match: - return BPParsedCommand(command="set", remainder_text=str(match.group("rest") or "").strip()) + return BPParsedCommand( + command="set", remainder_text=str(match.group("rest") or "").strip() + ) return BPParsedCommand(command=None, remainder_text="") @@ -63,7 +64,9 @@ def bp_subcommands_enabled() -> bool: return bool(raw) -def bp_trigger_matches(message_text: str, trigger_token: str, exact_match_only: bool) -> bool: +def bp_trigger_matches( + message_text: str, trigger_token: str, exact_match_only: bool +) -> bool: body = str(message_text or "").strip() trigger = str(trigger_token or "").strip() parsed = parse_bp_subcommand(body) @@ -144,7 +147,8 @@ class BPCommandHandler(CommandHandler): "enabled": True, "generation_mode": "ai" if variant_key == "bp" else "verbatim", "send_plan_to_egress": "post_result" in action_types, - "send_status_to_source": str(profile.visibility_mode or "") == "status_in_source", + "send_status_to_source": str(profile.visibility_mode or "") + == "status_in_source", "send_status_to_egress": False, "store_document": True, } @@ -224,10 +228,14 @@ class BPCommandHandler(CommandHandler): ts__lte=int(trigger.ts or 0), ) .order_by("ts") - .select_related("session", "session__identifier", "session__identifier__person") + .select_related( + "session", "session__identifier", "session__identifier__person" + ) ) - def _annotation(self, mode: str, message_count: int, has_addendum: bool = False) -> str: + def _annotation( + self, mode: str, message_count: int, has_addendum: bool = False + ) -> str: if mode == "set" and has_addendum: return "Generated from 1 message + 1 addendum." if message_count == 1: @@ -291,21 +299,29 @@ class BPCommandHandler(CommandHandler): if anchor is None: run.status = "failed" run.error = "bp_set_range_requires_reply_target" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) rows = await self._load_window(trigger, anchor) deterministic_content = plain_text_blob(rows) if not deterministic_content.strip(): run.status = "failed" run.error = "bp_set_range_empty_content" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) if str(policy.get("generation_mode") or "verbatim") == "ai": - ai_obj = await sync_to_async(lambda: AI.objects.filter(user=trigger.user).first())() + ai_obj = await sync_to_async( + lambda: AI.objects.filter(user=trigger.user).first() + )() if ai_obj is None: run.status = "failed" run.error = "ai_not_configured" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) prompt = [ { @@ -329,12 +345,16 @@ class BPCommandHandler(CommandHandler): except Exception as exc: run.status = "failed" run.error = f"bp_ai_failed:{exc}" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) if not content: run.status = "failed" run.error = "empty_ai_response" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) else: content = deterministic_content @@ -360,9 +380,7 @@ class BPCommandHandler(CommandHandler): elif anchor is not None and remainder: base = str(anchor.text or "").strip() or "(no text)" content = ( - f"{base}\n" - "--- Addendum (newer message text) ---\n" - f"{remainder}" + f"{base}\n" "--- Addendum (newer message text) ---\n" f"{remainder}" ) source_ids.extend([str(anchor.id), str(trigger.id)]) has_addendum = True @@ -373,15 +391,21 @@ class BPCommandHandler(CommandHandler): else: run.status = "failed" run.error = "bp_set_empty_content" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) if str(policy.get("generation_mode") or "verbatim") == "ai": - ai_obj = await sync_to_async(lambda: AI.objects.filter(user=trigger.user).first())() + ai_obj = await sync_to_async( + lambda: AI.objects.filter(user=trigger.user).first() + )() if ai_obj is None: run.status = "failed" run.error = "ai_not_configured" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) prompt = [ { @@ -405,16 +429,22 @@ class BPCommandHandler(CommandHandler): except Exception as exc: run.status = "failed" run.error = f"bp_ai_failed:{exc}" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) if not ai_content: run.status = "failed" run.error = "empty_ai_response" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) content = ai_content - annotation = self._annotation("set", 1 if not has_addendum else 2, has_addendum) + annotation = self._annotation( + "set", 1 if not has_addendum else 2, has_addendum + ) doc = None if bool(policy.get("store_document", True)): doc = await self._persist_document( @@ -430,7 +460,9 @@ class BPCommandHandler(CommandHandler): else: run.status = "failed" run.error = "bp_unknown_subcommand" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) fanout_stats = {"sent_bindings": 0, "failed_bindings": 0} @@ -479,7 +511,9 @@ class BPCommandHandler(CommandHandler): if trigger.reply_to_id is None: run.status = "failed" run.error = "bp_requires_reply_target" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) anchor = trigger.reply_to @@ -488,7 +522,9 @@ class BPCommandHandler(CommandHandler): rows, author_rewrites={"USER": "Operator", "BOT": "Assistant"}, ) - max_transcript_chars = int(getattr(settings, "BP_MAX_TRANSCRIPT_CHARS", 12000) or 12000) + max_transcript_chars = int( + getattr(settings, "BP_MAX_TRANSCRIPT_CHARS", 12000) or 12000 + ) transcript = _clamp_transcript(transcript, max_transcript_chars) default_template = ( "Business Plan:\n" @@ -499,7 +535,9 @@ class BPCommandHandler(CommandHandler): "- Risks" ) template_text = profile.template_text or default_template - max_template_chars = int(getattr(settings, "BP_MAX_TEMPLATE_CHARS", 5000) or 5000) + max_template_chars = int( + getattr(settings, "BP_MAX_TEMPLATE_CHARS", 5000) or 5000 + ) template_text = str(template_text or "")[:max_template_chars] generation_mode = str(policy.get("generation_mode") or "ai") if generation_mode == "verbatim": @@ -507,14 +545,20 @@ class BPCommandHandler(CommandHandler): if not summary.strip(): run.status = "failed" run.error = "bp_verbatim_empty_content" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) else: - ai_obj = await sync_to_async(lambda: AI.objects.filter(user=trigger.user).first())() + ai_obj = await sync_to_async( + lambda: AI.objects.filter(user=trigger.user).first() + )() if ai_obj is None: run.status = "failed" run.error = "ai_not_configured" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) prompt = [ @@ -530,13 +574,20 @@ class BPCommandHandler(CommandHandler): }, ] try: - summary = str(await ai_runner.run_prompt(prompt, ai_obj, operation="command_bp_extract") or "").strip() + summary = str( + await ai_runner.run_prompt( + prompt, ai_obj, operation="command_bp_extract" + ) + or "" + ).strip() if not summary: raise RuntimeError("empty_ai_response") except Exception as exc: run.status = "failed" run.error = f"bp_ai_failed:{exc}" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="failed", error=run.error) annotation = self._annotation("legacy", len(rows)) @@ -588,23 +639,31 @@ class BPCommandHandler(CommandHandler): async def execute(self, ctx: CommandContext) -> CommandResult: trigger = await sync_to_async( - lambda: Message.objects.select_related("user", "session").filter(id=ctx.message_id).first() + lambda: Message.objects.select_related("user", "session") + .filter(id=ctx.message_id) + .first() )() if trigger is None: return CommandResult(ok=False, status="failed", error="trigger_not_found") profile = await sync_to_async( - lambda: trigger.user.commandprofile_set.filter(slug=self.slug, enabled=True).first() + lambda: trigger.user.commandprofile_set.filter( + slug=self.slug, enabled=True + ).first() )() if profile is None: return CommandResult(ok=False, status="skipped", error="profile_missing") actions = await sync_to_async(list)( - CommandAction.objects.filter(profile=profile, enabled=True).order_by("position", "id") + CommandAction.objects.filter(profile=profile, enabled=True).order_by( + "position", "id" + ) ) action_types = {row.action_type for row in actions} if "extract_bp" not in action_types: - return CommandResult(ok=False, status="skipped", error="extract_bp_disabled") + return CommandResult( + ok=False, status="skipped", error="extract_bp_disabled" + ) run, created = await sync_to_async(CommandRun.objects.get_or_create)( profile=profile, @@ -612,7 +671,11 @@ class BPCommandHandler(CommandHandler): defaults={"user": trigger.user, "status": "running"}, ) if not created and run.status in {"ok", "running"}: - return CommandResult(ok=True, status="ok", payload={"document_id": str(run.result_ref_id or "")}) + return CommandResult( + ok=True, + status="ok", + payload={"document_id": str(run.result_ref_id or "")}, + ) run.status = "running" run.error = "" @@ -627,7 +690,9 @@ class BPCommandHandler(CommandHandler): if not bool(policy.get("enabled")): run.status = "skipped" run.error = f"variant_disabled:{variant_key}" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) return CommandResult(ok=False, status="skipped", error=run.error) parsed = parse_bp_subcommand(ctx.message_text) diff --git a/core/commands/handlers/claude.py b/core/commands/handlers/claude.py index ff1cb0a..4bed6ea 100644 --- a/core/commands/handlers/claude.py +++ b/core/commands/handlers/claude.py @@ -20,8 +20,8 @@ from core.models import ( TaskProject, TaskProviderConfig, ) -from core.tasks.codex_support import channel_variants, resolve_external_chat_id from core.tasks.codex_approval import queue_codex_event_with_pre_approval +from core.tasks.codex_support import channel_variants, resolve_external_chat_id _CLAUDE_DEFAULT_RE = re.compile( r"^\s*(?:\.claude\b|#claude#?)(?P.*)$", @@ -31,7 +31,9 @@ _CLAUDE_PLAN_RE = re.compile( r"^\s*(?:\.claude\s+plan\b|#claude\s+plan#?)(?P.*)$", re.IGNORECASE | re.DOTALL, ) -_CLAUDE_STATUS_RE = re.compile(r"^\s*(?:\.claude\s+status\b|#claude\s+status#?)\s*$", re.IGNORECASE) +_CLAUDE_STATUS_RE = re.compile( + r"^\s*(?:\.claude\s+status\b|#claude\s+status#?)\s*$", re.IGNORECASE +) _CLAUDE_APPROVE_DENY_RE = re.compile( r"^\s*(?:\.claude|#claude)\s+(?Papprove|deny)\s+(?P[A-Za-z0-9._:-]+)#?\s*$", re.IGNORECASE, @@ -83,7 +85,9 @@ def parse_claude_command(text: str) -> ClaudeParsedCommand: return ClaudeParsedCommand(command=None, body_text="", approval_key="") -def claude_trigger_matches(message_text: str, trigger_token: str, exact_match_only: bool) -> bool: +def claude_trigger_matches( + message_text: str, trigger_token: str, exact_match_only: bool +) -> bool: body = str(message_text or "").strip() parsed = parse_claude_command(body) if parsed.command: @@ -103,7 +107,9 @@ class ClaudeCommandHandler(CommandHandler): async def _load_trigger(self, message_id: str) -> Message | None: return await sync_to_async( - lambda: Message.objects.select_related("user", "session", "session__identifier", "reply_to") + lambda: Message.objects.select_related( + "user", "session", "session__identifier", "reply_to" + ) .filter(id=message_id) .first() )() @@ -114,11 +120,18 @@ class ClaudeCommandHandler(CommandHandler): identifier = getattr(getattr(trigger, "session", None), "identifier", None) fallback_service = str(getattr(identifier, "service", "") or "").strip().lower() fallback_identifier = str(getattr(identifier, "identifier", "") or "").strip() - if service == "web" and fallback_service and fallback_identifier and fallback_service != "web": + if ( + service == "web" + and fallback_service + and fallback_identifier + and fallback_service != "web" + ): return fallback_service, fallback_identifier return service or "web", channel - async def _mapped_sources(self, user, service: str, channel: str) -> list[ChatTaskSource]: + async def _mapped_sources( + self, user, service: str, channel: str + ) -> list[ChatTaskSource]: variants = channel_variants(service, channel) if not variants: return [] @@ -131,7 +144,9 @@ class ClaudeCommandHandler(CommandHandler): ).select_related("project", "epic") ) - async def _linked_task_from_reply(self, user, reply_to: Message | None) -> DerivedTask | None: + async def _linked_task_from_reply( + self, user, reply_to: Message | None + ) -> DerivedTask | None: if reply_to is None: return None by_origin = await sync_to_async( @@ -143,7 +158,9 @@ class ClaudeCommandHandler(CommandHandler): if by_origin is not None: return by_origin return await sync_to_async( - lambda: DerivedTask.objects.filter(user=user, events__source_message=reply_to) + lambda: DerivedTask.objects.filter( + user=user, events__source_message=reply_to + ) .select_related("project", "epic") .order_by("-created_at") .first() @@ -164,10 +181,14 @@ class ClaudeCommandHandler(CommandHandler): return "" return str(m.group(1) or "").strip() - async def _resolve_task(self, user, reference_code: str, reply_task: DerivedTask | None) -> DerivedTask | None: + async def _resolve_task( + self, user, reference_code: str, reply_task: DerivedTask | None + ) -> DerivedTask | None: if reference_code: return await sync_to_async( - lambda: DerivedTask.objects.filter(user=user, reference_code=reference_code) + lambda: DerivedTask.objects.filter( + user=user, reference_code=reference_code + ) .select_related("project", "epic") .order_by("-created_at") .first() @@ -190,7 +211,9 @@ class ClaudeCommandHandler(CommandHandler): return reply_task.project, "" if project_token: project = await sync_to_async( - lambda: TaskProject.objects.filter(user=user, name__iexact=project_token).first() + lambda: TaskProject.objects.filter( + user=user, name__iexact=project_token + ).first() )() if project is not None: return project, "" @@ -199,20 +222,31 @@ class ClaudeCommandHandler(CommandHandler): mapped = await self._mapped_sources(user, service, channel) project_ids = sorted({str(row.project_id) for row in mapped if row.project_id}) if len(project_ids) == 1: - project = next((row.project for row in mapped if str(row.project_id) == project_ids[0]), None) + project = next( + ( + row.project + for row in mapped + if str(row.project_id) == project_ids[0] + ), + None, + ) return project, "" if len(project_ids) > 1: return None, "project_required:[project:Name]" return None, "project_unresolved" - async def _post_source_status(self, trigger: Message, text: str, suffix: str) -> None: + async def _post_source_status( + self, trigger: Message, text: str, suffix: str + ) -> None: await post_status_in_source( trigger_message=trigger, text=text, origin_tag=f"claude-status:{suffix}", ) - async def _run_status(self, trigger: Message, service: str, channel: str, project: TaskProject | None) -> CommandResult: + async def _run_status( + self, trigger: Message, service: str, channel: str, project: TaskProject | None + ) -> CommandResult: def _load_runs(): qs = CodexRun.objects.filter(user=trigger.user) if service: @@ -225,7 +259,9 @@ class ClaudeCommandHandler(CommandHandler): runs = await sync_to_async(_load_runs)() if not runs: - await self._post_source_status(trigger, "[claude] no recent runs for this scope.", "empty") + await self._post_source_status( + trigger, "[claude] no recent runs for this scope.", "empty" + ) return CommandResult(ok=True, status="ok", payload={"count": 0}) lines = ["[claude] recent runs:"] for row in runs: @@ -249,24 +285,38 @@ class ClaudeCommandHandler(CommandHandler): ).first() )() settings_payload = dict(getattr(cfg, "settings", {}) or {}) - approver_service = str(settings_payload.get("approver_service") or "").strip().lower() - approver_identifier = str(settings_payload.get("approver_identifier") or "").strip() + approver_service = ( + str(settings_payload.get("approver_service") or "").strip().lower() + ) + approver_identifier = str( + settings_payload.get("approver_identifier") or "" + ).strip() if not approver_service or not approver_identifier: - return CommandResult(ok=False, status="failed", error="approver_channel_not_configured") + return CommandResult( + ok=False, status="failed", error="approver_channel_not_configured" + ) - if str(current_service or "").strip().lower() != approver_service or str(current_channel or "").strip() not in set( - channel_variants(approver_service, approver_identifier) - ): - return CommandResult(ok=False, status="failed", error="approval_command_not_allowed_in_this_channel") + if str(current_service or "").strip().lower() != approver_service or str( + current_channel or "" + ).strip() not in set(channel_variants(approver_service, approver_identifier)): + return CommandResult( + ok=False, + status="failed", + error="approval_command_not_allowed_in_this_channel", + ) approval_key = parsed.approval_key request = await sync_to_async( - lambda: CodexPermissionRequest.objects.select_related("codex_run", "external_sync_event") + lambda: CodexPermissionRequest.objects.select_related( + "codex_run", "external_sync_event" + ) .filter(user=trigger.user, approval_key=approval_key) .first() )() if request is None: - return CommandResult(ok=False, status="failed", error="approval_key_not_found") + return CommandResult( + ok=False, status="failed", error="approval_key_not_found" + ) now = timezone.now() if parsed.command == "approve": @@ -283,14 +333,20 @@ class ClaudeCommandHandler(CommandHandler): ] ) if request.external_sync_event_id: - await sync_to_async(ExternalSyncEvent.objects.filter(id=request.external_sync_event_id).update)( + await sync_to_async( + ExternalSyncEvent.objects.filter( + id=request.external_sync_event_id + ).update + )( status="ok", error="", ) run = request.codex_run run.status = "approved_waiting_resume" run.error = "" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) source_service = str(run.source_service or "") source_channel = str(run.source_channel or "") resume_payload = dict(request.resume_payload or {}) @@ -302,14 +358,18 @@ class ClaudeCommandHandler(CommandHandler): provider_payload["source_service"] = source_service provider_payload["source_channel"] = source_channel event_action = resume_action - resume_idempotency_key = str(resume_payload.get("idempotency_key") or "").strip() + resume_idempotency_key = str( + resume_payload.get("idempotency_key") or "" + ).strip() resume_event_key = ( resume_idempotency_key if resume_idempotency_key else f"{self._approval_prefix}:{approval_key}:approved" ) else: - provider_payload = dict(run.request_payload.get("provider_payload") or {}) + provider_payload = dict( + run.request_payload.get("provider_payload") or {} + ) provider_payload.update( { "mode": "approval_response", @@ -337,17 +397,30 @@ class ClaudeCommandHandler(CommandHandler): "error": "", }, ) - return CommandResult(ok=True, status="ok", payload={"approval_key": approval_key, "resolution": "approved"}) + return CommandResult( + ok=True, + status="ok", + payload={"approval_key": approval_key, "resolution": "approved"}, + ) request.status = "denied" request.resolved_at = now request.resolved_by_identifier = current_channel request.resolution_note = "denied via claude command" await sync_to_async(request.save)( - update_fields=["status", "resolved_at", "resolved_by_identifier", "resolution_note"] + update_fields=[ + "status", + "resolved_at", + "resolved_by_identifier", + "resolution_note", + ] ) if request.external_sync_event_id: - await sync_to_async(ExternalSyncEvent.objects.filter(id=request.external_sync_event_id).update)( + await sync_to_async( + ExternalSyncEvent.objects.filter( + id=request.external_sync_event_id + ).update + )( status="failed", error="approval_denied", ) @@ -374,7 +447,11 @@ class ClaudeCommandHandler(CommandHandler): "error": "approval_denied", }, ) - return CommandResult(ok=True, status="ok", payload={"approval_key": approval_key, "resolution": "denied"}) + return CommandResult( + ok=True, + status="ok", + payload={"approval_key": approval_key, "resolution": "denied"}, + ) async def _create_submission( self, @@ -391,7 +468,9 @@ class ClaudeCommandHandler(CommandHandler): ).first() )() if cfg is None: - return CommandResult(ok=False, status="failed", error="provider_disabled_or_missing") + return CommandResult( + ok=False, status="failed", error="provider_disabled_or_missing" + ) service, channel = self._effective_scope(trigger) external_chat_id = await sync_to_async(resolve_external_chat_id)( @@ -418,7 +497,9 @@ class ClaudeCommandHandler(CommandHandler): if mode == "plan": anchor = trigger.reply_to if anchor is None: - return CommandResult(ok=False, status="failed", error="reply_required_for_claude_plan") + return CommandResult( + ok=False, status="failed", error="reply_required_for_claude_plan" + ) rows = await sync_to_async(list)( Message.objects.filter( user=trigger.user, @@ -427,7 +508,9 @@ class ClaudeCommandHandler(CommandHandler): ts__lte=int(trigger.ts or 0), ) .order_by("ts") - .select_related("session", "session__identifier", "session__identifier__person") + .select_related( + "session", "session__identifier", "session__identifier__person" + ) ) payload["reply_context"] = { "anchor_message_id": str(anchor.id), @@ -446,12 +529,18 @@ class ClaudeCommandHandler(CommandHandler): source_channel=channel, external_chat_id=external_chat_id, status="waiting_approval", - request_payload={"action": "append_update", "provider_payload": dict(payload)}, + request_payload={ + "action": "append_update", + "provider_payload": dict(payload), + }, result_payload={}, error="", ) payload["codex_run_id"] = str(run.id) - run.request_payload = {"action": "append_update", "provider_payload": dict(payload)} + run.request_payload = { + "action": "append_update", + "provider_payload": dict(payload), + } await sync_to_async(run.save)(update_fields=["request_payload", "updated_at"]) idempotency_key = f"claude_cmd:{trigger.id}:{mode}:{task.id}:{hashlib.sha1(str(body_text or '').encode('utf-8')).hexdigest()[:12]}" @@ -476,20 +565,26 @@ class ClaudeCommandHandler(CommandHandler): return CommandResult(ok=False, status="failed", error="trigger_not_found") profile = await sync_to_async( - lambda: CommandProfile.objects.filter(user=trigger.user, slug=self.slug, enabled=True).first() + lambda: CommandProfile.objects.filter( + user=trigger.user, slug=self.slug, enabled=True + ).first() )() if profile is None: return CommandResult(ok=False, status="skipped", error="profile_missing") parsed = parse_claude_command(ctx.message_text) if not parsed.command: - return CommandResult(ok=False, status="skipped", error="claude_command_not_matched") + return CommandResult( + ok=False, status="skipped", error="claude_command_not_matched" + ) service, channel = self._effective_scope(trigger) if parsed.command == "status": project = None - reply_task = await self._linked_task_from_reply(trigger.user, trigger.reply_to) + reply_task = await self._linked_task_from_reply( + trigger.user, trigger.reply_to + ) if reply_task is not None: project = reply_task.project return await self._run_status(trigger, service, channel, project) @@ -507,7 +602,9 @@ class ClaudeCommandHandler(CommandHandler): reply_task = await self._linked_task_from_reply(trigger.user, trigger.reply_to) task = await self._resolve_task(trigger.user, reference_code, reply_task) if task is None: - return CommandResult(ok=False, status="failed", error="task_target_required") + return CommandResult( + ok=False, status="failed", error="task_target_required" + ) project, project_error = await self._resolve_project( user=trigger.user, @@ -518,7 +615,9 @@ class ClaudeCommandHandler(CommandHandler): project_token=project_token, ) if project is None: - return CommandResult(ok=False, status="failed", error=project_error or "project_unresolved") + return CommandResult( + ok=False, status="failed", error=project_error or "project_unresolved" + ) mode = "plan" if parsed.command == "plan" else "default" return await self._create_submission( diff --git a/core/commands/handlers/codex.py b/core/commands/handlers/codex.py index 56b21db..89b8746 100644 --- a/core/commands/handlers/codex.py +++ b/core/commands/handlers/codex.py @@ -20,8 +20,8 @@ from core.models import ( TaskProject, TaskProviderConfig, ) -from core.tasks.codex_support import channel_variants, resolve_external_chat_id from core.tasks.codex_approval import queue_codex_event_with_pre_approval +from core.tasks.codex_support import channel_variants, resolve_external_chat_id _CODEX_DEFAULT_RE = re.compile( r"^\s*(?:\.codex\b|#codex#?)(?P.*)$", @@ -31,7 +31,9 @@ _CODEX_PLAN_RE = re.compile( r"^\s*(?:\.codex\s+plan\b|#codex\s+plan#?)(?P.*)$", re.IGNORECASE | re.DOTALL, ) -_CODEX_STATUS_RE = re.compile(r"^\s*(?:\.codex\s+status\b|#codex\s+status#?)\s*$", re.IGNORECASE) +_CODEX_STATUS_RE = re.compile( + r"^\s*(?:\.codex\s+status\b|#codex\s+status#?)\s*$", re.IGNORECASE +) _CODEX_APPROVE_DENY_RE = re.compile( r"^\s*(?:\.codex|#codex)\s+(?Papprove|deny)\s+(?P[A-Za-z0-9._:-]+)#?\s*$", re.IGNORECASE, @@ -55,7 +57,6 @@ class CodexParsedCommand(dict): return str(self.get("approval_key") or "") - def parse_codex_command(text: str) -> CodexParsedCommand: body = str(text or "") m = _CODEX_APPROVE_DENY_RE.match(body) @@ -84,7 +85,9 @@ def parse_codex_command(text: str) -> CodexParsedCommand: return CodexParsedCommand(command=None, body_text="", approval_key="") -def codex_trigger_matches(message_text: str, trigger_token: str, exact_match_only: bool) -> bool: +def codex_trigger_matches( + message_text: str, trigger_token: str, exact_match_only: bool +) -> bool: body = str(message_text or "").strip() parsed = parse_codex_command(body) if parsed.command: @@ -102,7 +105,9 @@ class CodexCommandHandler(CommandHandler): async def _load_trigger(self, message_id: str) -> Message | None: return await sync_to_async( - lambda: Message.objects.select_related("user", "session", "session__identifier", "reply_to") + lambda: Message.objects.select_related( + "user", "session", "session__identifier", "reply_to" + ) .filter(id=message_id) .first() )() @@ -113,11 +118,18 @@ class CodexCommandHandler(CommandHandler): identifier = getattr(getattr(trigger, "session", None), "identifier", None) fallback_service = str(getattr(identifier, "service", "") or "").strip().lower() fallback_identifier = str(getattr(identifier, "identifier", "") or "").strip() - if service == "web" and fallback_service and fallback_identifier and fallback_service != "web": + if ( + service == "web" + and fallback_service + and fallback_identifier + and fallback_service != "web" + ): return fallback_service, fallback_identifier return service or "web", channel - async def _mapped_sources(self, user, service: str, channel: str) -> list[ChatTaskSource]: + async def _mapped_sources( + self, user, service: str, channel: str + ) -> list[ChatTaskSource]: variants = channel_variants(service, channel) if not variants: return [] @@ -130,7 +142,9 @@ class CodexCommandHandler(CommandHandler): ).select_related("project", "epic") ) - async def _linked_task_from_reply(self, user, reply_to: Message | None) -> DerivedTask | None: + async def _linked_task_from_reply( + self, user, reply_to: Message | None + ) -> DerivedTask | None: if reply_to is None: return None by_origin = await sync_to_async( @@ -142,7 +156,9 @@ class CodexCommandHandler(CommandHandler): if by_origin is not None: return by_origin return await sync_to_async( - lambda: DerivedTask.objects.filter(user=user, events__source_message=reply_to) + lambda: DerivedTask.objects.filter( + user=user, events__source_message=reply_to + ) .select_related("project", "epic") .order_by("-created_at") .first() @@ -163,10 +179,14 @@ class CodexCommandHandler(CommandHandler): return "" return str(m.group(1) or "").strip() - async def _resolve_task(self, user, reference_code: str, reply_task: DerivedTask | None) -> DerivedTask | None: + async def _resolve_task( + self, user, reference_code: str, reply_task: DerivedTask | None + ) -> DerivedTask | None: if reference_code: return await sync_to_async( - lambda: DerivedTask.objects.filter(user=user, reference_code=reference_code) + lambda: DerivedTask.objects.filter( + user=user, reference_code=reference_code + ) .select_related("project", "epic") .order_by("-created_at") .first() @@ -189,7 +209,9 @@ class CodexCommandHandler(CommandHandler): return reply_task.project, "" if project_token: project = await sync_to_async( - lambda: TaskProject.objects.filter(user=user, name__iexact=project_token).first() + lambda: TaskProject.objects.filter( + user=user, name__iexact=project_token + ).first() )() if project is not None: return project, "" @@ -198,20 +220,31 @@ class CodexCommandHandler(CommandHandler): mapped = await self._mapped_sources(user, service, channel) project_ids = sorted({str(row.project_id) for row in mapped if row.project_id}) if len(project_ids) == 1: - project = next((row.project for row in mapped if str(row.project_id) == project_ids[0]), None) + project = next( + ( + row.project + for row in mapped + if str(row.project_id) == project_ids[0] + ), + None, + ) return project, "" if len(project_ids) > 1: return None, "project_required:[project:Name]" return None, "project_unresolved" - async def _post_source_status(self, trigger: Message, text: str, suffix: str) -> None: + async def _post_source_status( + self, trigger: Message, text: str, suffix: str + ) -> None: await post_status_in_source( trigger_message=trigger, text=text, origin_tag=f"codex-status:{suffix}", ) - async def _run_status(self, trigger: Message, service: str, channel: str, project: TaskProject | None) -> CommandResult: + async def _run_status( + self, trigger: Message, service: str, channel: str, project: TaskProject | None + ) -> CommandResult: def _load_runs(): qs = CodexRun.objects.filter(user=trigger.user) if service: @@ -224,7 +257,9 @@ class CodexCommandHandler(CommandHandler): runs = await sync_to_async(_load_runs)() if not runs: - await self._post_source_status(trigger, "[codex] no recent runs for this scope.", "empty") + await self._post_source_status( + trigger, "[codex] no recent runs for this scope.", "empty" + ) return CommandResult(ok=True, status="ok", payload={"count": 0}) lines = ["[codex] recent runs:"] for row in runs: @@ -243,27 +278,43 @@ class CodexCommandHandler(CommandHandler): current_channel: str, ) -> CommandResult: cfg = await sync_to_async( - lambda: TaskProviderConfig.objects.filter(user=trigger.user, provider="codex_cli").first() + lambda: TaskProviderConfig.objects.filter( + user=trigger.user, provider="codex_cli" + ).first() )() settings_payload = dict(getattr(cfg, "settings", {}) or {}) - approver_service = str(settings_payload.get("approver_service") or "").strip().lower() - approver_identifier = str(settings_payload.get("approver_identifier") or "").strip() + approver_service = ( + str(settings_payload.get("approver_service") or "").strip().lower() + ) + approver_identifier = str( + settings_payload.get("approver_identifier") or "" + ).strip() if not approver_service or not approver_identifier: - return CommandResult(ok=False, status="failed", error="approver_channel_not_configured") + return CommandResult( + ok=False, status="failed", error="approver_channel_not_configured" + ) - if str(current_service or "").strip().lower() != approver_service or str(current_channel or "").strip() not in set( - channel_variants(approver_service, approver_identifier) - ): - return CommandResult(ok=False, status="failed", error="approval_command_not_allowed_in_this_channel") + if str(current_service or "").strip().lower() != approver_service or str( + current_channel or "" + ).strip() not in set(channel_variants(approver_service, approver_identifier)): + return CommandResult( + ok=False, + status="failed", + error="approval_command_not_allowed_in_this_channel", + ) approval_key = parsed.approval_key request = await sync_to_async( - lambda: CodexPermissionRequest.objects.select_related("codex_run", "external_sync_event") + lambda: CodexPermissionRequest.objects.select_related( + "codex_run", "external_sync_event" + ) .filter(user=trigger.user, approval_key=approval_key) .first() )() if request is None: - return CommandResult(ok=False, status="failed", error="approval_key_not_found") + return CommandResult( + ok=False, status="failed", error="approval_key_not_found" + ) now = timezone.now() if parsed.command == "approve": @@ -280,14 +331,20 @@ class CodexCommandHandler(CommandHandler): ] ) if request.external_sync_event_id: - await sync_to_async(ExternalSyncEvent.objects.filter(id=request.external_sync_event_id).update)( + await sync_to_async( + ExternalSyncEvent.objects.filter( + id=request.external_sync_event_id + ).update + )( status="ok", error="", ) run = request.codex_run run.status = "approved_waiting_resume" run.error = "" - await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) + await sync_to_async(run.save)( + update_fields=["status", "error", "updated_at"] + ) source_service = str(run.source_service or "") source_channel = str(run.source_channel or "") resume_payload = dict(request.resume_payload or {}) @@ -299,14 +356,18 @@ class CodexCommandHandler(CommandHandler): provider_payload["source_service"] = source_service provider_payload["source_channel"] = source_channel event_action = resume_action - resume_idempotency_key = str(resume_payload.get("idempotency_key") or "").strip() + resume_idempotency_key = str( + resume_payload.get("idempotency_key") or "" + ).strip() resume_event_key = ( resume_idempotency_key if resume_idempotency_key else f"codex_approval:{approval_key}:approved" ) else: - provider_payload = dict(run.request_payload.get("provider_payload") or {}) + provider_payload = dict( + run.request_payload.get("provider_payload") or {} + ) provider_payload.update( { "mode": "approval_response", @@ -334,17 +395,30 @@ class CodexCommandHandler(CommandHandler): "error": "", }, ) - return CommandResult(ok=True, status="ok", payload={"approval_key": approval_key, "resolution": "approved"}) + return CommandResult( + ok=True, + status="ok", + payload={"approval_key": approval_key, "resolution": "approved"}, + ) request.status = "denied" request.resolved_at = now request.resolved_by_identifier = current_channel request.resolution_note = "denied via command" await sync_to_async(request.save)( - update_fields=["status", "resolved_at", "resolved_by_identifier", "resolution_note"] + update_fields=[ + "status", + "resolved_at", + "resolved_by_identifier", + "resolution_note", + ] ) if request.external_sync_event_id: - await sync_to_async(ExternalSyncEvent.objects.filter(id=request.external_sync_event_id).update)( + await sync_to_async( + ExternalSyncEvent.objects.filter( + id=request.external_sync_event_id + ).update + )( status="failed", error="approval_denied", ) @@ -371,7 +445,11 @@ class CodexCommandHandler(CommandHandler): "error": "approval_denied", }, ) - return CommandResult(ok=True, status="ok", payload={"approval_key": approval_key, "resolution": "denied"}) + return CommandResult( + ok=True, + status="ok", + payload={"approval_key": approval_key, "resolution": "denied"}, + ) async def _create_submission( self, @@ -383,10 +461,14 @@ class CodexCommandHandler(CommandHandler): project: TaskProject, ) -> CommandResult: cfg = await sync_to_async( - lambda: TaskProviderConfig.objects.filter(user=trigger.user, provider="codex_cli", enabled=True).first() + lambda: TaskProviderConfig.objects.filter( + user=trigger.user, provider="codex_cli", enabled=True + ).first() )() if cfg is None: - return CommandResult(ok=False, status="failed", error="provider_disabled_or_missing") + return CommandResult( + ok=False, status="failed", error="provider_disabled_or_missing" + ) service, channel = self._effective_scope(trigger) external_chat_id = await sync_to_async(resolve_external_chat_id)( @@ -413,7 +495,9 @@ class CodexCommandHandler(CommandHandler): if mode == "plan": anchor = trigger.reply_to if anchor is None: - return CommandResult(ok=False, status="failed", error="reply_required_for_codex_plan") + return CommandResult( + ok=False, status="failed", error="reply_required_for_codex_plan" + ) rows = await sync_to_async(list)( Message.objects.filter( user=trigger.user, @@ -422,7 +506,9 @@ class CodexCommandHandler(CommandHandler): ts__lte=int(trigger.ts or 0), ) .order_by("ts") - .select_related("session", "session__identifier", "session__identifier__person") + .select_related( + "session", "session__identifier", "session__identifier__person" + ) ) payload["reply_context"] = { "anchor_message_id": str(anchor.id), @@ -441,12 +527,18 @@ class CodexCommandHandler(CommandHandler): source_channel=channel, external_chat_id=external_chat_id, status="waiting_approval", - request_payload={"action": "append_update", "provider_payload": dict(payload)}, + request_payload={ + "action": "append_update", + "provider_payload": dict(payload), + }, result_payload={}, error="", ) payload["codex_run_id"] = str(run.id) - run.request_payload = {"action": "append_update", "provider_payload": dict(payload)} + run.request_payload = { + "action": "append_update", + "provider_payload": dict(payload), + } await sync_to_async(run.save)(update_fields=["request_payload", "updated_at"]) idempotency_key = f"codex_cmd:{trigger.id}:{mode}:{task.id}:{hashlib.sha1(str(body_text or '').encode('utf-8')).hexdigest()[:12]}" @@ -471,20 +563,26 @@ class CodexCommandHandler(CommandHandler): return CommandResult(ok=False, status="failed", error="trigger_not_found") profile = await sync_to_async( - lambda: CommandProfile.objects.filter(user=trigger.user, slug=self.slug, enabled=True).first() + lambda: CommandProfile.objects.filter( + user=trigger.user, slug=self.slug, enabled=True + ).first() )() if profile is None: return CommandResult(ok=False, status="skipped", error="profile_missing") parsed = parse_codex_command(ctx.message_text) if not parsed.command: - return CommandResult(ok=False, status="skipped", error="codex_command_not_matched") + return CommandResult( + ok=False, status="skipped", error="codex_command_not_matched" + ) service, channel = self._effective_scope(trigger) if parsed.command == "status": project = None - reply_task = await self._linked_task_from_reply(trigger.user, trigger.reply_to) + reply_task = await self._linked_task_from_reply( + trigger.user, trigger.reply_to + ) if reply_task is not None: project = reply_task.project return await self._run_status(trigger, service, channel, project) @@ -502,7 +600,9 @@ class CodexCommandHandler(CommandHandler): reply_task = await self._linked_task_from_reply(trigger.user, trigger.reply_to) task = await self._resolve_task(trigger.user, reference_code, reply_task) if task is None: - return CommandResult(ok=False, status="failed", error="task_target_required") + return CommandResult( + ok=False, status="failed", error="task_target_required" + ) project, project_error = await self._resolve_project( user=trigger.user, @@ -513,7 +613,9 @@ class CodexCommandHandler(CommandHandler): project_token=project_token, ) if project is None: - return CommandResult(ok=False, status="failed", error=project_error or "project_unresolved") + return CommandResult( + ok=False, status="failed", error=project_error or "project_unresolved" + ) mode = "plan" if parsed.command == "plan" else "default" return await self._create_submission( diff --git a/core/commands/policies.py b/core/commands/policies.py index 6d9af9f..028c3cb 100644 --- a/core/commands/policies.py +++ b/core/commands/policies.py @@ -32,7 +32,8 @@ def _legacy_defaults(profile: CommandProfile, post_result_enabled: bool) -> dict "enabled": True, "generation_mode": "ai", "send_plan_to_egress": bool(post_result_enabled), - "send_status_to_source": str(profile.visibility_mode or "") == "status_in_source", + "send_status_to_source": str(profile.visibility_mode or "") + == "status_in_source", "send_status_to_egress": False, "store_document": True, } @@ -56,7 +57,9 @@ def ensure_variant_policies_for_profile( *, action_rows: Iterable[CommandAction] | None = None, ) -> dict[str, CommandVariantPolicy]: - actions = list(action_rows) if action_rows is not None else list(profile.actions.all()) + actions = ( + list(action_rows) if action_rows is not None else list(profile.actions.all()) + ) post_result_enabled = any( row.action_type == "post_result" and bool(row.enabled) for row in actions ) @@ -91,7 +94,9 @@ def ensure_variant_policies_for_profile( return result -def load_variant_policy(profile: CommandProfile, variant_key: str) -> CommandVariantPolicy | None: +def load_variant_policy( + profile: CommandProfile, variant_key: str +) -> CommandVariantPolicy | None: key = str(variant_key or "").strip() if not key: return None diff --git a/core/context_processors.py b/core/context_processors.py index 8871b7b..ce82229 100644 --- a/core/context_processors.py +++ b/core/context_processors.py @@ -27,6 +27,7 @@ def settings_hierarchy_nav(request): ai_models_href = reverse("ai_models") ai_traces_href = reverse("ai_execution_log") commands_href = reverse("command_routing") + business_plans_href = reverse("business_plan_inbox") tasks_href = reverse("tasks_settings") translation_href = reverse("translation_settings") availability_href = reverse("availability_settings") @@ -55,6 +56,8 @@ def settings_hierarchy_nav(request): modules_routes = { "modules_settings", "command_routing", + "business_plan_inbox", + "business_plan_editor", "tasks_settings", "translation_settings", "availability_settings", @@ -106,7 +109,12 @@ def settings_hierarchy_nav(request): "title": "Modules", "tabs": [ _tab("Commands", commands_href, path == commands_href), - _tab("Tasks", tasks_href, path == tasks_href), + _tab( + "Business Plans", + business_plans_href, + url_name in {"business_plan_inbox", "business_plan_editor"}, + ), + _tab("Task Automation", tasks_href, path == tasks_href), _tab("Translation", translation_href, path == translation_href), _tab("Availability", availability_href, path == availability_href), ], diff --git a/core/db/sql.py b/core/db/sql.py index 882e940..b9ee732 100644 --- a/core/db/sql.py +++ b/core/db/sql.py @@ -24,7 +24,6 @@ async def init_mysql_pool(): async def close_mysql_pool(): """Close the MySQL connection pool properly.""" - global mysql_pool if mysql_pool: mysql_pool.close() await mysql_pool.wait_closed() diff --git a/core/events/ledger.py b/core/events/ledger.py index 4253c3a..41957f5 100644 --- a/core/events/ledger.py +++ b/core/events/ledger.py @@ -15,8 +15,12 @@ def event_ledger_enabled() -> bool: def event_ledger_status() -> dict: return { - "event_ledger_dual_write": bool(getattr(settings, "EVENT_LEDGER_DUAL_WRITE", False)), - "event_primary_write_path": bool(getattr(settings, "EVENT_PRIMARY_WRITE_PATH", False)), + "event_ledger_dual_write": bool( + getattr(settings, "EVENT_LEDGER_DUAL_WRITE", False) + ), + "event_primary_write_path": bool( + getattr(settings, "EVENT_PRIMARY_WRITE_PATH", False) + ), } @@ -61,9 +65,7 @@ def append_event_sync( if not normalized_type: raise ValueError("event_type is required") - candidates = { - str(choice[0]) for choice in ConversationEvent.EVENT_TYPE_CHOICES - } + candidates = {str(choice[0]) for choice in ConversationEvent.EVENT_TYPE_CHOICES} if normalized_type not in candidates: raise ValueError(f"unsupported event_type: {normalized_type}") diff --git a/core/events/projection.py b/core/events/projection.py index 9ae951b..8426790 100644 --- a/core/events/projection.py +++ b/core/events/projection.py @@ -90,7 +90,9 @@ def project_session_from_events(session: ChatSession) -> list[dict]: order.append(message_id) state.ts = _safe_int(payload.get("message_ts"), _safe_int(event.ts)) state.text = str(payload.get("text") or state.text or "") - delivered_default = _safe_int(payload.get("delivered_ts"), _safe_int(event.ts)) + delivered_default = _safe_int( + payload.get("delivered_ts"), _safe_int(event.ts) + ) if state.delivered_ts is None: state.delivered_ts = delivered_default or None continue @@ -111,7 +113,11 @@ def project_session_from_events(session: ChatSession) -> list[dict]: continue if event_type in {"reaction_added", "reaction_removed"}: - source_service = str(payload.get("source_service") or event.origin_transport or "").strip().lower() + source_service = ( + str(payload.get("source_service") or event.origin_transport or "") + .strip() + .lower() + ) actor = str(payload.get("actor") or event.actor_identifier or "").strip() emoji = str(payload.get("emoji") or "").strip() if not source_service and not actor and not emoji: @@ -121,7 +127,9 @@ def project_session_from_events(session: ChatSession) -> list[dict]: "source_service": source_service, "actor": actor, "emoji": emoji, - "removed": bool(event_type == "reaction_removed" or payload.get("remove")), + "removed": bool( + event_type == "reaction_removed" or payload.get("remove") + ), } output = [] @@ -135,12 +143,12 @@ def project_session_from_events(session: ChatSession) -> list[dict]: "ts": int(state.ts or 0), "text": str(state.text or ""), "delivered_ts": ( - int(state.delivered_ts) - if state.delivered_ts is not None - else None + int(state.delivered_ts) if state.delivered_ts is not None else None ), "read_ts": int(state.read_ts) if state.read_ts is not None else None, - "reactions": _normalize_reactions(list((state.reactions or {}).values())), + "reactions": _normalize_reactions( + list((state.reactions or {}).values()) + ), } ) return output @@ -182,7 +190,9 @@ def shadow_compare_session(session: ChatSession, detail_limit: int = 50) -> dict cause_samples = {key: [] for key in cause_counts.keys()} cause_sample_limit = min(5, max(0, int(detail_limit))) - def _record_detail(message_id: str, issue: str, cause: str, extra: dict | None = None): + def _record_detail( + message_id: str, issue: str, cause: str, extra: dict | None = None + ): if cause in cause_counts: cause_counts[cause] += 1 row = {"message_id": message_id, "issue": issue, "cause": cause} @@ -224,13 +234,10 @@ def shadow_compare_session(session: ChatSession, detail_limit: int = 50) -> dict db_delivered_ts = db_row.get("delivered_ts") projected_delivered_ts = projected.get("delivered_ts") - if ( - (db_delivered_ts is None) != (projected_delivered_ts is None) - or ( - db_delivered_ts is not None - and projected_delivered_ts is not None - and int(db_delivered_ts) != int(projected_delivered_ts) - ) + if (db_delivered_ts is None) != (projected_delivered_ts is None) or ( + db_delivered_ts is not None + and projected_delivered_ts is not None + and int(db_delivered_ts) != int(projected_delivered_ts) ): counters["delivered_ts_mismatch"] += 1 _record_detail( @@ -245,13 +252,10 @@ def shadow_compare_session(session: ChatSession, detail_limit: int = 50) -> dict db_read_ts = db_row.get("read_ts") projected_read_ts = projected.get("read_ts") - if ( - (db_read_ts is None) != (projected_read_ts is None) - or ( - db_read_ts is not None - and projected_read_ts is not None - and int(db_read_ts) != int(projected_read_ts) - ) + if (db_read_ts is None) != (projected_read_ts is None) or ( + db_read_ts is not None + and projected_read_ts is not None + and int(db_read_ts) != int(projected_read_ts) ): counters["read_ts_mismatch"] += 1 _record_detail( @@ -264,12 +268,19 @@ def shadow_compare_session(session: ChatSession, detail_limit: int = 50) -> dict db_reactions = _normalize_reactions( list((db_row.get("receipt_payload") or {}).get("reactions") or []) ) - projected_reactions = _normalize_reactions(list(projected.get("reactions") or [])) + projected_reactions = _normalize_reactions( + list(projected.get("reactions") or []) + ) if db_reactions != projected_reactions: counters["reactions_mismatch"] += 1 cause = "payload_normalization_gap" strategy = str( - ((db_row.get("receipt_payload") or {}).get("reaction_last_match_strategy") or "") + ( + (db_row.get("receipt_payload") or {}).get( + "reaction_last_match_strategy" + ) + or "" + ) ).strip() if strategy == "nearest_ts_window": cause = "ambiguous_reaction_target" diff --git a/core/forms.py b/core/forms.py index 9fd0446..6a7dbc1 100644 --- a/core/forms.py +++ b/core/forms.py @@ -1,6 +1,7 @@ from django import forms from django.contrib.auth.forms import UserCreationForm from django.forms import ModelForm + from mixins.restrictions import RestrictedFormMixin from .models import ( diff --git a/core/gateway/commands.py b/core/gateway/commands.py index 56d69da..36ca660 100644 --- a/core/gateway/commands.py +++ b/core/gateway/commands.py @@ -8,7 +8,6 @@ from asgiref.sync import sync_to_async from core.models import GatewayCommandEvent from core.security.command_policy import CommandSecurityContext, evaluate_command_policy - GatewayEmit = Callable[[str], None] GatewayHandler = Callable[["GatewayCommandContext", GatewayEmit], Awaitable[bool]] GatewayMatcher = Callable[[str], bool] @@ -103,7 +102,10 @@ async def dispatch_gateway_command( emit(message) event.status = "blocked" event.error = f"{decision.code}:{decision.reason}" - event.response_meta = {"policy_code": decision.code, "policy_reason": decision.reason} + event.response_meta = { + "policy_code": decision.code, + "policy_reason": decision.reason, + } await sync_to_async(event.save)( update_fields=["status", "error", "response_meta", "updated_at"] ) @@ -129,5 +131,7 @@ async def dispatch_gateway_command( event.status = "ok" if handled else "ignored" event.response_meta = {"responses": responses} - await sync_to_async(event.save)(update_fields=["status", "response_meta", "updated_at"]) + await sync_to_async(event.save)( + update_fields=["status", "response_meta", "updated_at"] + ) return bool(handled) diff --git a/core/management/commands/backfill_contact_availability.py b/core/management/commands/backfill_contact_availability.py index b0d2b25..29b72e2 100644 --- a/core/management/commands/backfill_contact_availability.py +++ b/core/management/commands/backfill_contact_availability.py @@ -4,7 +4,7 @@ from typing import Iterable from django.core.management.base import BaseCommand -from core.models import Message, User +from core.models import Message from core.presence import AvailabilitySignal, record_inferred_signal from core.presence.inference import now_ms @@ -19,7 +19,9 @@ class Command(BaseCommand): parser.add_argument("--user-id", default="") parser.add_argument("--dry-run", action="store_true", default=False) - def _iter_messages(self, *, days: int, limit: int, service: str, user_id: str) -> Iterable[Message]: + def _iter_messages( + self, *, days: int, limit: int, service: str, user_id: str + ) -> Iterable[Message]: cutoff_ts = now_ms() - (max(1, int(days)) * 24 * 60 * 60 * 1000) qs = Message.objects.filter(ts__gte=cutoff_ts).select_related( "user", "session", "session__identifier", "session__identifier__person" @@ -40,7 +42,9 @@ class Command(BaseCommand): created = 0 scanned = 0 - for msg in self._iter_messages(days=days, limit=limit, service=service_filter, user_id=user_filter): + for msg in self._iter_messages( + days=days, limit=limit, service=service_filter, user_id=user_filter + ): scanned += 1 identifier = getattr(getattr(msg, "session", None), "identifier", None) person = getattr(identifier, "person", None) @@ -48,12 +52,18 @@ class Command(BaseCommand): if not identifier or not person or not user: continue - service = str(getattr(msg, "source_service", "") or identifier.service or "").strip().lower() + service = ( + str(getattr(msg, "source_service", "") or identifier.service or "") + .strip() + .lower() + ) if not service: continue base_ts = int(getattr(msg, "ts", 0) or 0) - message_author = str(getattr(msg, "custom_author", "") or "").strip().upper() + message_author = ( + str(getattr(msg, "custom_author", "") or "").strip().upper() + ) outgoing = message_author in {"USER", "BOT"} candidates = [] @@ -84,7 +94,9 @@ class Command(BaseCommand): "origin": "backfill_contact_availability", "message_id": str(msg.id), "inferred_from": "read_receipt", - "read_by": str(getattr(msg, "read_by_identifier", "") or ""), + "read_by": str( + getattr(msg, "read_by_identifier", "") or "" + ), }, } ) diff --git a/core/management/commands/codex_worker.py b/core/management/commands/codex_worker.py index 2eebe19..0ffa48c 100644 --- a/core/management/commands/codex_worker.py +++ b/core/management/commands/codex_worker.py @@ -7,7 +7,12 @@ from asgiref.sync import async_to_sync from django.core.management.base import BaseCommand from core.clients.transport import send_message_raw -from core.models import CodexPermissionRequest, CodexRun, ExternalSyncEvent, TaskProviderConfig +from core.models import ( + CodexPermissionRequest, + CodexRun, + ExternalSyncEvent, + TaskProviderConfig, +) from core.tasks.providers import get_provider from core.util import logs @@ -15,7 +20,9 @@ log = logs.get_logger("codex_worker") class Command(BaseCommand): - help = "Process queued external sync events for worker-backed providers (codex_cli)." + help = ( + "Process queued external sync events for worker-backed providers (codex_cli)." + ) def add_arguments(self, parser): parser.add_argument("--once", action="store_true", default=False) @@ -73,7 +80,9 @@ class Command(BaseCommand): payload = dict(event.payload or {}) action = str(payload.get("action") or "append_update").strip().lower() provider_payload = dict(payload.get("provider_payload") or payload) - run_id = str(provider_payload.get("codex_run_id") or payload.get("codex_run_id") or "").strip() + run_id = str( + provider_payload.get("codex_run_id") or payload.get("codex_run_id") or "" + ).strip() codex_run = None if run_id: codex_run = CodexRun.objects.filter(id=run_id, user=event.user).first() @@ -104,9 +113,13 @@ class Command(BaseCommand): result_payload = dict(result.payload or {}) requires_approval = bool(result_payload.get("requires_approval")) if requires_approval: - approval_key = str(result_payload.get("approval_key") or uuid.uuid4().hex[:12]).strip() + approval_key = str( + result_payload.get("approval_key") or uuid.uuid4().hex[:12] + ).strip() permission_request = dict(result_payload.get("permission_request") or {}) - summary = str(result_payload.get("summary") or permission_request.get("summary") or "").strip() + summary = str( + result_payload.get("summary") or permission_request.get("summary") or "" + ).strip() requested_permissions = permission_request.get("requested_permissions") if not isinstance(requested_permissions, (list, dict)): requested_permissions = permission_request or {} @@ -121,28 +134,42 @@ class Command(BaseCommand): codex_run.status = "waiting_approval" codex_run.result_payload = dict(result_payload) codex_run.error = "" - codex_run.save(update_fields=["status", "result_payload", "error", "updated_at"]) + codex_run.save( + update_fields=["status", "result_payload", "error", "updated_at"] + ) CodexPermissionRequest.objects.update_or_create( approval_key=approval_key, defaults={ "user": event.user, - "codex_run": codex_run if codex_run is not None else CodexRun.objects.create( - user=event.user, - task=event.task, - derived_task_event=event.task_event, - source_service=str(provider_payload.get("source_service") or ""), - source_channel=str(provider_payload.get("source_channel") or ""), - external_chat_id=str(provider_payload.get("external_chat_id") or ""), - status="waiting_approval", - request_payload=dict(payload or {}), - result_payload=dict(result_payload), - error="", + "codex_run": ( + codex_run + if codex_run is not None + else CodexRun.objects.create( + user=event.user, + task=event.task, + derived_task_event=event.task_event, + source_service=str( + provider_payload.get("source_service") or "" + ), + source_channel=str( + provider_payload.get("source_channel") or "" + ), + external_chat_id=str( + provider_payload.get("external_chat_id") or "" + ), + status="waiting_approval", + request_payload=dict(payload or {}), + result_payload=dict(result_payload), + error="", + ) ), "external_sync_event": event, "summary": summary, - "requested_permissions": requested_permissions if isinstance(requested_permissions, dict) else { - "items": list(requested_permissions or []) - }, + "requested_permissions": ( + requested_permissions + if isinstance(requested_permissions, dict) + else {"items": list(requested_permissions or [])} + ), "resume_payload": dict(resume_payload or {}), "status": "pending", "resolved_at": None, @@ -150,9 +177,17 @@ class Command(BaseCommand): "resolution_note": "", }, ) - approver_service = str((cfg.settings or {}).get("approver_service") or "").strip().lower() - approver_identifier = str((cfg.settings or {}).get("approver_identifier") or "").strip() - requested_text = result_payload.get("permission_request") or result_payload.get("requested_permissions") or {} + approver_service = ( + str((cfg.settings or {}).get("approver_service") or "").strip().lower() + ) + approver_identifier = str( + (cfg.settings or {}).get("approver_identifier") or "" + ).strip() + requested_text = ( + result_payload.get("permission_request") + or result_payload.get("requested_permissions") + or {} + ) if approver_service and approver_identifier: try: async_to_sync(send_message_raw)( @@ -168,10 +203,17 @@ class Command(BaseCommand): metadata={"origin_tag": f"codex-approval:{approval_key}"}, ) except Exception: - log.exception("failed to notify approver channel for approval_key=%s", approval_key) + log.exception( + "failed to notify approver channel for approval_key=%s", + approval_key, + ) else: - source_service = str(provider_payload.get("source_service") or "").strip().lower() - source_channel = str(provider_payload.get("source_channel") or "").strip() + source_service = ( + str(provider_payload.get("source_service") or "").strip().lower() + ) + source_channel = str( + provider_payload.get("source_channel") or "" + ).strip() if source_service and source_channel: try: async_to_sync(send_message_raw)( @@ -185,7 +227,9 @@ class Command(BaseCommand): metadata={"origin_tag": "codex-approval-missing-target"}, ) except Exception: - log.exception("failed to notify source channel for missing approver target") + log.exception( + "failed to notify source channel for missing approver target" + ) return event.status = "ok" if result.ok else "failed" @@ -201,18 +245,24 @@ class Command(BaseCommand): approval_key = str(provider_payload.get("approval_key") or "").strip() if mode == "approval_response" and approval_key: req = ( - CodexPermissionRequest.objects.select_related("external_sync_event", "codex_run") + CodexPermissionRequest.objects.select_related( + "external_sync_event", "codex_run" + ) .filter(user=event.user, approval_key=approval_key) .first() ) if req and req.external_sync_event_id: if result.ok: - ExternalSyncEvent.objects.filter(id=req.external_sync_event_id).update( + ExternalSyncEvent.objects.filter( + id=req.external_sync_event_id + ).update( status="ok", error="", ) elif str(event.error or "").strip() == "approval_denied": - ExternalSyncEvent.objects.filter(id=req.external_sync_event_id).update( + ExternalSyncEvent.objects.filter( + id=req.external_sync_event_id + ).update( status="failed", error="approval_denied", ) @@ -220,9 +270,16 @@ class Command(BaseCommand): codex_run.status = "ok" if result.ok else "failed" codex_run.error = str(result.error or "") codex_run.result_payload = result_payload - codex_run.save(update_fields=["status", "error", "result_payload", "updated_at"]) + codex_run.save( + update_fields=["status", "error", "result_payload", "updated_at"] + ) - if result.ok and result.external_key and event.task_id and not str(event.task.external_key or "").strip(): + if ( + result.ok + and result.external_key + and event.task_id + and not str(event.task.external_key or "").strip() + ): event.task.external_key = str(result.external_key) event.task.save(update_fields=["external_key"]) @@ -250,7 +307,11 @@ class Command(BaseCommand): continue for row_id in claimed_ids: - event = ExternalSyncEvent.objects.filter(id=row_id).select_related("task", "user").first() + event = ( + ExternalSyncEvent.objects.filter(id=row_id) + .select_related("task", "user") + .first() + ) if event is None: continue try: diff --git a/core/management/commands/event_projection_shadow.py b/core/management/commands/event_projection_shadow.py index 3be8861..1a863d8 100644 --- a/core/management/commands/event_projection_shadow.py +++ b/core/management/commands/event_projection_shadow.py @@ -85,7 +85,9 @@ class Command(BaseCommand): compared = shadow_compare_session(session, detail_limit=detail_limit) aggregate["sessions_scanned"] += 1 aggregate["db_message_count"] += int(compared.get("db_message_count") or 0) - aggregate["projected_message_count"] += int(compared.get("projected_message_count") or 0) + aggregate["projected_message_count"] += int( + compared.get("projected_message_count") or 0 + ) aggregate["mismatch_total"] += int(compared.get("mismatch_total") or 0) for key in aggregate["counters"].keys(): aggregate["counters"][key] += int( diff --git a/core/management/commands/recalculate_contact_availability.py b/core/management/commands/recalculate_contact_availability.py index 03d358f..1c2e667 100644 --- a/core/management/commands/recalculate_contact_availability.py +++ b/core/management/commands/recalculate_contact_availability.py @@ -1,14 +1,11 @@ from __future__ import annotations -from collections import defaultdict - from django.core.management.base import BaseCommand from core.models import ContactAvailabilityEvent, ContactAvailabilitySpan, Message from core.presence import AvailabilitySignal, record_native_signal from core.presence.inference import now_ms - _SOURCE_ORDER = { "message_in": 10, "message_out": 20, @@ -51,9 +48,14 @@ class Command(BaseCommand): if not identifier or not person or not user: continue - service = str( - getattr(msg, "source_service", "") or getattr(identifier, "service", "") - ).strip().lower() + service = ( + str( + getattr(msg, "source_service", "") + or getattr(identifier, "service", "") + ) + .strip() + .lower() + ) if not service: continue @@ -95,12 +97,16 @@ class Command(BaseCommand): "origin": "recalculate_contact_availability", "message_id": str(msg.id), "inferred_from": "read_receipt", - "read_by": str(getattr(msg, "read_by_identifier", "") or ""), + "read_by": str( + getattr(msg, "read_by_identifier", "") or "" + ), }, } ) - reactions = list((getattr(msg, "receipt_payload", {}) or {}).get("reactions") or []) + reactions = list( + (getattr(msg, "receipt_payload", {}) or {}).get("reactions") or [] + ) for reaction in reactions: item = dict(reaction or {}) if bool(item.get("removed")): @@ -124,7 +130,9 @@ class Command(BaseCommand): "inferred_from": "reaction", "emoji": str(item.get("emoji") or ""), "actor": str(item.get("actor") or ""), - "source_service": str(item.get("source_service") or service), + "source_service": str( + item.get("source_service") or service + ), }, } ) diff --git a/core/management/commands/reconcile_workspace_metric_history.py b/core/management/commands/reconcile_workspace_metric_history.py index 45227fa..37540c9 100644 --- a/core/management/commands/reconcile_workspace_metric_history.py +++ b/core/management/commands/reconcile_workspace_metric_history.py @@ -67,7 +67,9 @@ def _compute_payload(rows, identifier_values): pending_out_ts = None first_ts = int(rows[0]["ts"] or 0) last_ts = int(rows[-1]["ts"] or 0) - latest_service = str(rows[-1].get("session__identifier__service") or "").strip().lower() + latest_service = ( + str(rows[-1].get("session__identifier__service") or "").strip().lower() + ) for row in rows: ts = int(row.get("ts") or 0) @@ -162,18 +164,18 @@ def _compute_payload(rows, identifier_values): payload = { "source_event_ts": last_ts, "stability_state": stability_state, - "stability_score": float(stability_score_value) - if stability_score_value is not None - else None, + "stability_score": ( + float(stability_score_value) if stability_score_value is not None else None + ), "stability_confidence": round(confidence, 3), "stability_sample_messages": message_count, "stability_sample_days": sample_days, - "commitment_inbound_score": float(commitment_in_value) - if commitment_in_value is not None - else None, - "commitment_outbound_score": float(commitment_out_value) - if commitment_out_value is not None - else None, + "commitment_inbound_score": ( + float(commitment_in_value) if commitment_in_value is not None else None + ), + "commitment_outbound_score": ( + float(commitment_out_value) if commitment_out_value is not None else None + ), "commitment_confidence": round(confidence, 3), "inbound_messages": inbound_count, "outbound_messages": outbound_count, @@ -232,15 +234,17 @@ class Command(BaseCommand): dry_run = bool(options.get("dry_run")) reset = not bool(options.get("no_reset")) compact_enabled = not bool(options.get("skip_compact")) - today_start = dj_timezone.now().astimezone(timezone.utc).replace( - hour=0, - minute=0, - second=0, - microsecond=0, - ) - cutoff_ts = int( - (today_start.timestamp() * 1000) - (days * 24 * 60 * 60 * 1000) + today_start = ( + dj_timezone.now() + .astimezone(timezone.utc) + .replace( + hour=0, + minute=0, + second=0, + microsecond=0, + ) ) + cutoff_ts = int((today_start.timestamp() * 1000) - (days * 24 * 60 * 60 * 1000)) people_qs = Person.objects.all() if user_id: @@ -256,14 +260,18 @@ class Command(BaseCommand): compacted_deleted = 0 for person in people: - identifiers_qs = PersonIdentifier.objects.filter(user=person.user, person=person) + identifiers_qs = PersonIdentifier.objects.filter( + user=person.user, person=person + ) if service: identifiers_qs = identifiers_qs.filter(service=service) identifiers = list(identifiers_qs) if not identifiers: continue identifier_values = { - str(row.identifier or "").strip() for row in identifiers if row.identifier + str(row.identifier or "").strip() + for row in identifiers + if row.identifier } if not identifier_values: continue @@ -350,7 +358,9 @@ class Command(BaseCommand): snapshots_created += 1 if dry_run: continue - WorkspaceMetricSnapshot.objects.create(conversation=conversation, **payload) + WorkspaceMetricSnapshot.objects.create( + conversation=conversation, **payload + ) existing_signatures.add(signature) if not latest_payload: @@ -368,7 +378,9 @@ class Command(BaseCommand): "updated_at": dj_timezone.now().isoformat(), } if not dry_run: - conversation.platform_type = latest_service or conversation.platform_type + conversation.platform_type = ( + latest_service or conversation.platform_type + ) conversation.last_event_ts = latest_payload.get("source_event_ts") conversation.stability_state = str( latest_payload.get("stability_state") @@ -416,7 +428,9 @@ class Command(BaseCommand): ) if compact_enabled: snapshot_rows = list( - WorkspaceMetricSnapshot.objects.filter(conversation=conversation) + WorkspaceMetricSnapshot.objects.filter( + conversation=conversation + ) .order_by("computed_at", "id") .values("id", "computed_at", "source_event_ts") ) @@ -428,7 +442,9 @@ class Command(BaseCommand): ) if keep_ids: compacted_deleted += ( - WorkspaceMetricSnapshot.objects.filter(conversation=conversation) + WorkspaceMetricSnapshot.objects.filter( + conversation=conversation + ) .exclude(id__in=list(keep_ids)) .delete()[0] ) diff --git a/core/management/commands/task_sync_worker.py b/core/management/commands/task_sync_worker.py index df48011..dcfe443 100644 --- a/core/management/commands/task_sync_worker.py +++ b/core/management/commands/task_sync_worker.py @@ -4,4 +4,6 @@ from core.management.commands.codex_worker import Command as LegacyCodexWorkerCo class Command(LegacyCodexWorkerCommand): - help = "Process queued task-sync events for worker-backed providers (Codex + Claude)." + help = ( + "Process queued task-sync events for worker-backed providers (Codex + Claude)." + ) diff --git a/core/mcp/server.py b/core/mcp/server.py index 903198b..dc92c05 100644 --- a/core/mcp/server.py +++ b/core/mcp/server.py @@ -123,7 +123,9 @@ def _handle_message(message: dict[str, Any]) -> dict[str, Any] | None: msg_id, { "isError": True, - "content": [{"type": "text", "text": json.dumps({"error": str(exc)})}], + "content": [ + {"type": "text", "text": json.dumps({"error": str(exc)})} + ], }, ) diff --git a/core/mcp/tools.py b/core/mcp/tools.py index 54973cf..317a58a 100644 --- a/core/mcp/tools.py +++ b/core/mcp/tools.py @@ -216,7 +216,9 @@ def _next_unique_slug(*, user_id: int, requested_slug: str) -> str: raise ValueError("slug cannot be empty") candidate = base idx = 2 - while KnowledgeArticle.objects.filter(user_id=int(user_id), slug=candidate).exists(): + while KnowledgeArticle.objects.filter( + user_id=int(user_id), slug=candidate + ).exists(): suffix = f"-{idx}" candidate = f"{base[: max(1, 255 - len(suffix))]}{suffix}" idx += 1 @@ -645,9 +647,7 @@ def tool_wiki_update_article(arguments: dict[str, Any]) -> dict[str, Any]: ) if status_marker and status == "archived" and article.status != "archived": if not approve_archive: - raise ValueError( - "approve_archive=true is required to archive an article" - ) + raise ValueError("approve_archive=true is required to archive an article") if title: article.title = title @@ -705,7 +705,9 @@ def tool_wiki_list(arguments: dict[str, Any]) -> dict[str, Any]: def tool_wiki_get(arguments: dict[str, Any]) -> dict[str, Any]: article = _get_article_for_user(arguments) include_revisions = bool(arguments.get("include_revisions")) - revision_limit = _safe_limit(arguments.get("revision_limit"), default=20, low=1, high=200) + revision_limit = _safe_limit( + arguments.get("revision_limit"), default=20, low=1, high=200 + ) payload = {"article": _article_payload(article)} if include_revisions: revisions = article.revisions.order_by("-revision")[:revision_limit] @@ -714,7 +716,9 @@ def tool_wiki_get(arguments: dict[str, Any]) -> dict[str, Any]: def tool_project_get_guidelines(arguments: dict[str, Any]) -> dict[str, Any]: - max_chars = _safe_limit(arguments.get("max_chars"), default=16000, low=500, high=50000) + max_chars = _safe_limit( + arguments.get("max_chars"), default=16000, low=500, high=50000 + ) base = Path(settings.BASE_DIR).resolve() file_names = ["AGENTS.md", "LLM_CODING_STANDARDS.md", "INSTALL.md", "README.md"] payload = [] @@ -734,7 +738,9 @@ def tool_project_get_guidelines(arguments: dict[str, Any]) -> dict[str, Any]: def tool_project_get_layout(arguments: dict[str, Any]) -> dict[str, Any]: - max_entries = _safe_limit(arguments.get("max_entries"), default=300, low=50, high=4000) + max_entries = _safe_limit( + arguments.get("max_entries"), default=300, low=50, high=4000 + ) base = Path(settings.BASE_DIR).resolve() roots = ["app", "core", "scripts", "utilities", "artifacts"] items: list[str] = [] @@ -754,7 +760,9 @@ def tool_project_get_layout(arguments: dict[str, Any]) -> dict[str, Any]: def tool_project_get_runbook(arguments: dict[str, Any]) -> dict[str, Any]: - max_chars = _safe_limit(arguments.get("max_chars"), default=16000, low=500, high=50000) + max_chars = _safe_limit( + arguments.get("max_chars"), default=16000, low=500, high=50000 + ) base = Path(settings.BASE_DIR).resolve() file_names = [ "INSTALL.md", @@ -792,7 +800,11 @@ def tool_docs_append_run_note(arguments: dict[str, Any]) -> dict[str, Any]: path = Path("/tmp/gia-mcp-run-notes.md") else: candidate = Path(raw_path) - path = candidate.resolve() if candidate.is_absolute() else (base / candidate).resolve() + path = ( + candidate.resolve() + if candidate.is_absolute() + else (base / candidate).resolve() + ) allowed_roots = [base, Path("/tmp").resolve()] if not any(str(path).startswith(str(root)) for root in allowed_roots): raise ValueError("path must be within project root or /tmp") @@ -812,7 +824,11 @@ def tool_docs_append_run_note(arguments: dict[str, Any]) -> dict[str, Any]: TOOL_DEFS: dict[str, dict[str, Any]] = { "manticore.status": { "description": "Report configured memory backend status (django or manticore).", - "inputSchema": {"type": "object", "properties": {}, "additionalProperties": False}, + "inputSchema": { + "type": "object", + "properties": {}, + "additionalProperties": False, + }, "handler": tool_manticore_status, }, "manticore.query": { diff --git a/core/memory/__init__.py b/core/memory/__init__.py index e4f2781..4897c75 100644 --- a/core/memory/__init__.py +++ b/core/memory/__init__.py @@ -1,4 +1,4 @@ -from .search_backend import get_memory_search_backend from .retrieval import retrieve_memories_for_prompt +from .search_backend import get_memory_search_backend __all__ = ["get_memory_search_backend", "retrieve_memories_for_prompt"] diff --git a/core/memory/pipeline.py b/core/memory/pipeline.py index b0b2c3b..2624aba 100644 --- a/core/memory/pipeline.py +++ b/core/memory/pipeline.py @@ -224,7 +224,9 @@ def create_memory_change_request( person_id=person_id or (str(memory.person_id or "") if memory else "") or None, action=normalized_action, status="pending", - proposed_memory_kind=str(memory_kind or (memory.memory_kind if memory else "")).strip(), + proposed_memory_kind=str( + memory_kind or (memory.memory_kind if memory else "") + ).strip(), proposed_content=dict(content or {}), proposed_confidence_score=( float(confidence_score) @@ -335,7 +337,9 @@ def review_memory_change_request( @transaction.atomic -def run_memory_hygiene(*, user_id: int | None = None, dry_run: bool = False) -> dict[str, int]: +def run_memory_hygiene( + *, user_id: int | None = None, dry_run: bool = False +) -> dict[str, int]: now = timezone.now() queryset = MemoryItem.objects.filter(status="active") if user_id is not None: @@ -357,7 +361,9 @@ def run_memory_hygiene(*, user_id: int | None = None, dry_run: bool = False) -> for item in queryset.select_related("conversation", "person"): content = item.content or {} field = str(content.get("field") or content.get("key") or "").strip().lower() - text = _clean_value(str(content.get("text") or content.get("value") or "")).lower() + text = _clean_value( + str(content.get("text") or content.get("value") or "") + ).lower() if not field or not text: continue scope = ( diff --git a/core/memory/retrieval.py b/core/memory/retrieval.py index 8f19d86..31e192d 100644 --- a/core/memory/retrieval.py +++ b/core/memory/retrieval.py @@ -59,7 +59,11 @@ def retrieve_memories_for_prompt( limit=safe_limit, include_statuses=statuses, ) - ids = [str(hit.memory_id or "").strip() for hit in hits if str(hit.memory_id or "").strip()] + ids = [ + str(hit.memory_id or "").strip() + for hit in hits + if str(hit.memory_id or "").strip() + ] scoped = _base_queryset( user_id=int(user_id), person_id=person_id, @@ -82,11 +86,17 @@ def retrieve_memories_for_prompt( "content": item.content or {}, "provenance": item.provenance or {}, "confidence_score": float(item.confidence_score or 0.0), - "expires_at": item.expires_at.isoformat() if item.expires_at else "", - "last_verified_at": ( - item.last_verified_at.isoformat() if item.last_verified_at else "" + "expires_at": ( + item.expires_at.isoformat() if item.expires_at else "" + ), + "last_verified_at": ( + item.last_verified_at.isoformat() + if item.last_verified_at + else "" + ), + "updated_at": ( + item.updated_at.isoformat() if item.updated_at else "" ), - "updated_at": item.updated_at.isoformat() if item.updated_at else "", "search_score": float(hit.score or 0.0), "search_summary": str(hit.summary or ""), } diff --git a/core/memory/search_backend.py b/core/memory/search_backend.py index 5f34ac1..1732dda 100644 --- a/core/memory/search_backend.py +++ b/core/memory/search_backend.py @@ -1,7 +1,6 @@ from __future__ import annotations import hashlib -import json import time from dataclasses import dataclass from typing import Any @@ -144,9 +143,10 @@ class ManticoreMemorySearchBackend(BaseMemorySearchBackend): self.base_url = str( getattr(settings, "MANTICORE_HTTP_URL", "http://localhost:9308") ).rstrip("/") - self.table = str( - getattr(settings, "MANTICORE_MEMORY_TABLE", "gia_memory_items") - ).strip() or "gia_memory_items" + self.table = ( + str(getattr(settings, "MANTICORE_MEMORY_TABLE", "gia_memory_items")).strip() + or "gia_memory_items" + ) self.timeout_seconds = int(getattr(settings, "MANTICORE_HTTP_TIMEOUT", 5) or 5) self._table_cache_key = f"{self.base_url}|{self.table}" @@ -163,7 +163,9 @@ class ManticoreMemorySearchBackend(BaseMemorySearchBackend): return dict(payload or {}) def ensure_table(self) -> None: - last_ready = float(self._table_ready_cache.get(self._table_cache_key, 0.0) or 0.0) + last_ready = float( + self._table_ready_cache.get(self._table_cache_key, 0.0) or 0.0 + ) if (time.time() - last_ready) <= float(self._table_ready_ttl_seconds): return self._sql( @@ -254,7 +256,9 @@ class ManticoreMemorySearchBackend(BaseMemorySearchBackend): try: values.append(self._build_upsert_values_clause(item)) except Exception as exc: - log.warning("memory-search upsert build failed id=%s err=%s", item.id, exc) + log.warning( + "memory-search upsert build failed id=%s err=%s", item.id, exc + ) continue if len(values) >= batch_size: self._sql( @@ -290,7 +294,11 @@ class ManticoreMemorySearchBackend(BaseMemorySearchBackend): where_parts = [f"user_id={int(user_id)}", f"MATCH('{self._escape(needle)}')"] if conversation_id: where_parts.append(f"conversation_id='{self._escape(conversation_id)}'") - statuses = [str(item or "").strip() for item in include_statuses if str(item or "").strip()] + statuses = [ + str(item or "").strip() + for item in include_statuses + if str(item or "").strip() + ] if statuses: in_clause = ",".join(f"'{self._escape(item)}'" for item in statuses) where_parts.append(f"status IN ({in_clause})") diff --git a/core/messaging/history.py b/core/messaging/history.py index e1906e3..fd5ad55 100644 --- a/core/messaging/history.py +++ b/core/messaging/history.py @@ -1,12 +1,13 @@ -from asgiref.sync import sync_to_async -from django.conf import settings import time import uuid +from asgiref.sync import sync_to_async +from django.conf import settings + from core.events.ledger import append_event from core.messaging.utils import messages_to_string -from core.observability.tracing import ensure_trace_id from core.models import ChatSession, Message, QueuedMessage +from core.observability.tracing import ensure_trace_id from core.util import logs log = logs.get_logger("history") @@ -272,7 +273,9 @@ async def store_own_message( trace_id=ensure_trace_id(trace_id, message_meta or {}), ) except Exception as exc: - log.warning("Event ledger append failed for own message=%s: %s", msg.id, exc) + log.warning( + "Event ledger append failed for own message=%s: %s", msg.id, exc + ) return msg diff --git a/core/messaging/reply_sync.py b/core/messaging/reply_sync.py index 99473fe..98e3215 100644 --- a/core/messaging/reply_sync.py +++ b/core/messaging/reply_sync.py @@ -335,8 +335,12 @@ def extract_reply_ref(service: str, raw_payload: dict[str, Any]) -> dict[str, st svc = _clean(service).lower() payload = _as_dict(raw_payload) if svc == "xmpp": - reply_id = _clean(payload.get("reply_source_message_id") or payload.get("reply_id")) - reply_chat = _clean(payload.get("reply_source_chat_id") or payload.get("reply_chat_id")) + reply_id = _clean( + payload.get("reply_source_message_id") or payload.get("reply_id") + ) + reply_chat = _clean( + payload.get("reply_source_chat_id") or payload.get("reply_chat_id") + ) if reply_id: return { "reply_source_message_id": reply_id, @@ -363,7 +367,9 @@ def extract_origin_tag(raw_payload: dict[str, Any] | None) -> str: return _find_origin_tag(_as_dict(raw_payload)) -async def resolve_reply_target(user, session, reply_ref: dict[str, str]) -> Message | None: +async def resolve_reply_target( + user, session, reply_ref: dict[str, str] +) -> Message | None: if not reply_ref or session is None: return None reply_source_message_id = _clean(reply_ref.get("reply_source_message_id")) diff --git a/core/migrations/0029_answermemory_answersuggestionevent_chattasksource_and_more.py b/core/migrations/0029_answermemory_answersuggestionevent_chattasksource_and_more.py index e38ae5e..46cd34d 100644 --- a/core/migrations/0029_answermemory_answersuggestionevent_chattasksource_and_more.py +++ b/core/migrations/0029_answermemory_answersuggestionevent_chattasksource_and_more.py @@ -1,7 +1,8 @@ # Generated by Django 5.2.11 on 2026-03-02 11:55 -import django.db.models.deletion import uuid + +import django.db.models.deletion from django.conf import settings from django.db import migrations, models diff --git a/core/migrations/0033_contactavailability_and_externalchatlink.py b/core/migrations/0033_contactavailability_and_externalchatlink.py index 93eefc1..aeb5dc0 100644 --- a/core/migrations/0033_contactavailability_and_externalchatlink.py +++ b/core/migrations/0033_contactavailability_and_externalchatlink.py @@ -1,6 +1,6 @@ +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion class Migration(migrations.Migration): diff --git a/core/migrations/0034_codexrun_codexpermissionrequest_and_more.py b/core/migrations/0034_codexrun_codexpermissionrequest_and_more.py index 49c74aa..a242da2 100644 --- a/core/migrations/0034_codexrun_codexpermissionrequest_and_more.py +++ b/core/migrations/0034_codexrun_codexpermissionrequest_and_more.py @@ -1,8 +1,8 @@ import uuid +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion class Migration(migrations.Migration): diff --git a/core/migrations/0035_conversationevent_adapterhealthevent.py b/core/migrations/0035_conversationevent_adapterhealthevent.py index 1ed6429..a393676 100644 --- a/core/migrations/0035_conversationevent_adapterhealthevent.py +++ b/core/migrations/0035_conversationevent_adapterhealthevent.py @@ -1,8 +1,8 @@ import uuid +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion class Migration(migrations.Migration): diff --git a/core/migrations/0040_commandsecuritypolicy_gatewaycommandevent.py b/core/migrations/0040_commandsecuritypolicy_gatewaycommandevent.py index a0b273f..9da74e7 100644 --- a/core/migrations/0040_commandsecuritypolicy_gatewaycommandevent.py +++ b/core/migrations/0040_commandsecuritypolicy_gatewaycommandevent.py @@ -1,8 +1,8 @@ # Generated by Django 4.2.19 on 2026-03-07 00:00 +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion class Migration(migrations.Migration): diff --git a/core/migrations/0042_userxmppomemotrustedkey_and_more.py b/core/migrations/0042_userxmppomemotrustedkey_and_more.py new file mode 100644 index 0000000..4d80b18 --- /dev/null +++ b/core/migrations/0042_userxmppomemotrustedkey_and_more.py @@ -0,0 +1,39 @@ +# Generated by Django 5.2.7 on 2026-03-07 20:12 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0041_useraccessibilitysettings'), + ] + + operations = [ + migrations.CreateModel( + name='UserXmppOmemoTrustedKey', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('jid', models.CharField(blank=True, default='', max_length=255)), + ('key_type', models.CharField(choices=[('fingerprint', 'Fingerprint'), ('client_key', 'Client key')], default='fingerprint', max_length=32)), + ('key_id', models.CharField(max_length=255)), + ('trusted', models.BooleanField(default=False)), + ('source', models.CharField(blank=True, default='', max_length=64)), + ('last_seen_at', models.DateTimeField(blank=True, null=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='xmpp_omemo_trusted_keys', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'indexes': [ + models.Index(fields=['user', 'trusted', 'updated_at'], name='core_userxomemo_trusted_idx'), + models.Index(fields=['user', 'jid', 'updated_at'], name='core_userxomemo_jid_idx'), + ], + 'constraints': [ + models.UniqueConstraint(fields=('user', 'jid', 'key_type', 'key_id'), name='unique_user_xmpp_omemo_trusted_key'), + ], + }, + ), + ] diff --git a/core/migrations/0043_userxmppsecuritysettings_encrypt_contact_messages_with_omemo.py b/core/migrations/0043_userxmppsecuritysettings_encrypt_contact_messages_with_omemo.py new file mode 100644 index 0000000..3b1c11a --- /dev/null +++ b/core/migrations/0043_userxmppsecuritysettings_encrypt_contact_messages_with_omemo.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2.7 on 2026-03-07 20:23 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("core", "0042_userxmppomemotrustedkey_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="userxmppsecuritysettings", + name="encrypt_contact_messages_with_omemo", + field=models.BooleanField(default=True), + ), + ] diff --git a/core/models.py b/core/models.py index 8ba8a4f..e8672c9 100644 --- a/core/models.py +++ b/core/models.py @@ -20,14 +20,14 @@ SERVICE_CHOICES = ( ) CHANNEL_SERVICE_CHOICES = SERVICE_CHOICES + (("web", "Web"),) MBTI_CHOICES = ( - ("INTJ", "INTJ - Architect"),# ;) + ("INTJ", "INTJ - Architect"), # ;) ("INTP", "INTP - Logician"), ("ENTJ", "ENTJ - Commander"), ("ENTP", "ENTP - Debater"), ("INFJ", "INFJ - Advocate"), ("INFP", "INFP - Mediator"), ("ENFJ", "ENFJ - Protagonist"), - ("ENFP", "ENFP - Campaigner"), # <3 + ("ENFP", "ENFP - Campaigner"), # <3 ("ISTJ", "ISTJ - Logistician"), ("ISFJ", "ISFJ - Defender"), ("ESTJ", "ESTJ - Executive"), @@ -241,17 +241,13 @@ class PlatformChatLink(models.Model): raise ValidationError("Person must belong to the same user.") if self.person_identifier_id: if self.person_identifier.user_id != self.user_id: - raise ValidationError( - "Person identifier must belong to the same user." - ) + raise ValidationError("Person identifier must belong to the same user.") if self.person_identifier.person_id != self.person_id: raise ValidationError( "Person identifier must belong to the selected person." ) if self.person_identifier.service != self.service: - raise ValidationError( - "Chat links cannot be linked across platforms." - ) + raise ValidationError("Chat links cannot be linked across platforms.") def save(self, *args, **kwargs): value = str(self.chat_identifier or "").strip() @@ -1869,9 +1865,7 @@ class PatternArtifactExport(models.Model): class CommandProfile(models.Model): - WINDOW_SCOPE_CHOICES = ( - ("conversation", "Conversation"), - ) + WINDOW_SCOPE_CHOICES = (("conversation", "Conversation"),) VISIBILITY_CHOICES = ( ("status_in_source", "Status In Source"), ("silent", "Silent"), @@ -2039,7 +2033,9 @@ class BusinessPlanDocument(models.Model): class Meta: indexes = [ models.Index(fields=["user", "status", "updated_at"]), - models.Index(fields=["user", "source_service", "source_channel_identifier"]), + models.Index( + fields=["user", "source_service", "source_channel_identifier"] + ), ] @@ -2243,7 +2239,9 @@ class TranslationEventLog(models.Model): class AnswerMemory(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="answer_memory") + user = models.ForeignKey( + User, on_delete=models.CASCADE, related_name="answer_memory" + ) service = models.CharField(max_length=255, choices=CHANNEL_SERVICE_CHOICES) channel_identifier = models.CharField(max_length=255) question_fingerprint = models.CharField(max_length=128) @@ -2261,7 +2259,9 @@ class AnswerMemory(models.Model): class Meta: indexes = [ - models.Index(fields=["user", "service", "channel_identifier", "created_at"]), + models.Index( + fields=["user", "service", "channel_identifier", "created_at"] + ), models.Index(fields=["user", "question_fingerprint", "created_at"]), ] @@ -2284,7 +2284,9 @@ class AnswerSuggestionEvent(models.Model): on_delete=models.CASCADE, related_name="answer_suggestion_events", ) - status = models.CharField(max_length=32, choices=STATUS_CHOICES, default="suggested") + status = models.CharField( + max_length=32, choices=STATUS_CHOICES, default="suggested" + ) candidate_answer = models.ForeignKey( AnswerMemory, on_delete=models.SET_NULL, @@ -2305,7 +2307,9 @@ class AnswerSuggestionEvent(models.Model): class TaskProject(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="task_projects") + user = models.ForeignKey( + User, on_delete=models.CASCADE, related_name="task_projects" + ) name = models.CharField(max_length=255) external_key = models.CharField(max_length=255, blank=True, default="") active = models.BooleanField(default=True) @@ -2349,7 +2353,9 @@ class TaskEpic(models.Model): class ChatTaskSource(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="chat_task_sources") + user = models.ForeignKey( + User, on_delete=models.CASCADE, related_name="chat_task_sources" + ) service = models.CharField(max_length=255, choices=CHANNEL_SERVICE_CHOICES) channel_identifier = models.CharField(max_length=255) project = models.ForeignKey( @@ -2378,7 +2384,9 @@ class ChatTaskSource(models.Model): class DerivedTask(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="derived_tasks") + user = models.ForeignKey( + User, on_delete=models.CASCADE, related_name="derived_tasks" + ) project = models.ForeignKey( TaskProject, on_delete=models.CASCADE, @@ -2574,7 +2582,9 @@ class ExternalSyncEvent(models.Model): ) id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="external_sync_events") + user = models.ForeignKey( + User, on_delete=models.CASCADE, related_name="external_sync_events" + ) task = models.ForeignKey( DerivedTask, on_delete=models.SET_NULL, @@ -2606,7 +2616,9 @@ class ExternalSyncEvent(models.Model): class TaskProviderConfig(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="task_provider_configs") + user = models.ForeignKey( + User, on_delete=models.CASCADE, related_name="task_provider_configs" + ) provider = models.CharField(max_length=64, default="mock") enabled = models.BooleanField(default=False) settings = models.JSONField(default=dict, blank=True) @@ -2684,7 +2696,9 @@ class CodexRun(models.Model): class Meta: indexes = [ models.Index(fields=["user", "status", "updated_at"]), - models.Index(fields=["user", "source_service", "source_channel", "created_at"]), + models.Index( + fields=["user", "source_service", "source_channel", "created_at"] + ), ] @@ -2697,7 +2711,9 @@ class CodexPermissionRequest(models.Model): ) id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="codex_permission_requests") + user = models.ForeignKey( + User, on_delete=models.CASCADE, related_name="codex_permission_requests" + ) codex_run = models.ForeignKey( CodexRun, on_delete=models.CASCADE, @@ -2910,7 +2926,49 @@ class UserXmppOmemoState(models.Model): class Meta: indexes = [ - models.Index(fields=["status", "updated_at"], name="core_userxm_status_133ead_idx"), + models.Index( + fields=["status", "updated_at"], name="core_userxm_status_133ead_idx" + ), + ] + + +class UserXmppOmemoTrustedKey(models.Model): + KEY_TYPE_CHOICES = ( + ("fingerprint", "Fingerprint"), + ("client_key", "Client key"), + ) + + user = models.ForeignKey( + User, + on_delete=models.CASCADE, + related_name="xmpp_omemo_trusted_keys", + ) + jid = models.CharField(max_length=255, blank=True, default="") + key_type = models.CharField( + max_length=32, choices=KEY_TYPE_CHOICES, default="fingerprint" + ) + key_id = models.CharField(max_length=255) + trusted = models.BooleanField(default=False) + source = models.CharField(max_length=64, blank=True, default="") + last_seen_at = models.DateTimeField(blank=True, null=True) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["user", "jid", "key_type", "key_id"], + name="unique_user_xmpp_omemo_trusted_key", + ), + ] + indexes = [ + models.Index( + fields=["user", "trusted", "updated_at"], + name="core_userxomemo_trusted_idx", + ), + models.Index( + fields=["user", "jid", "updated_at"], name="core_userxomemo_jid_idx" + ), ] @@ -2921,6 +2979,7 @@ class UserXmppSecuritySettings(models.Model): related_name="xmpp_security_settings", ) require_omemo = models.BooleanField(default=False) + encrypt_contact_messages_with_omemo = models.BooleanField(default=True) created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) @@ -2938,7 +2997,9 @@ class UserAccessibilitySettings(models.Model): class TaskCompletionPattern(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="task_completion_patterns") + user = models.ForeignKey( + User, on_delete=models.CASCADE, related_name="task_completion_patterns" + ) phrase = models.CharField(max_length=64) enabled = models.BooleanField(default=True) position = models.PositiveIntegerField(default=0) diff --git a/core/modules/router.py b/core/modules/router.py index 320cef4..ea7988a 100644 --- a/core/modules/router.py +++ b/core/modules/router.py @@ -4,22 +4,22 @@ import re from asgiref.sync import sync_to_async from django.conf import settings +from core.assist.engine import process_inbound_assist from core.clients import transport -from core.events import event_ledger_status from core.clients.instagram import InstagramClient from core.clients.signal import SignalClient from core.clients.whatsapp import WhatsAppClient from core.clients.xmpp import XMPPClient -from core.assist.engine import process_inbound_assist from core.commands.base import CommandContext from core.commands.engine import process_inbound_message +from core.events import event_ledger_status from core.messaging import history from core.models import PersonIdentifier +from core.observability.tracing import ensure_trace_id from core.presence import AvailabilitySignal, record_native_signal from core.realtime.typing_state import set_person_typing_state from core.translation.engine import process_inbound_translation from core.util import logs -from core.observability.tracing import ensure_trace_id class UnifiedRouter(object): @@ -119,7 +119,9 @@ class UnifiedRouter(object): return identifiers = await self._resolve_identifier_objects(protocol, identifier) if identifiers: - outgoing = str(getattr(local_message, "custom_author", "") or "").strip().upper() in { + outgoing = str( + getattr(local_message, "custom_author", "") or "" + ).strip().upper() in { "USER", "BOT", } @@ -268,7 +270,9 @@ class UnifiedRouter(object): ts=int(read_ts or 0), payload={ "origin": "router.message_read", - "message_timestamps": [int(v) for v in list(timestamps or []) if str(v).isdigit()], + "message_timestamps": [ + int(v) for v in list(timestamps or []) if str(v).isdigit() + ], "read_by": str(read_by or row.identifier), }, ) diff --git a/core/presence/engine.py b/core/presence/engine.py index 5d51e21..d32ea97 100644 --- a/core/presence/engine.py +++ b/core/presence/engine.py @@ -12,9 +12,15 @@ from core.models import ( PersonIdentifier, User, ) + from .inference import fade_confidence, now_ms, should_fade -POSITIVE_SOURCE_KINDS = {"native_presence", "read_receipt", "typing_start", "message_in"} +POSITIVE_SOURCE_KINDS = { + "native_presence", + "read_receipt", + "typing_start", + "message_in", +} @dataclass(slots=True) @@ -99,7 +105,8 @@ def record_native_signal(signal: AvailabilitySignal) -> ContactAvailabilityEvent person_identifier=signal.person_identifier, service=str(signal.service or "").strip().lower() or "signal", source_kind=str(signal.source_kind or "").strip() or "native_presence", - availability_state=str(signal.availability_state or "unknown").strip() or "unknown", + availability_state=str(signal.availability_state or "unknown").strip() + or "unknown", confidence=float(signal.confidence or 0.0), ts=_normalize_ts(signal.ts), payload=dict(signal.payload or {}), @@ -109,7 +116,9 @@ def record_native_signal(signal: AvailabilitySignal) -> ContactAvailabilityEvent return event -def record_inferred_signal(signal: AvailabilitySignal) -> ContactAvailabilityEvent | None: +def record_inferred_signal( + signal: AvailabilitySignal, +) -> ContactAvailabilityEvent | None: settings_row = get_settings(signal.user) if not settings_row.enabled or not settings_row.inference_enabled: return None @@ -151,7 +160,9 @@ def ensure_fading_state( return None if latest.source_kind not in POSITIVE_SOURCE_KINDS: return None - if not should_fade(int(latest.ts or 0), current_ts, settings_row.fade_threshold_seconds): + if not should_fade( + int(latest.ts or 0), current_ts, settings_row.fade_threshold_seconds + ): return None elapsed = max(0, current_ts - int(latest.ts or 0)) diff --git a/core/presence/query.py b/core/presence/query.py index 82f548c..06d1217 100644 --- a/core/presence/query.py +++ b/core/presence/query.py @@ -3,6 +3,7 @@ from __future__ import annotations from django.db.models import Q from core.models import ContactAvailabilityEvent, ContactAvailabilitySpan, Person, User + from .engine import ensure_fading_state from .inference import now_ms @@ -19,9 +20,7 @@ def spans_for_range( qs = ContactAvailabilitySpan.objects.filter( user=user, person=person, - ).filter( - Q(start_ts__lte=end_ts) & Q(end_ts__gte=start_ts) - ) + ).filter(Q(start_ts__lte=end_ts) & Q(end_ts__gte=start_ts)) if service: qs = qs.filter(service=str(service).strip().lower()) diff --git a/core/security/__init__.py b/core/security/__init__.py index ce85308..d6f58b8 100644 --- a/core/security/__init__.py +++ b/core/security/__init__.py @@ -1,2 +1 @@ """Security helpers shared across transport adapters.""" - diff --git a/core/security/attachments.py b/core/security/attachments.py index 3fe24e7..861f4ad 100644 --- a/core/security/attachments.py +++ b/core/security/attachments.py @@ -101,7 +101,9 @@ def validate_attachment_metadata( raise ValueError(f"blocked_mime_type:{normalized_type}") allow_unmatched = bool(getattr(settings, "ATTACHMENT_ALLOW_UNKNOWN_MIME", False)) - if not any(fnmatch(normalized_type, pattern) for pattern in _allowed_mime_patterns()): + if not any( + fnmatch(normalized_type, pattern) for pattern in _allowed_mime_patterns() + ): if not allow_unmatched: raise ValueError(f"unsupported_mime_type:{normalized_type}") diff --git a/core/security/command_policy.py b/core/security/command_policy.py index cdccf12..f3e0d6f 100644 --- a/core/security/command_policy.py +++ b/core/security/command_policy.py @@ -68,15 +68,13 @@ def _omemo_facts(ctx: CommandSecurityContext) -> tuple[str, str]: message_meta = dict(ctx.message_meta or {}) payload = dict(ctx.payload or {}) xmpp_meta = dict(message_meta.get("xmpp") or {}) - status = str( - xmpp_meta.get("omemo_status") - or payload.get("omemo_status") - or "" - ).strip().lower() + status = ( + str(xmpp_meta.get("omemo_status") or payload.get("omemo_status") or "") + .strip() + .lower() + ) client_key = str( - xmpp_meta.get("omemo_client_key") - or payload.get("omemo_client_key") - or "" + xmpp_meta.get("omemo_client_key") or payload.get("omemo_client_key") or "" ).strip() return status, client_key @@ -160,7 +158,8 @@ def evaluate_command_policy( service = _normalize_service(context.service) channel = _normalize_channel(context.channel_identifier) allowed_services = [ - item.lower() for item in _normalize_list(getattr(policy, "allowed_services", [])) + item.lower() + for item in _normalize_list(getattr(policy, "allowed_services", [])) ] global_allowed_services = [ item.lower() diff --git a/core/tasks/chat_defaults.py b/core/tasks/chat_defaults.py index 61cad49..f84440d 100644 --- a/core/tasks/chat_defaults.py +++ b/core/tasks/chat_defaults.py @@ -83,7 +83,9 @@ def ensure_default_source_for_chat( message=None, ): service_key = str(service or "").strip().lower() - normalized_identifier = normalize_channel_identifier(service_key, channel_identifier) + normalized_identifier = normalize_channel_identifier( + service_key, channel_identifier + ) variants = channel_variants(service_key, normalized_identifier) if not service_key or not variants: return None diff --git a/core/tasks/codex_approval.py b/core/tasks/codex_approval.py index fb142f8..dcd6bbd 100644 --- a/core/tasks/codex_approval.py +++ b/core/tasks/codex_approval.py @@ -72,9 +72,13 @@ def queue_codex_event_with_pre_approval( }, ) - cfg = TaskProviderConfig.objects.filter(user=user, provider=provider, enabled=True).first() + cfg = TaskProviderConfig.objects.filter( + user=user, provider=provider, enabled=True + ).first() settings_payload = dict(getattr(cfg, "settings", {}) or {}) - approver_service = str(settings_payload.get("approver_service") or "").strip().lower() + approver_service = ( + str(settings_payload.get("approver_service") or "").strip().lower() + ) approver_identifier = str(settings_payload.get("approver_identifier") or "").strip() if approver_service and approver_identifier: try: diff --git a/core/tasks/codex_support.py b/core/tasks/codex_support.py index 3d1f850..18cbc38 100644 --- a/core/tasks/codex_support.py +++ b/core/tasks/codex_support.py @@ -57,7 +57,9 @@ def resolve_external_chat_id(*, user, provider: str, service: str, channel: str) provider=provider, enabled=True, ) - .filter(Q(person_identifier=person_identifier) | Q(person=person_identifier.person)) + .filter( + Q(person_identifier=person_identifier) | Q(person=person_identifier.person) + ) .order_by("-updated_at", "-id") .first() ) diff --git a/core/tasks/engine.py b/core/tasks/engine.py index 200ad58..99b03ba 100644 --- a/core/tasks/engine.py +++ b/core/tasks/engine.py @@ -22,16 +22,23 @@ from core.models import ( TaskEpic, TaskProviderConfig, ) -from core.tasks.chat_defaults import ensure_default_source_for_chat, resolve_message_scope -from core.tasks.codex_approval import queue_codex_event_with_pre_approval -from core.tasks.providers import get_provider -from core.tasks.codex_support import resolve_external_chat_id from core.security.command_policy import CommandSecurityContext, evaluate_command_policy +from core.tasks.chat_defaults import ( + ensure_default_source_for_chat, + resolve_message_scope, +) +from core.tasks.codex_approval import queue_codex_event_with_pre_approval +from core.tasks.codex_support import resolve_external_chat_id +from core.tasks.providers import get_provider _TASK_HINT_RE = re.compile(r"\b(todo|task|action|need to|please)\b", re.IGNORECASE) -_COMPLETION_RE = re.compile(r"\b(done|completed|fixed)\s*#([A-Za-z0-9_-]+)\b", re.IGNORECASE) +_COMPLETION_RE = re.compile( + r"\b(done|completed|fixed)\s*#([A-Za-z0-9_-]+)\b", re.IGNORECASE +) _BALANCED_HINT_RE = re.compile(r"\b(todo|task|action item|action)\b", re.IGNORECASE) -_BROAD_HINT_RE = re.compile(r"\b(todo|task|action|need to|please|reminder)\b", re.IGNORECASE) +_BROAD_HINT_RE = re.compile( + r"\b(todo|task|action|need to|please|reminder)\b", re.IGNORECASE +) _PREFIX_HEAD_TRIM = " \t\r\n`'\"([{<*#-–—_>.,:;!/?\\|" _LIST_TASKS_RE = re.compile( r"^\s*(?:\.l(?:\s+list(?:\s+tasks?)?)?|\.list(?:\s+tasks?)?)\s*$", @@ -151,15 +158,23 @@ async def _resolve_source_mappings(message: Message) -> list[ChatTaskSource]: lookup_service = str(message.source_service or "").strip().lower() variants = _channel_variants(lookup_service, message.source_chat_id or "") session_identifier = getattr(getattr(message, "session", None), "identifier", None) - canonical_service = str(getattr(session_identifier, "service", "") or "").strip().lower() - canonical_identifier = str(getattr(session_identifier, "identifier", "") or "").strip() + canonical_service = ( + str(getattr(session_identifier, "service", "") or "").strip().lower() + ) + canonical_identifier = str( + getattr(session_identifier, "identifier", "") or "" + ).strip() if lookup_service == "web" and canonical_service and canonical_service != "web": lookup_service = canonical_service variants = _channel_variants(lookup_service, message.source_chat_id or "") for expanded in _channel_variants(lookup_service, canonical_identifier): if expanded and expanded not in variants: variants.append(expanded) - elif canonical_service and canonical_identifier and canonical_service == lookup_service: + elif ( + canonical_service + and canonical_identifier + and canonical_service == lookup_service + ): for expanded in _channel_variants(canonical_service, canonical_identifier): if expanded and expanded not in variants: variants.append(expanded) @@ -170,10 +185,14 @@ async def _resolve_source_mappings(message: Message) -> list[ChatTaskSource]: if not signal_value: continue companions += await sync_to_async(list)( - Chat.objects.filter(source_uuid=signal_value).values_list("source_number", flat=True) + Chat.objects.filter(source_uuid=signal_value).values_list( + "source_number", flat=True + ) ) companions += await sync_to_async(list)( - Chat.objects.filter(source_number=signal_value).values_list("source_uuid", flat=True) + Chat.objects.filter(source_number=signal_value).values_list( + "source_uuid", flat=True + ) ) for candidate in companions: for expanded in _channel_variants("signal", str(candidate or "").strip()): @@ -271,7 +290,8 @@ def _normalize_flags(raw: dict | None) -> dict: row = dict(raw or {}) return { "derive_enabled": _to_bool(row.get("derive_enabled"), True), - "match_mode": str(row.get("match_mode") or "balanced").strip().lower() or "balanced", + "match_mode": str(row.get("match_mode") or "balanced").strip().lower() + or "balanced", "require_prefix": _to_bool(row.get("require_prefix"), False), "allowed_prefixes": _parse_prefixes(row.get("allowed_prefixes")), "completion_enabled": _to_bool(row.get("completion_enabled"), True), @@ -287,7 +307,9 @@ def _normalize_partial_flags(raw: dict | None) -> dict: if "derive_enabled" in row: out["derive_enabled"] = _to_bool(row.get("derive_enabled"), True) if "match_mode" in row: - out["match_mode"] = str(row.get("match_mode") or "balanced").strip().lower() or "balanced" + out["match_mode"] = ( + str(row.get("match_mode") or "balanced").strip().lower() or "balanced" + ) if "require_prefix" in row: out["require_prefix"] = _to_bool(row.get("require_prefix"), False) if "allowed_prefixes" in row: @@ -304,7 +326,9 @@ def _normalize_partial_flags(raw: dict | None) -> dict: def _effective_flags(source: ChatTaskSource) -> dict: - project_flags = _normalize_flags(getattr(getattr(source, "project", None), "settings", {}) or {}) + project_flags = _normalize_flags( + getattr(getattr(source, "project", None), "settings", {}) or {} + ) source_flags = _normalize_partial_flags(getattr(source, "settings", {}) or {}) merged = dict(project_flags) merged.update(source_flags) @@ -360,7 +384,10 @@ async def _derive_title(message: Message) -> str: {"role": "user", "content": text[:2000]}, ] try: - title = str(await ai_runner.run_prompt(prompt, ai_obj, operation="task_derive_title") or "").strip() + title = str( + await ai_runner.run_prompt(prompt, ai_obj, operation="task_derive_title") + or "" + ).strip() except Exception: title = "" return (title or text)[:255] @@ -376,9 +403,13 @@ async def _derive_title_with_flags(message: Message, flags: dict) -> str: return (cleaned or title or "Untitled task")[:255] -async def _emit_sync_event(task: DerivedTask, event: DerivedTaskEvent, action: str) -> None: +async def _emit_sync_event( + task: DerivedTask, event: DerivedTaskEvent, action: str +) -> None: cfg = await sync_to_async( - lambda: TaskProviderConfig.objects.filter(user=task.user, enabled=True).order_by("provider").first() + lambda: TaskProviderConfig.objects.filter(user=task.user, enabled=True) + .order_by("provider") + .first() )() provider_name = str(getattr(cfg, "provider", "mock") or "mock") provider_settings = dict(getattr(cfg, "settings", {}) or {}) @@ -416,7 +447,11 @@ async def _emit_sync_event(task: DerivedTask, event: DerivedTaskEvent, action: s "source_channel": str(task.source_channel or ""), "external_chat_id": external_chat_id, "origin_message_id": str(getattr(task, "origin_message_id", "") or ""), - "trigger_message_id": str(getattr(event, "source_message_id", "") or getattr(task, "origin_message_id", "") or ""), + "trigger_message_id": str( + getattr(event, "source_message_id", "") + or getattr(task, "origin_message_id", "") + or "" + ), "mode": "default", "payload": event.payload, "memory_context": memory_context, @@ -495,7 +530,9 @@ async def _emit_sync_event(task: DerivedTask, event: DerivedTaskEvent, action: s codex_run.status = status codex_run.result_payload = dict(result.payload or {}) codex_run.error = str(result.error or "") - await sync_to_async(codex_run.save)(update_fields=["status", "result_payload", "error", "updated_at"]) + await sync_to_async(codex_run.save)( + update_fields=["status", "result_payload", "error", "updated_at"] + ) if result.ok and result.external_key and not task.external_key: task.external_key = str(result.external_key) await sync_to_async(task.save)(update_fields=["external_key"]) @@ -503,15 +540,28 @@ async def _emit_sync_event(task: DerivedTask, event: DerivedTaskEvent, action: s async def _completion_regex(message: Message) -> re.Pattern: patterns = await sync_to_async(list)( - TaskCompletionPattern.objects.filter(user=message.user, enabled=True).order_by("position", "created_at") + TaskCompletionPattern.objects.filter(user=message.user, enabled=True).order_by( + "position", "created_at" + ) ) - phrases = [str(row.phrase or "").strip() for row in patterns if str(row.phrase or "").strip()] + phrases = [ + str(row.phrase or "").strip() + for row in patterns + if str(row.phrase or "").strip() + ] if not phrases: phrases = ["done", "completed", "fixed"] - return re.compile(r"\\b(?:" + "|".join(re.escape(p) for p in phrases) + r")\\s*#([A-Za-z0-9_-]+)\\b", re.IGNORECASE) + return re.compile( + r"\\b(?:" + + "|".join(re.escape(p) for p in phrases) + + r")\\s*#([A-Za-z0-9_-]+)\\b", + re.IGNORECASE, + ) -async def _send_scope_message(source: ChatTaskSource, message: Message, text: str) -> None: +async def _send_scope_message( + source: ChatTaskSource, message: Message, text: str +) -> None: await send_message_raw( source.service or message.source_service or "web", source.channel_identifier or message.source_chat_id or "", @@ -521,7 +571,9 @@ async def _send_scope_message(source: ChatTaskSource, message: Message, text: st ) -async def _handle_scope_task_commands(message: Message, sources: list[ChatTaskSource], text: str) -> bool: +async def _handle_scope_task_commands( + message: Message, sources: list[ChatTaskSource], text: str +) -> bool: if not sources: return False body = str(text or "").strip() @@ -538,7 +590,9 @@ async def _handle_scope_task_commands(message: Message, sources: list[ChatTaskSo .order_by("-created_at")[:20] ) if not open_rows: - await _send_scope_message(source, message, "[task] no open tasks in this chat.") + await _send_scope_message( + source, message, "[task] no open tasks in this chat." + ) return True lines = ["[task] open tasks:"] for row in open_rows: @@ -573,7 +627,9 @@ async def _handle_scope_task_commands(message: Message, sources: list[ChatTaskSo .first() )() if task is None: - await _send_scope_message(source, message, "[task] nothing to undo in this chat.") + await _send_scope_message( + source, message, "[task] nothing to undo in this chat." + ) return True ref = str(task.reference_code or "") title = str(task.title or "") @@ -596,10 +652,16 @@ async def _handle_scope_task_commands(message: Message, sources: list[ChatTaskSo .first() )() if task is None: - await _send_scope_message(source, message, f"[task] #{reference} not found.") + await _send_scope_message( + source, message, f"[task] #{reference} not found." + ) return True due_str = f"\ndue: {task.due_date}" if task.due_date else "" - assignee_str = f"\nassignee: {task.assignee_identifier}" if task.assignee_identifier else "" + assignee_str = ( + f"\nassignee: {task.assignee_identifier}" + if task.assignee_identifier + else "" + ) detail = ( f"[task] #{task.reference_code}: {task.title}" f"\nstatus: {task.status_snapshot}" @@ -624,7 +686,9 @@ async def _handle_scope_task_commands(message: Message, sources: list[ChatTaskSo .first() )() if task is None: - await _send_scope_message(source, message, f"[task] #{reference} not found.") + await _send_scope_message( + source, message, f"[task] #{reference} not found." + ) return True task.status_snapshot = "completed" await sync_to_async(task.save)(update_fields=["status_snapshot"]) @@ -633,10 +697,16 @@ async def _handle_scope_task_commands(message: Message, sources: list[ChatTaskSo event_type="completion_marked", actor_identifier=str(message.sender_uuid or ""), source_message=message, - payload={"marker": reference, "command": ".task complete", "via": "chat_command"}, + payload={ + "marker": reference, + "command": ".task complete", + "via": "chat_command", + }, ) await _emit_sync_event(task, event, "complete") - await _send_scope_message(source, message, f"[task] completed #{task.reference_code}: {task.title}") + await _send_scope_message( + source, message, f"[task] completed #{task.reference_code}: {task.title}" + ) return True return False @@ -656,7 +726,9 @@ def _strip_epic_token(text: str) -> str: return re.sub(r"\s{2,}", " ", cleaned).strip() -async def _handle_epic_create_command(message: Message, sources: list[ChatTaskSource], text: str) -> bool: +async def _handle_epic_create_command( + message: Message, sources: list[ChatTaskSource], text: str +) -> bool: match = _EPIC_CREATE_RE.match(str(text or "")) if not match or not sources: return False @@ -766,13 +838,21 @@ async def process_inbound_task_intelligence(message: Message) -> None: if not submit_decision.allowed: return - completion_allowed = any(bool(_effective_flags(source).get("completion_enabled")) for source in sources) + completion_allowed = any( + bool(_effective_flags(source).get("completion_enabled")) for source in sources + ) completion_rx = await _completion_regex(message) if completion_allowed else None - marker_match = (completion_rx.search(text) if completion_rx else None) or (_COMPLETION_RE.search(text) if completion_allowed else None) + marker_match = (completion_rx.search(text) if completion_rx else None) or ( + _COMPLETION_RE.search(text) if completion_allowed else None + ) if marker_match: ref_code = str(marker_match.group(marker_match.lastindex or 1) or "").strip() task = await sync_to_async( - lambda: DerivedTask.objects.filter(user=message.user, reference_code=ref_code).order_by("-created_at").first() + lambda: DerivedTask.objects.filter( + user=message.user, reference_code=ref_code + ) + .order_by("-created_at") + .first() )() if not task: # parser warning event attached to a newly derived placeholder in mapped project @@ -848,7 +928,11 @@ async def process_inbound_task_intelligence(message: Message) -> None: status_snapshot="open", due_date=parsed_due_date, assignee_identifier=parsed_assignee, - immutable_payload={"origin_text": text, "task_text": task_text, "flags": flags}, + immutable_payload={ + "origin_text": text, + "task_text": task_text, + "flags": flags, + }, ) event = await sync_to_async(DerivedTaskEvent.objects.create)( task=task, diff --git a/core/tasks/providers/claude_cli.py b/core/tasks/providers/claude_cli.py index 428cb41..2a66250 100644 --- a/core/tasks/providers/claude_cli.py +++ b/core/tasks/providers/claude_cli.py @@ -40,13 +40,21 @@ class ClaudeCLITaskProvider(TaskProvider): return True if "unrecognized subcommand 'create'" in text and "usage: claude" in text: return True - if "unrecognized subcommand 'append_update'" in text and "usage: claude" in text: + if ( + "unrecognized subcommand 'append_update'" in text + and "usage: claude" in text + ): return True - if "unrecognized subcommand 'mark_complete'" in text and "usage: claude" in text: + if ( + "unrecognized subcommand 'mark_complete'" in text + and "usage: claude" in text + ): return True return False - def _builtin_stub_result(self, op: str, payload: dict, stderr: str) -> ProviderResult: + def _builtin_stub_result( + self, op: str, payload: dict, stderr: str + ) -> ProviderResult: mode = str(payload.get("mode") or "default").strip().lower() external_key = ( str(payload.get("external_key") or "").strip() @@ -117,7 +125,10 @@ class ClaudeCLITaskProvider(TaskProvider): cwd=workspace if workspace else None, ) stderr_probe = str(completed.stderr or "").lower() - if completed.returncode != 0 and "unexpected argument '--op'" in stderr_probe: + if ( + completed.returncode != 0 + and "unexpected argument '--op'" in stderr_probe + ): completed = subprocess.run( fallback_cmd, capture_output=True, @@ -133,7 +144,9 @@ class ClaudeCLITaskProvider(TaskProvider): payload={"op": op, "timeout_seconds": command_timeout}, ) except Exception as exc: - return ProviderResult(ok=False, error=f"claude_cli_exec_error:{exc}", payload={"op": op}) + return ProviderResult( + ok=False, error=f"claude_cli_exec_error:{exc}", payload={"op": op} + ) stdout = str(completed.stdout or "").strip() stderr = str(completed.stderr or "").strip() @@ -172,7 +185,12 @@ class ClaudeCLITaskProvider(TaskProvider): out_payload.update(parsed) if (not ok) and self._is_task_sync_contract_mismatch(stderr): return self._builtin_stub_result(op, dict(payload or {}), stderr) - return ProviderResult(ok=ok, external_key=ext, error=("" if ok else stderr[:4000]), payload=out_payload) + return ProviderResult( + ok=ok, + external_key=ext, + error=("" if ok else stderr[:4000]), + payload=out_payload, + ) def healthcheck(self, config: dict) -> ProviderResult: command = self._command(config) @@ -193,7 +211,11 @@ class ClaudeCLITaskProvider(TaskProvider): "stdout": str(completed.stdout or "").strip()[:1000], "stderr": str(completed.stderr or "").strip()[:1000], }, - error=("" if completed.returncode == 0 else str(completed.stderr or "").strip()[:1000]), + error=( + "" + if completed.returncode == 0 + else str(completed.stderr or "").strip()[:1000] + ), ) def create_task(self, config: dict, payload: dict) -> ProviderResult: diff --git a/core/tasks/providers/codex_cli.py b/core/tasks/providers/codex_cli.py index 46957db..9d00493 100644 --- a/core/tasks/providers/codex_cli.py +++ b/core/tasks/providers/codex_cli.py @@ -46,7 +46,9 @@ class CodexCLITaskProvider(TaskProvider): return True return False - def _builtin_stub_result(self, op: str, payload: dict, stderr: str) -> ProviderResult: + def _builtin_stub_result( + self, op: str, payload: dict, stderr: str + ) -> ProviderResult: mode = str(payload.get("mode") or "default").strip().lower() external_key = ( str(payload.get("external_key") or "").strip() @@ -117,7 +119,10 @@ class CodexCLITaskProvider(TaskProvider): cwd=workspace if workspace else None, ) stderr_probe = str(completed.stderr or "").lower() - if completed.returncode != 0 and "unexpected argument '--op'" in stderr_probe: + if ( + completed.returncode != 0 + and "unexpected argument '--op'" in stderr_probe + ): completed = subprocess.run( fallback_cmd, capture_output=True, @@ -133,7 +138,9 @@ class CodexCLITaskProvider(TaskProvider): payload={"op": op, "timeout_seconds": command_timeout}, ) except Exception as exc: - return ProviderResult(ok=False, error=f"codex_cli_exec_error:{exc}", payload={"op": op}) + return ProviderResult( + ok=False, error=f"codex_cli_exec_error:{exc}", payload={"op": op} + ) stdout = str(completed.stdout or "").strip() stderr = str(completed.stderr or "").strip() @@ -172,7 +179,12 @@ class CodexCLITaskProvider(TaskProvider): out_payload.update(parsed) if (not ok) and self._is_task_sync_contract_mismatch(stderr): return self._builtin_stub_result(op, dict(payload or {}), stderr) - return ProviderResult(ok=ok, external_key=ext, error=("" if ok else stderr[:4000]), payload=out_payload) + return ProviderResult( + ok=ok, + external_key=ext, + error=("" if ok else stderr[:4000]), + payload=out_payload, + ) def healthcheck(self, config: dict) -> ProviderResult: command = self._command(config) @@ -193,7 +205,11 @@ class CodexCLITaskProvider(TaskProvider): "stdout": str(completed.stdout or "").strip()[:1000], "stderr": str(completed.stderr or "").strip()[:1000], }, - error=("" if completed.returncode == 0 else str(completed.stderr or "").strip()[:1000]), + error=( + "" + if completed.returncode == 0 + else str(completed.stderr or "").strip()[:1000] + ), ) def create_task(self, config: dict, payload: dict) -> ProviderResult: diff --git a/core/tasks/providers/mock.py b/core/tasks/providers/mock.py index 044dae3..b798f29 100644 --- a/core/tasks/providers/mock.py +++ b/core/tasks/providers/mock.py @@ -12,14 +12,30 @@ class MockTaskProvider(TaskProvider): return ProviderResult(ok=True, payload={"provider": self.name}) def create_task(self, config: dict, payload: dict) -> ProviderResult: - ext = str(payload.get("external_key") or "") or f"mock-{int(time.time() * 1000)}" - return ProviderResult(ok=True, external_key=ext, payload={"action": "create_task"}) + ext = ( + str(payload.get("external_key") or "") or f"mock-{int(time.time() * 1000)}" + ) + return ProviderResult( + ok=True, external_key=ext, payload={"action": "create_task"} + ) def append_update(self, config: dict, payload: dict) -> ProviderResult: - return ProviderResult(ok=True, external_key=str(payload.get("external_key") or ""), payload={"action": "append_update"}) + return ProviderResult( + ok=True, + external_key=str(payload.get("external_key") or ""), + payload={"action": "append_update"}, + ) def mark_complete(self, config: dict, payload: dict) -> ProviderResult: - return ProviderResult(ok=True, external_key=str(payload.get("external_key") or ""), payload={"action": "mark_complete"}) + return ProviderResult( + ok=True, + external_key=str(payload.get("external_key") or ""), + payload={"action": "mark_complete"}, + ) def link_task(self, config: dict, payload: dict) -> ProviderResult: - return ProviderResult(ok=True, external_key=str(payload.get("external_key") or ""), payload={"action": "link_task"}) + return ProviderResult( + ok=True, + external_key=str(payload.get("external_key") or ""), + payload={"action": "link_task"}, + ) diff --git a/core/templates/base.html b/core/templates/base.html index ad07418..2dd8ad2 100644 --- a/core/templates/base.html +++ b/core/templates/base.html @@ -342,7 +342,7 @@ hx-trigger="click" hx-swap="innerHTML"> - Message + Compose - Tasks + Task Inbox AI - Search - - Queue - - - - OSINT - {% endif %} - - Install -