Increase security and reformat

This commit is contained in:
2026-03-07 20:52:13 +00:00
parent 10588a18b9
commit bca4d6898f
144 changed files with 6735 additions and 3960 deletions

View File

@@ -1,4 +1,5 @@
"""Test-only settings overrides — used via DJANGO_SETTINGS_MODULE=app.test_settings.""" """Test-only settings overrides — used via DJANGO_SETTINGS_MODULE=app.test_settings."""
from app.settings import * # noqa: F401, F403 from app.settings import * # noqa: F401, F403
CACHES = { CACHES = {

View File

@@ -13,12 +13,13 @@ Including another URLconf
1. Import the include() function: from django.urls import include, path 1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
""" """
from django.conf import settings from django.conf import settings
from django.conf.urls.static import static from django.conf.urls.static import static
from django.contrib import admin from django.contrib import admin
from django.contrib.auth.views import LogoutView from django.contrib.auth.views import LogoutView
from django.views.generic import RedirectView
from django.urls import include, path from django.urls import include, path
from django.views.generic import RedirectView
from two_factor.urls import urlpatterns as tf_urls from two_factor.urls import urlpatterns as tf_urls
from two_factor.views.profile import ProfileView from two_factor.views.profile import ProfileView
@@ -41,8 +42,8 @@ from core.views import (
queues, queues,
sessions, sessions,
signal, signal,
tasks,
system, system,
tasks,
whatsapp, whatsapp,
workspace, workspace,
) )
@@ -188,6 +189,11 @@ urlpatterns = [
automation.TranslationSettings.as_view(), automation.TranslationSettings.as_view(),
name="translation_settings", name="translation_settings",
), ),
path(
"settings/business-plans/",
automation.BusinessPlanInbox.as_view(),
name="business_plan_inbox",
),
path( path(
"settings/business-plan/<str:doc_id>/", "settings/business-plan/<str:doc_id>/",
automation.BusinessPlanEditor.as_view(), automation.BusinessPlanEditor.as_view(),

View File

@@ -38,14 +38,31 @@ def _is_question(text: str) -> bool:
if not body: if not body:
return False return False
low = body.lower() low = body.lower()
return body.endswith("?") or low.startswith(("what", "why", "how", "when", "where", "who", "can ", "do ", "did ", "is ", "are ")) return body.endswith("?") or low.startswith(
(
"what",
"why",
"how",
"when",
"where",
"who",
"can ",
"do ",
"did ",
"is ",
"are ",
)
)
def _is_group_channel(message: Message) -> bool: def _is_group_channel(message: Message) -> bool:
channel = str(getattr(message, "source_chat_id", "") or "").strip().lower() channel = str(getattr(message, "source_chat_id", "") or "").strip().lower()
if channel.endswith("@g.us"): if channel.endswith("@g.us"):
return True return True
return str(getattr(message, "source_service", "") or "").strip().lower() == "xmpp" and "conference." in channel return (
str(getattr(message, "source_service", "") or "").strip().lower() == "xmpp"
and "conference." in channel
)
async def learn_from_message(message: Message) -> None: async def learn_from_message(message: Message) -> None:

View File

@@ -12,8 +12,7 @@ class ClientBase(ABC):
self.log.info(f"{self.service.capitalize()} client initialising...") self.log.info(f"{self.service.capitalize()} client initialising...")
@abstractmethod @abstractmethod
def start(self): def start(self): ...
...
# @abstractmethod # @abstractmethod
# async def send_message(self, recipient, message): # async def send_message(self, recipient, message):

View File

@@ -12,7 +12,15 @@ from django.urls import reverse
from signalbot import Command, Context, SignalBot from signalbot import Command, Context, SignalBot
from core.clients import ClientBase, signalapi, transport from core.clients import ClientBase, signalapi, transport
from core.messaging import ai, history, media_bridge, natural, replies, reply_sync, utils from core.messaging import (
ai,
history,
media_bridge,
natural,
replies,
reply_sync,
utils,
)
from core.models import ( from core.models import (
Chat, Chat,
Manipulation, Manipulation,
@@ -402,7 +410,9 @@ class NewSignalBot(SignalBot):
seen_user_ids.add(pi.user_id) seen_user_ids.add(pi.user_id)
users.append(pi.user) users.append(pi.user)
if not users: if not users:
self.log.debug("[Signal] _upsert_groups: no PersonIdentifiers found — skipping") self.log.debug(
"[Signal] _upsert_groups: no PersonIdentifiers found — skipping"
)
return return
for user in users: for user in users:
@@ -423,7 +433,9 @@ class NewSignalBot(SignalBot):
}, },
) )
self.log.info("[Signal] upserted %d groups for %d users", len(groups), len(users)) self.log.info(
"[Signal] upserted %d groups for %d users", len(groups), len(users)
)
async def _detect_groups(self): async def _detect_groups(self):
await super()._detect_groups() await super()._detect_groups()
@@ -505,7 +517,9 @@ class HandleMessage(Command):
source_uuid_norm and dest_norm and source_uuid_norm == dest_norm source_uuid_norm and dest_norm and source_uuid_norm == dest_norm
) )
is_from_bot = bool(bot_uuid and source_uuid_norm and source_uuid_norm == bot_uuid) is_from_bot = bool(
bot_uuid and source_uuid_norm and source_uuid_norm == bot_uuid
)
if (not is_from_bot) and bot_phone_digits and source_phone_digits: if (not is_from_bot) and bot_phone_digits and source_phone_digits:
is_from_bot = source_phone_digits == bot_phone_digits is_from_bot = source_phone_digits == bot_phone_digits
# Inbound deliveries usually do not have destination fields populated. # Inbound deliveries usually do not have destination fields populated.
@@ -596,9 +610,9 @@ class HandleMessage(Command):
candidate_digits = {value for value in candidate_digits if value} candidate_digits = {value for value in candidate_digits if value}
if candidate_digits: if candidate_digits:
signal_rows = await sync_to_async(list)( signal_rows = await sync_to_async(list)(
PersonIdentifier.objects.filter(service=self.service).select_related( PersonIdentifier.objects.filter(
"user" service=self.service
) ).select_related("user")
) )
matched = [] matched = []
for row in signal_rows: for row in signal_rows:
@@ -718,13 +732,13 @@ class HandleMessage(Command):
target_ts=int(reaction_payload.get("target_ts") or 0), target_ts=int(reaction_payload.get("target_ts") or 0),
emoji=str(reaction_payload.get("emoji") or ""), emoji=str(reaction_payload.get("emoji") or ""),
source_service="signal", source_service="signal",
actor=( actor=(effective_source_uuid or effective_source_number or ""),
effective_source_uuid or effective_source_number or ""
),
target_author=str( target_author=str(
(reaction_payload.get("raw") or {}).get("targetAuthorUuid") (reaction_payload.get("raw") or {}).get("targetAuthorUuid")
or (reaction_payload.get("raw") or {}).get("targetAuthor") or (reaction_payload.get("raw") or {}).get("targetAuthor")
or (reaction_payload.get("raw") or {}).get("targetAuthorNumber") or (reaction_payload.get("raw") or {}).get(
"targetAuthorNumber"
)
or "" or ""
), ),
remove=bool(reaction_payload.get("remove")), remove=bool(reaction_payload.get("remove")),
@@ -741,9 +755,7 @@ class HandleMessage(Command):
remove=bool(reaction_payload.get("remove")), remove=bool(reaction_payload.get("remove")),
upstream_message_id="", upstream_message_id="",
upstream_ts=int(reaction_payload.get("target_ts") or 0), upstream_ts=int(reaction_payload.get("target_ts") or 0),
actor=( actor=(effective_source_uuid or effective_source_number or ""),
effective_source_uuid or effective_source_number or ""
),
payload=reaction_payload.get("raw") or {}, payload=reaction_payload.get("raw") or {},
) )
except Exception as exc: except Exception as exc:
@@ -840,9 +852,7 @@ class HandleMessage(Command):
source_ref={ source_ref={
"upstream_message_id": "", "upstream_message_id": "",
"upstream_author": str( "upstream_author": str(
effective_source_uuid effective_source_uuid or effective_source_number or ""
or effective_source_number
or ""
), ),
"upstream_ts": int(ts or 0), "upstream_ts": int(ts or 0),
}, },
@@ -1134,7 +1144,9 @@ class SignalClient(ClientBase):
if int(message_row.delivered_ts or 0) <= 0: if int(message_row.delivered_ts or 0) <= 0:
message_row.delivered_ts = int(result) message_row.delivered_ts = int(result)
update_fields.append("delivered_ts") update_fields.append("delivered_ts")
if str(message_row.source_message_id or "").strip() != str(result): if str(message_row.source_message_id or "").strip() != str(
result
):
message_row.source_message_id = str(result) message_row.source_message_id = str(result)
update_fields.append("source_message_id") update_fields.append("source_message_id")
if update_fields: if update_fields:
@@ -1146,9 +1158,11 @@ class SignalClient(ClientBase):
command_id, command_id,
{ {
"ok": True, "ok": True,
"timestamp": int(result) "timestamp": (
int(result)
if isinstance(result, int) if isinstance(result, int)
else int(time.time() * 1000), else int(time.time() * 1000)
),
}, },
) )
except Exception as exc: except Exception as exc:
@@ -1248,7 +1262,9 @@ class SignalClient(ClientBase):
if _digits_only(getattr(row, "identifier", "")) in candidate_digits if _digits_only(getattr(row, "identifier", "")) in candidate_digits
] ]
async def _auto_link_single_user_signal_identifier(self, source_uuid: str, source_number: str): async def _auto_link_single_user_signal_identifier(
self, source_uuid: str, source_number: str
):
owner_rows = await sync_to_async(list)( owner_rows = await sync_to_async(list)(
PersonIdentifier.objects.filter(service=self.service) PersonIdentifier.objects.filter(service=self.service)
.select_related("user") .select_related("user")
@@ -1292,7 +1308,9 @@ class SignalClient(ClientBase):
payload = json.loads(raw_message or "{}") payload = json.loads(raw_message or "{}")
except Exception: except Exception:
return return
exception_payload = payload.get("exception") if isinstance(payload, dict) else None exception_payload = (
payload.get("exception") if isinstance(payload, dict) else None
)
if isinstance(exception_payload, dict): if isinstance(exception_payload, dict):
err_type = str(exception_payload.get("type") or "").strip() err_type = str(exception_payload.get("type") or "").strip()
err_msg = str(exception_payload.get("message") or "").strip() err_msg = str(exception_payload.get("message") or "").strip()
@@ -1322,7 +1340,9 @@ class SignalClient(ClientBase):
(envelope.get("timestamp") if isinstance(envelope, dict) else 0) (envelope.get("timestamp") if isinstance(envelope, dict) else 0)
or int(time.time() * 1000) or int(time.time() * 1000)
), ),
last_inbound_exception_account=str(payload.get("account") or "").strip(), last_inbound_exception_account=str(
payload.get("account") or ""
).strip(),
last_inbound_exception_source_uuid=envelope_source_uuid, last_inbound_exception_source_uuid=envelope_source_uuid,
last_inbound_exception_source_number=envelope_source_number, last_inbound_exception_source_number=envelope_source_number,
last_inbound_exception_envelope_ts=envelope_ts, last_inbound_exception_envelope_ts=envelope_ts,
@@ -1346,7 +1366,11 @@ class SignalClient(ClientBase):
raw_text = sync_sent_message.get("message") raw_text = sync_sent_message.get("message")
if isinstance(raw_text, dict): if isinstance(raw_text, dict):
text = _extract_signal_text( text = _extract_signal_text(
{"envelope": {"syncMessage": {"sentMessage": {"message": raw_text}}}}, {
"envelope": {
"syncMessage": {"sentMessage": {"message": raw_text}}
}
},
str( str(
raw_text.get("message") raw_text.get("message")
or raw_text.get("text") or raw_text.get("text")
@@ -1396,9 +1420,15 @@ class SignalClient(ClientBase):
source_service="signal", source_service="signal",
actor=(source_uuid or source_number or ""), actor=(source_uuid or source_number or ""),
target_author=str( target_author=str(
(reaction_payload.get("raw") or {}).get("targetAuthorUuid") (reaction_payload.get("raw") or {}).get(
or (reaction_payload.get("raw") or {}).get("targetAuthor") "targetAuthorUuid"
or (reaction_payload.get("raw") or {}).get("targetAuthorNumber") )
or (reaction_payload.get("raw") or {}).get(
"targetAuthor"
)
or (reaction_payload.get("raw") or {}).get(
"targetAuthorNumber"
)
or "" or ""
), ),
remove=bool(reaction_payload.get("remove")), remove=bool(reaction_payload.get("remove")),
@@ -1505,7 +1535,9 @@ class SignalClient(ClientBase):
source_chat_id = destination_number or destination_uuid or sender_key source_chat_id = destination_number or destination_uuid or sender_key
reply_ref = reply_sync.extract_reply_ref(self.service, payload) reply_ref = reply_sync.extract_reply_ref(self.service, payload)
for identifier in identifiers: for identifier in identifiers:
session = await history.get_chat_session(identifier.user, identifier) session = await history.get_chat_session(
identifier.user, identifier
)
reply_target = await reply_sync.resolve_reply_target( reply_target = await reply_sync.resolve_reply_target(
identifier.user, identifier.user,
session, session,
@@ -1552,13 +1584,19 @@ class SignalClient(ClientBase):
if not isinstance(data_message, dict): if not isinstance(data_message, dict):
return return
source_uuid = str(envelope.get("sourceUuid") or envelope.get("source") or "").strip() source_uuid = str(
envelope.get("sourceUuid") or envelope.get("source") or ""
).strip()
source_number = str(envelope.get("sourceNumber") or "").strip() source_number = str(envelope.get("sourceNumber") or "").strip()
bot_uuid = str(getattr(self.client, "bot_uuid", "") or "").strip() bot_uuid = str(getattr(self.client, "bot_uuid", "") or "").strip()
bot_phone = str(getattr(self.client, "phone_number", "") or "").strip() bot_phone = str(getattr(self.client, "phone_number", "") or "").strip()
if source_uuid and bot_uuid and source_uuid == bot_uuid: if source_uuid and bot_uuid and source_uuid == bot_uuid:
return return
if source_number and bot_phone and _digits_only(source_number) == _digits_only(bot_phone): if (
source_number
and bot_phone
and _digits_only(source_number) == _digits_only(bot_phone)
):
return return
identifiers = await self._resolve_signal_identifiers(source_uuid, source_number) identifiers = await self._resolve_signal_identifiers(source_uuid, source_number)
@@ -1610,14 +1648,18 @@ class SignalClient(ClientBase):
target_author=str( target_author=str(
(reaction_payload.get("raw") or {}).get("targetAuthorUuid") (reaction_payload.get("raw") or {}).get("targetAuthorUuid")
or (reaction_payload.get("raw") or {}).get("targetAuthor") or (reaction_payload.get("raw") or {}).get("targetAuthor")
or (reaction_payload.get("raw") or {}).get("targetAuthorNumber") or (reaction_payload.get("raw") or {}).get(
"targetAuthorNumber"
)
or "" or ""
), ),
remove=bool(reaction_payload.get("remove")), remove=bool(reaction_payload.get("remove")),
payload=reaction_payload.get("raw") or {}, payload=reaction_payload.get("raw") or {},
) )
except Exception as exc: except Exception as exc:
self.log.warning("signal raw reaction history apply failed: %s", exc) self.log.warning(
"signal raw reaction history apply failed: %s", exc
)
try: try:
await self.ur.xmpp.client.apply_external_reaction( await self.ur.xmpp.client.apply_external_reaction(
identifier.user, identifier.user,
@@ -1631,7 +1673,9 @@ class SignalClient(ClientBase):
payload=reaction_payload.get("raw") or {}, payload=reaction_payload.get("raw") or {},
) )
except Exception as exc: except Exception as exc:
self.log.warning("signal raw reaction relay to XMPP failed: %s", exc) self.log.warning(
"signal raw reaction relay to XMPP failed: %s", exc
)
transport.update_runtime_state( transport.update_runtime_state(
self.service, self.service,
last_inbound_ok_ts=int(time.time() * 1000), last_inbound_ok_ts=int(time.time() * 1000),
@@ -1683,7 +1727,9 @@ class SignalClient(ClientBase):
) )
return return
text = _extract_signal_text(payload, str(data_message.get("message") or "").strip()) text = _extract_signal_text(
payload, str(data_message.get("message") or "").strip()
)
if not text: if not text:
return return
@@ -1702,7 +1748,11 @@ class SignalClient(ClientBase):
or envelope.get("timestamp") or envelope.get("timestamp")
or ts or ts
).strip() ).strip()
sender_key = source_uuid or source_number or (identifiers[0].identifier if identifiers else "") sender_key = (
source_uuid
or source_number
or (identifiers[0].identifier if identifiers else "")
)
source_chat_id = source_number or source_uuid or sender_key source_chat_id = source_number or source_uuid or sender_key
reply_ref = reply_sync.extract_reply_ref(self.service, payload) reply_ref = reply_sync.extract_reply_ref(self.service, payload)

View File

@@ -927,7 +927,11 @@ async def send_reaction(
service_key = _service_key(service) service_key = _service_key(service)
if _capability_checks_enabled() and not supports(service_key, "reactions"): if _capability_checks_enabled() and not supports(service_key, "reactions"):
reason = unsupported_reason(service_key, "reactions") reason = unsupported_reason(service_key, "reactions")
log.warning("capability-check failed service=%s feature=reactions: %s", service_key, reason) log.warning(
"capability-check failed service=%s feature=reactions: %s",
service_key,
reason,
)
return False return False
if not str(emoji or "").strip() and not remove: if not str(emoji or "").strip() and not remove:
return False return False

View File

@@ -173,9 +173,7 @@ class WhatsAppClient(ClientBase):
if db_dir: if db_dir:
os.makedirs(db_dir, exist_ok=True) os.makedirs(db_dir, exist_ok=True)
if db_dir and not os.access(db_dir, os.W_OK): if db_dir and not os.access(db_dir, os.W_OK):
raise PermissionError( raise PermissionError(f"session db directory is not writable: {db_dir}")
f"session db directory is not writable: {db_dir}"
)
except Exception as exc: except Exception as exc:
self._publish_state( self._publish_state(
connected=False, connected=False,
@@ -772,9 +770,11 @@ class WhatsAppClient(ClientBase):
command_id, command_id,
{ {
"ok": True, "ok": True,
"timestamp": int(result) "timestamp": (
int(result)
if isinstance(result, int) if isinstance(result, int)
else int(time.time() * 1000), else int(time.time() * 1000)
),
}, },
) )
self.log.debug( self.log.debug(
@@ -1910,9 +1910,7 @@ class WhatsAppClient(ClientBase):
jid_value = self._jid_to_identifier( jid_value = self._jid_to_identifier(
self._pluck(group, "JID") or self._pluck(group, "jid") self._pluck(group, "JID") or self._pluck(group, "jid")
) )
identifier = ( identifier = jid_value.split("@", 1)[0].strip() if jid_value else ""
jid_value.split("@", 1)[0].strip() if jid_value else ""
)
if not identifier: if not identifier:
continue continue
name = ( name = (
@@ -2362,12 +2360,22 @@ class WhatsAppClient(ClientBase):
node = ( node = (
self._pluck(message_obj, "reactionMessage") self._pluck(message_obj, "reactionMessage")
or self._pluck(message_obj, "reaction_message") or self._pluck(message_obj, "reaction_message")
or self._pluck(message_obj, "ephemeralMessage", "message", "reactionMessage") or self._pluck(
or self._pluck(message_obj, "ephemeral_message", "message", "reaction_message") message_obj, "ephemeralMessage", "message", "reactionMessage"
)
or self._pluck(
message_obj, "ephemeral_message", "message", "reaction_message"
)
or self._pluck(message_obj, "viewOnceMessage", "message", "reactionMessage") or self._pluck(message_obj, "viewOnceMessage", "message", "reactionMessage")
or self._pluck(message_obj, "view_once_message", "message", "reaction_message") or self._pluck(
or self._pluck(message_obj, "viewOnceMessageV2", "message", "reactionMessage") message_obj, "view_once_message", "message", "reaction_message"
or self._pluck(message_obj, "view_once_message_v2", "message", "reaction_message") )
or self._pluck(
message_obj, "viewOnceMessageV2", "message", "reactionMessage"
)
or self._pluck(
message_obj, "view_once_message_v2", "message", "reaction_message"
)
or self._pluck( or self._pluck(
message_obj, message_obj,
"viewOnceMessageV2Extension", "viewOnceMessageV2Extension",
@@ -2410,7 +2418,9 @@ class WhatsAppClient(ClientBase):
explicit_remove = self._pluck(node, "remove") or self._pluck(node, "isRemove") explicit_remove = self._pluck(node, "remove") or self._pluck(node, "isRemove")
if explicit_remove is None: if explicit_remove is None:
explicit_remove = self._pluck(node, "is_remove") explicit_remove = self._pluck(node, "is_remove")
remove = bool(explicit_remove) if explicit_remove is not None else bool(not emoji) remove = (
bool(explicit_remove) if explicit_remove is not None else bool(not emoji)
)
if not target_msg_id: if not target_msg_id:
return None return None
return { return {
@@ -2418,7 +2428,11 @@ class WhatsAppClient(ClientBase):
"target_message_id": target_msg_id, "target_message_id": target_msg_id,
"remove": remove, "remove": remove,
"target_ts": int(target_ts or 0), "target_ts": int(target_ts or 0),
"raw": self._proto_to_dict(node) or dict(node or {}) if isinstance(node, dict) else {}, "raw": (
self._proto_to_dict(node) or dict(node or {})
if isinstance(node, dict)
else {}
),
} }
async def _download_event_media(self, event): async def _download_event_media(self, event):
@@ -2760,7 +2774,9 @@ class WhatsAppClient(ClientBase):
or self._pluck(msg_obj, "MessageContextInfo") or self._pluck(msg_obj, "MessageContextInfo")
or {}, or {},
"message": { "message": {
"extendedTextMessage": self._pluck(msg_obj, "extendedTextMessage") "extendedTextMessage": self._pluck(
msg_obj, "extendedTextMessage"
)
or self._pluck(msg_obj, "ExtendedTextMessage") or self._pluck(msg_obj, "ExtendedTextMessage")
or {}, or {},
"imageMessage": self._pluck(msg_obj, "imageMessage") or {}, "imageMessage": self._pluck(msg_obj, "imageMessage") or {},
@@ -2814,7 +2830,9 @@ class WhatsAppClient(ClientBase):
or {}, or {},
"viewOnceMessage": self._pluck(msg_obj, "viewOnceMessage") "viewOnceMessage": self._pluck(msg_obj, "viewOnceMessage")
or {}, or {},
"viewOnceMessageV2": self._pluck(msg_obj, "viewOnceMessageV2") "viewOnceMessageV2": self._pluck(
msg_obj, "viewOnceMessageV2"
)
or {}, or {},
"viewOnceMessageV2Extension": self._pluck( "viewOnceMessageV2Extension": self._pluck(
msg_obj, "viewOnceMessageV2Extension" msg_obj, "viewOnceMessageV2Extension"
@@ -2840,12 +2858,12 @@ class WhatsAppClient(ClientBase):
reply_sync.extract_origin_tag(payload), reply_sync.extract_origin_tag(payload),
) )
if self._chat_matches_reply_debug(chat): if self._chat_matches_reply_debug(chat):
info_obj = self._proto_to_dict(self._pluck(event_obj, "Info")) or self._pluck( info_obj = self._proto_to_dict(
event_obj, "Info" self._pluck(event_obj, "Info")
) ) or self._pluck(event_obj, "Info")
raw_obj = self._proto_to_dict(self._pluck(event_obj, "Raw")) or self._pluck( raw_obj = self._proto_to_dict(
event_obj, "Raw" self._pluck(event_obj, "Raw")
) ) or self._pluck(event_obj, "Raw")
message_meta["wa_reply_debug"] = { message_meta["wa_reply_debug"] = {
"reply_ref": reply_ref, "reply_ref": reply_ref,
"reply_target_id": str(getattr(reply_target, "id", "") or ""), "reply_target_id": str(getattr(reply_target, "id", "") or ""),
@@ -3087,9 +3105,11 @@ class WhatsAppClient(ClientBase):
) )
matched = False matched = False
for candidate in candidates: for candidate in candidates:
candidate_local = str(self._jid_to_identifier(candidate) or "").split( candidate_local = (
"@", 1 str(self._jid_to_identifier(candidate) or "")
)[0].strip() .split("@", 1)[0]
.strip()
)
if candidate_local and candidate_local == local: if candidate_local and candidate_local == local:
matched = True matched = True
break break
@@ -3124,7 +3144,12 @@ class WhatsAppClient(ClientBase):
# WhatsApp group ids are numeric and usually very long (commonly start # WhatsApp group ids are numeric and usually very long (commonly start
# with 120...). Treat those as groups when no explicit mapping exists. # with 120...). Treat those as groups when no explicit mapping exists.
digits = re.sub(r"[^0-9]", "", local) digits = re.sub(r"[^0-9]", "", local)
if digits and digits == local and len(digits) >= 15 and digits.startswith("120"): if (
digits
and digits == local
and len(digits) >= 15
and digits.startswith("120")
):
return f"{digits}@g.us" return f"{digits}@g.us"
return "" return ""
@@ -3264,7 +3289,9 @@ class WhatsAppClient(ClientBase):
person_identifier = await sync_to_async( person_identifier = await sync_to_async(
lambda: ( lambda: (
Message.objects.filter(id=legacy_message_id) Message.objects.filter(id=legacy_message_id)
.select_related("session__identifier__user", "session__identifier__person") .select_related(
"session__identifier__user", "session__identifier__person"
)
.first() .first()
) )
)() )()
@@ -3274,7 +3301,9 @@ class WhatsAppClient(ClientBase):
) )
if ( if (
person_identifier is not None person_identifier is not None
and str(getattr(person_identifier, "service", "") or "").strip().lower() and str(getattr(person_identifier, "service", "") or "")
.strip()
.lower()
!= "whatsapp" != "whatsapp"
): ):
person_identifier = None person_identifier = None
@@ -3418,6 +3447,8 @@ class WhatsAppClient(ClientBase):
from neonize.proto.waE2E.WAWebProtobufsE2E_pb2 import ( from neonize.proto.waE2E.WAWebProtobufsE2E_pb2 import (
ContextInfo, ContextInfo,
ExtendedTextMessage, ExtendedTextMessage,
)
from neonize.proto.waE2E.WAWebProtobufsE2E_pb2 import (
Message as WAProtoMessage, Message as WAProtoMessage,
) )

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import base64 import base64
import json import json
import mimetypes import logging
import os import os
import re import re
import time import time
@@ -17,6 +17,8 @@ from slixmpp.componentxmpp import ComponentXMPP
from slixmpp.plugins.xep_0085.stanza import Active, Composing, Gone, Inactive, Paused from slixmpp.plugins.xep_0085.stanza import Active, Composing, Gone, Inactive, Paused
from slixmpp.stanza import Message from slixmpp.stanza import Message
from slixmpp.xmlstream import register_stanza_plugin from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.xmlstream.handler import Callback
from slixmpp.xmlstream.matcher import StanzaPath
from slixmpp.xmlstream.stanzabase import ET from slixmpp.xmlstream.stanzabase import ET
from core.clients import ClientBase, transport from core.clients import ClientBase, transport
@@ -55,6 +57,9 @@ EMOJI_ONLY_PATTERN = re.compile(
r"^[\U0001F300-\U0001FAFF\u2600-\u27BF\uFE0F\u200D\u2640-\u2642\u2764]+$" r"^[\U0001F300-\U0001FAFF\u2600-\u27BF\uFE0F\u200D\u2640-\u2642\u2764]+$"
) )
TOTP_BASE32_SECRET_RE = re.compile(r"^[A-Z2-7]{16,}$") TOTP_BASE32_SECRET_RE = re.compile(r"^[A-Z2-7]{16,}$")
PUBSUB_NS = "http://jabber.org/protocol/pubsub"
OMEMO_OLD_NS = "eu.siacs.conversations.axolotl"
OMEMO_OLD_DEVICELIST_NODE = "eu.siacs.conversations.axolotl.devicelist"
def _clean_url(value): def _clean_url(value):
@@ -147,6 +152,7 @@ def _parse_greentext_reaction(body_text):
def _omemo_plugin_available() -> bool: def _omemo_plugin_available() -> bool:
try: try:
import importlib import importlib
return importlib.util.find_spec("slixmpp_omemo") is not None return importlib.util.find_spec("slixmpp_omemo") is not None
except Exception: except Exception:
return False return False
@@ -166,15 +172,30 @@ def _extract_sender_omemo_client_key(stanza) -> dict:
return {"status": "no_omemo"} return {"status": "no_omemo"}
def _format_omemo_identity_fingerprint(identity_key) -> str:
if isinstance(identity_key, (bytes, bytearray)):
key_bytes = bytes(identity_key)
else:
return ""
try:
from omemo.session_manager import SessionManager as _OmemoSessionManager
return " ".join(_OmemoSessionManager.format_identity_key(key_bytes)).upper()
except Exception:
return ":".join(f"{b:02X}" for b in key_bytes)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# OMEMO storage + plugin implementation # OMEMO storage + plugin implementation
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
try: try:
from omemo.storage import Just, Maybe, Nothing, Storage as _OmemoStorageBase from omemo.storage import Just, Maybe, Nothing
from omemo.storage import Storage as _OmemoStorageBase
from slixmpp.plugins.base import register_plugin as _slixmpp_register_plugin
from slixmpp_omemo import XEP_0384 as _XEP_0384Base from slixmpp_omemo import XEP_0384 as _XEP_0384Base
from slixmpp_omemo.base_session_manager import TrustLevel as _OmemoTrustLevel from slixmpp_omemo.base_session_manager import TrustLevel as _OmemoTrustLevel
from slixmpp.plugins.base import register_plugin as _slixmpp_register_plugin
_OMEMO_AVAILABLE = True _OMEMO_AVAILABLE = True
except ImportError: except ImportError:
_OMEMO_AVAILABLE = False _OMEMO_AVAILABLE = False
@@ -185,6 +206,7 @@ except ImportError:
if _OMEMO_AVAILABLE: if _OMEMO_AVAILABLE:
class _OmemoStorage(_OmemoStorageBase): class _OmemoStorage(_OmemoStorageBase):
"""JSON-file-backed OMEMO key storage.""" """JSON-file-backed OMEMO key storage."""
@@ -224,7 +246,14 @@ if _OMEMO_AVAILABLE:
name = "xep_0384" name = "xep_0384"
description = "OMEMO Encryption (GIA gateway)" description = "OMEMO Encryption (GIA gateway)"
dependencies = {"xep_0004", "xep_0030", "xep_0060", "xep_0163", "xep_0280", "xep_0334"} dependencies = {
"xep_0004",
"xep_0030",
"xep_0060",
"xep_0163",
"xep_0280",
"xep_0334",
}
default_config = { default_config = {
"fallback_message": "This message is OMEMO encrypted.", "fallback_message": "This message is OMEMO encrypted.",
"data_dir": "", "data_dir": "",
@@ -248,6 +277,7 @@ if _OMEMO_AVAILABLE:
async def _devices_blindly_trusted(self, blindly_trusted, identifier): async def _devices_blindly_trusted(self, blindly_trusted, identifier):
import logging import logging
logging.getLogger(__name__).info( logging.getLogger(__name__).info(
"OMEMO: blindly trusted %d new device(s)", len(blindly_trusted) "OMEMO: blindly trusted %d new device(s)", len(blindly_trusted)
) )
@@ -255,6 +285,7 @@ if _OMEMO_AVAILABLE:
async def _prompt_manual_trust(self, manually_trusted, identifier): async def _prompt_manual_trust(self, manually_trusted, identifier):
"""Auto-trust all undecided devices (gateway mode).""" """Auto-trust all undecided devices (gateway mode)."""
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.info( log.info(
"OMEMO: auto-trusting %d undecided device(s) (gateway mode)", "OMEMO: auto-trusting %d undecided device(s) (gateway mode)",
@@ -270,11 +301,12 @@ if _OMEMO_AVAILABLE:
_OmemoTrustLevel.BLINDLY_TRUSTED.value, _OmemoTrustLevel.BLINDLY_TRUSTED.value,
) )
except Exception as exc: except Exception as exc:
log.warning("OMEMO set_trust failed for %s: %s", device.bare_jid, exc) log.warning(
"OMEMO set_trust failed for %s: %s", device.bare_jid, exc
)
class XMPPComponent(ComponentXMPP): class XMPPComponent(ComponentXMPP):
""" """
A simple Slixmpp component that echoes messages. A simple Slixmpp component that echoes messages.
""" """
@@ -289,6 +321,7 @@ class XMPPComponent(ComponentXMPP):
self._session_live = False self._session_live = False
self.log = logs.get_logger("XMPP") self.log = logs.get_logger("XMPP")
logging.getLogger("slixmpp_omemo").setLevel(logging.DEBUG)
super().__init__(jid, secret, server, port) super().__init__(jid, secret, server, port)
# Enable message IDs so the OMEMO plugin can associate encrypted stanzas. # Enable message IDs so the OMEMO plugin can associate encrypted stanzas.
@@ -305,6 +338,20 @@ class XMPPComponent(ComponentXMPP):
self.add_event_handler("session_start", self.session_start) self.add_event_handler("session_start", self.session_start)
self.add_event_handler("disconnected", self.on_disconnected) self.add_event_handler("disconnected", self.on_disconnected)
self.add_event_handler("message", self.message) self.add_event_handler("message", self.message)
self.register_handler(
Callback(
"OMEMOPubSubItemsGet",
StanzaPath("iq/pubsub/items"),
self._handle_pubsub_iq_get,
)
)
self.register_handler(
Callback(
"OMEMOPubSubPublishSet",
StanzaPath("iq/pubsub/publish"),
self._handle_pubsub_iq_set,
)
)
# Presence event handlers # Presence event handlers
self.add_event_handler("presence_available", self.on_presence_available) self.add_event_handler("presence_available", self.on_presence_available)
@@ -327,6 +374,137 @@ class XMPPComponent(ComponentXMPP):
self.add_event_handler("chatstate_paused", self.on_chatstate_paused) self.add_event_handler("chatstate_paused", self.on_chatstate_paused)
self.add_event_handler("chatstate_inactive", self.on_chatstate_inactive) self.add_event_handler("chatstate_inactive", self.on_chatstate_inactive)
self.add_event_handler("chatstate_gone", self.on_chatstate_gone) self.add_event_handler("chatstate_gone", self.on_chatstate_gone)
self._omemo_component_pubsub = self._seed_component_omemo_pubsub()
@staticmethod
def _clone_xml(node):
return ET.fromstring(ET.tostring(node, encoding="unicode"))
def _seed_component_omemo_pubsub(self):
return {}
def _sync_component_omemo_pubsub(self, device_id, identity_key):
device_id = str(device_id or "").strip()
if not device_id:
return
if isinstance(identity_key, (bytes, bytearray)):
key_bytes = bytes(identity_key)
else:
key_bytes = str(identity_key or "").encode()
identity_b64 = base64.b64encode(key_bytes).decode()
devicelist = ET.Element(f"{{{OMEMO_OLD_NS}}}list")
dev = ET.SubElement(devicelist, f"{{{OMEMO_OLD_NS}}}device")
dev.set("id", device_id)
bundle = ET.Element(f"{{{OMEMO_OLD_NS}}}bundle")
ik = ET.SubElement(bundle, f"{{{OMEMO_OLD_NS}}}identityKey")
ik.text = identity_b64
self._omemo_component_pubsub = {
OMEMO_OLD_DEVICELIST_NODE: {
"item_id": "current",
"payload": devicelist,
},
f"eu.siacs.conversations.axolotl.bundles:{device_id}": {
"item_id": "current",
"payload": bundle,
},
}
async def _reply_pubsub_items(self, iq, node_name):
reply = iq.reply()
pubsub = ET.Element(f"{{{PUBSUB_NS}}}pubsub")
items = ET.SubElement(pubsub, f"{{{PUBSUB_NS}}}items")
items.set("node", node_name)
stored = self._omemo_component_pubsub.get(node_name)
if stored is None and node_name.startswith(
"eu.siacs.conversations.axolotl.bundles:"
):
bundle_nodes = [
key
for key in self._omemo_component_pubsub
if key.startswith("eu.siacs.conversations.axolotl.bundles:")
]
if len(bundle_nodes) == 1:
stored = self._omemo_component_pubsub.get(bundle_nodes[0])
if stored:
item = ET.SubElement(items, f"{{{PUBSUB_NS}}}item")
item.set("id", str(stored.get("item_id") or "current"))
payload = stored.get("payload")
if payload is not None:
item.append(self._clone_xml(payload))
reply.append(pubsub)
await reply.send()
async def on_iq_get(self, iq):
try:
iq_to = str(iq["to"] or "").split("/", 1)[0].strip().lower()
except Exception:
iq_to = ""
own_jid = str(getattr(self.boundjid, "bare", "") or "").strip().lower()
if iq_to and own_jid and iq_to != own_jid:
return
items = iq.xml.find(f".//{{{PUBSUB_NS}}}items")
if items is None:
return
node_name = str(items.attrib.get("node") or "").strip()
if not node_name:
return
self.log.debug(
"OMEMO IQ-GET pubsub items node=%s from=%s to=%s",
node_name,
iq.get("from"),
iq.get("to"),
)
if (
node_name.startswith("eu.siacs.conversations.axolotl.bundles:")
or node_name == OMEMO_OLD_DEVICELIST_NODE
or node_name == "urn:xmpp:omemo:2:devices"
or node_name == "urn:xmpp:omemo:2:bundles"
):
await self._reply_pubsub_items(iq, node_name)
async def on_iq_set(self, iq):
try:
iq_to = str(iq["to"] or "").split("/", 1)[0].strip().lower()
except Exception:
iq_to = ""
own_jid = str(getattr(self.boundjid, "bare", "") or "").strip().lower()
if iq_to and own_jid and iq_to != own_jid:
return
publish = iq.xml.find(f".//{{{PUBSUB_NS}}}publish")
if publish is None:
return
node_name = str(publish.attrib.get("node") or "").strip()
if not node_name:
return
self.log.debug(
"OMEMO IQ-SET pubsub publish node=%s from=%s to=%s",
node_name,
iq.get("from"),
iq.get("to"),
)
item = publish.find(f"{{{PUBSUB_NS}}}item")
payload = None
item_id = "current"
if item is not None:
item_id = str(item.attrib.get("id") or "current")
for child in list(item):
payload = child
break
if payload is not None:
self._omemo_component_pubsub[node_name] = {
"item_id": item_id,
"payload": self._clone_xml(payload),
}
await iq.reply().send()
def _handle_pubsub_iq_get(self, iq):
asyncio.create_task(self.on_iq_get(iq))
def _handle_pubsub_iq_set(self, iq):
asyncio.create_task(self.on_iq_set(iq))
def _user_xmpp_domain(self): def _user_xmpp_domain(self):
domain = str(getattr(settings, "XMPP_USER_DOMAIN", "") or "").strip() domain = str(getattr(settings, "XMPP_USER_DOMAIN", "") or "").strip()
@@ -354,8 +532,6 @@ class XMPPComponent(ComponentXMPP):
self.log.error(f"Failed to enable Carbons: {e}") self.log.error(f"Failed to enable Carbons: {e}")
def get_identifier(self, msg): def get_identifier(self, msg):
xmpp_message_id = str(msg.get("id") or "").strip()
# Extract sender JID (full format: user@domain/resource) # Extract sender JID (full format: user@domain/resource)
sender_jid = str(msg["from"]) sender_jid = str(msg["from"])
@@ -445,6 +621,7 @@ class XMPPComponent(ComponentXMPP):
def _derived_omemo_fingerprint(self, jid: str) -> str: def _derived_omemo_fingerprint(self, jid: str) -> str:
import hashlib import hashlib
return hashlib.sha256(f"xmpp-omemo-key:{jid}".encode()).hexdigest()[:32] return hashlib.sha256(f"xmpp-omemo-key:{jid}".encode()).hexdigest()[:32]
def _get_omemo_plugin(self): def _get_omemo_plugin(self):
@@ -456,26 +633,198 @@ class XMPPComponent(ComponentXMPP):
async def _bootstrap_omemo_for_authentic_channel(self): async def _bootstrap_omemo_for_authentic_channel(self):
jid = str(getattr(settings, "XMPP_JID", "") or "").strip() jid = str(getattr(settings, "XMPP_JID", "") or "").strip()
omemo_plugin = self._get_omemo_plugin() omemo_plugin = self._get_omemo_plugin()
omemo_enabled = omemo_plugin is not None omemo_enabled = omemo_plugin is not None
status = "active" if omemo_enabled else "not_available" status = "active" if omemo_enabled else "not_available"
reason = "OMEMO plugin active" if omemo_enabled else "xep_0384 plugin not loaded" reason = (
"OMEMO plugin active" if omemo_enabled else "xep_0384 plugin not loaded"
)
fingerprint = self._derived_omemo_fingerprint(jid) fingerprint = self._derived_omemo_fingerprint(jid)
if omemo_enabled: if omemo_enabled:
try: try:
import asyncio as _asyncio import asyncio as _asyncio
# OMEMO session-manager bootstrap can take longer in component mode because
# it performs consistency checks and initial pubsub interactions before
# device publication is complete.
session_manager = await _asyncio.wait_for( session_manager = await _asyncio.wait_for(
omemo_plugin.get_session_manager(), timeout=15.0 omemo_plugin.get_session_manager(), timeout=180.0
) )
own_devices = await session_manager.get_own_device_information() own_devices = await session_manager.get_own_device_information()
device_id = None
if own_devices: if own_devices:
key_bytes = own_devices[0].identity_key # own_devices is a tuple: (own_device, other_devices)
own_device = (
own_devices[0]
if isinstance(own_devices, (tuple, list))
else own_devices
)
key_bytes = own_device.identity_key
try:
from omemo.session_manager import (
SessionManager as _OmemoSessionManager,
)
fingerprint = " ".join(
_OmemoSessionManager.format_identity_key(key_bytes)
).upper()
except Exception:
fingerprint = ":".join(f"{b:02X}" for b in key_bytes) fingerprint = ":".join(f"{b:02X}" for b in key_bytes)
device_id = own_device.device_id
self._sync_component_omemo_pubsub(device_id, key_bytes)
self.log.info(
"OMEMO: own device created, device_id=%s fingerprint=%s",
device_id,
fingerprint,
)
else:
# Fallback: session manager may not auto-create devices for component JIDs
# Manually generate a device ID and use it
import random
device_id = random.randint(1, 2**31 - 1)
self.log.warning(
"OMEMO: session manager did not create device (component JID limitation), "
"using fallback device_id=%s",
device_id,
)
# CRITICAL FIX: Publish BOTH device list AND device bundle
# Clients need both:
# 1. Device list node: eu.siacs.conversations.axolotl/devices/{jid} (lists device IDs)
# 2. Device bundle nodes: eu.siacs.conversations.axolotl/bundles:{device_id} (contains keys)
# Without bundles, clients can't encrypt (they hang in "pending...")
if device_id:
try:
namespace = "eu.siacs.conversations.axolotl"
pubsub_service = (
str(getattr(settings, "XMPP_PUBSUB_SERVICE", "pubsub"))
or "pubsub"
)
# Step 1: Publish device list
try:
device_list_dict = {device_id: None}
await session_manager._upload_device_list(
namespace, device_list_dict
)
self.log.info(
"OMEMO: device list uploaded via session manager for %s",
jid,
)
except (AttributeError, TypeError):
# Fallback: Manual publish via XEP-0060
try:
device_item = ET.Element("list")
device_item.set("xmlns", namespace)
for dev_id in device_list_dict:
device_elem = ET.SubElement(device_item, "device")
device_elem.set("id", str(dev_id))
node_name = f"{namespace}/devices/{jid}"
await self["xep_0060"].publish(
pubsub_service, node_name, payload=device_item
)
self.log.info(
"OMEMO: device list published via XEP-0060 for %s",
jid,
)
except Exception as e:
self.log.warning(
"OMEMO: device list publish failed: %s", e
)
# Step 2: Publish device bundle (CRITICAL - contains the keys!)
# This is what was missing - clients couldn't find the keys
try:
if own_devices:
# Get actual identity key from device
own_device = (
own_devices[0]
if isinstance(own_devices, (tuple, list))
else own_devices
)
identity_key = own_device.identity_key
signed_prekey = getattr(
own_device, "signed_prekey", None
)
# Build bundle item with actual keys
bundle_item = ET.Element("bundle")
bundle_item.set("xmlns", namespace)
# Add signed prekey
if signed_prekey:
spk_elem = ET.SubElement(
bundle_item, "signedPreKeyPublic"
)
spk_elem.text = signed_prekey
else:
# Fallback: use hash of identity key
import base64
spk_elem = ET.SubElement(
bundle_item, "signedPreKeyPublic"
)
spk_elem.text = base64.b64encode(
identity_key
).decode()
# Add identity key
ik_elem = ET.SubElement(bundle_item, "identityKey")
import base64
ik_elem.text = base64.b64encode(identity_key).decode()
# Publish bundle
bundle_node = f"{namespace}/bundles:{device_id}"
await self["xep_0060"].publish(
pubsub_service, bundle_node, payload=bundle_item
)
self.log.info(
"OMEMO: device bundle published for %s (device_id=%s)",
jid,
device_id,
)
except Exception as bundle_exc:
self.log.warning(
"OMEMO: device bundle publish failed: %s", bundle_exc
)
except Exception as upload_exc:
self.log.warning(
"OMEMO: device/bundle upload error: %s", upload_exc
)
# Try to refresh device list to ensure all devices are properly registered.
try:
await session_manager.refresh_device_lists(jid)
self.log.info("OMEMO: device list refreshed for %s", jid)
except Exception as refresh_exc:
self.log.debug("OMEMO: device list refresh: %s", refresh_exc)
except _asyncio.TimeoutError:
self.log.error(
"OMEMO: session manager initialization timeout after 180s. "
"Device list may not be published to PubSub. "
"Clients will not be able to discover gateway devices. "
"Check PubSub server connectivity and latency."
)
status = "timeout"
reason = (
"Session manager initialization timeout - device list not published"
)
except Exception as exc: except Exception as exc:
self.log.warning("OMEMO: could not read own device fingerprint: %s", exc) self.log.warning(
"OMEMO: could not initialize device information: %s", exc
)
self.log.info( self.log.info(
"OMEMO bootstrap: jid=%s enabled=%s status=%s fingerprint=%s", "OMEMO bootstrap: jid=%s enabled=%s status=%s fingerprint=%s",
jid, omemo_enabled, status, fingerprint, jid,
omemo_enabled,
status,
fingerprint,
) )
transport.update_runtime_state( transport.update_runtime_state(
"xmpp", "xmpp",
@@ -486,20 +835,42 @@ class XMPPComponent(ComponentXMPP):
omemo_status_reason=reason, omemo_status_reason=reason,
) )
async def _record_sender_omemo_state(self, user, *, sender_jid, recipient_jid, message_stanza): async def _record_sender_omemo_state(
self,
user,
*,
sender_jid,
recipient_jid,
message_stanza,
sender_fingerprint="",
):
parsed = _extract_sender_omemo_client_key(message_stanza) parsed = _extract_sender_omemo_client_key(message_stanza)
status = str(parsed.get("status") or "no_omemo") status = str(parsed.get("status") or "no_omemo")
client_key = str(parsed.get("client_key") or "") client_key = str(parsed.get("client_key") or "")
await sync_to_async(UserXmppOmemoState.objects.update_or_create)(
user=user, def _save_row():
defaults={ row, _ = UserXmppOmemoState.objects.get_or_create(user=user)
"status": status, details = dict(row.details or {})
"latest_client_key": client_key, if sender_fingerprint:
"last_sender_jid": str(sender_jid or ""), details["latest_client_fingerprint"] = str(sender_fingerprint)
"last_target_jid": str(recipient_jid or ""), row.status = status
}, row.latest_client_key = client_key
row.last_sender_jid = str(sender_jid or "")
row.last_target_jid = str(recipient_jid or "")
row.details = details
row.save(
update_fields=[
"status",
"latest_client_key",
"last_sender_jid",
"last_target_jid",
"details",
"updated_at",
]
) )
await sync_to_async(_save_row)()
_approval_event_prefix = "codex_approval" _approval_event_prefix = "codex_approval"
_APPROVAL_PROVIDER_COMMANDS = { _APPROVAL_PROVIDER_COMMANDS = {
@@ -525,7 +896,9 @@ class XMPPComponent(ComponentXMPP):
run.status = "approved_waiting_resume" if status == "approved" else status run.status = "approved_waiting_resume" if status == "approved" else status
await sync_to_async(run.save)(update_fields=["status"]) await sync_to_async(run.save)(update_fields=["status"])
if request.external_sync_event_id: if request.external_sync_event_id:
evt = await sync_to_async(ExternalSyncEvent.objects.get)(pk=request.external_sync_event_id) evt = await sync_to_async(ExternalSyncEvent.objects.get)(
pk=request.external_sync_event_id
)
evt.status = "ok" evt.status = "ok"
await sync_to_async(evt.save)(update_fields=["status"]) await sync_to_async(evt.save)(update_fields=["status"])
user = await sync_to_async(User.objects.get)(pk=request.user_id) user = await sync_to_async(User.objects.get)(pk=request.user_id)
@@ -547,9 +920,9 @@ class XMPPComponent(ComponentXMPP):
async def _approval_list_pending(self, user, scope, sym): async def _approval_list_pending(self, user, scope, sym):
requests = await sync_to_async(list)( requests = await sync_to_async(list)(
CodexPermissionRequest.objects.filter( CodexPermissionRequest.objects.filter(user=user, status="pending").order_by(
user=user, status="pending" "-requested_at"
).order_by("-requested_at")[:20] )[:20]
) )
sym(f"pending={len(requests)}") sym(f"pending={len(requests)}")
for req in requests: for req in requests:
@@ -557,9 +930,9 @@ class XMPPComponent(ComponentXMPP):
async def _approval_status(self, user, approval_key, sym): async def _approval_status(self, user, approval_key, sym):
try: try:
req = await sync_to_async( req = await sync_to_async(CodexPermissionRequest.objects.get)(
CodexPermissionRequest.objects.get user=user, approval_key=approval_key
)(user=user, approval_key=approval_key) )
sym(f"status={req.status} key={req.approval_key}") sym(f"status={req.status} key={req.approval_key}")
except CodexPermissionRequest.DoesNotExist: except CodexPermissionRequest.DoesNotExist:
sym(f"approval_key_not_found:{approval_key}") sym(f"approval_key_not_found:{approval_key}")
@@ -583,7 +956,9 @@ class XMPPComponent(ComponentXMPP):
return True return True
provider = self._resolve_request_provider(req) provider = self._resolve_request_provider(req)
if not provider.startswith(expected_provider): if not provider.startswith(expected_provider):
sym(f"approval_key_not_for_provider:{approval_key} provider={provider}") sym(
f"approval_key_not_for_provider:{approval_key} provider={provider}"
)
return True return True
await self._apply_approval_decision(req, action, sym) await self._apply_approval_decision(req, action, sym)
sym(f"{action}d: {approval_key}") sym(f"{action}d: {approval_key}")
@@ -720,9 +1095,6 @@ class XMPPComponent(ComponentXMPP):
if action in {"status", "help"}: if action in {"status", "help"}:
return "" return ""
return rest return rest
compact = text.replace(" ", "").strip().upper()
if TOTP_BASE32_SECRET_RE.match(compact):
return compact
return "" return ""
async def _handle_totp_command(self, user, body, sym): async def _handle_totp_command(self, user, body, sym):
@@ -763,11 +1135,7 @@ class XMPPComponent(ComponentXMPP):
def _save_device(): def _save_device():
from django_otp.plugins.otp_totp.models import TOTPDevice from django_otp.plugins.otp_totp.models import TOTPDevice
device = ( device = TOTPDevice.objects.filter(user=user).order_by("-id").first()
TOTPDevice.objects.filter(user=user)
.order_by("-id")
.first()
)
if device is None: if device is None:
device = TOTPDevice(user=user, name="gateway") device = TOTPDevice(user=user, name="gateway")
device.key = key_bytes.hex() device.key = key_bytes.hex()
@@ -798,7 +1166,9 @@ class XMPPComponent(ComponentXMPP):
command_text = str(body or "").strip() command_text = str(body or "").strip()
async def _contacts_handler(_ctx, emit): async def _contacts_handler(_ctx, emit):
persons = await sync_to_async(list)(Person.objects.filter(user=sender_user).order_by("name")) persons = await sync_to_async(list)(
Person.objects.filter(user=sender_user).order_by("name")
)
if not persons: if not persons:
emit("No contacts found.") emit("No contacts found.")
return True return True
@@ -815,7 +1185,9 @@ class XMPPComponent(ComponentXMPP):
return True return True
async def _approval_handler(_ctx, emit): async def _approval_handler(_ctx, emit):
return await self._handle_approval_command(sender_user, command_text, sender_jid, emit) return await self._handle_approval_command(
sender_user, command_text, sender_jid, emit
)
async def _tasks_handler(_ctx, emit): async def _tasks_handler(_ctx, emit):
return await self._handle_tasks_command(sender_user, command_text, emit) return await self._handle_tasks_command(sender_user, command_text, emit)
@@ -845,7 +1217,10 @@ class XMPPComponent(ComponentXMPP):
GatewayCommandRoute( GatewayCommandRoute(
name="approval", name="approval",
scope_key="gateway.approval", scope_key="gateway.approval",
matcher=lambda text: str(text or "").strip().lower().startswith(".approval") matcher=lambda text: str(text or "")
.strip()
.lower()
.startswith(".approval")
or any( or any(
str(text or "").strip().lower().startswith(prefix + " ") str(text or "").strip().lower().startswith(prefix + " ")
or str(text or "").strip().lower() == prefix or str(text or "").strip().lower() == prefix
@@ -856,7 +1231,10 @@ class XMPPComponent(ComponentXMPP):
GatewayCommandRoute( GatewayCommandRoute(
name="tasks", name="tasks",
scope_key="gateway.tasks", scope_key="gateway.tasks",
matcher=lambda text: str(text or "").strip().lower().startswith(".tasks"), matcher=lambda text: str(text or "")
.strip()
.lower()
.startswith(".tasks"),
handler=_tasks_handler, handler=_tasks_handler,
), ),
GatewayCommandRoute( GatewayCommandRoute(
@@ -1365,16 +1743,36 @@ class XMPPComponent(ComponentXMPP):
def on_presence_subscribe(self, pres): def on_presence_subscribe(self, pres):
""" """
Handle incoming presence subscription requests. Handle incoming presence subscription requests.
Accept only if the recipient has a contact matching the sender. Accept subscriptions to:
1. The gateway component itself (jews.zm.is) - for OMEMO device discovery
2. User contacts at the gateway (user|service@jews.zm.is) - for user-to-user messaging
""" """
sender_jid = str(pres["from"]).split("/")[0] # Bare JID (user@domain) sender_jid = str(pres["from"]).split("/")[0] # Bare JID (user@domain)
recipient_jid = str(pres["to"]).split("/")[0] recipient_jid = str(pres["to"]).split("/")[0]
component_jid = str(self.boundjid.bare)
self.log.debug( self.log.debug(
f"Received subscription request from {sender_jid} to {recipient_jid}" f"Received subscription request from {sender_jid} to {recipient_jid}"
) )
# Check if subscription is to the gateway component itself
if recipient_jid == component_jid:
# Auto-accept subscriptions to the gateway component (for OMEMO)
self.log.info(
"Auto-accepting subscription to gateway component from %s", sender_jid
)
self.send_presence(ptype="subscribed", pto=sender_jid, pfrom=component_jid)
# Send presence availability to enable device discovery
self.send_presence(ptype="available", pto=sender_jid, pfrom=component_jid)
self.log.info(
"Gateway component is available to %s (OMEMO device discovery enabled)",
sender_jid,
)
return
# Otherwise, handle user-to-user subscription (existing logic)
try: try:
# Extract sender and recipient usernames # Extract sender and recipient usernames
user_username, _ = sender_jid.split("@", 1) user_username, _ = sender_jid.split("@", 1)
@@ -1400,24 +1798,24 @@ class XMPPComponent(ComponentXMPP):
self.log.debug("Resolving subscription identifier service=%s", service) self.log.debug("Resolving subscription identifier service=%s", service)
PersonIdentifier.objects.get(user=user, person=person, service=service) PersonIdentifier.objects.get(user=user, person=person, service=service)
component_jid = f"{person_name.lower()}|{service}@{self.boundjid.bare}" contact_jid = f"{person_name.lower()}|{service}@{self.boundjid.bare}"
# Accept the subscription # Accept the subscription
self.send_presence(ptype="subscribed", pto=sender_jid, pfrom=component_jid) self.send_presence(ptype="subscribed", pto=sender_jid, pfrom=contact_jid)
self.log.debug( self.log.debug(
f"Accepted subscription from {sender_jid}, sent from {component_jid}" f"Accepted subscription from {sender_jid}, sent from {contact_jid}"
) )
# Send a presence request **from the recipient to the sender** (ASKS THEM TO ACCEPT BACK) # Send a presence request **from the recipient to the sender** (ASKS THEM TO ACCEPT BACK)
# self.send_presence(ptype="subscribe", pto=sender_jid, pfrom=component_jid) # self.send_presence(ptype="subscribe", pto=sender_jid, pfrom=contact_jid)
# Add sender to roster # Add sender to roster
# self.update_roster(sender_jid, name=sender_jid.split("@")[0]) # self.update_roster(sender_jid, name=sender_jid.split("@")[0])
# Send presence update to sender **from the correct JID** # Send presence update to sender **from the correct JID**
self.send_presence(ptype="available", pto=sender_jid, pfrom=component_jid) self.send_presence(ptype="available", pto=sender_jid, pfrom=contact_jid)
self.log.debug( self.log.debug(
"Sent presence update from %s to %s", component_jid, sender_jid "Sent presence update from %s to %s", contact_jid, sender_jid
) )
except (User.DoesNotExist, Person.DoesNotExist, PersonIdentifier.DoesNotExist): except (User.DoesNotExist, Person.DoesNotExist, PersonIdentifier.DoesNotExist):
@@ -1645,14 +2043,20 @@ class XMPPComponent(ComponentXMPP):
# Attempt to decrypt OMEMO-encrypted messages before body extraction. # Attempt to decrypt OMEMO-encrypted messages before body extraction.
original_msg = msg original_msg = msg
omemo_plugin = self._get_omemo_plugin() omemo_plugin = self._get_omemo_plugin()
sender_omemo_fingerprint = ""
if omemo_plugin: if omemo_plugin:
try: try:
if omemo_plugin.is_encrypted(msg): if omemo_plugin.is_encrypted(msg):
decrypted, _ = await omemo_plugin.decrypt_message(msg) decrypted, sender_device = await omemo_plugin.decrypt_message(msg)
msg = decrypted msg = decrypted
sender_omemo_fingerprint = _format_omemo_identity_fingerprint(
getattr(sender_device, "identity_key", b"")
)
self.log.debug("OMEMO: decrypted message from %s", sender_jid) self.log.debug("OMEMO: decrypted message from %s", sender_jid)
except Exception as exc: except Exception as exc:
self.log.warning("OMEMO: decryption failed from %s: %s", sender_jid, exc) self.log.warning(
"OMEMO: decryption failed from %s: %s", sender_jid, exc
)
# Extract message body # Extract message body
body = msg["body"] if msg["body"] else "" body = msg["body"] if msg["body"] else ""
@@ -1678,7 +2082,9 @@ class XMPPComponent(ComponentXMPP):
or "application/octet-stream", or "application/octet-stream",
) )
except Exception as exc: except Exception as exc:
self.log.warning("xmpp dropped unsafe attachment url=%s: %s", url_value, exc) self.log.warning(
"xmpp dropped unsafe attachment url=%s: %s", url_value, exc
)
continue continue
attachments.append( attachments.append(
{ {
@@ -1723,7 +2129,9 @@ class XMPPComponent(ComponentXMPP):
content_type=_content_type_from_filename_or_url(safe_url), content_type=_content_type_from_filename_or_url(safe_url),
) )
except Exception as exc: except Exception as exc:
self.log.warning("xmpp dropped extracted unsafe url=%s: %s", url_value, exc) self.log.warning(
"xmpp dropped extracted unsafe url=%s: %s", url_value, exc
)
continue continue
attachments.append( attachments.append(
{ {
@@ -1787,6 +2195,7 @@ class XMPPComponent(ComponentXMPP):
sender_jid=sender_jid, sender_jid=sender_jid,
recipient_jid=recipient_jid, recipient_jid=recipient_jid,
message_stanza=original_msg, message_stanza=original_msg,
sender_fingerprint=sender_omemo_fingerprint,
) )
except Exception as exc: except Exception as exc:
self.log.warning("OMEMO: failed to record sender state: %s", exc) self.log.warning("OMEMO: failed to record sender state: %s", exc)
@@ -1795,8 +2204,11 @@ class XMPPComponent(ComponentXMPP):
# Enforce mandatory encryption policy. # Enforce mandatory encryption policy.
try: try:
from core.models import UserXmppSecuritySettings from core.models import UserXmppSecuritySettings
sec_settings = await sync_to_async( sec_settings = await sync_to_async(
lambda: UserXmppSecuritySettings.objects.filter(user=sender_user).first() lambda: UserXmppSecuritySettings.objects.filter(
user=sender_user
).first()
)() )()
if sec_settings and sec_settings.require_omemo: if sec_settings and sec_settings.require_omemo:
omemo_status = str(omemo_observation.get("status") or "") omemo_status = str(omemo_observation.get("status") or "")
@@ -1812,7 +2224,7 @@ class XMPPComponent(ComponentXMPP):
if recipient_jid == settings.XMPP_JID: if recipient_jid == settings.XMPP_JID:
self.log.debug("Handling command message sent to gateway JID") self.log.debug("Handling command message sent to gateway JID")
if body.startswith(".") or self._extract_totp_secret_candidate(body): if body.startswith("."):
await self._route_gateway_command( await self._route_gateway_command(
sender_user=sender_user, sender_user=sender_user,
body=body, body=body,
@@ -1824,7 +2236,9 @@ class XMPPComponent(ComponentXMPP):
"sender_jid": str(sender_jid or ""), "sender_jid": str(sender_jid or ""),
"recipient_jid": str(recipient_jid or ""), "recipient_jid": str(recipient_jid or ""),
"omemo_status": str(omemo_observation.get("status") or ""), "omemo_status": str(omemo_observation.get("status") or ""),
"omemo_client_key": str(omemo_observation.get("client_key") or ""), "omemo_client_key": str(
omemo_observation.get("client_key") or ""
),
} }
}, },
sym=sym, sym=sym,
@@ -2004,7 +2418,9 @@ class XMPPComponent(ComponentXMPP):
"sender_jid": str(sender_jid or ""), "sender_jid": str(sender_jid or ""),
"recipient_jid": str(recipient_jid or ""), "recipient_jid": str(recipient_jid or ""),
"omemo_status": str(omemo_observation.get("status") or ""), "omemo_status": str(omemo_observation.get("status") or ""),
"omemo_client_key": str(omemo_observation.get("client_key") or ""), "omemo_client_key": str(
omemo_observation.get("client_key") or ""
),
} }
}, },
) )
@@ -2127,9 +2543,7 @@ class XMPPComponent(ComponentXMPP):
f"Upload failed: {response.status} {await response.text()}" f"Upload failed: {response.status} {await response.text()}"
) )
return None return None
self.log.debug( self.log.debug("Successfully uploaded %s to %s", filename, upload_url)
"Successfully uploaded %s to %s", filename, upload_url
)
# Send XMPP message immediately after successful upload # Send XMPP message immediately after successful upload
xmpp_msg_id = await self.send_xmpp_message( xmpp_msg_id = await self.send_xmpp_message(
@@ -2145,7 +2559,13 @@ class XMPPComponent(ComponentXMPP):
return None return None
async def send_xmpp_message( async def send_xmpp_message(
self, recipient_jid, sender_jid, body_text, attachment_url=None self,
recipient_jid,
sender_jid,
body_text,
attachment_url=None,
*,
use_omemo_encryption=True,
): ):
"""Sends an XMPP message with either text or an attachment URL.""" """Sends an XMPP message with either text or an attachment URL."""
msg = self.make_message(mto=recipient_jid, mfrom=sender_jid, mtype="chat") msg = self.make_message(mto=recipient_jid, mfrom=sender_jid, mtype="chat")
@@ -2163,29 +2583,35 @@ class XMPPComponent(ComponentXMPP):
self.log.debug("Sending XMPP message: %s", msg.xml) self.log.debug("Sending XMPP message: %s", msg.xml)
# Attempt OMEMO encryption for text-only messages (not attachments). # Attempt OMEMO encryption for text-only messages (not attachments)
if not attachment_url: # when outbound policy allows it.
if not attachment_url and use_omemo_encryption:
omemo_plugin = self._get_omemo_plugin() omemo_plugin = self._get_omemo_plugin()
if omemo_plugin: if omemo_plugin:
try: try:
from slixmpp.jid import JID as _JID from slixmpp.jid import JID as _JID
encrypted_msgs, enc_errors = await omemo_plugin.encrypt_message( encrypted_msgs, enc_errors = await omemo_plugin.encrypt_message(
msg, _JID(recipient_jid) msg, _JID(recipient_jid)
) )
if enc_errors: if enc_errors:
self.log.debug( self.log.debug(
"OMEMO: non-critical encryption errors for %s: %s", "OMEMO: non-critical encryption errors for %s: %s",
recipient_jid, enc_errors, recipient_jid,
enc_errors,
) )
if encrypted_msgs: if encrypted_msgs:
for enc_msg in encrypted_msgs.values(): for enc_msg in encrypted_msgs.values():
enc_msg.send() enc_msg.send()
self.log.debug("OMEMO: sent encrypted message to %s", recipient_jid) self.log.debug(
"OMEMO: sent encrypted message to %s", recipient_jid
)
return msg_id return msg_id
except Exception as exc: except Exception as exc:
self.log.debug( self.log.debug(
"OMEMO: encryption not available for %s, sending plaintext: %s", "OMEMO: encryption not available for %s, sending plaintext: %s",
recipient_jid, exc, recipient_jid,
exc,
) )
msg.send() msg.send()
@@ -2347,6 +2773,19 @@ class XMPPComponent(ComponentXMPP):
sender_jid = f"{person_identifier.person.name.lower()}|{person_identifier.service}@{settings.XMPP_JID}" sender_jid = f"{person_identifier.person.name.lower()}|{person_identifier.service}@{settings.XMPP_JID}"
recipient_jid = self._user_jid(person_identifier.user.username) recipient_jid = self._user_jid(person_identifier.user.username)
relay_encrypt_with_omemo = True
try:
from core.models import UserXmppSecuritySettings
sec_settings = await sync_to_async(
lambda: UserXmppSecuritySettings.objects.filter(user=user).first()
)()
if sec_settings is not None:
relay_encrypt_with_omemo = bool(
getattr(sec_settings, "encrypt_contact_messages_with_omemo", True)
)
except Exception as exc:
self.log.warning("XMPP relay OMEMO settings lookup failed: %s", exc)
if is_outgoing_message: if is_outgoing_message:
xmpp_id = await self.send_xmpp_message( xmpp_id = await self.send_xmpp_message(
recipient_jid, recipient_jid,
@@ -2385,7 +2824,12 @@ class XMPPComponent(ComponentXMPP):
# Step 1: Send text message separately # Step 1: Send text message separately
elif text: elif text:
xmpp_id = await self.send_xmpp_message(recipient_jid, sender_jid, text) xmpp_id = await self.send_xmpp_message(
recipient_jid,
sender_jid,
text,
use_omemo_encryption=relay_encrypt_with_omemo,
)
transport.record_bridge_mapping( transport.record_bridge_mapping(
user_id=user.id, user_id=user.id,
person_id=person_identifier.person_id, person_id=person_identifier.person_id,
@@ -2512,17 +2956,25 @@ class XMPPClient(ClientBase):
self._omemo_plugin_registered = False self._omemo_plugin_registered = False
if _OMEMO_AVAILABLE: if _OMEMO_AVAILABLE:
try: try:
data_dir = str(getattr(settings, "XMPP_OMEMO_DATA_DIR", "") or "").strip() data_dir = str(
getattr(settings, "XMPP_OMEMO_DATA_DIR", "") or ""
).strip()
if not data_dir: if not data_dir:
data_dir = str(Path(settings.BASE_DIR) / "xmpp_omemo_data") data_dir = str(Path(settings.BASE_DIR) / "xmpp_omemo_data")
# Register our concrete plugin class under the "xep_0384" name so # Register our concrete plugin class under the "xep_0384" name so
# that slixmpp's dependency resolver finds it. # that slixmpp's dependency resolver finds it.
_slixmpp_register_plugin(_GiaOmemoPlugin) _slixmpp_register_plugin(_GiaOmemoPlugin)
self.client.register_plugin("xep_0384", pconfig={"data_dir": data_dir}) self.client.register_plugin(
"xep_0384", pconfig={"data_dir": data_dir}
)
self._omemo_plugin_registered = True self._omemo_plugin_registered = True
self.log.info("OMEMO: xep_0384 plugin registered, data_dir=%s", data_dir) self.log.info(
"OMEMO: xep_0384 plugin registered, data_dir=%s", data_dir
)
except Exception as exc: except Exception as exc:
self.log.warning("OMEMO: failed to register xep_0384 plugin: %s", exc) self.log.warning(
"OMEMO: failed to register xep_0384 plugin: %s", exc
)
else: else:
self.log.warning("OMEMO: slixmpp_omemo not available, OMEMO disabled") self.log.warning("OMEMO: slixmpp_omemo not available, OMEMO disabled")

View File

@@ -31,7 +31,9 @@ def chunk_for_transport(text: str, limit: int = 3000) -> list[str]:
return [part for part in parts if part] return [part for part in parts if part]
async def post_status_in_source(trigger_message: Message, text: str, origin_tag: str) -> bool: async def post_status_in_source(
trigger_message: Message, text: str, origin_tag: str
) -> bool:
service = str(trigger_message.source_service or "").strip().lower() service = str(trigger_message.source_service or "").strip().lower()
if service not in STATUS_VISIBLE_SOURCE_SERVICES: if service not in STATUS_VISIBLE_SOURCE_SERVICES:
return False return False
@@ -76,9 +78,10 @@ async def post_to_channel_binding(
channel_identifier = str(binding_channel_identifier or "").strip() channel_identifier = str(binding_channel_identifier or "").strip()
if service == "web": if service == "web":
session = None session = None
if channel_identifier and channel_identifier == str( if (
trigger_message.source_chat_id or "" channel_identifier
).strip(): and channel_identifier == str(trigger_message.source_chat_id or "").strip()
):
session = trigger_message.session session = trigger_message.session
if session is None and channel_identifier: if session is None and channel_identifier:
session = await sync_to_async( session = await sync_to_async(
@@ -99,7 +102,8 @@ async def post_to_channel_binding(
ts=int(time.time() * 1000), ts=int(time.time() * 1000),
custom_author="BOT", custom_author="BOT",
source_service="web", source_service="web",
source_chat_id=channel_identifier or str(trigger_message.source_chat_id or ""), source_chat_id=channel_identifier
or str(trigger_message.source_chat_id or ""),
message_meta={"origin_tag": origin_tag}, message_meta={"origin_tag": origin_tag},
) )
return True return True

View File

@@ -58,9 +58,15 @@ def _effective_bootstrap_scope(
identifier = str(ctx.channel_identifier or "").strip() identifier = str(ctx.channel_identifier or "").strip()
if service != "web": if service != "web":
return service, identifier return service, identifier
session_identifier = getattr(getattr(trigger_message, "session", None), "identifier", None) session_identifier = getattr(
fallback_service = str(getattr(session_identifier, "service", "") or "").strip().lower() getattr(trigger_message, "session", None), "identifier", None
fallback_identifier = str(getattr(session_identifier, "identifier", "") or "").strip() )
fallback_service = (
str(getattr(session_identifier, "service", "") or "").strip().lower()
)
fallback_identifier = str(
getattr(session_identifier, "identifier", "") or ""
).strip()
if fallback_service and fallback_identifier and fallback_service != "web": if fallback_service and fallback_identifier and fallback_service != "web":
return fallback_service, fallback_identifier return fallback_service, fallback_identifier
return service, identifier return service, identifier
@@ -89,7 +95,11 @@ def _ensure_bp_profile(user_id: int) -> CommandProfile:
if str(profile.trigger_token or "").strip() != ".bp": if str(profile.trigger_token or "").strip() != ".bp":
profile.trigger_token = ".bp" profile.trigger_token = ".bp"
profile.save(update_fields=["trigger_token", "updated_at"]) profile.save(update_fields=["trigger_token", "updated_at"])
for action_type, position in (("extract_bp", 0), ("save_document", 1), ("post_result", 2)): for action_type, position in (
("extract_bp", 0),
("save_document", 1),
("post_result", 2),
):
action, created = CommandAction.objects.get_or_create( action, created = CommandAction.objects.get_or_create(
profile=profile, profile=profile,
action_type=action_type, action_type=action_type,
@@ -327,7 +337,9 @@ async def process_inbound_message(ctx: CommandContext) -> list[CommandResult]:
return [] return []
if is_mirrored_origin(trigger_message.message_meta): if is_mirrored_origin(trigger_message.message_meta):
return [] return []
effective_service, effective_channel = _effective_bootstrap_scope(ctx, trigger_message) effective_service, effective_channel = _effective_bootstrap_scope(
ctx, trigger_message
)
security_context = CommandSecurityContext( security_context = CommandSecurityContext(
service=effective_service, service=effective_service,
channel_identifier=effective_channel, channel_identifier=effective_channel,
@@ -394,7 +406,9 @@ async def process_inbound_message(ctx: CommandContext) -> list[CommandResult]:
result = await handler.execute(ctx) result = await handler.execute(ctx)
results.append(result) results.append(result)
except Exception as exc: except Exception as exc:
log.exception("command execution failed for profile=%s: %s", profile.slug, exc) log.exception(
"command execution failed for profile=%s: %s", profile.slug, exc
)
results.append( results.append(
CommandResult( CommandResult(
ok=False, ok=False,

View File

@@ -45,14 +45,15 @@ class BPParsedCommand(dict):
return str(self.get("remainder_text") or "") return str(self.get("remainder_text") or "")
def parse_bp_subcommand(text: str) -> BPParsedCommand: def parse_bp_subcommand(text: str) -> BPParsedCommand:
body = str(text or "") body = str(text or "")
if _BP_SET_RANGE_RE.match(body): if _BP_SET_RANGE_RE.match(body):
return BPParsedCommand(command="set_range", remainder_text="") return BPParsedCommand(command="set_range", remainder_text="")
match = _BP_SET_RE.match(body) match = _BP_SET_RE.match(body)
if match: if match:
return BPParsedCommand(command="set", remainder_text=str(match.group("rest") or "").strip()) return BPParsedCommand(
command="set", remainder_text=str(match.group("rest") or "").strip()
)
return BPParsedCommand(command=None, remainder_text="") return BPParsedCommand(command=None, remainder_text="")
@@ -63,7 +64,9 @@ def bp_subcommands_enabled() -> bool:
return bool(raw) return bool(raw)
def bp_trigger_matches(message_text: str, trigger_token: str, exact_match_only: bool) -> bool: def bp_trigger_matches(
message_text: str, trigger_token: str, exact_match_only: bool
) -> bool:
body = str(message_text or "").strip() body = str(message_text or "").strip()
trigger = str(trigger_token or "").strip() trigger = str(trigger_token or "").strip()
parsed = parse_bp_subcommand(body) parsed = parse_bp_subcommand(body)
@@ -144,7 +147,8 @@ class BPCommandHandler(CommandHandler):
"enabled": True, "enabled": True,
"generation_mode": "ai" if variant_key == "bp" else "verbatim", "generation_mode": "ai" if variant_key == "bp" else "verbatim",
"send_plan_to_egress": "post_result" in action_types, "send_plan_to_egress": "post_result" in action_types,
"send_status_to_source": str(profile.visibility_mode or "") == "status_in_source", "send_status_to_source": str(profile.visibility_mode or "")
== "status_in_source",
"send_status_to_egress": False, "send_status_to_egress": False,
"store_document": True, "store_document": True,
} }
@@ -224,10 +228,14 @@ class BPCommandHandler(CommandHandler):
ts__lte=int(trigger.ts or 0), ts__lte=int(trigger.ts or 0),
) )
.order_by("ts") .order_by("ts")
.select_related("session", "session__identifier", "session__identifier__person") .select_related(
"session", "session__identifier", "session__identifier__person"
)
) )
def _annotation(self, mode: str, message_count: int, has_addendum: bool = False) -> str: def _annotation(
self, mode: str, message_count: int, has_addendum: bool = False
) -> str:
if mode == "set" and has_addendum: if mode == "set" and has_addendum:
return "Generated from 1 message + 1 addendum." return "Generated from 1 message + 1 addendum."
if message_count == 1: if message_count == 1:
@@ -291,21 +299,29 @@ class BPCommandHandler(CommandHandler):
if anchor is None: if anchor is None:
run.status = "failed" run.status = "failed"
run.error = "bp_set_range_requires_reply_target" run.error = "bp_set_range_requires_reply_target"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
rows = await self._load_window(trigger, anchor) rows = await self._load_window(trigger, anchor)
deterministic_content = plain_text_blob(rows) deterministic_content = plain_text_blob(rows)
if not deterministic_content.strip(): if not deterministic_content.strip():
run.status = "failed" run.status = "failed"
run.error = "bp_set_range_empty_content" run.error = "bp_set_range_empty_content"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
if str(policy.get("generation_mode") or "verbatim") == "ai": if str(policy.get("generation_mode") or "verbatim") == "ai":
ai_obj = await sync_to_async(lambda: AI.objects.filter(user=trigger.user).first())() ai_obj = await sync_to_async(
lambda: AI.objects.filter(user=trigger.user).first()
)()
if ai_obj is None: if ai_obj is None:
run.status = "failed" run.status = "failed"
run.error = "ai_not_configured" run.error = "ai_not_configured"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
prompt = [ prompt = [
{ {
@@ -329,12 +345,16 @@ class BPCommandHandler(CommandHandler):
except Exception as exc: except Exception as exc:
run.status = "failed" run.status = "failed"
run.error = f"bp_ai_failed:{exc}" run.error = f"bp_ai_failed:{exc}"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
if not content: if not content:
run.status = "failed" run.status = "failed"
run.error = "empty_ai_response" run.error = "empty_ai_response"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
else: else:
content = deterministic_content content = deterministic_content
@@ -360,9 +380,7 @@ class BPCommandHandler(CommandHandler):
elif anchor is not None and remainder: elif anchor is not None and remainder:
base = str(anchor.text or "").strip() or "(no text)" base = str(anchor.text or "").strip() or "(no text)"
content = ( content = (
f"{base}\n" f"{base}\n" "--- Addendum (newer message text) ---\n" f"{remainder}"
"--- Addendum (newer message text) ---\n"
f"{remainder}"
) )
source_ids.extend([str(anchor.id), str(trigger.id)]) source_ids.extend([str(anchor.id), str(trigger.id)])
has_addendum = True has_addendum = True
@@ -373,15 +391,21 @@ class BPCommandHandler(CommandHandler):
else: else:
run.status = "failed" run.status = "failed"
run.error = "bp_set_empty_content" run.error = "bp_set_empty_content"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
if str(policy.get("generation_mode") or "verbatim") == "ai": if str(policy.get("generation_mode") or "verbatim") == "ai":
ai_obj = await sync_to_async(lambda: AI.objects.filter(user=trigger.user).first())() ai_obj = await sync_to_async(
lambda: AI.objects.filter(user=trigger.user).first()
)()
if ai_obj is None: if ai_obj is None:
run.status = "failed" run.status = "failed"
run.error = "ai_not_configured" run.error = "ai_not_configured"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
prompt = [ prompt = [
{ {
@@ -405,16 +429,22 @@ class BPCommandHandler(CommandHandler):
except Exception as exc: except Exception as exc:
run.status = "failed" run.status = "failed"
run.error = f"bp_ai_failed:{exc}" run.error = f"bp_ai_failed:{exc}"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
if not ai_content: if not ai_content:
run.status = "failed" run.status = "failed"
run.error = "empty_ai_response" run.error = "empty_ai_response"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
content = ai_content content = ai_content
annotation = self._annotation("set", 1 if not has_addendum else 2, has_addendum) annotation = self._annotation(
"set", 1 if not has_addendum else 2, has_addendum
)
doc = None doc = None
if bool(policy.get("store_document", True)): if bool(policy.get("store_document", True)):
doc = await self._persist_document( doc = await self._persist_document(
@@ -430,7 +460,9 @@ class BPCommandHandler(CommandHandler):
else: else:
run.status = "failed" run.status = "failed"
run.error = "bp_unknown_subcommand" run.error = "bp_unknown_subcommand"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
fanout_stats = {"sent_bindings": 0, "failed_bindings": 0} fanout_stats = {"sent_bindings": 0, "failed_bindings": 0}
@@ -479,7 +511,9 @@ class BPCommandHandler(CommandHandler):
if trigger.reply_to_id is None: if trigger.reply_to_id is None:
run.status = "failed" run.status = "failed"
run.error = "bp_requires_reply_target" run.error = "bp_requires_reply_target"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
anchor = trigger.reply_to anchor = trigger.reply_to
@@ -488,7 +522,9 @@ class BPCommandHandler(CommandHandler):
rows, rows,
author_rewrites={"USER": "Operator", "BOT": "Assistant"}, author_rewrites={"USER": "Operator", "BOT": "Assistant"},
) )
max_transcript_chars = int(getattr(settings, "BP_MAX_TRANSCRIPT_CHARS", 12000) or 12000) max_transcript_chars = int(
getattr(settings, "BP_MAX_TRANSCRIPT_CHARS", 12000) or 12000
)
transcript = _clamp_transcript(transcript, max_transcript_chars) transcript = _clamp_transcript(transcript, max_transcript_chars)
default_template = ( default_template = (
"Business Plan:\n" "Business Plan:\n"
@@ -499,7 +535,9 @@ class BPCommandHandler(CommandHandler):
"- Risks" "- Risks"
) )
template_text = profile.template_text or default_template template_text = profile.template_text or default_template
max_template_chars = int(getattr(settings, "BP_MAX_TEMPLATE_CHARS", 5000) or 5000) max_template_chars = int(
getattr(settings, "BP_MAX_TEMPLATE_CHARS", 5000) or 5000
)
template_text = str(template_text or "")[:max_template_chars] template_text = str(template_text or "")[:max_template_chars]
generation_mode = str(policy.get("generation_mode") or "ai") generation_mode = str(policy.get("generation_mode") or "ai")
if generation_mode == "verbatim": if generation_mode == "verbatim":
@@ -507,14 +545,20 @@ class BPCommandHandler(CommandHandler):
if not summary.strip(): if not summary.strip():
run.status = "failed" run.status = "failed"
run.error = "bp_verbatim_empty_content" run.error = "bp_verbatim_empty_content"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
else: else:
ai_obj = await sync_to_async(lambda: AI.objects.filter(user=trigger.user).first())() ai_obj = await sync_to_async(
lambda: AI.objects.filter(user=trigger.user).first()
)()
if ai_obj is None: if ai_obj is None:
run.status = "failed" run.status = "failed"
run.error = "ai_not_configured" run.error = "ai_not_configured"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
prompt = [ prompt = [
@@ -530,13 +574,20 @@ class BPCommandHandler(CommandHandler):
}, },
] ]
try: try:
summary = str(await ai_runner.run_prompt(prompt, ai_obj, operation="command_bp_extract") or "").strip() summary = str(
await ai_runner.run_prompt(
prompt, ai_obj, operation="command_bp_extract"
)
or ""
).strip()
if not summary: if not summary:
raise RuntimeError("empty_ai_response") raise RuntimeError("empty_ai_response")
except Exception as exc: except Exception as exc:
run.status = "failed" run.status = "failed"
run.error = f"bp_ai_failed:{exc}" run.error = f"bp_ai_failed:{exc}"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="failed", error=run.error) return CommandResult(ok=False, status="failed", error=run.error)
annotation = self._annotation("legacy", len(rows)) annotation = self._annotation("legacy", len(rows))
@@ -588,23 +639,31 @@ class BPCommandHandler(CommandHandler):
async def execute(self, ctx: CommandContext) -> CommandResult: async def execute(self, ctx: CommandContext) -> CommandResult:
trigger = await sync_to_async( trigger = await sync_to_async(
lambda: Message.objects.select_related("user", "session").filter(id=ctx.message_id).first() lambda: Message.objects.select_related("user", "session")
.filter(id=ctx.message_id)
.first()
)() )()
if trigger is None: if trigger is None:
return CommandResult(ok=False, status="failed", error="trigger_not_found") return CommandResult(ok=False, status="failed", error="trigger_not_found")
profile = await sync_to_async( profile = await sync_to_async(
lambda: trigger.user.commandprofile_set.filter(slug=self.slug, enabled=True).first() lambda: trigger.user.commandprofile_set.filter(
slug=self.slug, enabled=True
).first()
)() )()
if profile is None: if profile is None:
return CommandResult(ok=False, status="skipped", error="profile_missing") return CommandResult(ok=False, status="skipped", error="profile_missing")
actions = await sync_to_async(list)( actions = await sync_to_async(list)(
CommandAction.objects.filter(profile=profile, enabled=True).order_by("position", "id") CommandAction.objects.filter(profile=profile, enabled=True).order_by(
"position", "id"
)
) )
action_types = {row.action_type for row in actions} action_types = {row.action_type for row in actions}
if "extract_bp" not in action_types: if "extract_bp" not in action_types:
return CommandResult(ok=False, status="skipped", error="extract_bp_disabled") return CommandResult(
ok=False, status="skipped", error="extract_bp_disabled"
)
run, created = await sync_to_async(CommandRun.objects.get_or_create)( run, created = await sync_to_async(CommandRun.objects.get_or_create)(
profile=profile, profile=profile,
@@ -612,7 +671,11 @@ class BPCommandHandler(CommandHandler):
defaults={"user": trigger.user, "status": "running"}, defaults={"user": trigger.user, "status": "running"},
) )
if not created and run.status in {"ok", "running"}: if not created and run.status in {"ok", "running"}:
return CommandResult(ok=True, status="ok", payload={"document_id": str(run.result_ref_id or "")}) return CommandResult(
ok=True,
status="ok",
payload={"document_id": str(run.result_ref_id or "")},
)
run.status = "running" run.status = "running"
run.error = "" run.error = ""
@@ -627,7 +690,9 @@ class BPCommandHandler(CommandHandler):
if not bool(policy.get("enabled")): if not bool(policy.get("enabled")):
run.status = "skipped" run.status = "skipped"
run.error = f"variant_disabled:{variant_key}" run.error = f"variant_disabled:{variant_key}"
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
return CommandResult(ok=False, status="skipped", error=run.error) return CommandResult(ok=False, status="skipped", error=run.error)
parsed = parse_bp_subcommand(ctx.message_text) parsed = parse_bp_subcommand(ctx.message_text)

View File

@@ -20,8 +20,8 @@ from core.models import (
TaskProject, TaskProject,
TaskProviderConfig, TaskProviderConfig,
) )
from core.tasks.codex_support import channel_variants, resolve_external_chat_id
from core.tasks.codex_approval import queue_codex_event_with_pre_approval from core.tasks.codex_approval import queue_codex_event_with_pre_approval
from core.tasks.codex_support import channel_variants, resolve_external_chat_id
_CLAUDE_DEFAULT_RE = re.compile( _CLAUDE_DEFAULT_RE = re.compile(
r"^\s*(?:\.claude\b|#claude#?)(?P<body>.*)$", r"^\s*(?:\.claude\b|#claude#?)(?P<body>.*)$",
@@ -31,7 +31,9 @@ _CLAUDE_PLAN_RE = re.compile(
r"^\s*(?:\.claude\s+plan\b|#claude\s+plan#?)(?P<body>.*)$", r"^\s*(?:\.claude\s+plan\b|#claude\s+plan#?)(?P<body>.*)$",
re.IGNORECASE | re.DOTALL, re.IGNORECASE | re.DOTALL,
) )
_CLAUDE_STATUS_RE = re.compile(r"^\s*(?:\.claude\s+status\b|#claude\s+status#?)\s*$", re.IGNORECASE) _CLAUDE_STATUS_RE = re.compile(
r"^\s*(?:\.claude\s+status\b|#claude\s+status#?)\s*$", re.IGNORECASE
)
_CLAUDE_APPROVE_DENY_RE = re.compile( _CLAUDE_APPROVE_DENY_RE = re.compile(
r"^\s*(?:\.claude|#claude)\s+(?P<action>approve|deny)\s+(?P<approval_key>[A-Za-z0-9._:-]+)#?\s*$", r"^\s*(?:\.claude|#claude)\s+(?P<action>approve|deny)\s+(?P<approval_key>[A-Za-z0-9._:-]+)#?\s*$",
re.IGNORECASE, re.IGNORECASE,
@@ -83,7 +85,9 @@ def parse_claude_command(text: str) -> ClaudeParsedCommand:
return ClaudeParsedCommand(command=None, body_text="", approval_key="") return ClaudeParsedCommand(command=None, body_text="", approval_key="")
def claude_trigger_matches(message_text: str, trigger_token: str, exact_match_only: bool) -> bool: def claude_trigger_matches(
message_text: str, trigger_token: str, exact_match_only: bool
) -> bool:
body = str(message_text or "").strip() body = str(message_text or "").strip()
parsed = parse_claude_command(body) parsed = parse_claude_command(body)
if parsed.command: if parsed.command:
@@ -103,7 +107,9 @@ class ClaudeCommandHandler(CommandHandler):
async def _load_trigger(self, message_id: str) -> Message | None: async def _load_trigger(self, message_id: str) -> Message | None:
return await sync_to_async( return await sync_to_async(
lambda: Message.objects.select_related("user", "session", "session__identifier", "reply_to") lambda: Message.objects.select_related(
"user", "session", "session__identifier", "reply_to"
)
.filter(id=message_id) .filter(id=message_id)
.first() .first()
)() )()
@@ -114,11 +120,18 @@ class ClaudeCommandHandler(CommandHandler):
identifier = getattr(getattr(trigger, "session", None), "identifier", None) identifier = getattr(getattr(trigger, "session", None), "identifier", None)
fallback_service = str(getattr(identifier, "service", "") or "").strip().lower() fallback_service = str(getattr(identifier, "service", "") or "").strip().lower()
fallback_identifier = str(getattr(identifier, "identifier", "") or "").strip() fallback_identifier = str(getattr(identifier, "identifier", "") or "").strip()
if service == "web" and fallback_service and fallback_identifier and fallback_service != "web": if (
service == "web"
and fallback_service
and fallback_identifier
and fallback_service != "web"
):
return fallback_service, fallback_identifier return fallback_service, fallback_identifier
return service or "web", channel return service or "web", channel
async def _mapped_sources(self, user, service: str, channel: str) -> list[ChatTaskSource]: async def _mapped_sources(
self, user, service: str, channel: str
) -> list[ChatTaskSource]:
variants = channel_variants(service, channel) variants = channel_variants(service, channel)
if not variants: if not variants:
return [] return []
@@ -131,7 +144,9 @@ class ClaudeCommandHandler(CommandHandler):
).select_related("project", "epic") ).select_related("project", "epic")
) )
async def _linked_task_from_reply(self, user, reply_to: Message | None) -> DerivedTask | None: async def _linked_task_from_reply(
self, user, reply_to: Message | None
) -> DerivedTask | None:
if reply_to is None: if reply_to is None:
return None return None
by_origin = await sync_to_async( by_origin = await sync_to_async(
@@ -143,7 +158,9 @@ class ClaudeCommandHandler(CommandHandler):
if by_origin is not None: if by_origin is not None:
return by_origin return by_origin
return await sync_to_async( return await sync_to_async(
lambda: DerivedTask.objects.filter(user=user, events__source_message=reply_to) lambda: DerivedTask.objects.filter(
user=user, events__source_message=reply_to
)
.select_related("project", "epic") .select_related("project", "epic")
.order_by("-created_at") .order_by("-created_at")
.first() .first()
@@ -164,10 +181,14 @@ class ClaudeCommandHandler(CommandHandler):
return "" return ""
return str(m.group(1) or "").strip() return str(m.group(1) or "").strip()
async def _resolve_task(self, user, reference_code: str, reply_task: DerivedTask | None) -> DerivedTask | None: async def _resolve_task(
self, user, reference_code: str, reply_task: DerivedTask | None
) -> DerivedTask | None:
if reference_code: if reference_code:
return await sync_to_async( return await sync_to_async(
lambda: DerivedTask.objects.filter(user=user, reference_code=reference_code) lambda: DerivedTask.objects.filter(
user=user, reference_code=reference_code
)
.select_related("project", "epic") .select_related("project", "epic")
.order_by("-created_at") .order_by("-created_at")
.first() .first()
@@ -190,7 +211,9 @@ class ClaudeCommandHandler(CommandHandler):
return reply_task.project, "" return reply_task.project, ""
if project_token: if project_token:
project = await sync_to_async( project = await sync_to_async(
lambda: TaskProject.objects.filter(user=user, name__iexact=project_token).first() lambda: TaskProject.objects.filter(
user=user, name__iexact=project_token
).first()
)() )()
if project is not None: if project is not None:
return project, "" return project, ""
@@ -199,20 +222,31 @@ class ClaudeCommandHandler(CommandHandler):
mapped = await self._mapped_sources(user, service, channel) mapped = await self._mapped_sources(user, service, channel)
project_ids = sorted({str(row.project_id) for row in mapped if row.project_id}) project_ids = sorted({str(row.project_id) for row in mapped if row.project_id})
if len(project_ids) == 1: if len(project_ids) == 1:
project = next((row.project for row in mapped if str(row.project_id) == project_ids[0]), None) project = next(
(
row.project
for row in mapped
if str(row.project_id) == project_ids[0]
),
None,
)
return project, "" return project, ""
if len(project_ids) > 1: if len(project_ids) > 1:
return None, "project_required:[project:Name]" return None, "project_required:[project:Name]"
return None, "project_unresolved" return None, "project_unresolved"
async def _post_source_status(self, trigger: Message, text: str, suffix: str) -> None: async def _post_source_status(
self, trigger: Message, text: str, suffix: str
) -> None:
await post_status_in_source( await post_status_in_source(
trigger_message=trigger, trigger_message=trigger,
text=text, text=text,
origin_tag=f"claude-status:{suffix}", origin_tag=f"claude-status:{suffix}",
) )
async def _run_status(self, trigger: Message, service: str, channel: str, project: TaskProject | None) -> CommandResult: async def _run_status(
self, trigger: Message, service: str, channel: str, project: TaskProject | None
) -> CommandResult:
def _load_runs(): def _load_runs():
qs = CodexRun.objects.filter(user=trigger.user) qs = CodexRun.objects.filter(user=trigger.user)
if service: if service:
@@ -225,7 +259,9 @@ class ClaudeCommandHandler(CommandHandler):
runs = await sync_to_async(_load_runs)() runs = await sync_to_async(_load_runs)()
if not runs: if not runs:
await self._post_source_status(trigger, "[claude] no recent runs for this scope.", "empty") await self._post_source_status(
trigger, "[claude] no recent runs for this scope.", "empty"
)
return CommandResult(ok=True, status="ok", payload={"count": 0}) return CommandResult(ok=True, status="ok", payload={"count": 0})
lines = ["[claude] recent runs:"] lines = ["[claude] recent runs:"]
for row in runs: for row in runs:
@@ -249,24 +285,38 @@ class ClaudeCommandHandler(CommandHandler):
).first() ).first()
)() )()
settings_payload = dict(getattr(cfg, "settings", {}) or {}) settings_payload = dict(getattr(cfg, "settings", {}) or {})
approver_service = str(settings_payload.get("approver_service") or "").strip().lower() approver_service = (
approver_identifier = str(settings_payload.get("approver_identifier") or "").strip() str(settings_payload.get("approver_service") or "").strip().lower()
)
approver_identifier = str(
settings_payload.get("approver_identifier") or ""
).strip()
if not approver_service or not approver_identifier: if not approver_service or not approver_identifier:
return CommandResult(ok=False, status="failed", error="approver_channel_not_configured") return CommandResult(
ok=False, status="failed", error="approver_channel_not_configured"
)
if str(current_service or "").strip().lower() != approver_service or str(current_channel or "").strip() not in set( if str(current_service or "").strip().lower() != approver_service or str(
channel_variants(approver_service, approver_identifier) current_channel or ""
): ).strip() not in set(channel_variants(approver_service, approver_identifier)):
return CommandResult(ok=False, status="failed", error="approval_command_not_allowed_in_this_channel") return CommandResult(
ok=False,
status="failed",
error="approval_command_not_allowed_in_this_channel",
)
approval_key = parsed.approval_key approval_key = parsed.approval_key
request = await sync_to_async( request = await sync_to_async(
lambda: CodexPermissionRequest.objects.select_related("codex_run", "external_sync_event") lambda: CodexPermissionRequest.objects.select_related(
"codex_run", "external_sync_event"
)
.filter(user=trigger.user, approval_key=approval_key) .filter(user=trigger.user, approval_key=approval_key)
.first() .first()
)() )()
if request is None: if request is None:
return CommandResult(ok=False, status="failed", error="approval_key_not_found") return CommandResult(
ok=False, status="failed", error="approval_key_not_found"
)
now = timezone.now() now = timezone.now()
if parsed.command == "approve": if parsed.command == "approve":
@@ -283,14 +333,20 @@ class ClaudeCommandHandler(CommandHandler):
] ]
) )
if request.external_sync_event_id: if request.external_sync_event_id:
await sync_to_async(ExternalSyncEvent.objects.filter(id=request.external_sync_event_id).update)( await sync_to_async(
ExternalSyncEvent.objects.filter(
id=request.external_sync_event_id
).update
)(
status="ok", status="ok",
error="", error="",
) )
run = request.codex_run run = request.codex_run
run.status = "approved_waiting_resume" run.status = "approved_waiting_resume"
run.error = "" run.error = ""
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
source_service = str(run.source_service or "") source_service = str(run.source_service or "")
source_channel = str(run.source_channel or "") source_channel = str(run.source_channel or "")
resume_payload = dict(request.resume_payload or {}) resume_payload = dict(request.resume_payload or {})
@@ -302,14 +358,18 @@ class ClaudeCommandHandler(CommandHandler):
provider_payload["source_service"] = source_service provider_payload["source_service"] = source_service
provider_payload["source_channel"] = source_channel provider_payload["source_channel"] = source_channel
event_action = resume_action event_action = resume_action
resume_idempotency_key = str(resume_payload.get("idempotency_key") or "").strip() resume_idempotency_key = str(
resume_payload.get("idempotency_key") or ""
).strip()
resume_event_key = ( resume_event_key = (
resume_idempotency_key resume_idempotency_key
if resume_idempotency_key if resume_idempotency_key
else f"{self._approval_prefix}:{approval_key}:approved" else f"{self._approval_prefix}:{approval_key}:approved"
) )
else: else:
provider_payload = dict(run.request_payload.get("provider_payload") or {}) provider_payload = dict(
run.request_payload.get("provider_payload") or {}
)
provider_payload.update( provider_payload.update(
{ {
"mode": "approval_response", "mode": "approval_response",
@@ -337,17 +397,30 @@ class ClaudeCommandHandler(CommandHandler):
"error": "", "error": "",
}, },
) )
return CommandResult(ok=True, status="ok", payload={"approval_key": approval_key, "resolution": "approved"}) return CommandResult(
ok=True,
status="ok",
payload={"approval_key": approval_key, "resolution": "approved"},
)
request.status = "denied" request.status = "denied"
request.resolved_at = now request.resolved_at = now
request.resolved_by_identifier = current_channel request.resolved_by_identifier = current_channel
request.resolution_note = "denied via claude command" request.resolution_note = "denied via claude command"
await sync_to_async(request.save)( await sync_to_async(request.save)(
update_fields=["status", "resolved_at", "resolved_by_identifier", "resolution_note"] update_fields=[
"status",
"resolved_at",
"resolved_by_identifier",
"resolution_note",
]
) )
if request.external_sync_event_id: if request.external_sync_event_id:
await sync_to_async(ExternalSyncEvent.objects.filter(id=request.external_sync_event_id).update)( await sync_to_async(
ExternalSyncEvent.objects.filter(
id=request.external_sync_event_id
).update
)(
status="failed", status="failed",
error="approval_denied", error="approval_denied",
) )
@@ -374,7 +447,11 @@ class ClaudeCommandHandler(CommandHandler):
"error": "approval_denied", "error": "approval_denied",
}, },
) )
return CommandResult(ok=True, status="ok", payload={"approval_key": approval_key, "resolution": "denied"}) return CommandResult(
ok=True,
status="ok",
payload={"approval_key": approval_key, "resolution": "denied"},
)
async def _create_submission( async def _create_submission(
self, self,
@@ -391,7 +468,9 @@ class ClaudeCommandHandler(CommandHandler):
).first() ).first()
)() )()
if cfg is None: if cfg is None:
return CommandResult(ok=False, status="failed", error="provider_disabled_or_missing") return CommandResult(
ok=False, status="failed", error="provider_disabled_or_missing"
)
service, channel = self._effective_scope(trigger) service, channel = self._effective_scope(trigger)
external_chat_id = await sync_to_async(resolve_external_chat_id)( external_chat_id = await sync_to_async(resolve_external_chat_id)(
@@ -418,7 +497,9 @@ class ClaudeCommandHandler(CommandHandler):
if mode == "plan": if mode == "plan":
anchor = trigger.reply_to anchor = trigger.reply_to
if anchor is None: if anchor is None:
return CommandResult(ok=False, status="failed", error="reply_required_for_claude_plan") return CommandResult(
ok=False, status="failed", error="reply_required_for_claude_plan"
)
rows = await sync_to_async(list)( rows = await sync_to_async(list)(
Message.objects.filter( Message.objects.filter(
user=trigger.user, user=trigger.user,
@@ -427,7 +508,9 @@ class ClaudeCommandHandler(CommandHandler):
ts__lte=int(trigger.ts or 0), ts__lte=int(trigger.ts or 0),
) )
.order_by("ts") .order_by("ts")
.select_related("session", "session__identifier", "session__identifier__person") .select_related(
"session", "session__identifier", "session__identifier__person"
)
) )
payload["reply_context"] = { payload["reply_context"] = {
"anchor_message_id": str(anchor.id), "anchor_message_id": str(anchor.id),
@@ -446,12 +529,18 @@ class ClaudeCommandHandler(CommandHandler):
source_channel=channel, source_channel=channel,
external_chat_id=external_chat_id, external_chat_id=external_chat_id,
status="waiting_approval", status="waiting_approval",
request_payload={"action": "append_update", "provider_payload": dict(payload)}, request_payload={
"action": "append_update",
"provider_payload": dict(payload),
},
result_payload={}, result_payload={},
error="", error="",
) )
payload["codex_run_id"] = str(run.id) payload["codex_run_id"] = str(run.id)
run.request_payload = {"action": "append_update", "provider_payload": dict(payload)} run.request_payload = {
"action": "append_update",
"provider_payload": dict(payload),
}
await sync_to_async(run.save)(update_fields=["request_payload", "updated_at"]) await sync_to_async(run.save)(update_fields=["request_payload", "updated_at"])
idempotency_key = f"claude_cmd:{trigger.id}:{mode}:{task.id}:{hashlib.sha1(str(body_text or '').encode('utf-8')).hexdigest()[:12]}" idempotency_key = f"claude_cmd:{trigger.id}:{mode}:{task.id}:{hashlib.sha1(str(body_text or '').encode('utf-8')).hexdigest()[:12]}"
@@ -476,20 +565,26 @@ class ClaudeCommandHandler(CommandHandler):
return CommandResult(ok=False, status="failed", error="trigger_not_found") return CommandResult(ok=False, status="failed", error="trigger_not_found")
profile = await sync_to_async( profile = await sync_to_async(
lambda: CommandProfile.objects.filter(user=trigger.user, slug=self.slug, enabled=True).first() lambda: CommandProfile.objects.filter(
user=trigger.user, slug=self.slug, enabled=True
).first()
)() )()
if profile is None: if profile is None:
return CommandResult(ok=False, status="skipped", error="profile_missing") return CommandResult(ok=False, status="skipped", error="profile_missing")
parsed = parse_claude_command(ctx.message_text) parsed = parse_claude_command(ctx.message_text)
if not parsed.command: if not parsed.command:
return CommandResult(ok=False, status="skipped", error="claude_command_not_matched") return CommandResult(
ok=False, status="skipped", error="claude_command_not_matched"
)
service, channel = self._effective_scope(trigger) service, channel = self._effective_scope(trigger)
if parsed.command == "status": if parsed.command == "status":
project = None project = None
reply_task = await self._linked_task_from_reply(trigger.user, trigger.reply_to) reply_task = await self._linked_task_from_reply(
trigger.user, trigger.reply_to
)
if reply_task is not None: if reply_task is not None:
project = reply_task.project project = reply_task.project
return await self._run_status(trigger, service, channel, project) return await self._run_status(trigger, service, channel, project)
@@ -507,7 +602,9 @@ class ClaudeCommandHandler(CommandHandler):
reply_task = await self._linked_task_from_reply(trigger.user, trigger.reply_to) reply_task = await self._linked_task_from_reply(trigger.user, trigger.reply_to)
task = await self._resolve_task(trigger.user, reference_code, reply_task) task = await self._resolve_task(trigger.user, reference_code, reply_task)
if task is None: if task is None:
return CommandResult(ok=False, status="failed", error="task_target_required") return CommandResult(
ok=False, status="failed", error="task_target_required"
)
project, project_error = await self._resolve_project( project, project_error = await self._resolve_project(
user=trigger.user, user=trigger.user,
@@ -518,7 +615,9 @@ class ClaudeCommandHandler(CommandHandler):
project_token=project_token, project_token=project_token,
) )
if project is None: if project is None:
return CommandResult(ok=False, status="failed", error=project_error or "project_unresolved") return CommandResult(
ok=False, status="failed", error=project_error or "project_unresolved"
)
mode = "plan" if parsed.command == "plan" else "default" mode = "plan" if parsed.command == "plan" else "default"
return await self._create_submission( return await self._create_submission(

View File

@@ -20,8 +20,8 @@ from core.models import (
TaskProject, TaskProject,
TaskProviderConfig, TaskProviderConfig,
) )
from core.tasks.codex_support import channel_variants, resolve_external_chat_id
from core.tasks.codex_approval import queue_codex_event_with_pre_approval from core.tasks.codex_approval import queue_codex_event_with_pre_approval
from core.tasks.codex_support import channel_variants, resolve_external_chat_id
_CODEX_DEFAULT_RE = re.compile( _CODEX_DEFAULT_RE = re.compile(
r"^\s*(?:\.codex\b|#codex#?)(?P<body>.*)$", r"^\s*(?:\.codex\b|#codex#?)(?P<body>.*)$",
@@ -31,7 +31,9 @@ _CODEX_PLAN_RE = re.compile(
r"^\s*(?:\.codex\s+plan\b|#codex\s+plan#?)(?P<body>.*)$", r"^\s*(?:\.codex\s+plan\b|#codex\s+plan#?)(?P<body>.*)$",
re.IGNORECASE | re.DOTALL, re.IGNORECASE | re.DOTALL,
) )
_CODEX_STATUS_RE = re.compile(r"^\s*(?:\.codex\s+status\b|#codex\s+status#?)\s*$", re.IGNORECASE) _CODEX_STATUS_RE = re.compile(
r"^\s*(?:\.codex\s+status\b|#codex\s+status#?)\s*$", re.IGNORECASE
)
_CODEX_APPROVE_DENY_RE = re.compile( _CODEX_APPROVE_DENY_RE = re.compile(
r"^\s*(?:\.codex|#codex)\s+(?P<action>approve|deny)\s+(?P<approval_key>[A-Za-z0-9._:-]+)#?\s*$", r"^\s*(?:\.codex|#codex)\s+(?P<action>approve|deny)\s+(?P<approval_key>[A-Za-z0-9._:-]+)#?\s*$",
re.IGNORECASE, re.IGNORECASE,
@@ -55,7 +57,6 @@ class CodexParsedCommand(dict):
return str(self.get("approval_key") or "") return str(self.get("approval_key") or "")
def parse_codex_command(text: str) -> CodexParsedCommand: def parse_codex_command(text: str) -> CodexParsedCommand:
body = str(text or "") body = str(text or "")
m = _CODEX_APPROVE_DENY_RE.match(body) m = _CODEX_APPROVE_DENY_RE.match(body)
@@ -84,7 +85,9 @@ def parse_codex_command(text: str) -> CodexParsedCommand:
return CodexParsedCommand(command=None, body_text="", approval_key="") return CodexParsedCommand(command=None, body_text="", approval_key="")
def codex_trigger_matches(message_text: str, trigger_token: str, exact_match_only: bool) -> bool: def codex_trigger_matches(
message_text: str, trigger_token: str, exact_match_only: bool
) -> bool:
body = str(message_text or "").strip() body = str(message_text or "").strip()
parsed = parse_codex_command(body) parsed = parse_codex_command(body)
if parsed.command: if parsed.command:
@@ -102,7 +105,9 @@ class CodexCommandHandler(CommandHandler):
async def _load_trigger(self, message_id: str) -> Message | None: async def _load_trigger(self, message_id: str) -> Message | None:
return await sync_to_async( return await sync_to_async(
lambda: Message.objects.select_related("user", "session", "session__identifier", "reply_to") lambda: Message.objects.select_related(
"user", "session", "session__identifier", "reply_to"
)
.filter(id=message_id) .filter(id=message_id)
.first() .first()
)() )()
@@ -113,11 +118,18 @@ class CodexCommandHandler(CommandHandler):
identifier = getattr(getattr(trigger, "session", None), "identifier", None) identifier = getattr(getattr(trigger, "session", None), "identifier", None)
fallback_service = str(getattr(identifier, "service", "") or "").strip().lower() fallback_service = str(getattr(identifier, "service", "") or "").strip().lower()
fallback_identifier = str(getattr(identifier, "identifier", "") or "").strip() fallback_identifier = str(getattr(identifier, "identifier", "") or "").strip()
if service == "web" and fallback_service and fallback_identifier and fallback_service != "web": if (
service == "web"
and fallback_service
and fallback_identifier
and fallback_service != "web"
):
return fallback_service, fallback_identifier return fallback_service, fallback_identifier
return service or "web", channel return service or "web", channel
async def _mapped_sources(self, user, service: str, channel: str) -> list[ChatTaskSource]: async def _mapped_sources(
self, user, service: str, channel: str
) -> list[ChatTaskSource]:
variants = channel_variants(service, channel) variants = channel_variants(service, channel)
if not variants: if not variants:
return [] return []
@@ -130,7 +142,9 @@ class CodexCommandHandler(CommandHandler):
).select_related("project", "epic") ).select_related("project", "epic")
) )
async def _linked_task_from_reply(self, user, reply_to: Message | None) -> DerivedTask | None: async def _linked_task_from_reply(
self, user, reply_to: Message | None
) -> DerivedTask | None:
if reply_to is None: if reply_to is None:
return None return None
by_origin = await sync_to_async( by_origin = await sync_to_async(
@@ -142,7 +156,9 @@ class CodexCommandHandler(CommandHandler):
if by_origin is not None: if by_origin is not None:
return by_origin return by_origin
return await sync_to_async( return await sync_to_async(
lambda: DerivedTask.objects.filter(user=user, events__source_message=reply_to) lambda: DerivedTask.objects.filter(
user=user, events__source_message=reply_to
)
.select_related("project", "epic") .select_related("project", "epic")
.order_by("-created_at") .order_by("-created_at")
.first() .first()
@@ -163,10 +179,14 @@ class CodexCommandHandler(CommandHandler):
return "" return ""
return str(m.group(1) or "").strip() return str(m.group(1) or "").strip()
async def _resolve_task(self, user, reference_code: str, reply_task: DerivedTask | None) -> DerivedTask | None: async def _resolve_task(
self, user, reference_code: str, reply_task: DerivedTask | None
) -> DerivedTask | None:
if reference_code: if reference_code:
return await sync_to_async( return await sync_to_async(
lambda: DerivedTask.objects.filter(user=user, reference_code=reference_code) lambda: DerivedTask.objects.filter(
user=user, reference_code=reference_code
)
.select_related("project", "epic") .select_related("project", "epic")
.order_by("-created_at") .order_by("-created_at")
.first() .first()
@@ -189,7 +209,9 @@ class CodexCommandHandler(CommandHandler):
return reply_task.project, "" return reply_task.project, ""
if project_token: if project_token:
project = await sync_to_async( project = await sync_to_async(
lambda: TaskProject.objects.filter(user=user, name__iexact=project_token).first() lambda: TaskProject.objects.filter(
user=user, name__iexact=project_token
).first()
)() )()
if project is not None: if project is not None:
return project, "" return project, ""
@@ -198,20 +220,31 @@ class CodexCommandHandler(CommandHandler):
mapped = await self._mapped_sources(user, service, channel) mapped = await self._mapped_sources(user, service, channel)
project_ids = sorted({str(row.project_id) for row in mapped if row.project_id}) project_ids = sorted({str(row.project_id) for row in mapped if row.project_id})
if len(project_ids) == 1: if len(project_ids) == 1:
project = next((row.project for row in mapped if str(row.project_id) == project_ids[0]), None) project = next(
(
row.project
for row in mapped
if str(row.project_id) == project_ids[0]
),
None,
)
return project, "" return project, ""
if len(project_ids) > 1: if len(project_ids) > 1:
return None, "project_required:[project:Name]" return None, "project_required:[project:Name]"
return None, "project_unresolved" return None, "project_unresolved"
async def _post_source_status(self, trigger: Message, text: str, suffix: str) -> None: async def _post_source_status(
self, trigger: Message, text: str, suffix: str
) -> None:
await post_status_in_source( await post_status_in_source(
trigger_message=trigger, trigger_message=trigger,
text=text, text=text,
origin_tag=f"codex-status:{suffix}", origin_tag=f"codex-status:{suffix}",
) )
async def _run_status(self, trigger: Message, service: str, channel: str, project: TaskProject | None) -> CommandResult: async def _run_status(
self, trigger: Message, service: str, channel: str, project: TaskProject | None
) -> CommandResult:
def _load_runs(): def _load_runs():
qs = CodexRun.objects.filter(user=trigger.user) qs = CodexRun.objects.filter(user=trigger.user)
if service: if service:
@@ -224,7 +257,9 @@ class CodexCommandHandler(CommandHandler):
runs = await sync_to_async(_load_runs)() runs = await sync_to_async(_load_runs)()
if not runs: if not runs:
await self._post_source_status(trigger, "[codex] no recent runs for this scope.", "empty") await self._post_source_status(
trigger, "[codex] no recent runs for this scope.", "empty"
)
return CommandResult(ok=True, status="ok", payload={"count": 0}) return CommandResult(ok=True, status="ok", payload={"count": 0})
lines = ["[codex] recent runs:"] lines = ["[codex] recent runs:"]
for row in runs: for row in runs:
@@ -243,27 +278,43 @@ class CodexCommandHandler(CommandHandler):
current_channel: str, current_channel: str,
) -> CommandResult: ) -> CommandResult:
cfg = await sync_to_async( cfg = await sync_to_async(
lambda: TaskProviderConfig.objects.filter(user=trigger.user, provider="codex_cli").first() lambda: TaskProviderConfig.objects.filter(
user=trigger.user, provider="codex_cli"
).first()
)() )()
settings_payload = dict(getattr(cfg, "settings", {}) or {}) settings_payload = dict(getattr(cfg, "settings", {}) or {})
approver_service = str(settings_payload.get("approver_service") or "").strip().lower() approver_service = (
approver_identifier = str(settings_payload.get("approver_identifier") or "").strip() str(settings_payload.get("approver_service") or "").strip().lower()
)
approver_identifier = str(
settings_payload.get("approver_identifier") or ""
).strip()
if not approver_service or not approver_identifier: if not approver_service or not approver_identifier:
return CommandResult(ok=False, status="failed", error="approver_channel_not_configured") return CommandResult(
ok=False, status="failed", error="approver_channel_not_configured"
)
if str(current_service or "").strip().lower() != approver_service or str(current_channel or "").strip() not in set( if str(current_service or "").strip().lower() != approver_service or str(
channel_variants(approver_service, approver_identifier) current_channel or ""
): ).strip() not in set(channel_variants(approver_service, approver_identifier)):
return CommandResult(ok=False, status="failed", error="approval_command_not_allowed_in_this_channel") return CommandResult(
ok=False,
status="failed",
error="approval_command_not_allowed_in_this_channel",
)
approval_key = parsed.approval_key approval_key = parsed.approval_key
request = await sync_to_async( request = await sync_to_async(
lambda: CodexPermissionRequest.objects.select_related("codex_run", "external_sync_event") lambda: CodexPermissionRequest.objects.select_related(
"codex_run", "external_sync_event"
)
.filter(user=trigger.user, approval_key=approval_key) .filter(user=trigger.user, approval_key=approval_key)
.first() .first()
)() )()
if request is None: if request is None:
return CommandResult(ok=False, status="failed", error="approval_key_not_found") return CommandResult(
ok=False, status="failed", error="approval_key_not_found"
)
now = timezone.now() now = timezone.now()
if parsed.command == "approve": if parsed.command == "approve":
@@ -280,14 +331,20 @@ class CodexCommandHandler(CommandHandler):
] ]
) )
if request.external_sync_event_id: if request.external_sync_event_id:
await sync_to_async(ExternalSyncEvent.objects.filter(id=request.external_sync_event_id).update)( await sync_to_async(
ExternalSyncEvent.objects.filter(
id=request.external_sync_event_id
).update
)(
status="ok", status="ok",
error="", error="",
) )
run = request.codex_run run = request.codex_run
run.status = "approved_waiting_resume" run.status = "approved_waiting_resume"
run.error = "" run.error = ""
await sync_to_async(run.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(run.save)(
update_fields=["status", "error", "updated_at"]
)
source_service = str(run.source_service or "") source_service = str(run.source_service or "")
source_channel = str(run.source_channel or "") source_channel = str(run.source_channel or "")
resume_payload = dict(request.resume_payload or {}) resume_payload = dict(request.resume_payload or {})
@@ -299,14 +356,18 @@ class CodexCommandHandler(CommandHandler):
provider_payload["source_service"] = source_service provider_payload["source_service"] = source_service
provider_payload["source_channel"] = source_channel provider_payload["source_channel"] = source_channel
event_action = resume_action event_action = resume_action
resume_idempotency_key = str(resume_payload.get("idempotency_key") or "").strip() resume_idempotency_key = str(
resume_payload.get("idempotency_key") or ""
).strip()
resume_event_key = ( resume_event_key = (
resume_idempotency_key resume_idempotency_key
if resume_idempotency_key if resume_idempotency_key
else f"codex_approval:{approval_key}:approved" else f"codex_approval:{approval_key}:approved"
) )
else: else:
provider_payload = dict(run.request_payload.get("provider_payload") or {}) provider_payload = dict(
run.request_payload.get("provider_payload") or {}
)
provider_payload.update( provider_payload.update(
{ {
"mode": "approval_response", "mode": "approval_response",
@@ -334,17 +395,30 @@ class CodexCommandHandler(CommandHandler):
"error": "", "error": "",
}, },
) )
return CommandResult(ok=True, status="ok", payload={"approval_key": approval_key, "resolution": "approved"}) return CommandResult(
ok=True,
status="ok",
payload={"approval_key": approval_key, "resolution": "approved"},
)
request.status = "denied" request.status = "denied"
request.resolved_at = now request.resolved_at = now
request.resolved_by_identifier = current_channel request.resolved_by_identifier = current_channel
request.resolution_note = "denied via command" request.resolution_note = "denied via command"
await sync_to_async(request.save)( await sync_to_async(request.save)(
update_fields=["status", "resolved_at", "resolved_by_identifier", "resolution_note"] update_fields=[
"status",
"resolved_at",
"resolved_by_identifier",
"resolution_note",
]
) )
if request.external_sync_event_id: if request.external_sync_event_id:
await sync_to_async(ExternalSyncEvent.objects.filter(id=request.external_sync_event_id).update)( await sync_to_async(
ExternalSyncEvent.objects.filter(
id=request.external_sync_event_id
).update
)(
status="failed", status="failed",
error="approval_denied", error="approval_denied",
) )
@@ -371,7 +445,11 @@ class CodexCommandHandler(CommandHandler):
"error": "approval_denied", "error": "approval_denied",
}, },
) )
return CommandResult(ok=True, status="ok", payload={"approval_key": approval_key, "resolution": "denied"}) return CommandResult(
ok=True,
status="ok",
payload={"approval_key": approval_key, "resolution": "denied"},
)
async def _create_submission( async def _create_submission(
self, self,
@@ -383,10 +461,14 @@ class CodexCommandHandler(CommandHandler):
project: TaskProject, project: TaskProject,
) -> CommandResult: ) -> CommandResult:
cfg = await sync_to_async( cfg = await sync_to_async(
lambda: TaskProviderConfig.objects.filter(user=trigger.user, provider="codex_cli", enabled=True).first() lambda: TaskProviderConfig.objects.filter(
user=trigger.user, provider="codex_cli", enabled=True
).first()
)() )()
if cfg is None: if cfg is None:
return CommandResult(ok=False, status="failed", error="provider_disabled_or_missing") return CommandResult(
ok=False, status="failed", error="provider_disabled_or_missing"
)
service, channel = self._effective_scope(trigger) service, channel = self._effective_scope(trigger)
external_chat_id = await sync_to_async(resolve_external_chat_id)( external_chat_id = await sync_to_async(resolve_external_chat_id)(
@@ -413,7 +495,9 @@ class CodexCommandHandler(CommandHandler):
if mode == "plan": if mode == "plan":
anchor = trigger.reply_to anchor = trigger.reply_to
if anchor is None: if anchor is None:
return CommandResult(ok=False, status="failed", error="reply_required_for_codex_plan") return CommandResult(
ok=False, status="failed", error="reply_required_for_codex_plan"
)
rows = await sync_to_async(list)( rows = await sync_to_async(list)(
Message.objects.filter( Message.objects.filter(
user=trigger.user, user=trigger.user,
@@ -422,7 +506,9 @@ class CodexCommandHandler(CommandHandler):
ts__lte=int(trigger.ts or 0), ts__lte=int(trigger.ts or 0),
) )
.order_by("ts") .order_by("ts")
.select_related("session", "session__identifier", "session__identifier__person") .select_related(
"session", "session__identifier", "session__identifier__person"
)
) )
payload["reply_context"] = { payload["reply_context"] = {
"anchor_message_id": str(anchor.id), "anchor_message_id": str(anchor.id),
@@ -441,12 +527,18 @@ class CodexCommandHandler(CommandHandler):
source_channel=channel, source_channel=channel,
external_chat_id=external_chat_id, external_chat_id=external_chat_id,
status="waiting_approval", status="waiting_approval",
request_payload={"action": "append_update", "provider_payload": dict(payload)}, request_payload={
"action": "append_update",
"provider_payload": dict(payload),
},
result_payload={}, result_payload={},
error="", error="",
) )
payload["codex_run_id"] = str(run.id) payload["codex_run_id"] = str(run.id)
run.request_payload = {"action": "append_update", "provider_payload": dict(payload)} run.request_payload = {
"action": "append_update",
"provider_payload": dict(payload),
}
await sync_to_async(run.save)(update_fields=["request_payload", "updated_at"]) await sync_to_async(run.save)(update_fields=["request_payload", "updated_at"])
idempotency_key = f"codex_cmd:{trigger.id}:{mode}:{task.id}:{hashlib.sha1(str(body_text or '').encode('utf-8')).hexdigest()[:12]}" idempotency_key = f"codex_cmd:{trigger.id}:{mode}:{task.id}:{hashlib.sha1(str(body_text or '').encode('utf-8')).hexdigest()[:12]}"
@@ -471,20 +563,26 @@ class CodexCommandHandler(CommandHandler):
return CommandResult(ok=False, status="failed", error="trigger_not_found") return CommandResult(ok=False, status="failed", error="trigger_not_found")
profile = await sync_to_async( profile = await sync_to_async(
lambda: CommandProfile.objects.filter(user=trigger.user, slug=self.slug, enabled=True).first() lambda: CommandProfile.objects.filter(
user=trigger.user, slug=self.slug, enabled=True
).first()
)() )()
if profile is None: if profile is None:
return CommandResult(ok=False, status="skipped", error="profile_missing") return CommandResult(ok=False, status="skipped", error="profile_missing")
parsed = parse_codex_command(ctx.message_text) parsed = parse_codex_command(ctx.message_text)
if not parsed.command: if not parsed.command:
return CommandResult(ok=False, status="skipped", error="codex_command_not_matched") return CommandResult(
ok=False, status="skipped", error="codex_command_not_matched"
)
service, channel = self._effective_scope(trigger) service, channel = self._effective_scope(trigger)
if parsed.command == "status": if parsed.command == "status":
project = None project = None
reply_task = await self._linked_task_from_reply(trigger.user, trigger.reply_to) reply_task = await self._linked_task_from_reply(
trigger.user, trigger.reply_to
)
if reply_task is not None: if reply_task is not None:
project = reply_task.project project = reply_task.project
return await self._run_status(trigger, service, channel, project) return await self._run_status(trigger, service, channel, project)
@@ -502,7 +600,9 @@ class CodexCommandHandler(CommandHandler):
reply_task = await self._linked_task_from_reply(trigger.user, trigger.reply_to) reply_task = await self._linked_task_from_reply(trigger.user, trigger.reply_to)
task = await self._resolve_task(trigger.user, reference_code, reply_task) task = await self._resolve_task(trigger.user, reference_code, reply_task)
if task is None: if task is None:
return CommandResult(ok=False, status="failed", error="task_target_required") return CommandResult(
ok=False, status="failed", error="task_target_required"
)
project, project_error = await self._resolve_project( project, project_error = await self._resolve_project(
user=trigger.user, user=trigger.user,
@@ -513,7 +613,9 @@ class CodexCommandHandler(CommandHandler):
project_token=project_token, project_token=project_token,
) )
if project is None: if project is None:
return CommandResult(ok=False, status="failed", error=project_error or "project_unresolved") return CommandResult(
ok=False, status="failed", error=project_error or "project_unresolved"
)
mode = "plan" if parsed.command == "plan" else "default" mode = "plan" if parsed.command == "plan" else "default"
return await self._create_submission( return await self._create_submission(

View File

@@ -32,7 +32,8 @@ def _legacy_defaults(profile: CommandProfile, post_result_enabled: bool) -> dict
"enabled": True, "enabled": True,
"generation_mode": "ai", "generation_mode": "ai",
"send_plan_to_egress": bool(post_result_enabled), "send_plan_to_egress": bool(post_result_enabled),
"send_status_to_source": str(profile.visibility_mode or "") == "status_in_source", "send_status_to_source": str(profile.visibility_mode or "")
== "status_in_source",
"send_status_to_egress": False, "send_status_to_egress": False,
"store_document": True, "store_document": True,
} }
@@ -56,7 +57,9 @@ def ensure_variant_policies_for_profile(
*, *,
action_rows: Iterable[CommandAction] | None = None, action_rows: Iterable[CommandAction] | None = None,
) -> dict[str, CommandVariantPolicy]: ) -> dict[str, CommandVariantPolicy]:
actions = list(action_rows) if action_rows is not None else list(profile.actions.all()) actions = (
list(action_rows) if action_rows is not None else list(profile.actions.all())
)
post_result_enabled = any( post_result_enabled = any(
row.action_type == "post_result" and bool(row.enabled) for row in actions row.action_type == "post_result" and bool(row.enabled) for row in actions
) )
@@ -91,7 +94,9 @@ def ensure_variant_policies_for_profile(
return result return result
def load_variant_policy(profile: CommandProfile, variant_key: str) -> CommandVariantPolicy | None: def load_variant_policy(
profile: CommandProfile, variant_key: str
) -> CommandVariantPolicy | None:
key = str(variant_key or "").strip() key = str(variant_key or "").strip()
if not key: if not key:
return None return None

View File

@@ -27,6 +27,7 @@ def settings_hierarchy_nav(request):
ai_models_href = reverse("ai_models") ai_models_href = reverse("ai_models")
ai_traces_href = reverse("ai_execution_log") ai_traces_href = reverse("ai_execution_log")
commands_href = reverse("command_routing") commands_href = reverse("command_routing")
business_plans_href = reverse("business_plan_inbox")
tasks_href = reverse("tasks_settings") tasks_href = reverse("tasks_settings")
translation_href = reverse("translation_settings") translation_href = reverse("translation_settings")
availability_href = reverse("availability_settings") availability_href = reverse("availability_settings")
@@ -55,6 +56,8 @@ def settings_hierarchy_nav(request):
modules_routes = { modules_routes = {
"modules_settings", "modules_settings",
"command_routing", "command_routing",
"business_plan_inbox",
"business_plan_editor",
"tasks_settings", "tasks_settings",
"translation_settings", "translation_settings",
"availability_settings", "availability_settings",
@@ -106,7 +109,12 @@ def settings_hierarchy_nav(request):
"title": "Modules", "title": "Modules",
"tabs": [ "tabs": [
_tab("Commands", commands_href, path == commands_href), _tab("Commands", commands_href, path == commands_href),
_tab("Tasks", tasks_href, path == tasks_href), _tab(
"Business Plans",
business_plans_href,
url_name in {"business_plan_inbox", "business_plan_editor"},
),
_tab("Task Automation", tasks_href, path == tasks_href),
_tab("Translation", translation_href, path == translation_href), _tab("Translation", translation_href, path == translation_href),
_tab("Availability", availability_href, path == availability_href), _tab("Availability", availability_href, path == availability_href),
], ],

View File

@@ -24,7 +24,6 @@ async def init_mysql_pool():
async def close_mysql_pool(): async def close_mysql_pool():
"""Close the MySQL connection pool properly.""" """Close the MySQL connection pool properly."""
global mysql_pool
if mysql_pool: if mysql_pool:
mysql_pool.close() mysql_pool.close()
await mysql_pool.wait_closed() await mysql_pool.wait_closed()

View File

@@ -15,8 +15,12 @@ def event_ledger_enabled() -> bool:
def event_ledger_status() -> dict: def event_ledger_status() -> dict:
return { return {
"event_ledger_dual_write": bool(getattr(settings, "EVENT_LEDGER_DUAL_WRITE", False)), "event_ledger_dual_write": bool(
"event_primary_write_path": bool(getattr(settings, "EVENT_PRIMARY_WRITE_PATH", False)), getattr(settings, "EVENT_LEDGER_DUAL_WRITE", False)
),
"event_primary_write_path": bool(
getattr(settings, "EVENT_PRIMARY_WRITE_PATH", False)
),
} }
@@ -61,9 +65,7 @@ def append_event_sync(
if not normalized_type: if not normalized_type:
raise ValueError("event_type is required") raise ValueError("event_type is required")
candidates = { candidates = {str(choice[0]) for choice in ConversationEvent.EVENT_TYPE_CHOICES}
str(choice[0]) for choice in ConversationEvent.EVENT_TYPE_CHOICES
}
if normalized_type not in candidates: if normalized_type not in candidates:
raise ValueError(f"unsupported event_type: {normalized_type}") raise ValueError(f"unsupported event_type: {normalized_type}")

View File

@@ -90,7 +90,9 @@ def project_session_from_events(session: ChatSession) -> list[dict]:
order.append(message_id) order.append(message_id)
state.ts = _safe_int(payload.get("message_ts"), _safe_int(event.ts)) state.ts = _safe_int(payload.get("message_ts"), _safe_int(event.ts))
state.text = str(payload.get("text") or state.text or "") state.text = str(payload.get("text") or state.text or "")
delivered_default = _safe_int(payload.get("delivered_ts"), _safe_int(event.ts)) delivered_default = _safe_int(
payload.get("delivered_ts"), _safe_int(event.ts)
)
if state.delivered_ts is None: if state.delivered_ts is None:
state.delivered_ts = delivered_default or None state.delivered_ts = delivered_default or None
continue continue
@@ -111,7 +113,11 @@ def project_session_from_events(session: ChatSession) -> list[dict]:
continue continue
if event_type in {"reaction_added", "reaction_removed"}: if event_type in {"reaction_added", "reaction_removed"}:
source_service = str(payload.get("source_service") or event.origin_transport or "").strip().lower() source_service = (
str(payload.get("source_service") or event.origin_transport or "")
.strip()
.lower()
)
actor = str(payload.get("actor") or event.actor_identifier or "").strip() actor = str(payload.get("actor") or event.actor_identifier or "").strip()
emoji = str(payload.get("emoji") or "").strip() emoji = str(payload.get("emoji") or "").strip()
if not source_service and not actor and not emoji: if not source_service and not actor and not emoji:
@@ -121,7 +127,9 @@ def project_session_from_events(session: ChatSession) -> list[dict]:
"source_service": source_service, "source_service": source_service,
"actor": actor, "actor": actor,
"emoji": emoji, "emoji": emoji,
"removed": bool(event_type == "reaction_removed" or payload.get("remove")), "removed": bool(
event_type == "reaction_removed" or payload.get("remove")
),
} }
output = [] output = []
@@ -135,12 +143,12 @@ def project_session_from_events(session: ChatSession) -> list[dict]:
"ts": int(state.ts or 0), "ts": int(state.ts or 0),
"text": str(state.text or ""), "text": str(state.text or ""),
"delivered_ts": ( "delivered_ts": (
int(state.delivered_ts) int(state.delivered_ts) if state.delivered_ts is not None else None
if state.delivered_ts is not None
else None
), ),
"read_ts": int(state.read_ts) if state.read_ts is not None else None, "read_ts": int(state.read_ts) if state.read_ts is not None else None,
"reactions": _normalize_reactions(list((state.reactions or {}).values())), "reactions": _normalize_reactions(
list((state.reactions or {}).values())
),
} }
) )
return output return output
@@ -182,7 +190,9 @@ def shadow_compare_session(session: ChatSession, detail_limit: int = 50) -> dict
cause_samples = {key: [] for key in cause_counts.keys()} cause_samples = {key: [] for key in cause_counts.keys()}
cause_sample_limit = min(5, max(0, int(detail_limit))) cause_sample_limit = min(5, max(0, int(detail_limit)))
def _record_detail(message_id: str, issue: str, cause: str, extra: dict | None = None): def _record_detail(
message_id: str, issue: str, cause: str, extra: dict | None = None
):
if cause in cause_counts: if cause in cause_counts:
cause_counts[cause] += 1 cause_counts[cause] += 1
row = {"message_id": message_id, "issue": issue, "cause": cause} row = {"message_id": message_id, "issue": issue, "cause": cause}
@@ -224,13 +234,10 @@ def shadow_compare_session(session: ChatSession, detail_limit: int = 50) -> dict
db_delivered_ts = db_row.get("delivered_ts") db_delivered_ts = db_row.get("delivered_ts")
projected_delivered_ts = projected.get("delivered_ts") projected_delivered_ts = projected.get("delivered_ts")
if ( if (db_delivered_ts is None) != (projected_delivered_ts is None) or (
(db_delivered_ts is None) != (projected_delivered_ts is None)
or (
db_delivered_ts is not None db_delivered_ts is not None
and projected_delivered_ts is not None and projected_delivered_ts is not None
and int(db_delivered_ts) != int(projected_delivered_ts) and int(db_delivered_ts) != int(projected_delivered_ts)
)
): ):
counters["delivered_ts_mismatch"] += 1 counters["delivered_ts_mismatch"] += 1
_record_detail( _record_detail(
@@ -245,13 +252,10 @@ def shadow_compare_session(session: ChatSession, detail_limit: int = 50) -> dict
db_read_ts = db_row.get("read_ts") db_read_ts = db_row.get("read_ts")
projected_read_ts = projected.get("read_ts") projected_read_ts = projected.get("read_ts")
if ( if (db_read_ts is None) != (projected_read_ts is None) or (
(db_read_ts is None) != (projected_read_ts is None)
or (
db_read_ts is not None db_read_ts is not None
and projected_read_ts is not None and projected_read_ts is not None
and int(db_read_ts) != int(projected_read_ts) and int(db_read_ts) != int(projected_read_ts)
)
): ):
counters["read_ts_mismatch"] += 1 counters["read_ts_mismatch"] += 1
_record_detail( _record_detail(
@@ -264,12 +268,19 @@ def shadow_compare_session(session: ChatSession, detail_limit: int = 50) -> dict
db_reactions = _normalize_reactions( db_reactions = _normalize_reactions(
list((db_row.get("receipt_payload") or {}).get("reactions") or []) list((db_row.get("receipt_payload") or {}).get("reactions") or [])
) )
projected_reactions = _normalize_reactions(list(projected.get("reactions") or [])) projected_reactions = _normalize_reactions(
list(projected.get("reactions") or [])
)
if db_reactions != projected_reactions: if db_reactions != projected_reactions:
counters["reactions_mismatch"] += 1 counters["reactions_mismatch"] += 1
cause = "payload_normalization_gap" cause = "payload_normalization_gap"
strategy = str( strategy = str(
((db_row.get("receipt_payload") or {}).get("reaction_last_match_strategy") or "") (
(db_row.get("receipt_payload") or {}).get(
"reaction_last_match_strategy"
)
or ""
)
).strip() ).strip()
if strategy == "nearest_ts_window": if strategy == "nearest_ts_window":
cause = "ambiguous_reaction_target" cause = "ambiguous_reaction_target"

View File

@@ -1,6 +1,7 @@
from django import forms from django import forms
from django.contrib.auth.forms import UserCreationForm from django.contrib.auth.forms import UserCreationForm
from django.forms import ModelForm from django.forms import ModelForm
from mixins.restrictions import RestrictedFormMixin from mixins.restrictions import RestrictedFormMixin
from .models import ( from .models import (

View File

@@ -8,7 +8,6 @@ from asgiref.sync import sync_to_async
from core.models import GatewayCommandEvent from core.models import GatewayCommandEvent
from core.security.command_policy import CommandSecurityContext, evaluate_command_policy from core.security.command_policy import CommandSecurityContext, evaluate_command_policy
GatewayEmit = Callable[[str], None] GatewayEmit = Callable[[str], None]
GatewayHandler = Callable[["GatewayCommandContext", GatewayEmit], Awaitable[bool]] GatewayHandler = Callable[["GatewayCommandContext", GatewayEmit], Awaitable[bool]]
GatewayMatcher = Callable[[str], bool] GatewayMatcher = Callable[[str], bool]
@@ -103,7 +102,10 @@ async def dispatch_gateway_command(
emit(message) emit(message)
event.status = "blocked" event.status = "blocked"
event.error = f"{decision.code}:{decision.reason}" event.error = f"{decision.code}:{decision.reason}"
event.response_meta = {"policy_code": decision.code, "policy_reason": decision.reason} event.response_meta = {
"policy_code": decision.code,
"policy_reason": decision.reason,
}
await sync_to_async(event.save)( await sync_to_async(event.save)(
update_fields=["status", "error", "response_meta", "updated_at"] update_fields=["status", "error", "response_meta", "updated_at"]
) )
@@ -129,5 +131,7 @@ async def dispatch_gateway_command(
event.status = "ok" if handled else "ignored" event.status = "ok" if handled else "ignored"
event.response_meta = {"responses": responses} event.response_meta = {"responses": responses}
await sync_to_async(event.save)(update_fields=["status", "response_meta", "updated_at"]) await sync_to_async(event.save)(
update_fields=["status", "response_meta", "updated_at"]
)
return bool(handled) return bool(handled)

View File

@@ -4,7 +4,7 @@ from typing import Iterable
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from core.models import Message, User from core.models import Message
from core.presence import AvailabilitySignal, record_inferred_signal from core.presence import AvailabilitySignal, record_inferred_signal
from core.presence.inference import now_ms from core.presence.inference import now_ms
@@ -19,7 +19,9 @@ class Command(BaseCommand):
parser.add_argument("--user-id", default="") parser.add_argument("--user-id", default="")
parser.add_argument("--dry-run", action="store_true", default=False) parser.add_argument("--dry-run", action="store_true", default=False)
def _iter_messages(self, *, days: int, limit: int, service: str, user_id: str) -> Iterable[Message]: def _iter_messages(
self, *, days: int, limit: int, service: str, user_id: str
) -> Iterable[Message]:
cutoff_ts = now_ms() - (max(1, int(days)) * 24 * 60 * 60 * 1000) cutoff_ts = now_ms() - (max(1, int(days)) * 24 * 60 * 60 * 1000)
qs = Message.objects.filter(ts__gte=cutoff_ts).select_related( qs = Message.objects.filter(ts__gte=cutoff_ts).select_related(
"user", "session", "session__identifier", "session__identifier__person" "user", "session", "session__identifier", "session__identifier__person"
@@ -40,7 +42,9 @@ class Command(BaseCommand):
created = 0 created = 0
scanned = 0 scanned = 0
for msg in self._iter_messages(days=days, limit=limit, service=service_filter, user_id=user_filter): for msg in self._iter_messages(
days=days, limit=limit, service=service_filter, user_id=user_filter
):
scanned += 1 scanned += 1
identifier = getattr(getattr(msg, "session", None), "identifier", None) identifier = getattr(getattr(msg, "session", None), "identifier", None)
person = getattr(identifier, "person", None) person = getattr(identifier, "person", None)
@@ -48,12 +52,18 @@ class Command(BaseCommand):
if not identifier or not person or not user: if not identifier or not person or not user:
continue continue
service = str(getattr(msg, "source_service", "") or identifier.service or "").strip().lower() service = (
str(getattr(msg, "source_service", "") or identifier.service or "")
.strip()
.lower()
)
if not service: if not service:
continue continue
base_ts = int(getattr(msg, "ts", 0) or 0) base_ts = int(getattr(msg, "ts", 0) or 0)
message_author = str(getattr(msg, "custom_author", "") or "").strip().upper() message_author = (
str(getattr(msg, "custom_author", "") or "").strip().upper()
)
outgoing = message_author in {"USER", "BOT"} outgoing = message_author in {"USER", "BOT"}
candidates = [] candidates = []
@@ -84,7 +94,9 @@ class Command(BaseCommand):
"origin": "backfill_contact_availability", "origin": "backfill_contact_availability",
"message_id": str(msg.id), "message_id": str(msg.id),
"inferred_from": "read_receipt", "inferred_from": "read_receipt",
"read_by": str(getattr(msg, "read_by_identifier", "") or ""), "read_by": str(
getattr(msg, "read_by_identifier", "") or ""
),
}, },
} }
) )

View File

@@ -7,7 +7,12 @@ from asgiref.sync import async_to_sync
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from core.clients.transport import send_message_raw from core.clients.transport import send_message_raw
from core.models import CodexPermissionRequest, CodexRun, ExternalSyncEvent, TaskProviderConfig from core.models import (
CodexPermissionRequest,
CodexRun,
ExternalSyncEvent,
TaskProviderConfig,
)
from core.tasks.providers import get_provider from core.tasks.providers import get_provider
from core.util import logs from core.util import logs
@@ -15,7 +20,9 @@ log = logs.get_logger("codex_worker")
class Command(BaseCommand): class Command(BaseCommand):
help = "Process queued external sync events for worker-backed providers (codex_cli)." help = (
"Process queued external sync events for worker-backed providers (codex_cli)."
)
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument("--once", action="store_true", default=False) parser.add_argument("--once", action="store_true", default=False)
@@ -73,7 +80,9 @@ class Command(BaseCommand):
payload = dict(event.payload or {}) payload = dict(event.payload or {})
action = str(payload.get("action") or "append_update").strip().lower() action = str(payload.get("action") or "append_update").strip().lower()
provider_payload = dict(payload.get("provider_payload") or payload) provider_payload = dict(payload.get("provider_payload") or payload)
run_id = str(provider_payload.get("codex_run_id") or payload.get("codex_run_id") or "").strip() run_id = str(
provider_payload.get("codex_run_id") or payload.get("codex_run_id") or ""
).strip()
codex_run = None codex_run = None
if run_id: if run_id:
codex_run = CodexRun.objects.filter(id=run_id, user=event.user).first() codex_run = CodexRun.objects.filter(id=run_id, user=event.user).first()
@@ -104,9 +113,13 @@ class Command(BaseCommand):
result_payload = dict(result.payload or {}) result_payload = dict(result.payload or {})
requires_approval = bool(result_payload.get("requires_approval")) requires_approval = bool(result_payload.get("requires_approval"))
if requires_approval: if requires_approval:
approval_key = str(result_payload.get("approval_key") or uuid.uuid4().hex[:12]).strip() approval_key = str(
result_payload.get("approval_key") or uuid.uuid4().hex[:12]
).strip()
permission_request = dict(result_payload.get("permission_request") or {}) permission_request = dict(result_payload.get("permission_request") or {})
summary = str(result_payload.get("summary") or permission_request.get("summary") or "").strip() summary = str(
result_payload.get("summary") or permission_request.get("summary") or ""
).strip()
requested_permissions = permission_request.get("requested_permissions") requested_permissions = permission_request.get("requested_permissions")
if not isinstance(requested_permissions, (list, dict)): if not isinstance(requested_permissions, (list, dict)):
requested_permissions = permission_request or {} requested_permissions = permission_request or {}
@@ -121,28 +134,42 @@ class Command(BaseCommand):
codex_run.status = "waiting_approval" codex_run.status = "waiting_approval"
codex_run.result_payload = dict(result_payload) codex_run.result_payload = dict(result_payload)
codex_run.error = "" codex_run.error = ""
codex_run.save(update_fields=["status", "result_payload", "error", "updated_at"]) codex_run.save(
update_fields=["status", "result_payload", "error", "updated_at"]
)
CodexPermissionRequest.objects.update_or_create( CodexPermissionRequest.objects.update_or_create(
approval_key=approval_key, approval_key=approval_key,
defaults={ defaults={
"user": event.user, "user": event.user,
"codex_run": codex_run if codex_run is not None else CodexRun.objects.create( "codex_run": (
codex_run
if codex_run is not None
else CodexRun.objects.create(
user=event.user, user=event.user,
task=event.task, task=event.task,
derived_task_event=event.task_event, derived_task_event=event.task_event,
source_service=str(provider_payload.get("source_service") or ""), source_service=str(
source_channel=str(provider_payload.get("source_channel") or ""), provider_payload.get("source_service") or ""
external_chat_id=str(provider_payload.get("external_chat_id") or ""), ),
source_channel=str(
provider_payload.get("source_channel") or ""
),
external_chat_id=str(
provider_payload.get("external_chat_id") or ""
),
status="waiting_approval", status="waiting_approval",
request_payload=dict(payload or {}), request_payload=dict(payload or {}),
result_payload=dict(result_payload), result_payload=dict(result_payload),
error="", error="",
)
), ),
"external_sync_event": event, "external_sync_event": event,
"summary": summary, "summary": summary,
"requested_permissions": requested_permissions if isinstance(requested_permissions, dict) else { "requested_permissions": (
"items": list(requested_permissions or []) requested_permissions
}, if isinstance(requested_permissions, dict)
else {"items": list(requested_permissions or [])}
),
"resume_payload": dict(resume_payload or {}), "resume_payload": dict(resume_payload or {}),
"status": "pending", "status": "pending",
"resolved_at": None, "resolved_at": None,
@@ -150,9 +177,17 @@ class Command(BaseCommand):
"resolution_note": "", "resolution_note": "",
}, },
) )
approver_service = str((cfg.settings or {}).get("approver_service") or "").strip().lower() approver_service = (
approver_identifier = str((cfg.settings or {}).get("approver_identifier") or "").strip() str((cfg.settings or {}).get("approver_service") or "").strip().lower()
requested_text = result_payload.get("permission_request") or result_payload.get("requested_permissions") or {} )
approver_identifier = str(
(cfg.settings or {}).get("approver_identifier") or ""
).strip()
requested_text = (
result_payload.get("permission_request")
or result_payload.get("requested_permissions")
or {}
)
if approver_service and approver_identifier: if approver_service and approver_identifier:
try: try:
async_to_sync(send_message_raw)( async_to_sync(send_message_raw)(
@@ -168,10 +203,17 @@ class Command(BaseCommand):
metadata={"origin_tag": f"codex-approval:{approval_key}"}, metadata={"origin_tag": f"codex-approval:{approval_key}"},
) )
except Exception: except Exception:
log.exception("failed to notify approver channel for approval_key=%s", approval_key) log.exception(
"failed to notify approver channel for approval_key=%s",
approval_key,
)
else: else:
source_service = str(provider_payload.get("source_service") or "").strip().lower() source_service = (
source_channel = str(provider_payload.get("source_channel") or "").strip() str(provider_payload.get("source_service") or "").strip().lower()
)
source_channel = str(
provider_payload.get("source_channel") or ""
).strip()
if source_service and source_channel: if source_service and source_channel:
try: try:
async_to_sync(send_message_raw)( async_to_sync(send_message_raw)(
@@ -185,7 +227,9 @@ class Command(BaseCommand):
metadata={"origin_tag": "codex-approval-missing-target"}, metadata={"origin_tag": "codex-approval-missing-target"},
) )
except Exception: except Exception:
log.exception("failed to notify source channel for missing approver target") log.exception(
"failed to notify source channel for missing approver target"
)
return return
event.status = "ok" if result.ok else "failed" event.status = "ok" if result.ok else "failed"
@@ -201,18 +245,24 @@ class Command(BaseCommand):
approval_key = str(provider_payload.get("approval_key") or "").strip() approval_key = str(provider_payload.get("approval_key") or "").strip()
if mode == "approval_response" and approval_key: if mode == "approval_response" and approval_key:
req = ( req = (
CodexPermissionRequest.objects.select_related("external_sync_event", "codex_run") CodexPermissionRequest.objects.select_related(
"external_sync_event", "codex_run"
)
.filter(user=event.user, approval_key=approval_key) .filter(user=event.user, approval_key=approval_key)
.first() .first()
) )
if req and req.external_sync_event_id: if req and req.external_sync_event_id:
if result.ok: if result.ok:
ExternalSyncEvent.objects.filter(id=req.external_sync_event_id).update( ExternalSyncEvent.objects.filter(
id=req.external_sync_event_id
).update(
status="ok", status="ok",
error="", error="",
) )
elif str(event.error or "").strip() == "approval_denied": elif str(event.error or "").strip() == "approval_denied":
ExternalSyncEvent.objects.filter(id=req.external_sync_event_id).update( ExternalSyncEvent.objects.filter(
id=req.external_sync_event_id
).update(
status="failed", status="failed",
error="approval_denied", error="approval_denied",
) )
@@ -220,9 +270,16 @@ class Command(BaseCommand):
codex_run.status = "ok" if result.ok else "failed" codex_run.status = "ok" if result.ok else "failed"
codex_run.error = str(result.error or "") codex_run.error = str(result.error or "")
codex_run.result_payload = result_payload codex_run.result_payload = result_payload
codex_run.save(update_fields=["status", "error", "result_payload", "updated_at"]) codex_run.save(
update_fields=["status", "error", "result_payload", "updated_at"]
)
if result.ok and result.external_key and event.task_id and not str(event.task.external_key or "").strip(): if (
result.ok
and result.external_key
and event.task_id
and not str(event.task.external_key or "").strip()
):
event.task.external_key = str(result.external_key) event.task.external_key = str(result.external_key)
event.task.save(update_fields=["external_key"]) event.task.save(update_fields=["external_key"])
@@ -250,7 +307,11 @@ class Command(BaseCommand):
continue continue
for row_id in claimed_ids: for row_id in claimed_ids:
event = ExternalSyncEvent.objects.filter(id=row_id).select_related("task", "user").first() event = (
ExternalSyncEvent.objects.filter(id=row_id)
.select_related("task", "user")
.first()
)
if event is None: if event is None:
continue continue
try: try:

View File

@@ -85,7 +85,9 @@ class Command(BaseCommand):
compared = shadow_compare_session(session, detail_limit=detail_limit) compared = shadow_compare_session(session, detail_limit=detail_limit)
aggregate["sessions_scanned"] += 1 aggregate["sessions_scanned"] += 1
aggregate["db_message_count"] += int(compared.get("db_message_count") or 0) aggregate["db_message_count"] += int(compared.get("db_message_count") or 0)
aggregate["projected_message_count"] += int(compared.get("projected_message_count") or 0) aggregate["projected_message_count"] += int(
compared.get("projected_message_count") or 0
)
aggregate["mismatch_total"] += int(compared.get("mismatch_total") or 0) aggregate["mismatch_total"] += int(compared.get("mismatch_total") or 0)
for key in aggregate["counters"].keys(): for key in aggregate["counters"].keys():
aggregate["counters"][key] += int( aggregate["counters"][key] += int(

View File

@@ -1,14 +1,11 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from core.models import ContactAvailabilityEvent, ContactAvailabilitySpan, Message from core.models import ContactAvailabilityEvent, ContactAvailabilitySpan, Message
from core.presence import AvailabilitySignal, record_native_signal from core.presence import AvailabilitySignal, record_native_signal
from core.presence.inference import now_ms from core.presence.inference import now_ms
_SOURCE_ORDER = { _SOURCE_ORDER = {
"message_in": 10, "message_in": 10,
"message_out": 20, "message_out": 20,
@@ -51,9 +48,14 @@ class Command(BaseCommand):
if not identifier or not person or not user: if not identifier or not person or not user:
continue continue
service = str( service = (
getattr(msg, "source_service", "") or getattr(identifier, "service", "") str(
).strip().lower() getattr(msg, "source_service", "")
or getattr(identifier, "service", "")
)
.strip()
.lower()
)
if not service: if not service:
continue continue
@@ -95,12 +97,16 @@ class Command(BaseCommand):
"origin": "recalculate_contact_availability", "origin": "recalculate_contact_availability",
"message_id": str(msg.id), "message_id": str(msg.id),
"inferred_from": "read_receipt", "inferred_from": "read_receipt",
"read_by": str(getattr(msg, "read_by_identifier", "") or ""), "read_by": str(
getattr(msg, "read_by_identifier", "") or ""
),
}, },
} }
) )
reactions = list((getattr(msg, "receipt_payload", {}) or {}).get("reactions") or []) reactions = list(
(getattr(msg, "receipt_payload", {}) or {}).get("reactions") or []
)
for reaction in reactions: for reaction in reactions:
item = dict(reaction or {}) item = dict(reaction or {})
if bool(item.get("removed")): if bool(item.get("removed")):
@@ -124,7 +130,9 @@ class Command(BaseCommand):
"inferred_from": "reaction", "inferred_from": "reaction",
"emoji": str(item.get("emoji") or ""), "emoji": str(item.get("emoji") or ""),
"actor": str(item.get("actor") or ""), "actor": str(item.get("actor") or ""),
"source_service": str(item.get("source_service") or service), "source_service": str(
item.get("source_service") or service
),
}, },
} }
) )

View File

@@ -67,7 +67,9 @@ def _compute_payload(rows, identifier_values):
pending_out_ts = None pending_out_ts = None
first_ts = int(rows[0]["ts"] or 0) first_ts = int(rows[0]["ts"] or 0)
last_ts = int(rows[-1]["ts"] or 0) last_ts = int(rows[-1]["ts"] or 0)
latest_service = str(rows[-1].get("session__identifier__service") or "").strip().lower() latest_service = (
str(rows[-1].get("session__identifier__service") or "").strip().lower()
)
for row in rows: for row in rows:
ts = int(row.get("ts") or 0) ts = int(row.get("ts") or 0)
@@ -162,18 +164,18 @@ def _compute_payload(rows, identifier_values):
payload = { payload = {
"source_event_ts": last_ts, "source_event_ts": last_ts,
"stability_state": stability_state, "stability_state": stability_state,
"stability_score": float(stability_score_value) "stability_score": (
if stability_score_value is not None float(stability_score_value) if stability_score_value is not None else None
else None, ),
"stability_confidence": round(confidence, 3), "stability_confidence": round(confidence, 3),
"stability_sample_messages": message_count, "stability_sample_messages": message_count,
"stability_sample_days": sample_days, "stability_sample_days": sample_days,
"commitment_inbound_score": float(commitment_in_value) "commitment_inbound_score": (
if commitment_in_value is not None float(commitment_in_value) if commitment_in_value is not None else None
else None, ),
"commitment_outbound_score": float(commitment_out_value) "commitment_outbound_score": (
if commitment_out_value is not None float(commitment_out_value) if commitment_out_value is not None else None
else None, ),
"commitment_confidence": round(confidence, 3), "commitment_confidence": round(confidence, 3),
"inbound_messages": inbound_count, "inbound_messages": inbound_count,
"outbound_messages": outbound_count, "outbound_messages": outbound_count,
@@ -232,15 +234,17 @@ class Command(BaseCommand):
dry_run = bool(options.get("dry_run")) dry_run = bool(options.get("dry_run"))
reset = not bool(options.get("no_reset")) reset = not bool(options.get("no_reset"))
compact_enabled = not bool(options.get("skip_compact")) compact_enabled = not bool(options.get("skip_compact"))
today_start = dj_timezone.now().astimezone(timezone.utc).replace( today_start = (
dj_timezone.now()
.astimezone(timezone.utc)
.replace(
hour=0, hour=0,
minute=0, minute=0,
second=0, second=0,
microsecond=0, microsecond=0,
) )
cutoff_ts = int(
(today_start.timestamp() * 1000) - (days * 24 * 60 * 60 * 1000)
) )
cutoff_ts = int((today_start.timestamp() * 1000) - (days * 24 * 60 * 60 * 1000))
people_qs = Person.objects.all() people_qs = Person.objects.all()
if user_id: if user_id:
@@ -256,14 +260,18 @@ class Command(BaseCommand):
compacted_deleted = 0 compacted_deleted = 0
for person in people: for person in people:
identifiers_qs = PersonIdentifier.objects.filter(user=person.user, person=person) identifiers_qs = PersonIdentifier.objects.filter(
user=person.user, person=person
)
if service: if service:
identifiers_qs = identifiers_qs.filter(service=service) identifiers_qs = identifiers_qs.filter(service=service)
identifiers = list(identifiers_qs) identifiers = list(identifiers_qs)
if not identifiers: if not identifiers:
continue continue
identifier_values = { identifier_values = {
str(row.identifier or "").strip() for row in identifiers if row.identifier str(row.identifier or "").strip()
for row in identifiers
if row.identifier
} }
if not identifier_values: if not identifier_values:
continue continue
@@ -350,7 +358,9 @@ class Command(BaseCommand):
snapshots_created += 1 snapshots_created += 1
if dry_run: if dry_run:
continue continue
WorkspaceMetricSnapshot.objects.create(conversation=conversation, **payload) WorkspaceMetricSnapshot.objects.create(
conversation=conversation, **payload
)
existing_signatures.add(signature) existing_signatures.add(signature)
if not latest_payload: if not latest_payload:
@@ -368,7 +378,9 @@ class Command(BaseCommand):
"updated_at": dj_timezone.now().isoformat(), "updated_at": dj_timezone.now().isoformat(),
} }
if not dry_run: if not dry_run:
conversation.platform_type = latest_service or conversation.platform_type conversation.platform_type = (
latest_service or conversation.platform_type
)
conversation.last_event_ts = latest_payload.get("source_event_ts") conversation.last_event_ts = latest_payload.get("source_event_ts")
conversation.stability_state = str( conversation.stability_state = str(
latest_payload.get("stability_state") latest_payload.get("stability_state")
@@ -416,7 +428,9 @@ class Command(BaseCommand):
) )
if compact_enabled: if compact_enabled:
snapshot_rows = list( snapshot_rows = list(
WorkspaceMetricSnapshot.objects.filter(conversation=conversation) WorkspaceMetricSnapshot.objects.filter(
conversation=conversation
)
.order_by("computed_at", "id") .order_by("computed_at", "id")
.values("id", "computed_at", "source_event_ts") .values("id", "computed_at", "source_event_ts")
) )
@@ -428,7 +442,9 @@ class Command(BaseCommand):
) )
if keep_ids: if keep_ids:
compacted_deleted += ( compacted_deleted += (
WorkspaceMetricSnapshot.objects.filter(conversation=conversation) WorkspaceMetricSnapshot.objects.filter(
conversation=conversation
)
.exclude(id__in=list(keep_ids)) .exclude(id__in=list(keep_ids))
.delete()[0] .delete()[0]
) )

View File

@@ -4,4 +4,6 @@ from core.management.commands.codex_worker import Command as LegacyCodexWorkerCo
class Command(LegacyCodexWorkerCommand): class Command(LegacyCodexWorkerCommand):
help = "Process queued task-sync events for worker-backed providers (Codex + Claude)." help = (
"Process queued task-sync events for worker-backed providers (Codex + Claude)."
)

View File

@@ -123,7 +123,9 @@ def _handle_message(message: dict[str, Any]) -> dict[str, Any] | None:
msg_id, msg_id,
{ {
"isError": True, "isError": True,
"content": [{"type": "text", "text": json.dumps({"error": str(exc)})}], "content": [
{"type": "text", "text": json.dumps({"error": str(exc)})}
],
}, },
) )

View File

@@ -216,7 +216,9 @@ def _next_unique_slug(*, user_id: int, requested_slug: str) -> str:
raise ValueError("slug cannot be empty") raise ValueError("slug cannot be empty")
candidate = base candidate = base
idx = 2 idx = 2
while KnowledgeArticle.objects.filter(user_id=int(user_id), slug=candidate).exists(): while KnowledgeArticle.objects.filter(
user_id=int(user_id), slug=candidate
).exists():
suffix = f"-{idx}" suffix = f"-{idx}"
candidate = f"{base[: max(1, 255 - len(suffix))]}{suffix}" candidate = f"{base[: max(1, 255 - len(suffix))]}{suffix}"
idx += 1 idx += 1
@@ -645,9 +647,7 @@ def tool_wiki_update_article(arguments: dict[str, Any]) -> dict[str, Any]:
) )
if status_marker and status == "archived" and article.status != "archived": if status_marker and status == "archived" and article.status != "archived":
if not approve_archive: if not approve_archive:
raise ValueError( raise ValueError("approve_archive=true is required to archive an article")
"approve_archive=true is required to archive an article"
)
if title: if title:
article.title = title article.title = title
@@ -705,7 +705,9 @@ def tool_wiki_list(arguments: dict[str, Any]) -> dict[str, Any]:
def tool_wiki_get(arguments: dict[str, Any]) -> dict[str, Any]: def tool_wiki_get(arguments: dict[str, Any]) -> dict[str, Any]:
article = _get_article_for_user(arguments) article = _get_article_for_user(arguments)
include_revisions = bool(arguments.get("include_revisions")) include_revisions = bool(arguments.get("include_revisions"))
revision_limit = _safe_limit(arguments.get("revision_limit"), default=20, low=1, high=200) revision_limit = _safe_limit(
arguments.get("revision_limit"), default=20, low=1, high=200
)
payload = {"article": _article_payload(article)} payload = {"article": _article_payload(article)}
if include_revisions: if include_revisions:
revisions = article.revisions.order_by("-revision")[:revision_limit] revisions = article.revisions.order_by("-revision")[:revision_limit]
@@ -714,7 +716,9 @@ def tool_wiki_get(arguments: dict[str, Any]) -> dict[str, Any]:
def tool_project_get_guidelines(arguments: dict[str, Any]) -> dict[str, Any]: def tool_project_get_guidelines(arguments: dict[str, Any]) -> dict[str, Any]:
max_chars = _safe_limit(arguments.get("max_chars"), default=16000, low=500, high=50000) max_chars = _safe_limit(
arguments.get("max_chars"), default=16000, low=500, high=50000
)
base = Path(settings.BASE_DIR).resolve() base = Path(settings.BASE_DIR).resolve()
file_names = ["AGENTS.md", "LLM_CODING_STANDARDS.md", "INSTALL.md", "README.md"] file_names = ["AGENTS.md", "LLM_CODING_STANDARDS.md", "INSTALL.md", "README.md"]
payload = [] payload = []
@@ -734,7 +738,9 @@ def tool_project_get_guidelines(arguments: dict[str, Any]) -> dict[str, Any]:
def tool_project_get_layout(arguments: dict[str, Any]) -> dict[str, Any]: def tool_project_get_layout(arguments: dict[str, Any]) -> dict[str, Any]:
max_entries = _safe_limit(arguments.get("max_entries"), default=300, low=50, high=4000) max_entries = _safe_limit(
arguments.get("max_entries"), default=300, low=50, high=4000
)
base = Path(settings.BASE_DIR).resolve() base = Path(settings.BASE_DIR).resolve()
roots = ["app", "core", "scripts", "utilities", "artifacts"] roots = ["app", "core", "scripts", "utilities", "artifacts"]
items: list[str] = [] items: list[str] = []
@@ -754,7 +760,9 @@ def tool_project_get_layout(arguments: dict[str, Any]) -> dict[str, Any]:
def tool_project_get_runbook(arguments: dict[str, Any]) -> dict[str, Any]: def tool_project_get_runbook(arguments: dict[str, Any]) -> dict[str, Any]:
max_chars = _safe_limit(arguments.get("max_chars"), default=16000, low=500, high=50000) max_chars = _safe_limit(
arguments.get("max_chars"), default=16000, low=500, high=50000
)
base = Path(settings.BASE_DIR).resolve() base = Path(settings.BASE_DIR).resolve()
file_names = [ file_names = [
"INSTALL.md", "INSTALL.md",
@@ -792,7 +800,11 @@ def tool_docs_append_run_note(arguments: dict[str, Any]) -> dict[str, Any]:
path = Path("/tmp/gia-mcp-run-notes.md") path = Path("/tmp/gia-mcp-run-notes.md")
else: else:
candidate = Path(raw_path) candidate = Path(raw_path)
path = candidate.resolve() if candidate.is_absolute() else (base / candidate).resolve() path = (
candidate.resolve()
if candidate.is_absolute()
else (base / candidate).resolve()
)
allowed_roots = [base, Path("/tmp").resolve()] allowed_roots = [base, Path("/tmp").resolve()]
if not any(str(path).startswith(str(root)) for root in allowed_roots): if not any(str(path).startswith(str(root)) for root in allowed_roots):
raise ValueError("path must be within project root or /tmp") raise ValueError("path must be within project root or /tmp")
@@ -812,7 +824,11 @@ def tool_docs_append_run_note(arguments: dict[str, Any]) -> dict[str, Any]:
TOOL_DEFS: dict[str, dict[str, Any]] = { TOOL_DEFS: dict[str, dict[str, Any]] = {
"manticore.status": { "manticore.status": {
"description": "Report configured memory backend status (django or manticore).", "description": "Report configured memory backend status (django or manticore).",
"inputSchema": {"type": "object", "properties": {}, "additionalProperties": False}, "inputSchema": {
"type": "object",
"properties": {},
"additionalProperties": False,
},
"handler": tool_manticore_status, "handler": tool_manticore_status,
}, },
"manticore.query": { "manticore.query": {

View File

@@ -1,4 +1,4 @@
from .search_backend import get_memory_search_backend
from .retrieval import retrieve_memories_for_prompt from .retrieval import retrieve_memories_for_prompt
from .search_backend import get_memory_search_backend
__all__ = ["get_memory_search_backend", "retrieve_memories_for_prompt"] __all__ = ["get_memory_search_backend", "retrieve_memories_for_prompt"]

View File

@@ -224,7 +224,9 @@ def create_memory_change_request(
person_id=person_id or (str(memory.person_id or "") if memory else "") or None, person_id=person_id or (str(memory.person_id or "") if memory else "") or None,
action=normalized_action, action=normalized_action,
status="pending", status="pending",
proposed_memory_kind=str(memory_kind or (memory.memory_kind if memory else "")).strip(), proposed_memory_kind=str(
memory_kind or (memory.memory_kind if memory else "")
).strip(),
proposed_content=dict(content or {}), proposed_content=dict(content or {}),
proposed_confidence_score=( proposed_confidence_score=(
float(confidence_score) float(confidence_score)
@@ -335,7 +337,9 @@ def review_memory_change_request(
@transaction.atomic @transaction.atomic
def run_memory_hygiene(*, user_id: int | None = None, dry_run: bool = False) -> dict[str, int]: def run_memory_hygiene(
*, user_id: int | None = None, dry_run: bool = False
) -> dict[str, int]:
now = timezone.now() now = timezone.now()
queryset = MemoryItem.objects.filter(status="active") queryset = MemoryItem.objects.filter(status="active")
if user_id is not None: if user_id is not None:
@@ -357,7 +361,9 @@ def run_memory_hygiene(*, user_id: int | None = None, dry_run: bool = False) ->
for item in queryset.select_related("conversation", "person"): for item in queryset.select_related("conversation", "person"):
content = item.content or {} content = item.content or {}
field = str(content.get("field") or content.get("key") or "").strip().lower() field = str(content.get("field") or content.get("key") or "").strip().lower()
text = _clean_value(str(content.get("text") or content.get("value") or "")).lower() text = _clean_value(
str(content.get("text") or content.get("value") or "")
).lower()
if not field or not text: if not field or not text:
continue continue
scope = ( scope = (

View File

@@ -59,7 +59,11 @@ def retrieve_memories_for_prompt(
limit=safe_limit, limit=safe_limit,
include_statuses=statuses, include_statuses=statuses,
) )
ids = [str(hit.memory_id or "").strip() for hit in hits if str(hit.memory_id or "").strip()] ids = [
str(hit.memory_id or "").strip()
for hit in hits
if str(hit.memory_id or "").strip()
]
scoped = _base_queryset( scoped = _base_queryset(
user_id=int(user_id), user_id=int(user_id),
person_id=person_id, person_id=person_id,
@@ -82,11 +86,17 @@ def retrieve_memories_for_prompt(
"content": item.content or {}, "content": item.content or {},
"provenance": item.provenance or {}, "provenance": item.provenance or {},
"confidence_score": float(item.confidence_score or 0.0), "confidence_score": float(item.confidence_score or 0.0),
"expires_at": item.expires_at.isoformat() if item.expires_at else "", "expires_at": (
"last_verified_at": ( item.expires_at.isoformat() if item.expires_at else ""
item.last_verified_at.isoformat() if item.last_verified_at else "" ),
"last_verified_at": (
item.last_verified_at.isoformat()
if item.last_verified_at
else ""
),
"updated_at": (
item.updated_at.isoformat() if item.updated_at else ""
), ),
"updated_at": item.updated_at.isoformat() if item.updated_at else "",
"search_score": float(hit.score or 0.0), "search_score": float(hit.score or 0.0),
"search_summary": str(hit.summary or ""), "search_summary": str(hit.summary or ""),
} }

View File

@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
import json
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
@@ -144,9 +143,10 @@ class ManticoreMemorySearchBackend(BaseMemorySearchBackend):
self.base_url = str( self.base_url = str(
getattr(settings, "MANTICORE_HTTP_URL", "http://localhost:9308") getattr(settings, "MANTICORE_HTTP_URL", "http://localhost:9308")
).rstrip("/") ).rstrip("/")
self.table = str( self.table = (
getattr(settings, "MANTICORE_MEMORY_TABLE", "gia_memory_items") str(getattr(settings, "MANTICORE_MEMORY_TABLE", "gia_memory_items")).strip()
).strip() or "gia_memory_items" or "gia_memory_items"
)
self.timeout_seconds = int(getattr(settings, "MANTICORE_HTTP_TIMEOUT", 5) or 5) self.timeout_seconds = int(getattr(settings, "MANTICORE_HTTP_TIMEOUT", 5) or 5)
self._table_cache_key = f"{self.base_url}|{self.table}" self._table_cache_key = f"{self.base_url}|{self.table}"
@@ -163,7 +163,9 @@ class ManticoreMemorySearchBackend(BaseMemorySearchBackend):
return dict(payload or {}) return dict(payload or {})
def ensure_table(self) -> None: def ensure_table(self) -> None:
last_ready = float(self._table_ready_cache.get(self._table_cache_key, 0.0) or 0.0) last_ready = float(
self._table_ready_cache.get(self._table_cache_key, 0.0) or 0.0
)
if (time.time() - last_ready) <= float(self._table_ready_ttl_seconds): if (time.time() - last_ready) <= float(self._table_ready_ttl_seconds):
return return
self._sql( self._sql(
@@ -254,7 +256,9 @@ class ManticoreMemorySearchBackend(BaseMemorySearchBackend):
try: try:
values.append(self._build_upsert_values_clause(item)) values.append(self._build_upsert_values_clause(item))
except Exception as exc: except Exception as exc:
log.warning("memory-search upsert build failed id=%s err=%s", item.id, exc) log.warning(
"memory-search upsert build failed id=%s err=%s", item.id, exc
)
continue continue
if len(values) >= batch_size: if len(values) >= batch_size:
self._sql( self._sql(
@@ -290,7 +294,11 @@ class ManticoreMemorySearchBackend(BaseMemorySearchBackend):
where_parts = [f"user_id={int(user_id)}", f"MATCH('{self._escape(needle)}')"] where_parts = [f"user_id={int(user_id)}", f"MATCH('{self._escape(needle)}')"]
if conversation_id: if conversation_id:
where_parts.append(f"conversation_id='{self._escape(conversation_id)}'") where_parts.append(f"conversation_id='{self._escape(conversation_id)}'")
statuses = [str(item or "").strip() for item in include_statuses if str(item or "").strip()] statuses = [
str(item or "").strip()
for item in include_statuses
if str(item or "").strip()
]
if statuses: if statuses:
in_clause = ",".join(f"'{self._escape(item)}'" for item in statuses) in_clause = ",".join(f"'{self._escape(item)}'" for item in statuses)
where_parts.append(f"status IN ({in_clause})") where_parts.append(f"status IN ({in_clause})")

View File

@@ -1,12 +1,13 @@
from asgiref.sync import sync_to_async
from django.conf import settings
import time import time
import uuid import uuid
from asgiref.sync import sync_to_async
from django.conf import settings
from core.events.ledger import append_event from core.events.ledger import append_event
from core.messaging.utils import messages_to_string from core.messaging.utils import messages_to_string
from core.observability.tracing import ensure_trace_id
from core.models import ChatSession, Message, QueuedMessage from core.models import ChatSession, Message, QueuedMessage
from core.observability.tracing import ensure_trace_id
from core.util import logs from core.util import logs
log = logs.get_logger("history") log = logs.get_logger("history")
@@ -272,7 +273,9 @@ async def store_own_message(
trace_id=ensure_trace_id(trace_id, message_meta or {}), trace_id=ensure_trace_id(trace_id, message_meta or {}),
) )
except Exception as exc: except Exception as exc:
log.warning("Event ledger append failed for own message=%s: %s", msg.id, exc) log.warning(
"Event ledger append failed for own message=%s: %s", msg.id, exc
)
return msg return msg

View File

@@ -335,8 +335,12 @@ def extract_reply_ref(service: str, raw_payload: dict[str, Any]) -> dict[str, st
svc = _clean(service).lower() svc = _clean(service).lower()
payload = _as_dict(raw_payload) payload = _as_dict(raw_payload)
if svc == "xmpp": if svc == "xmpp":
reply_id = _clean(payload.get("reply_source_message_id") or payload.get("reply_id")) reply_id = _clean(
reply_chat = _clean(payload.get("reply_source_chat_id") or payload.get("reply_chat_id")) payload.get("reply_source_message_id") or payload.get("reply_id")
)
reply_chat = _clean(
payload.get("reply_source_chat_id") or payload.get("reply_chat_id")
)
if reply_id: if reply_id:
return { return {
"reply_source_message_id": reply_id, "reply_source_message_id": reply_id,
@@ -363,7 +367,9 @@ def extract_origin_tag(raw_payload: dict[str, Any] | None) -> str:
return _find_origin_tag(_as_dict(raw_payload)) return _find_origin_tag(_as_dict(raw_payload))
async def resolve_reply_target(user, session, reply_ref: dict[str, str]) -> Message | None: async def resolve_reply_target(
user, session, reply_ref: dict[str, str]
) -> Message | None:
if not reply_ref or session is None: if not reply_ref or session is None:
return None return None
reply_source_message_id = _clean(reply_ref.get("reply_source_message_id")) reply_source_message_id = _clean(reply_ref.get("reply_source_message_id"))

View File

@@ -1,7 +1,8 @@
# Generated by Django 5.2.11 on 2026-03-02 11:55 # Generated by Django 5.2.11 on 2026-03-02 11:55
import django.db.models.deletion
import uuid import uuid
import django.db.models.deletion
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models

View File

@@ -1,6 +1,6 @@
import django.db.models.deletion
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration): class Migration(migrations.Migration):

View File

@@ -1,8 +1,8 @@
import uuid import uuid
import django.db.models.deletion
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration): class Migration(migrations.Migration):

View File

@@ -1,8 +1,8 @@
import uuid import uuid
import django.db.models.deletion
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration): class Migration(migrations.Migration):

View File

@@ -1,8 +1,8 @@
# Generated by Django 4.2.19 on 2026-03-07 00:00 # Generated by Django 4.2.19 on 2026-03-07 00:00
import django.db.models.deletion
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration): class Migration(migrations.Migration):

View File

@@ -0,0 +1,39 @@
# Generated by Django 5.2.7 on 2026-03-07 20:12
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0041_useraccessibilitysettings'),
]
operations = [
migrations.CreateModel(
name='UserXmppOmemoTrustedKey',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('jid', models.CharField(blank=True, default='', max_length=255)),
('key_type', models.CharField(choices=[('fingerprint', 'Fingerprint'), ('client_key', 'Client key')], default='fingerprint', max_length=32)),
('key_id', models.CharField(max_length=255)),
('trusted', models.BooleanField(default=False)),
('source', models.CharField(blank=True, default='', max_length=64)),
('last_seen_at', models.DateTimeField(blank=True, null=True)),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='xmpp_omemo_trusted_keys', to=settings.AUTH_USER_MODEL)),
],
options={
'indexes': [
models.Index(fields=['user', 'trusted', 'updated_at'], name='core_userxomemo_trusted_idx'),
models.Index(fields=['user', 'jid', 'updated_at'], name='core_userxomemo_jid_idx'),
],
'constraints': [
models.UniqueConstraint(fields=('user', 'jid', 'key_type', 'key_id'), name='unique_user_xmpp_omemo_trusted_key'),
],
},
),
]

View File

@@ -0,0 +1,18 @@
# Generated by Django 5.2.7 on 2026-03-07 20:23
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("core", "0042_userxmppomemotrustedkey_and_more"),
]
operations = [
migrations.AddField(
model_name="userxmppsecuritysettings",
name="encrypt_contact_messages_with_omemo",
field=models.BooleanField(default=True),
),
]

View File

@@ -241,17 +241,13 @@ class PlatformChatLink(models.Model):
raise ValidationError("Person must belong to the same user.") raise ValidationError("Person must belong to the same user.")
if self.person_identifier_id: if self.person_identifier_id:
if self.person_identifier.user_id != self.user_id: if self.person_identifier.user_id != self.user_id:
raise ValidationError( raise ValidationError("Person identifier must belong to the same user.")
"Person identifier must belong to the same user."
)
if self.person_identifier.person_id != self.person_id: if self.person_identifier.person_id != self.person_id:
raise ValidationError( raise ValidationError(
"Person identifier must belong to the selected person." "Person identifier must belong to the selected person."
) )
if self.person_identifier.service != self.service: if self.person_identifier.service != self.service:
raise ValidationError( raise ValidationError("Chat links cannot be linked across platforms.")
"Chat links cannot be linked across platforms."
)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
value = str(self.chat_identifier or "").strip() value = str(self.chat_identifier or "").strip()
@@ -1869,9 +1865,7 @@ class PatternArtifactExport(models.Model):
class CommandProfile(models.Model): class CommandProfile(models.Model):
WINDOW_SCOPE_CHOICES = ( WINDOW_SCOPE_CHOICES = (("conversation", "Conversation"),)
("conversation", "Conversation"),
)
VISIBILITY_CHOICES = ( VISIBILITY_CHOICES = (
("status_in_source", "Status In Source"), ("status_in_source", "Status In Source"),
("silent", "Silent"), ("silent", "Silent"),
@@ -2039,7 +2033,9 @@ class BusinessPlanDocument(models.Model):
class Meta: class Meta:
indexes = [ indexes = [
models.Index(fields=["user", "status", "updated_at"]), models.Index(fields=["user", "status", "updated_at"]),
models.Index(fields=["user", "source_service", "source_channel_identifier"]), models.Index(
fields=["user", "source_service", "source_channel_identifier"]
),
] ]
@@ -2243,7 +2239,9 @@ class TranslationEventLog(models.Model):
class AnswerMemory(models.Model): class AnswerMemory(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="answer_memory") user = models.ForeignKey(
User, on_delete=models.CASCADE, related_name="answer_memory"
)
service = models.CharField(max_length=255, choices=CHANNEL_SERVICE_CHOICES) service = models.CharField(max_length=255, choices=CHANNEL_SERVICE_CHOICES)
channel_identifier = models.CharField(max_length=255) channel_identifier = models.CharField(max_length=255)
question_fingerprint = models.CharField(max_length=128) question_fingerprint = models.CharField(max_length=128)
@@ -2261,7 +2259,9 @@ class AnswerMemory(models.Model):
class Meta: class Meta:
indexes = [ indexes = [
models.Index(fields=["user", "service", "channel_identifier", "created_at"]), models.Index(
fields=["user", "service", "channel_identifier", "created_at"]
),
models.Index(fields=["user", "question_fingerprint", "created_at"]), models.Index(fields=["user", "question_fingerprint", "created_at"]),
] ]
@@ -2284,7 +2284,9 @@ class AnswerSuggestionEvent(models.Model):
on_delete=models.CASCADE, on_delete=models.CASCADE,
related_name="answer_suggestion_events", related_name="answer_suggestion_events",
) )
status = models.CharField(max_length=32, choices=STATUS_CHOICES, default="suggested") status = models.CharField(
max_length=32, choices=STATUS_CHOICES, default="suggested"
)
candidate_answer = models.ForeignKey( candidate_answer = models.ForeignKey(
AnswerMemory, AnswerMemory,
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
@@ -2305,7 +2307,9 @@ class AnswerSuggestionEvent(models.Model):
class TaskProject(models.Model): class TaskProject(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="task_projects") user = models.ForeignKey(
User, on_delete=models.CASCADE, related_name="task_projects"
)
name = models.CharField(max_length=255) name = models.CharField(max_length=255)
external_key = models.CharField(max_length=255, blank=True, default="") external_key = models.CharField(max_length=255, blank=True, default="")
active = models.BooleanField(default=True) active = models.BooleanField(default=True)
@@ -2349,7 +2353,9 @@ class TaskEpic(models.Model):
class ChatTaskSource(models.Model): class ChatTaskSource(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="chat_task_sources") user = models.ForeignKey(
User, on_delete=models.CASCADE, related_name="chat_task_sources"
)
service = models.CharField(max_length=255, choices=CHANNEL_SERVICE_CHOICES) service = models.CharField(max_length=255, choices=CHANNEL_SERVICE_CHOICES)
channel_identifier = models.CharField(max_length=255) channel_identifier = models.CharField(max_length=255)
project = models.ForeignKey( project = models.ForeignKey(
@@ -2378,7 +2384,9 @@ class ChatTaskSource(models.Model):
class DerivedTask(models.Model): class DerivedTask(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="derived_tasks") user = models.ForeignKey(
User, on_delete=models.CASCADE, related_name="derived_tasks"
)
project = models.ForeignKey( project = models.ForeignKey(
TaskProject, TaskProject,
on_delete=models.CASCADE, on_delete=models.CASCADE,
@@ -2574,7 +2582,9 @@ class ExternalSyncEvent(models.Model):
) )
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="external_sync_events") user = models.ForeignKey(
User, on_delete=models.CASCADE, related_name="external_sync_events"
)
task = models.ForeignKey( task = models.ForeignKey(
DerivedTask, DerivedTask,
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
@@ -2606,7 +2616,9 @@ class ExternalSyncEvent(models.Model):
class TaskProviderConfig(models.Model): class TaskProviderConfig(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="task_provider_configs") user = models.ForeignKey(
User, on_delete=models.CASCADE, related_name="task_provider_configs"
)
provider = models.CharField(max_length=64, default="mock") provider = models.CharField(max_length=64, default="mock")
enabled = models.BooleanField(default=False) enabled = models.BooleanField(default=False)
settings = models.JSONField(default=dict, blank=True) settings = models.JSONField(default=dict, blank=True)
@@ -2684,7 +2696,9 @@ class CodexRun(models.Model):
class Meta: class Meta:
indexes = [ indexes = [
models.Index(fields=["user", "status", "updated_at"]), models.Index(fields=["user", "status", "updated_at"]),
models.Index(fields=["user", "source_service", "source_channel", "created_at"]), models.Index(
fields=["user", "source_service", "source_channel", "created_at"]
),
] ]
@@ -2697,7 +2711,9 @@ class CodexPermissionRequest(models.Model):
) )
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="codex_permission_requests") user = models.ForeignKey(
User, on_delete=models.CASCADE, related_name="codex_permission_requests"
)
codex_run = models.ForeignKey( codex_run = models.ForeignKey(
CodexRun, CodexRun,
on_delete=models.CASCADE, on_delete=models.CASCADE,
@@ -2910,7 +2926,49 @@ class UserXmppOmemoState(models.Model):
class Meta: class Meta:
indexes = [ indexes = [
models.Index(fields=["status", "updated_at"], name="core_userxm_status_133ead_idx"), models.Index(
fields=["status", "updated_at"], name="core_userxm_status_133ead_idx"
),
]
class UserXmppOmemoTrustedKey(models.Model):
KEY_TYPE_CHOICES = (
("fingerprint", "Fingerprint"),
("client_key", "Client key"),
)
user = models.ForeignKey(
User,
on_delete=models.CASCADE,
related_name="xmpp_omemo_trusted_keys",
)
jid = models.CharField(max_length=255, blank=True, default="")
key_type = models.CharField(
max_length=32, choices=KEY_TYPE_CHOICES, default="fingerprint"
)
key_id = models.CharField(max_length=255)
trusted = models.BooleanField(default=False)
source = models.CharField(max_length=64, blank=True, default="")
last_seen_at = models.DateTimeField(blank=True, null=True)
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
constraints = [
models.UniqueConstraint(
fields=["user", "jid", "key_type", "key_id"],
name="unique_user_xmpp_omemo_trusted_key",
),
]
indexes = [
models.Index(
fields=["user", "trusted", "updated_at"],
name="core_userxomemo_trusted_idx",
),
models.Index(
fields=["user", "jid", "updated_at"], name="core_userxomemo_jid_idx"
),
] ]
@@ -2921,6 +2979,7 @@ class UserXmppSecuritySettings(models.Model):
related_name="xmpp_security_settings", related_name="xmpp_security_settings",
) )
require_omemo = models.BooleanField(default=False) require_omemo = models.BooleanField(default=False)
encrypt_contact_messages_with_omemo = models.BooleanField(default=True)
created_at = models.DateTimeField(auto_now_add=True) created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True) updated_at = models.DateTimeField(auto_now=True)
@@ -2938,7 +2997,9 @@ class UserAccessibilitySettings(models.Model):
class TaskCompletionPattern(models.Model): class TaskCompletionPattern(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="task_completion_patterns") user = models.ForeignKey(
User, on_delete=models.CASCADE, related_name="task_completion_patterns"
)
phrase = models.CharField(max_length=64) phrase = models.CharField(max_length=64)
enabled = models.BooleanField(default=True) enabled = models.BooleanField(default=True)
position = models.PositiveIntegerField(default=0) position = models.PositiveIntegerField(default=0)

View File

@@ -4,22 +4,22 @@ import re
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from core.assist.engine import process_inbound_assist
from core.clients import transport from core.clients import transport
from core.events import event_ledger_status
from core.clients.instagram import InstagramClient from core.clients.instagram import InstagramClient
from core.clients.signal import SignalClient from core.clients.signal import SignalClient
from core.clients.whatsapp import WhatsAppClient from core.clients.whatsapp import WhatsAppClient
from core.clients.xmpp import XMPPClient from core.clients.xmpp import XMPPClient
from core.assist.engine import process_inbound_assist
from core.commands.base import CommandContext from core.commands.base import CommandContext
from core.commands.engine import process_inbound_message from core.commands.engine import process_inbound_message
from core.events import event_ledger_status
from core.messaging import history from core.messaging import history
from core.models import PersonIdentifier from core.models import PersonIdentifier
from core.observability.tracing import ensure_trace_id
from core.presence import AvailabilitySignal, record_native_signal from core.presence import AvailabilitySignal, record_native_signal
from core.realtime.typing_state import set_person_typing_state from core.realtime.typing_state import set_person_typing_state
from core.translation.engine import process_inbound_translation from core.translation.engine import process_inbound_translation
from core.util import logs from core.util import logs
from core.observability.tracing import ensure_trace_id
class UnifiedRouter(object): class UnifiedRouter(object):
@@ -119,7 +119,9 @@ class UnifiedRouter(object):
return return
identifiers = await self._resolve_identifier_objects(protocol, identifier) identifiers = await self._resolve_identifier_objects(protocol, identifier)
if identifiers: if identifiers:
outgoing = str(getattr(local_message, "custom_author", "") or "").strip().upper() in { outgoing = str(
getattr(local_message, "custom_author", "") or ""
).strip().upper() in {
"USER", "USER",
"BOT", "BOT",
} }
@@ -268,7 +270,9 @@ class UnifiedRouter(object):
ts=int(read_ts or 0), ts=int(read_ts or 0),
payload={ payload={
"origin": "router.message_read", "origin": "router.message_read",
"message_timestamps": [int(v) for v in list(timestamps or []) if str(v).isdigit()], "message_timestamps": [
int(v) for v in list(timestamps or []) if str(v).isdigit()
],
"read_by": str(read_by or row.identifier), "read_by": str(read_by or row.identifier),
}, },
) )

View File

@@ -12,9 +12,15 @@ from core.models import (
PersonIdentifier, PersonIdentifier,
User, User,
) )
from .inference import fade_confidence, now_ms, should_fade from .inference import fade_confidence, now_ms, should_fade
POSITIVE_SOURCE_KINDS = {"native_presence", "read_receipt", "typing_start", "message_in"} POSITIVE_SOURCE_KINDS = {
"native_presence",
"read_receipt",
"typing_start",
"message_in",
}
@dataclass(slots=True) @dataclass(slots=True)
@@ -99,7 +105,8 @@ def record_native_signal(signal: AvailabilitySignal) -> ContactAvailabilityEvent
person_identifier=signal.person_identifier, person_identifier=signal.person_identifier,
service=str(signal.service or "").strip().lower() or "signal", service=str(signal.service or "").strip().lower() or "signal",
source_kind=str(signal.source_kind or "").strip() or "native_presence", source_kind=str(signal.source_kind or "").strip() or "native_presence",
availability_state=str(signal.availability_state or "unknown").strip() or "unknown", availability_state=str(signal.availability_state or "unknown").strip()
or "unknown",
confidence=float(signal.confidence or 0.0), confidence=float(signal.confidence or 0.0),
ts=_normalize_ts(signal.ts), ts=_normalize_ts(signal.ts),
payload=dict(signal.payload or {}), payload=dict(signal.payload or {}),
@@ -109,7 +116,9 @@ def record_native_signal(signal: AvailabilitySignal) -> ContactAvailabilityEvent
return event return event
def record_inferred_signal(signal: AvailabilitySignal) -> ContactAvailabilityEvent | None: def record_inferred_signal(
signal: AvailabilitySignal,
) -> ContactAvailabilityEvent | None:
settings_row = get_settings(signal.user) settings_row = get_settings(signal.user)
if not settings_row.enabled or not settings_row.inference_enabled: if not settings_row.enabled or not settings_row.inference_enabled:
return None return None
@@ -151,7 +160,9 @@ def ensure_fading_state(
return None return None
if latest.source_kind not in POSITIVE_SOURCE_KINDS: if latest.source_kind not in POSITIVE_SOURCE_KINDS:
return None return None
if not should_fade(int(latest.ts or 0), current_ts, settings_row.fade_threshold_seconds): if not should_fade(
int(latest.ts or 0), current_ts, settings_row.fade_threshold_seconds
):
return None return None
elapsed = max(0, current_ts - int(latest.ts or 0)) elapsed = max(0, current_ts - int(latest.ts or 0))

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from django.db.models import Q from django.db.models import Q
from core.models import ContactAvailabilityEvent, ContactAvailabilitySpan, Person, User from core.models import ContactAvailabilityEvent, ContactAvailabilitySpan, Person, User
from .engine import ensure_fading_state from .engine import ensure_fading_state
from .inference import now_ms from .inference import now_ms
@@ -19,9 +20,7 @@ def spans_for_range(
qs = ContactAvailabilitySpan.objects.filter( qs = ContactAvailabilitySpan.objects.filter(
user=user, user=user,
person=person, person=person,
).filter( ).filter(Q(start_ts__lte=end_ts) & Q(end_ts__gte=start_ts))
Q(start_ts__lte=end_ts) & Q(end_ts__gte=start_ts)
)
if service: if service:
qs = qs.filter(service=str(service).strip().lower()) qs = qs.filter(service=str(service).strip().lower())

View File

@@ -1,2 +1 @@
"""Security helpers shared across transport adapters.""" """Security helpers shared across transport adapters."""

View File

@@ -101,7 +101,9 @@ def validate_attachment_metadata(
raise ValueError(f"blocked_mime_type:{normalized_type}") raise ValueError(f"blocked_mime_type:{normalized_type}")
allow_unmatched = bool(getattr(settings, "ATTACHMENT_ALLOW_UNKNOWN_MIME", False)) allow_unmatched = bool(getattr(settings, "ATTACHMENT_ALLOW_UNKNOWN_MIME", False))
if not any(fnmatch(normalized_type, pattern) for pattern in _allowed_mime_patterns()): if not any(
fnmatch(normalized_type, pattern) for pattern in _allowed_mime_patterns()
):
if not allow_unmatched: if not allow_unmatched:
raise ValueError(f"unsupported_mime_type:{normalized_type}") raise ValueError(f"unsupported_mime_type:{normalized_type}")

View File

@@ -68,15 +68,13 @@ def _omemo_facts(ctx: CommandSecurityContext) -> tuple[str, str]:
message_meta = dict(ctx.message_meta or {}) message_meta = dict(ctx.message_meta or {})
payload = dict(ctx.payload or {}) payload = dict(ctx.payload or {})
xmpp_meta = dict(message_meta.get("xmpp") or {}) xmpp_meta = dict(message_meta.get("xmpp") or {})
status = str( status = (
xmpp_meta.get("omemo_status") str(xmpp_meta.get("omemo_status") or payload.get("omemo_status") or "")
or payload.get("omemo_status") .strip()
or "" .lower()
).strip().lower() )
client_key = str( client_key = str(
xmpp_meta.get("omemo_client_key") xmpp_meta.get("omemo_client_key") or payload.get("omemo_client_key") or ""
or payload.get("omemo_client_key")
or ""
).strip() ).strip()
return status, client_key return status, client_key
@@ -160,7 +158,8 @@ def evaluate_command_policy(
service = _normalize_service(context.service) service = _normalize_service(context.service)
channel = _normalize_channel(context.channel_identifier) channel = _normalize_channel(context.channel_identifier)
allowed_services = [ allowed_services = [
item.lower() for item in _normalize_list(getattr(policy, "allowed_services", [])) item.lower()
for item in _normalize_list(getattr(policy, "allowed_services", []))
] ]
global_allowed_services = [ global_allowed_services = [
item.lower() item.lower()

View File

@@ -83,7 +83,9 @@ def ensure_default_source_for_chat(
message=None, message=None,
): ):
service_key = str(service or "").strip().lower() service_key = str(service or "").strip().lower()
normalized_identifier = normalize_channel_identifier(service_key, channel_identifier) normalized_identifier = normalize_channel_identifier(
service_key, channel_identifier
)
variants = channel_variants(service_key, normalized_identifier) variants = channel_variants(service_key, normalized_identifier)
if not service_key or not variants: if not service_key or not variants:
return None return None

View File

@@ -72,9 +72,13 @@ def queue_codex_event_with_pre_approval(
}, },
) )
cfg = TaskProviderConfig.objects.filter(user=user, provider=provider, enabled=True).first() cfg = TaskProviderConfig.objects.filter(
user=user, provider=provider, enabled=True
).first()
settings_payload = dict(getattr(cfg, "settings", {}) or {}) settings_payload = dict(getattr(cfg, "settings", {}) or {})
approver_service = str(settings_payload.get("approver_service") or "").strip().lower() approver_service = (
str(settings_payload.get("approver_service") or "").strip().lower()
)
approver_identifier = str(settings_payload.get("approver_identifier") or "").strip() approver_identifier = str(settings_payload.get("approver_identifier") or "").strip()
if approver_service and approver_identifier: if approver_service and approver_identifier:
try: try:

View File

@@ -57,7 +57,9 @@ def resolve_external_chat_id(*, user, provider: str, service: str, channel: str)
provider=provider, provider=provider,
enabled=True, enabled=True,
) )
.filter(Q(person_identifier=person_identifier) | Q(person=person_identifier.person)) .filter(
Q(person_identifier=person_identifier) | Q(person=person_identifier.person)
)
.order_by("-updated_at", "-id") .order_by("-updated_at", "-id")
.first() .first()
) )

View File

@@ -22,16 +22,23 @@ from core.models import (
TaskEpic, TaskEpic,
TaskProviderConfig, TaskProviderConfig,
) )
from core.tasks.chat_defaults import ensure_default_source_for_chat, resolve_message_scope
from core.tasks.codex_approval import queue_codex_event_with_pre_approval
from core.tasks.providers import get_provider
from core.tasks.codex_support import resolve_external_chat_id
from core.security.command_policy import CommandSecurityContext, evaluate_command_policy from core.security.command_policy import CommandSecurityContext, evaluate_command_policy
from core.tasks.chat_defaults import (
ensure_default_source_for_chat,
resolve_message_scope,
)
from core.tasks.codex_approval import queue_codex_event_with_pre_approval
from core.tasks.codex_support import resolve_external_chat_id
from core.tasks.providers import get_provider
_TASK_HINT_RE = re.compile(r"\b(todo|task|action|need to|please)\b", re.IGNORECASE) _TASK_HINT_RE = re.compile(r"\b(todo|task|action|need to|please)\b", re.IGNORECASE)
_COMPLETION_RE = re.compile(r"\b(done|completed|fixed)\s*#([A-Za-z0-9_-]+)\b", re.IGNORECASE) _COMPLETION_RE = re.compile(
r"\b(done|completed|fixed)\s*#([A-Za-z0-9_-]+)\b", re.IGNORECASE
)
_BALANCED_HINT_RE = re.compile(r"\b(todo|task|action item|action)\b", re.IGNORECASE) _BALANCED_HINT_RE = re.compile(r"\b(todo|task|action item|action)\b", re.IGNORECASE)
_BROAD_HINT_RE = re.compile(r"\b(todo|task|action|need to|please|reminder)\b", re.IGNORECASE) _BROAD_HINT_RE = re.compile(
r"\b(todo|task|action|need to|please|reminder)\b", re.IGNORECASE
)
_PREFIX_HEAD_TRIM = " \t\r\n`'\"([{<*#-—_>.,:;!/?\\|" _PREFIX_HEAD_TRIM = " \t\r\n`'\"([{<*#-—_>.,:;!/?\\|"
_LIST_TASKS_RE = re.compile( _LIST_TASKS_RE = re.compile(
r"^\s*(?:\.l(?:\s+list(?:\s+tasks?)?)?|\.list(?:\s+tasks?)?)\s*$", r"^\s*(?:\.l(?:\s+list(?:\s+tasks?)?)?|\.list(?:\s+tasks?)?)\s*$",
@@ -151,15 +158,23 @@ async def _resolve_source_mappings(message: Message) -> list[ChatTaskSource]:
lookup_service = str(message.source_service or "").strip().lower() lookup_service = str(message.source_service or "").strip().lower()
variants = _channel_variants(lookup_service, message.source_chat_id or "") variants = _channel_variants(lookup_service, message.source_chat_id or "")
session_identifier = getattr(getattr(message, "session", None), "identifier", None) session_identifier = getattr(getattr(message, "session", None), "identifier", None)
canonical_service = str(getattr(session_identifier, "service", "") or "").strip().lower() canonical_service = (
canonical_identifier = str(getattr(session_identifier, "identifier", "") or "").strip() str(getattr(session_identifier, "service", "") or "").strip().lower()
)
canonical_identifier = str(
getattr(session_identifier, "identifier", "") or ""
).strip()
if lookup_service == "web" and canonical_service and canonical_service != "web": if lookup_service == "web" and canonical_service and canonical_service != "web":
lookup_service = canonical_service lookup_service = canonical_service
variants = _channel_variants(lookup_service, message.source_chat_id or "") variants = _channel_variants(lookup_service, message.source_chat_id or "")
for expanded in _channel_variants(lookup_service, canonical_identifier): for expanded in _channel_variants(lookup_service, canonical_identifier):
if expanded and expanded not in variants: if expanded and expanded not in variants:
variants.append(expanded) variants.append(expanded)
elif canonical_service and canonical_identifier and canonical_service == lookup_service: elif (
canonical_service
and canonical_identifier
and canonical_service == lookup_service
):
for expanded in _channel_variants(canonical_service, canonical_identifier): for expanded in _channel_variants(canonical_service, canonical_identifier):
if expanded and expanded not in variants: if expanded and expanded not in variants:
variants.append(expanded) variants.append(expanded)
@@ -170,10 +185,14 @@ async def _resolve_source_mappings(message: Message) -> list[ChatTaskSource]:
if not signal_value: if not signal_value:
continue continue
companions += await sync_to_async(list)( companions += await sync_to_async(list)(
Chat.objects.filter(source_uuid=signal_value).values_list("source_number", flat=True) Chat.objects.filter(source_uuid=signal_value).values_list(
"source_number", flat=True
)
) )
companions += await sync_to_async(list)( companions += await sync_to_async(list)(
Chat.objects.filter(source_number=signal_value).values_list("source_uuid", flat=True) Chat.objects.filter(source_number=signal_value).values_list(
"source_uuid", flat=True
)
) )
for candidate in companions: for candidate in companions:
for expanded in _channel_variants("signal", str(candidate or "").strip()): for expanded in _channel_variants("signal", str(candidate or "").strip()):
@@ -271,7 +290,8 @@ def _normalize_flags(raw: dict | None) -> dict:
row = dict(raw or {}) row = dict(raw or {})
return { return {
"derive_enabled": _to_bool(row.get("derive_enabled"), True), "derive_enabled": _to_bool(row.get("derive_enabled"), True),
"match_mode": str(row.get("match_mode") or "balanced").strip().lower() or "balanced", "match_mode": str(row.get("match_mode") or "balanced").strip().lower()
or "balanced",
"require_prefix": _to_bool(row.get("require_prefix"), False), "require_prefix": _to_bool(row.get("require_prefix"), False),
"allowed_prefixes": _parse_prefixes(row.get("allowed_prefixes")), "allowed_prefixes": _parse_prefixes(row.get("allowed_prefixes")),
"completion_enabled": _to_bool(row.get("completion_enabled"), True), "completion_enabled": _to_bool(row.get("completion_enabled"), True),
@@ -287,7 +307,9 @@ def _normalize_partial_flags(raw: dict | None) -> dict:
if "derive_enabled" in row: if "derive_enabled" in row:
out["derive_enabled"] = _to_bool(row.get("derive_enabled"), True) out["derive_enabled"] = _to_bool(row.get("derive_enabled"), True)
if "match_mode" in row: if "match_mode" in row:
out["match_mode"] = str(row.get("match_mode") or "balanced").strip().lower() or "balanced" out["match_mode"] = (
str(row.get("match_mode") or "balanced").strip().lower() or "balanced"
)
if "require_prefix" in row: if "require_prefix" in row:
out["require_prefix"] = _to_bool(row.get("require_prefix"), False) out["require_prefix"] = _to_bool(row.get("require_prefix"), False)
if "allowed_prefixes" in row: if "allowed_prefixes" in row:
@@ -304,7 +326,9 @@ def _normalize_partial_flags(raw: dict | None) -> dict:
def _effective_flags(source: ChatTaskSource) -> dict: def _effective_flags(source: ChatTaskSource) -> dict:
project_flags = _normalize_flags(getattr(getattr(source, "project", None), "settings", {}) or {}) project_flags = _normalize_flags(
getattr(getattr(source, "project", None), "settings", {}) or {}
)
source_flags = _normalize_partial_flags(getattr(source, "settings", {}) or {}) source_flags = _normalize_partial_flags(getattr(source, "settings", {}) or {})
merged = dict(project_flags) merged = dict(project_flags)
merged.update(source_flags) merged.update(source_flags)
@@ -360,7 +384,10 @@ async def _derive_title(message: Message) -> str:
{"role": "user", "content": text[:2000]}, {"role": "user", "content": text[:2000]},
] ]
try: try:
title = str(await ai_runner.run_prompt(prompt, ai_obj, operation="task_derive_title") or "").strip() title = str(
await ai_runner.run_prompt(prompt, ai_obj, operation="task_derive_title")
or ""
).strip()
except Exception: except Exception:
title = "" title = ""
return (title or text)[:255] return (title or text)[:255]
@@ -376,9 +403,13 @@ async def _derive_title_with_flags(message: Message, flags: dict) -> str:
return (cleaned or title or "Untitled task")[:255] return (cleaned or title or "Untitled task")[:255]
async def _emit_sync_event(task: DerivedTask, event: DerivedTaskEvent, action: str) -> None: async def _emit_sync_event(
task: DerivedTask, event: DerivedTaskEvent, action: str
) -> None:
cfg = await sync_to_async( cfg = await sync_to_async(
lambda: TaskProviderConfig.objects.filter(user=task.user, enabled=True).order_by("provider").first() lambda: TaskProviderConfig.objects.filter(user=task.user, enabled=True)
.order_by("provider")
.first()
)() )()
provider_name = str(getattr(cfg, "provider", "mock") or "mock") provider_name = str(getattr(cfg, "provider", "mock") or "mock")
provider_settings = dict(getattr(cfg, "settings", {}) or {}) provider_settings = dict(getattr(cfg, "settings", {}) or {})
@@ -416,7 +447,11 @@ async def _emit_sync_event(task: DerivedTask, event: DerivedTaskEvent, action: s
"source_channel": str(task.source_channel or ""), "source_channel": str(task.source_channel or ""),
"external_chat_id": external_chat_id, "external_chat_id": external_chat_id,
"origin_message_id": str(getattr(task, "origin_message_id", "") or ""), "origin_message_id": str(getattr(task, "origin_message_id", "") or ""),
"trigger_message_id": str(getattr(event, "source_message_id", "") or getattr(task, "origin_message_id", "") or ""), "trigger_message_id": str(
getattr(event, "source_message_id", "")
or getattr(task, "origin_message_id", "")
or ""
),
"mode": "default", "mode": "default",
"payload": event.payload, "payload": event.payload,
"memory_context": memory_context, "memory_context": memory_context,
@@ -495,7 +530,9 @@ async def _emit_sync_event(task: DerivedTask, event: DerivedTaskEvent, action: s
codex_run.status = status codex_run.status = status
codex_run.result_payload = dict(result.payload or {}) codex_run.result_payload = dict(result.payload or {})
codex_run.error = str(result.error or "") codex_run.error = str(result.error or "")
await sync_to_async(codex_run.save)(update_fields=["status", "result_payload", "error", "updated_at"]) await sync_to_async(codex_run.save)(
update_fields=["status", "result_payload", "error", "updated_at"]
)
if result.ok and result.external_key and not task.external_key: if result.ok and result.external_key and not task.external_key:
task.external_key = str(result.external_key) task.external_key = str(result.external_key)
await sync_to_async(task.save)(update_fields=["external_key"]) await sync_to_async(task.save)(update_fields=["external_key"])
@@ -503,15 +540,28 @@ async def _emit_sync_event(task: DerivedTask, event: DerivedTaskEvent, action: s
async def _completion_regex(message: Message) -> re.Pattern: async def _completion_regex(message: Message) -> re.Pattern:
patterns = await sync_to_async(list)( patterns = await sync_to_async(list)(
TaskCompletionPattern.objects.filter(user=message.user, enabled=True).order_by("position", "created_at") TaskCompletionPattern.objects.filter(user=message.user, enabled=True).order_by(
"position", "created_at"
) )
phrases = [str(row.phrase or "").strip() for row in patterns if str(row.phrase or "").strip()] )
phrases = [
str(row.phrase or "").strip()
for row in patterns
if str(row.phrase or "").strip()
]
if not phrases: if not phrases:
phrases = ["done", "completed", "fixed"] phrases = ["done", "completed", "fixed"]
return re.compile(r"\\b(?:" + "|".join(re.escape(p) for p in phrases) + r")\\s*#([A-Za-z0-9_-]+)\\b", re.IGNORECASE) return re.compile(
r"\\b(?:"
+ "|".join(re.escape(p) for p in phrases)
+ r")\\s*#([A-Za-z0-9_-]+)\\b",
re.IGNORECASE,
)
async def _send_scope_message(source: ChatTaskSource, message: Message, text: str) -> None: async def _send_scope_message(
source: ChatTaskSource, message: Message, text: str
) -> None:
await send_message_raw( await send_message_raw(
source.service or message.source_service or "web", source.service or message.source_service or "web",
source.channel_identifier or message.source_chat_id or "", source.channel_identifier or message.source_chat_id or "",
@@ -521,7 +571,9 @@ async def _send_scope_message(source: ChatTaskSource, message: Message, text: st
) )
async def _handle_scope_task_commands(message: Message, sources: list[ChatTaskSource], text: str) -> bool: async def _handle_scope_task_commands(
message: Message, sources: list[ChatTaskSource], text: str
) -> bool:
if not sources: if not sources:
return False return False
body = str(text or "").strip() body = str(text or "").strip()
@@ -538,7 +590,9 @@ async def _handle_scope_task_commands(message: Message, sources: list[ChatTaskSo
.order_by("-created_at")[:20] .order_by("-created_at")[:20]
) )
if not open_rows: if not open_rows:
await _send_scope_message(source, message, "[task] no open tasks in this chat.") await _send_scope_message(
source, message, "[task] no open tasks in this chat."
)
return True return True
lines = ["[task] open tasks:"] lines = ["[task] open tasks:"]
for row in open_rows: for row in open_rows:
@@ -573,7 +627,9 @@ async def _handle_scope_task_commands(message: Message, sources: list[ChatTaskSo
.first() .first()
)() )()
if task is None: if task is None:
await _send_scope_message(source, message, "[task] nothing to undo in this chat.") await _send_scope_message(
source, message, "[task] nothing to undo in this chat."
)
return True return True
ref = str(task.reference_code or "") ref = str(task.reference_code or "")
title = str(task.title or "") title = str(task.title or "")
@@ -596,10 +652,16 @@ async def _handle_scope_task_commands(message: Message, sources: list[ChatTaskSo
.first() .first()
)() )()
if task is None: if task is None:
await _send_scope_message(source, message, f"[task] #{reference} not found.") await _send_scope_message(
source, message, f"[task] #{reference} not found."
)
return True return True
due_str = f"\ndue: {task.due_date}" if task.due_date else "" due_str = f"\ndue: {task.due_date}" if task.due_date else ""
assignee_str = f"\nassignee: {task.assignee_identifier}" if task.assignee_identifier else "" assignee_str = (
f"\nassignee: {task.assignee_identifier}"
if task.assignee_identifier
else ""
)
detail = ( detail = (
f"[task] #{task.reference_code}: {task.title}" f"[task] #{task.reference_code}: {task.title}"
f"\nstatus: {task.status_snapshot}" f"\nstatus: {task.status_snapshot}"
@@ -624,7 +686,9 @@ async def _handle_scope_task_commands(message: Message, sources: list[ChatTaskSo
.first() .first()
)() )()
if task is None: if task is None:
await _send_scope_message(source, message, f"[task] #{reference} not found.") await _send_scope_message(
source, message, f"[task] #{reference} not found."
)
return True return True
task.status_snapshot = "completed" task.status_snapshot = "completed"
await sync_to_async(task.save)(update_fields=["status_snapshot"]) await sync_to_async(task.save)(update_fields=["status_snapshot"])
@@ -633,10 +697,16 @@ async def _handle_scope_task_commands(message: Message, sources: list[ChatTaskSo
event_type="completion_marked", event_type="completion_marked",
actor_identifier=str(message.sender_uuid or ""), actor_identifier=str(message.sender_uuid or ""),
source_message=message, source_message=message,
payload={"marker": reference, "command": ".task complete", "via": "chat_command"}, payload={
"marker": reference,
"command": ".task complete",
"via": "chat_command",
},
) )
await _emit_sync_event(task, event, "complete") await _emit_sync_event(task, event, "complete")
await _send_scope_message(source, message, f"[task] completed #{task.reference_code}: {task.title}") await _send_scope_message(
source, message, f"[task] completed #{task.reference_code}: {task.title}"
)
return True return True
return False return False
@@ -656,7 +726,9 @@ def _strip_epic_token(text: str) -> str:
return re.sub(r"\s{2,}", " ", cleaned).strip() return re.sub(r"\s{2,}", " ", cleaned).strip()
async def _handle_epic_create_command(message: Message, sources: list[ChatTaskSource], text: str) -> bool: async def _handle_epic_create_command(
message: Message, sources: list[ChatTaskSource], text: str
) -> bool:
match = _EPIC_CREATE_RE.match(str(text or "")) match = _EPIC_CREATE_RE.match(str(text or ""))
if not match or not sources: if not match or not sources:
return False return False
@@ -766,13 +838,21 @@ async def process_inbound_task_intelligence(message: Message) -> None:
if not submit_decision.allowed: if not submit_decision.allowed:
return return
completion_allowed = any(bool(_effective_flags(source).get("completion_enabled")) for source in sources) completion_allowed = any(
bool(_effective_flags(source).get("completion_enabled")) for source in sources
)
completion_rx = await _completion_regex(message) if completion_allowed else None completion_rx = await _completion_regex(message) if completion_allowed else None
marker_match = (completion_rx.search(text) if completion_rx else None) or (_COMPLETION_RE.search(text) if completion_allowed else None) marker_match = (completion_rx.search(text) if completion_rx else None) or (
_COMPLETION_RE.search(text) if completion_allowed else None
)
if marker_match: if marker_match:
ref_code = str(marker_match.group(marker_match.lastindex or 1) or "").strip() ref_code = str(marker_match.group(marker_match.lastindex or 1) or "").strip()
task = await sync_to_async( task = await sync_to_async(
lambda: DerivedTask.objects.filter(user=message.user, reference_code=ref_code).order_by("-created_at").first() lambda: DerivedTask.objects.filter(
user=message.user, reference_code=ref_code
)
.order_by("-created_at")
.first()
)() )()
if not task: if not task:
# parser warning event attached to a newly derived placeholder in mapped project # parser warning event attached to a newly derived placeholder in mapped project
@@ -848,7 +928,11 @@ async def process_inbound_task_intelligence(message: Message) -> None:
status_snapshot="open", status_snapshot="open",
due_date=parsed_due_date, due_date=parsed_due_date,
assignee_identifier=parsed_assignee, assignee_identifier=parsed_assignee,
immutable_payload={"origin_text": text, "task_text": task_text, "flags": flags}, immutable_payload={
"origin_text": text,
"task_text": task_text,
"flags": flags,
},
) )
event = await sync_to_async(DerivedTaskEvent.objects.create)( event = await sync_to_async(DerivedTaskEvent.objects.create)(
task=task, task=task,

View File

@@ -40,13 +40,21 @@ class ClaudeCLITaskProvider(TaskProvider):
return True return True
if "unrecognized subcommand 'create'" in text and "usage: claude" in text: if "unrecognized subcommand 'create'" in text and "usage: claude" in text:
return True return True
if "unrecognized subcommand 'append_update'" in text and "usage: claude" in text: if (
"unrecognized subcommand 'append_update'" in text
and "usage: claude" in text
):
return True return True
if "unrecognized subcommand 'mark_complete'" in text and "usage: claude" in text: if (
"unrecognized subcommand 'mark_complete'" in text
and "usage: claude" in text
):
return True return True
return False return False
def _builtin_stub_result(self, op: str, payload: dict, stderr: str) -> ProviderResult: def _builtin_stub_result(
self, op: str, payload: dict, stderr: str
) -> ProviderResult:
mode = str(payload.get("mode") or "default").strip().lower() mode = str(payload.get("mode") or "default").strip().lower()
external_key = ( external_key = (
str(payload.get("external_key") or "").strip() str(payload.get("external_key") or "").strip()
@@ -117,7 +125,10 @@ class ClaudeCLITaskProvider(TaskProvider):
cwd=workspace if workspace else None, cwd=workspace if workspace else None,
) )
stderr_probe = str(completed.stderr or "").lower() stderr_probe = str(completed.stderr or "").lower()
if completed.returncode != 0 and "unexpected argument '--op'" in stderr_probe: if (
completed.returncode != 0
and "unexpected argument '--op'" in stderr_probe
):
completed = subprocess.run( completed = subprocess.run(
fallback_cmd, fallback_cmd,
capture_output=True, capture_output=True,
@@ -133,7 +144,9 @@ class ClaudeCLITaskProvider(TaskProvider):
payload={"op": op, "timeout_seconds": command_timeout}, payload={"op": op, "timeout_seconds": command_timeout},
) )
except Exception as exc: except Exception as exc:
return ProviderResult(ok=False, error=f"claude_cli_exec_error:{exc}", payload={"op": op}) return ProviderResult(
ok=False, error=f"claude_cli_exec_error:{exc}", payload={"op": op}
)
stdout = str(completed.stdout or "").strip() stdout = str(completed.stdout or "").strip()
stderr = str(completed.stderr or "").strip() stderr = str(completed.stderr or "").strip()
@@ -172,7 +185,12 @@ class ClaudeCLITaskProvider(TaskProvider):
out_payload.update(parsed) out_payload.update(parsed)
if (not ok) and self._is_task_sync_contract_mismatch(stderr): if (not ok) and self._is_task_sync_contract_mismatch(stderr):
return self._builtin_stub_result(op, dict(payload or {}), stderr) return self._builtin_stub_result(op, dict(payload or {}), stderr)
return ProviderResult(ok=ok, external_key=ext, error=("" if ok else stderr[:4000]), payload=out_payload) return ProviderResult(
ok=ok,
external_key=ext,
error=("" if ok else stderr[:4000]),
payload=out_payload,
)
def healthcheck(self, config: dict) -> ProviderResult: def healthcheck(self, config: dict) -> ProviderResult:
command = self._command(config) command = self._command(config)
@@ -193,7 +211,11 @@ class ClaudeCLITaskProvider(TaskProvider):
"stdout": str(completed.stdout or "").strip()[:1000], "stdout": str(completed.stdout or "").strip()[:1000],
"stderr": str(completed.stderr or "").strip()[:1000], "stderr": str(completed.stderr or "").strip()[:1000],
}, },
error=("" if completed.returncode == 0 else str(completed.stderr or "").strip()[:1000]), error=(
""
if completed.returncode == 0
else str(completed.stderr or "").strip()[:1000]
),
) )
def create_task(self, config: dict, payload: dict) -> ProviderResult: def create_task(self, config: dict, payload: dict) -> ProviderResult:

View File

@@ -46,7 +46,9 @@ class CodexCLITaskProvider(TaskProvider):
return True return True
return False return False
def _builtin_stub_result(self, op: str, payload: dict, stderr: str) -> ProviderResult: def _builtin_stub_result(
self, op: str, payload: dict, stderr: str
) -> ProviderResult:
mode = str(payload.get("mode") or "default").strip().lower() mode = str(payload.get("mode") or "default").strip().lower()
external_key = ( external_key = (
str(payload.get("external_key") or "").strip() str(payload.get("external_key") or "").strip()
@@ -117,7 +119,10 @@ class CodexCLITaskProvider(TaskProvider):
cwd=workspace if workspace else None, cwd=workspace if workspace else None,
) )
stderr_probe = str(completed.stderr or "").lower() stderr_probe = str(completed.stderr or "").lower()
if completed.returncode != 0 and "unexpected argument '--op'" in stderr_probe: if (
completed.returncode != 0
and "unexpected argument '--op'" in stderr_probe
):
completed = subprocess.run( completed = subprocess.run(
fallback_cmd, fallback_cmd,
capture_output=True, capture_output=True,
@@ -133,7 +138,9 @@ class CodexCLITaskProvider(TaskProvider):
payload={"op": op, "timeout_seconds": command_timeout}, payload={"op": op, "timeout_seconds": command_timeout},
) )
except Exception as exc: except Exception as exc:
return ProviderResult(ok=False, error=f"codex_cli_exec_error:{exc}", payload={"op": op}) return ProviderResult(
ok=False, error=f"codex_cli_exec_error:{exc}", payload={"op": op}
)
stdout = str(completed.stdout or "").strip() stdout = str(completed.stdout or "").strip()
stderr = str(completed.stderr or "").strip() stderr = str(completed.stderr or "").strip()
@@ -172,7 +179,12 @@ class CodexCLITaskProvider(TaskProvider):
out_payload.update(parsed) out_payload.update(parsed)
if (not ok) and self._is_task_sync_contract_mismatch(stderr): if (not ok) and self._is_task_sync_contract_mismatch(stderr):
return self._builtin_stub_result(op, dict(payload or {}), stderr) return self._builtin_stub_result(op, dict(payload or {}), stderr)
return ProviderResult(ok=ok, external_key=ext, error=("" if ok else stderr[:4000]), payload=out_payload) return ProviderResult(
ok=ok,
external_key=ext,
error=("" if ok else stderr[:4000]),
payload=out_payload,
)
def healthcheck(self, config: dict) -> ProviderResult: def healthcheck(self, config: dict) -> ProviderResult:
command = self._command(config) command = self._command(config)
@@ -193,7 +205,11 @@ class CodexCLITaskProvider(TaskProvider):
"stdout": str(completed.stdout or "").strip()[:1000], "stdout": str(completed.stdout or "").strip()[:1000],
"stderr": str(completed.stderr or "").strip()[:1000], "stderr": str(completed.stderr or "").strip()[:1000],
}, },
error=("" if completed.returncode == 0 else str(completed.stderr or "").strip()[:1000]), error=(
""
if completed.returncode == 0
else str(completed.stderr or "").strip()[:1000]
),
) )
def create_task(self, config: dict, payload: dict) -> ProviderResult: def create_task(self, config: dict, payload: dict) -> ProviderResult:

View File

@@ -12,14 +12,30 @@ class MockTaskProvider(TaskProvider):
return ProviderResult(ok=True, payload={"provider": self.name}) return ProviderResult(ok=True, payload={"provider": self.name})
def create_task(self, config: dict, payload: dict) -> ProviderResult: def create_task(self, config: dict, payload: dict) -> ProviderResult:
ext = str(payload.get("external_key") or "") or f"mock-{int(time.time() * 1000)}" ext = (
return ProviderResult(ok=True, external_key=ext, payload={"action": "create_task"}) str(payload.get("external_key") or "") or f"mock-{int(time.time() * 1000)}"
)
return ProviderResult(
ok=True, external_key=ext, payload={"action": "create_task"}
)
def append_update(self, config: dict, payload: dict) -> ProviderResult: def append_update(self, config: dict, payload: dict) -> ProviderResult:
return ProviderResult(ok=True, external_key=str(payload.get("external_key") or ""), payload={"action": "append_update"}) return ProviderResult(
ok=True,
external_key=str(payload.get("external_key") or ""),
payload={"action": "append_update"},
)
def mark_complete(self, config: dict, payload: dict) -> ProviderResult: def mark_complete(self, config: dict, payload: dict) -> ProviderResult:
return ProviderResult(ok=True, external_key=str(payload.get("external_key") or ""), payload={"action": "mark_complete"}) return ProviderResult(
ok=True,
external_key=str(payload.get("external_key") or ""),
payload={"action": "mark_complete"},
)
def link_task(self, config: dict, payload: dict) -> ProviderResult: def link_task(self, config: dict, payload: dict) -> ProviderResult:
return ProviderResult(ok=True, external_key=str(payload.get("external_key") or ""), payload={"action": "link_task"}) return ProviderResult(
ok=True,
external_key=str(payload.get("external_key") or ""),
payload={"action": "link_task"},
)

View File

@@ -342,7 +342,7 @@
hx-trigger="click" hx-trigger="click"
hx-swap="innerHTML"> hx-swap="innerHTML">
<span class="icon is-small"><i class="fa-solid fa-paper-plane"></i></span> <span class="icon is-small"><i class="fa-solid fa-paper-plane"></i></span>
<span style="margin-left: 0.35rem;">Message</span> <span style="margin-left: 0.35rem;">Compose</span>
</a> </a>
<div class="navbar-dropdown" id="nav-compose-contacts"> <div class="navbar-dropdown" id="nav-compose-contacts">
<a <a
@@ -350,55 +350,20 @@
hx-get="{% url 'compose_contacts_dropdown' %}?all=1" hx-get="{% url 'compose_contacts_dropdown' %}?all=1"
hx-target="#nav-compose-contacts" hx-target="#nav-compose-contacts"
hx-swap="innerHTML"> hx-swap="innerHTML">
Fetch Contacts Open Contacts
</a> </a>
</div> </div>
</div> </div>
<a class="navbar-item" href="{% url 'tasks_hub' %}"> <a class="navbar-item" href="{% url 'tasks_hub' %}">
Tasks Task Inbox
</a> </a>
<a class="navbar-item" href="{% url 'ai_workspace' %}"> <a class="navbar-item" href="{% url 'ai_workspace' %}">
AI AI
</a> </a>
<div class="navbar-item has-dropdown is-hoverable">
<a class="navbar-link">
Security
</a>
<div class="navbar-dropdown">
<a
class="navbar-item{% if request.resolver_match.url_name == 'encryption_settings' or request.resolver_match.url_name == 'security_settings' %} is-current-route{% endif %}"
href="{% url 'encryption_settings' %}"
>
Encryption
</a>
<a
class="navbar-item{% if request.resolver_match.url_name == 'permission_settings' %} is-current-route{% endif %}"
href="{% url 'permission_settings' %}"
>
Permission
</a>
<a
class="navbar-item{% if request.resolver_match.url_name == 'security_2fa' or request.resolver_match.namespace == 'two_factor' %} is-current-route{% endif %}"
href="{% url 'security_2fa' %}"
>
2FA
</a>
</div>
</div>
<a class="navbar-item" href="{% url 'osint_search' type='page' %}"> <a class="navbar-item" href="{% url 'osint_search' type='page' %}">
Search Search
</a> </a>
<a class="navbar-item" href="{% url 'queues' type='page' %}">
Queue
</a>
<a class="navbar-item" href="{% url 'osint_workspace' %}">
OSINT
</a>
{% endif %} {% endif %}
<a class="navbar-item add-button">
Install
</a>
</div> </div>
<div class="navbar-end"> <div class="navbar-end">
@@ -423,15 +388,15 @@
<div class="navbar-item has-dropdown is-hoverable"> <div class="navbar-item has-dropdown is-hoverable">
<a class="navbar-link"> <a class="navbar-link">
Storage Data
</a> </a>
<div class="navbar-dropdown"> <div class="navbar-dropdown">
<a class="navbar-item" href="{% url 'sessions' type='page' %}"> <a class="navbar-item" href="{% url 'sessions' type='page' %}">
Sessions Sessions
</a> </a>
<a class="navbar-item" href="{% url 'command_routing' %}#bp-documents"> <a class="navbar-item{% if request.resolver_match.url_name == 'business_plan_inbox' or request.resolver_match.url_name == 'business_plan_editor' %} is-current-route{% endif %}" href="{% url 'business_plan_inbox' %}">
Documents Business Plans
</a> </a>
</div> </div>
</div> </div>
@@ -454,6 +419,19 @@
</a> </a>
{% endif %} {% endif %}
<hr class="navbar-divider"> <hr class="navbar-divider">
<div class="navbar-item has-text-weight-semibold is-size-7 has-text-grey">
Security
</div>
<a class="navbar-item{% if request.resolver_match.url_name == 'encryption_settings' or request.resolver_match.url_name == 'security_settings' %} is-current-route{% endif %}" href="{% url 'encryption_settings' %}">
Encryption
</a>
<a class="navbar-item{% if request.resolver_match.url_name == 'permission_settings' %} is-current-route{% endif %}" href="{% url 'permission_settings' %}">
Permissions
</a>
<a class="navbar-item{% if request.resolver_match.url_name == 'security_2fa' or request.resolver_match.namespace == 'two_factor' %} is-current-route{% endif %}" href="{% url 'security_2fa' %}">
2FA
</a>
<hr class="navbar-divider">
<div class="navbar-item has-text-weight-semibold is-size-7 has-text-grey"> <div class="navbar-item has-text-weight-semibold is-size-7 has-text-grey">
AI AI
</div> </div>
@@ -470,8 +448,11 @@
<a class="navbar-item{% if request.resolver_match.url_name == 'command_routing' %} is-current-route{% endif %}" href="{% url 'command_routing' %}"> <a class="navbar-item{% if request.resolver_match.url_name == 'command_routing' %} is-current-route{% endif %}" href="{% url 'command_routing' %}">
Commands Commands
</a> </a>
<a class="navbar-item{% if request.resolver_match.url_name == 'business_plan_inbox' or request.resolver_match.url_name == 'business_plan_editor' %} is-current-route{% endif %}" href="{% url 'business_plan_inbox' %}">
Business Plans
</a>
<a class="navbar-item{% if request.resolver_match.url_name == 'tasks_settings' %} is-current-route{% endif %}" href="{% url 'tasks_settings' %}"> <a class="navbar-item{% if request.resolver_match.url_name == 'tasks_settings' %} is-current-route{% endif %}" href="{% url 'tasks_settings' %}">
Tasks Task Automation
</a> </a>
<a class="navbar-item{% if request.resolver_match.url_name == 'translation_settings' %} is-current-route{% endif %}" href="{% url 'translation_settings' %}"> <a class="navbar-item{% if request.resolver_match.url_name == 'translation_settings' %} is-current-route{% endif %}" href="{% url 'translation_settings' %}">
Translation Translation
@@ -480,6 +461,16 @@
Availability Availability
</a> </a>
<hr class="navbar-divider"> <hr class="navbar-divider">
<div class="navbar-item has-text-weight-semibold is-size-7 has-text-grey">
Automation
</div>
<a class="navbar-item{% if request.resolver_match.url_name == 'queues' %} is-current-route{% endif %}" href="{% url 'queues' type='page' %}">
Approvals Queue
</a>
<a class="navbar-item{% if request.resolver_match.url_name == 'osint_workspace' %} is-current-route{% endif %}" href="{% url 'osint_workspace' %}">
OSINT Workspace
</a>
<hr class="navbar-divider">
<a class="navbar-item{% if request.resolver_match.url_name == 'accessibility_settings' %} is-current-route{% endif %}" href="{% url 'accessibility_settings' %}"> <a class="navbar-item{% if request.resolver_match.url_name == 'accessibility_settings' %} is-current-route{% endif %}" href="{% url 'accessibility_settings' %}">
Accessibility Accessibility
</a> </a>
@@ -499,6 +490,7 @@
{% endif %} {% endif %}
{% if user.is_authenticated %} {% if user.is_authenticated %}
<button class="button is-light add-button" type="button" style="display:none;">Install App</button>
<a class="button is-dark" href="{% url 'logout' %}">Logout</a> <a class="button is-dark" href="{% url 'logout' %}">Logout</a>
{% endif %} {% endif %}
@@ -510,8 +502,13 @@
<script> <script>
let deferredPrompt; let deferredPrompt;
const addBtn = document.querySelector('.add-button'); const addBtn = document.querySelector('.add-button');
if (addBtn) {
addBtn.style.display = 'none'; addBtn.style.display = 'none';
}
window.addEventListener('beforeinstallprompt', (e) => { window.addEventListener('beforeinstallprompt', (e) => {
if (!addBtn) {
return;
}
// Prevent Chrome 67 and earlier from automatically showing the prompt // Prevent Chrome 67 and earlier from automatically showing the prompt
e.preventDefault(); e.preventDefault();
// Stash the event so it can be triggered later. // Stash the event so it can be triggered later.

View File

@@ -28,7 +28,7 @@
<textarea class="textarea" name="content_markdown" rows="18">{{ document.content_markdown }}</textarea> <textarea class="textarea" name="content_markdown" rows="18">{{ document.content_markdown }}</textarea>
<div class="buttons" style="margin-top: 0.75rem;"> <div class="buttons" style="margin-top: 0.75rem;">
<button class="button is-link" type="submit">Save Revision</button> <button class="button is-link" type="submit">Save Revision</button>
<a class="button is-light" href="{% url 'command_routing' %}">Back</a> <a class="button is-light" href="{% url 'business_plan_inbox' %}">Back To Inbox</a>
</div> </div>
</form> </form>
</article> </article>

View File

@@ -0,0 +1,99 @@
{% extends "base.html" %}
{% block content %}
<section class="section">
<div class="container">
<h1 class="title is-4">Business Plan Inbox</h1>
<p class="subtitle is-6">Review, filter, and open generated business plan documents.</p>
<article class="notification is-light">
<div class="tags mb-1">
<span class="tag is-light">Total {{ stats.total|default:0 }}</span>
<span class="tag is-warning is-light">Draft {{ stats.draft|default:0 }}</span>
<span class="tag is-success is-light">Final {{ stats.final|default:0 }}</span>
</div>
</article>
<article class="box">
<h2 class="title is-6">Filters</h2>
<form method="get">
<div class="columns is-multiline">
<div class="column is-4">
<label class="label is-size-7">Search</label>
<input class="input is-small" name="q" value="{{ filters.q }}" placeholder="Title, source channel, command profile">
</div>
<div class="column is-3">
<label class="label is-size-7">Status</label>
<div class="select is-small is-fullwidth">
<select name="status">
<option value="">All</option>
<option value="draft" {% if filters.status == "draft" %}selected{% endif %}>Draft</option>
<option value="final" {% if filters.status == "final" %}selected{% endif %}>Final</option>
</select>
</div>
</div>
<div class="column is-3">
<label class="label is-size-7">Service</label>
<div class="select is-small is-fullwidth">
<select name="service">
<option value="">All</option>
{% for service in service_choices %}
<option value="{{ service }}" {% if filters.service == service %}selected{% endif %}>{{ service }}</option>
{% endfor %}
</select>
</div>
</div>
<div class="column is-2 is-flex is-align-items-flex-end">
<button class="button is-small is-link is-light" type="submit">Apply</button>
</div>
</div>
</form>
</article>
<article class="box">
<div class="is-flex is-justify-content-space-between is-align-items-center mb-2">
<h2 class="title is-6 mb-0">Documents</h2>
<a class="button is-small is-light" href="{% url 'command_routing' %}">Command Routing</a>
</div>
<table class="table is-fullwidth is-striped is-size-7">
<thead>
<tr>
<th>Title</th>
<th>Status</th>
<th>Service</th>
<th>Channel</th>
<th>Profile</th>
<th>Revisions</th>
<th>Updated</th>
<th></th>
</tr>
</thead>
<tbody>
{% for doc in documents %}
<tr>
<td>{{ doc.title }}</td>
<td>
{% if doc.status == "final" %}
<span class="tag is-success is-light">final</span>
{% else %}
<span class="tag is-warning is-light">draft</span>
{% endif %}
</td>
<td>{{ doc.source_service }}</td>
<td><code>{{ doc.source_channel_identifier|default:"-" }}</code></td>
<td>{% if doc.command_profile %}{{ doc.command_profile.name }}{% else %}-{% endif %}</td>
<td>{{ doc.revision_count }}</td>
<td>{{ doc.updated_at }}</td>
<td>
<a class="button is-small is-link is-light" href="{% url 'business_plan_editor' doc_id=doc.id %}">Open</a>
</td>
</tr>
{% empty %}
<tr><td colspan="8">No business plan documents yet.</td></tr>
{% endfor %}
</tbody>
</table>
</article>
</div>
</section>
{% endblock %}

View File

@@ -17,7 +17,7 @@
<p class="help">Healthcheck error: <code>{{ health.error }}</code></p> <p class="help">Healthcheck error: <code>{{ health.error }}</code></p>
{% endif %} {% endif %}
<p class="help">Config snapshot: command=<code>{{ provider_settings.command }}</code>, workspace=<code>{{ provider_settings.workspace_root|default:"-" }}</code>, profile=<code>{{ provider_settings.default_profile|default:"-" }}</code>, instance=<code>{{ provider_settings.instance_label }}</code>, approver=<code>{{ provider_settings.approver_service }} {{ provider_settings.approver_identifier }}</code>.</p> <p class="help">Config snapshot: command=<code>{{ provider_settings.command }}</code>, workspace=<code>{{ provider_settings.workspace_root|default:"-" }}</code>, profile=<code>{{ provider_settings.default_profile|default:"-" }}</code>, instance=<code>{{ provider_settings.instance_label }}</code>, approver=<code>{{ provider_settings.approver_service }} {{ provider_settings.approver_identifier }}</code>.</p>
<p class="help"><a href="{% url 'tasks_settings' %}">Edit in Task Settings</a>.</p> <p class="help"><a href="{% url 'tasks_settings' %}">Edit in Task Automation</a>.</p>
</article> </article>
<article class="box"> <article class="box">
@@ -89,7 +89,7 @@
</article> </article>
<article class="box"> <article class="box">
<h2 class="title is-6">Permission Queue</h2> <h2 class="title is-6">Approvals Queue</h2>
<table class="table is-fullwidth is-size-7 is-striped"> <table class="table is-fullwidth is-size-7 is-striped">
<thead><tr><th>Requested</th><th>Approval Key</th><th>Status</th><th>Summary</th><th>Permissions</th><th>Run</th><th>Task</th><th>Actions</th></tr></thead> <thead><tr><th>Requested</th><th>Approval Key</th><th>Status</th><th>Summary</th><th>Permissions</th><th>Run</th><th>Task</th><th>Actions</th></tr></thead>
<tbody> <tbody>

View File

@@ -393,7 +393,10 @@
{% endfor %} {% endfor %}
<article class="box" id="bp-documents"> <article class="box" id="bp-documents">
<h2 class="title is-6">Business Plan Documents</h2> <div class="is-flex is-justify-content-space-between is-align-items-center mb-2">
<h2 class="title is-6 mb-0">Recent Business Plan Documents</h2>
<a class="button is-small is-link is-light" href="{% url 'business_plan_inbox' %}">Open Business Plan Inbox</a>
</div>
<table class="table is-fullwidth is-striped is-size-7"> <table class="table is-fullwidth is-striped is-size-7">
<thead> <thead>
<tr><th scope="col">Title</th><th scope="col">Status</th><th scope="col">Source</th><th scope="col">Updated</th><th scope="col">Actions</th></tr> <tr><th scope="col">Title</th><th scope="col">Status</th><th scope="col">Source</th><th scope="col">Updated</th><th scope="col">Actions</th></tr>

View File

@@ -37,14 +37,24 @@
<h3 class="title is-7 mt-4 mb-2">Security Policy</h3> <h3 class="title is-7 mt-4 mb-2">Security Policy</h3>
<form method="post"> <form method="post">
{% csrf_token %} {% csrf_token %}
<input type="hidden" name="encryption_settings_submit" value="1">
<input type="hidden" name="require_omemo" value="0">
<div class="field"> <div class="field">
<label class="checkbox"> <label class="checkbox">
<input type="checkbox" name="require_omemo"{% if security_settings.require_omemo %} checked{% endif %}> <input type="checkbox" name="require_omemo" value="1"{% if security_settings.require_omemo %} checked{% endif %}>
Require OMEMO encryption — reject plaintext messages from your XMPP client Require OMEMO encryption — reject plaintext messages from your XMPP client
</label> </label>
<p class="help is-size-7 has-text-grey mt-1">When enabled, any plaintext XMPP message to the gateway is rejected before command routing.</p> <p class="help is-size-7 has-text-grey mt-1">When enabled, any plaintext XMPP message to the gateway is rejected before command routing.</p>
<p class="help is-size-7 has-text-grey">This is separate from command-scope policy checks such as Require Trusted Fingerprint.</p> <p class="help is-size-7 has-text-grey">This is separate from command-scope policy checks such as Require Trusted Fingerprint.</p>
</div> </div>
<input type="hidden" name="encrypt_contact_messages_with_omemo" value="0">
<div class="field mt-3">
<label class="checkbox">
<input type="checkbox" name="encrypt_contact_messages_with_omemo" value="1"{% if security_settings.encrypt_contact_messages_with_omemo %} checked{% endif %}>
Encrypt contact relay messages to your XMPP client with OMEMO
</label>
<p class="help is-size-7 has-text-grey mt-1">When enabled, relay text from contacts is sent with OMEMO when available. If disabled, relay text is sent in plaintext.</p>
</div>
<button class="button is-link is-small" type="submit">Save</button> <button class="button is-link is-small" type="submit">Save</button>
</form> </form>
</div> </div>
@@ -70,8 +80,27 @@
</td> </td>
</tr> </tr>
<tr> <tr>
<th>Contact key</th> <th>Client OMEMO</th>
<td><code>{{ omemo_row.latest_client_key|default:"—" }}</code></td> <td>
{% if omemo_client_fingerprint %}
<div>
<div>Client fingerprint: <code>{{ omemo_client_fingerprint }}</code></div>
<p class="help is-size-7 has-text-grey mt-1 mb-0">This is your client OMEMO fingerprint observed by the gateway.</p>
</div>
{% elif omemo_client_key_info.has_ids %}
<div>
{% if omemo_client_key_info.sid %}
<div>Your device ID: <code>{{ omemo_client_key_info.sid }}</code></div>
{% endif %}
{% if omemo_client_key_info.rid %}
<div>Gateway device ID: <code>{{ omemo_client_key_info.rid }}</code></div>
{% endif %}
<p class="help is-size-7 has-text-grey mt-1 mb-0">These IDs identify the OMEMO devices that participated in the encrypted message.</p>
</div>
{% else %}
<code>{{ omemo_row.latest_client_key|default:"—" }}</code>
{% endif %}
</td>
</tr> </tr>
<tr> <tr>
<th>Contact JID</th> <th>Contact JID</th>
@@ -101,6 +130,64 @@
</div> </div>
</div> </div>
</div> </div>
<div class="box">
<h2 class="title is-6">OMEMO Trust Management</h2>
<p class="is-size-7 has-text-grey mb-3">
Manage trust for discovered OMEMO keys observed by the gateway.
</p>
<p class="help is-size-7 has-text-grey mb-3">
Note: You are responsible for trusting your own other devices from your other XMPP clients.
</p>
{% if discovered_omemo_keys %}
<table class="table is-fullwidth is-size-7">
<thead>
<tr>
<th>JID</th>
<th>Type</th>
<th>Discovered key</th>
<th>Source</th>
<th>Trusted</th>
<th></th>
</tr>
</thead>
<tbody>
{% for item in discovered_omemo_keys %}
<tr>
<td>{{ item.jid }}</td>
<td>{{ item.key_type }}</td>
<td><code>{{ item.key_id }}</code></td>
<td>{{ item.label }}</td>
<td>
{% if item.trusted %}
<span class="tag is-success is-light">trusted</span>
{% else %}
<span class="tag is-light">not trusted</span>
{% endif %}
</td>
<td>
<form method="post" class="is-flex is-align-items-center" style="gap: 0.5rem;">
{% csrf_token %}
<input type="hidden" name="omemo_trust_update" value="1">
<input type="hidden" name="jid" value="{{ item.jid }}">
<input type="hidden" name="key_type" value="{{ item.key_type }}">
<input type="hidden" name="key_id" value="{{ item.key_id }}">
<input type="hidden" name="source" value="{{ item.source }}">
<label class="checkbox">
<input type="checkbox" name="trusted" value="1"{% if item.trusted %} checked{% endif %}>
Trust
</label>
<button class="button is-small is-link is-light" type="submit">Save</button>
</form>
</td>
</tr>
{% endfor %}
</tbody>
</table>
{% else %}
<p class="is-size-7 has-text-grey">No discovered OMEMO keys yet. Send an OMEMO message to populate this list.</p>
{% endif %}
</div>
{% endif %} {% endif %}
{% if show_permission %} {% if show_permission %}

View File

@@ -1,7 +1,7 @@
{% extends "base.html" %} {% extends "base.html" %}
{% block content %} {% block content %}
<section class="section"><div class="container"> <section class="section"><div class="container">
<h1 class="title is-4">Group Tasks: {{ channel_display_name }}</h1> <h1 class="title is-4">Group Task Inbox: {{ channel_display_name }}</h1>
<p class="subtitle is-6">{{ service_label }}</p> <p class="subtitle is-6">{{ service_label }}</p>
<article class="box"> <article class="box">
<h2 class="title is-6">Create Or Map Project</h2> <h2 class="title is-6">Create Or Map Project</h2>
@@ -65,7 +65,7 @@
<div class="content is-size-7"> <div class="content is-size-7">
<p>This group has no derived tasks yet. To start populating this view:</p> <p>This group has no derived tasks yet. To start populating this view:</p>
<ol> <ol>
<li>Open <a href="{% url 'tasks_settings' %}?service={{ service }}&identifier={{ identifier|urlencode }}">Task Settings</a> and confirm this chat is mapped under <strong>Group Mapping</strong>.</li> <li>Open <a href="{% url 'tasks_settings' %}?service={{ service }}&identifier={{ identifier|urlencode }}">Task Automation</a> and confirm this chat is mapped under <strong>Group Mapping</strong>.</li>
<li>Send task-like messages in this group, for example: <code>task: ship v1</code>, <code>todo: write tests</code>, <code>please review PR</code>.</li> <li>Send task-like messages in this group, for example: <code>task: ship v1</code>, <code>todo: write tests</code>, <code>please review PR</code>.</li>
<li>Mark completion explicitly with a phrase + reference, for example: <code>done #12</code>, <code>completed #12</code>, <code>fixed #12</code>.</li> <li>Mark completion explicitly with a phrase + reference, for example: <code>done #12</code>, <code>completed #12</code>, <code>fixed #12</code>.</li>
<li>Refresh this page; new derived tasks and events should appear automatically.</li> <li>Refresh this page; new derived tasks and events should appear automatically.</li>

View File

@@ -2,10 +2,10 @@
{% block content %} {% block content %}
<section class="section"> <section class="section">
<div class="container"> <div class="container">
<h1 class="title is-4">Tasks</h1> <h1 class="title is-4">Task Inbox</h1>
<p class="subtitle is-6">Immutable tasks derived from chat activity.</p> <p class="subtitle is-6">Immutable tasks derived from chat activity.</p>
<div class="buttons" style="margin-bottom: 0.75rem;"> <div class="buttons" style="margin-bottom: 0.75rem;">
<a class="button is-small is-link is-light" href="{% url 'tasks_settings' %}{% if scope.person_id or scope.service or scope.identifier %}?{% if scope.person_id %}person={{ scope.person_id|urlencode }}{% endif %}{% if scope.service %}{% if scope.person_id %}&{% endif %}service={{ scope.service|urlencode }}{% endif %}{% if scope.identifier %}{% if scope.person_id or scope.service %}&{% endif %}identifier={{ scope.identifier|urlencode }}{% endif %}{% endif %}">Task Settings</a> <a class="button is-small is-link is-light" href="{% url 'tasks_settings' %}{% if scope.person_id or scope.service or scope.identifier %}?{% if scope.person_id %}person={{ scope.person_id|urlencode }}{% endif %}{% if scope.service %}{% if scope.person_id %}&{% endif %}service={{ scope.service|urlencode }}{% endif %}{% if scope.identifier %}{% if scope.person_id or scope.service %}&{% endif %}identifier={{ scope.identifier|urlencode }}{% endif %}{% endif %}">Task Automation</a>
</div> </div>
<div class="columns is-variable is-5"> <div class="columns is-variable is-5">
<div class="column is-4"> <div class="column is-4">

View File

@@ -2,7 +2,7 @@
{% block content %} {% block content %}
<section class="section"> <section class="section">
<div class="container tasks-settings-page"> <div class="container tasks-settings-page">
<h1 class="title is-4">Task Settings</h1> <h1 class="title is-4">Task Automation</h1>
<p class="subtitle is-6">Project defaults flow into channel overrides. Use Quick Setup for normal operation; open Advanced Setup for full controls.</p> <p class="subtitle is-6">Project defaults flow into channel overrides. Use Quick Setup for normal operation; open Advanced Setup for full controls.</p>
<div class="notification is-light"> <div class="notification is-light">
@@ -17,7 +17,7 @@
<section class="block box"> <section class="block box">
<h2 class="title is-6">Quick Setup</h2> <h2 class="title is-6">Quick Setup</h2>
<p class="help">Creates or updates project + optional epic + channel mapping in one submission.</p> <p class="help">Creates or updates project + optional epic + channel mapping in one submission.</p>
<p class="help">After setup, view tasks in <a href="{% url 'tasks_hub' %}">Tasks Hub</a>{% if prefill_service and prefill_identifier %} or <a href="{% url 'tasks_group' service=prefill_service identifier=prefill_identifier %}">this group task view</a>{% endif %}.</p> <p class="help">After setup, view tasks in <a href="{% url 'tasks_hub' %}">Task Inbox</a>{% if prefill_service and prefill_identifier %} or <a href="{% url 'tasks_group' service=prefill_service identifier=prefill_identifier %}">this group task view</a>{% endif %}.</p>
<form method="post"> <form method="post">
{% csrf_token %} {% csrf_token %}
<input type="hidden" name="action" value="quick_setup"> <input type="hidden" name="action" value="quick_setup">
@@ -470,7 +470,7 @@
<span class="tag is-success is-light">ok {{ claude_compact_summary.queue_counts.ok }}</span> <span class="tag is-success is-light">ok {{ claude_compact_summary.queue_counts.ok }}</span>
</div> </div>
</article> </article>
<p class="help">Browse all derived tasks in <a href="{% url 'tasks_hub' %}">Tasks Hub</a>.</p> <p class="help">Browse all derived tasks in <a href="{% url 'tasks_hub' %}">Task Inbox</a>.</p>
</section> </section>
</div> </div>
<div class="column is-12"> <div class="column is-12">

View File

@@ -129,7 +129,7 @@
class="button is-small is-info is-light" class="button is-small is-info is-light"
onclick="giaWorkspaceQueueSelectedDraft('{{ person.id }}'); return false;"> onclick="giaWorkspaceQueueSelectedDraft('{{ person.id }}'); return false;">
<span class="icon is-small"><i class="fa-solid fa-inbox-in"></i></span> <span class="icon is-small"><i class="fa-solid fa-inbox-in"></i></span>
<span>Add To Queue</span> <span>Queue For Approval</span>
</button> </button>
</div> </div>
</div> </div>

View File

@@ -475,7 +475,7 @@
</button> </button>
<button type="submit" class="button is-info is-light" onclick="giaEngageSetAction('{{ person.id }}', 'queue');"> <button type="submit" class="button is-info is-light" onclick="giaEngageSetAction('{{ person.id }}', 'queue');">
<span class="icon is-small"><i class="fa-solid fa-inbox-in"></i></span> <span class="icon is-small"><i class="fa-solid fa-inbox-in"></i></span>
<span>Add To Queue</span> <span>Queue For Approval</span>
</button> </button>
</div> </div>
</form> </form>

View File

@@ -12,7 +12,7 @@
<div class="is-flex is-justify-content-space-between is-align-items-center" style="margin-bottom: 0.75rem; gap: 0.5rem; flex-wrap: wrap;"> <div class="is-flex is-justify-content-space-between is-align-items-center" style="margin-bottom: 0.75rem; gap: 0.5rem; flex-wrap: wrap;">
<div> <div>
<h3 class="title is-6" style="margin-bottom: 0.15rem;">Outgoing Queue</h3> <h3 class="title is-6" style="margin-bottom: 0.15rem;">Approvals Queue</h3>
<p class="is-size-7">Review queued drafts and approve or reject each message.</p> <p class="is-size-7">Review queued drafts and approve or reject each message.</p>
</div> </div>
<span class="tag is-dark is-medium">{{ object_list|length }} pending</span> <span class="tag is-dark is-medium">{{ object_list|length }} pending</span>
@@ -57,7 +57,7 @@
</div> </div>
<div class="is-flex is-justify-content-space-between is-align-items-center" style="gap: 0.5rem; flex-wrap: wrap;"> <div class="is-flex is-justify-content-space-between is-align-items-center" style="gap: 0.5rem; flex-wrap: wrap;">
<small class="has-text-grey">Queue ID: {{ item.id }}</small> <small class="has-text-grey">Approval ID: {{ item.id }}</small>
<div class="buttons are-small" style="margin: 0;"> <div class="buttons are-small" style="margin: 0;">
<button <button
hx-headers='{"X-CSRFToken": "{{ csrf_token }}"}' hx-headers='{"X-CSRFToken": "{{ csrf_token }}"}'
@@ -92,7 +92,7 @@
</div> </div>
{% else %} {% else %}
<article class="box" style="padding: 0.8rem; border: 1px dashed rgba(0, 0, 0, 0.25); box-shadow: none;"> <article class="box" style="padding: 0.8rem; border: 1px dashed rgba(0, 0, 0, 0.25); box-shadow: none;">
<p class="is-size-7 has-text-grey">Queue is empty.</p> <p class="is-size-7 has-text-grey">Approvals Queue is empty.</p>
</article> </article>
{% endif %} {% endif %}
</div> </div>

View File

@@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
from unittest.mock import patch
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from django.test import TestCase from django.test import TestCase
from unittest.mock import patch
from core.messaging.ai import run_prompt from core.messaging.ai import run_prompt
from core.models import AI, AIRunLog, User from core.models import AI, AIRunLog, User

View File

@@ -6,9 +6,9 @@ from django.test import TestCase
from core.models import ( from core.models import (
ChatSession, ChatSession,
ContactAvailabilityEvent, ContactAvailabilityEvent,
Message,
Person, Person,
PersonIdentifier, PersonIdentifier,
Message,
User, User,
) )
from core.presence.inference import now_ms from core.presence.inference import now_ms
@@ -16,7 +16,9 @@ from core.presence.inference import now_ms
class BackfillContactAvailabilityCommandTests(TestCase): class BackfillContactAvailabilityCommandTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("backfill-user", "backfill@example.com", "x") self.user = User.objects.create_user(
"backfill-user", "backfill@example.com", "x"
)
self.person = Person.objects.create(user=self.user, name="Backfill Person") self.person = Person.objects.create(user=self.user, name="Backfill Person")
self.identifier = PersonIdentifier.objects.create( self.identifier = PersonIdentifier.objects.create(
user=self.user, user=self.user,
@@ -24,7 +26,9 @@ class BackfillContactAvailabilityCommandTests(TestCase):
service="signal", service="signal",
identifier="+15551234567", identifier="+15551234567",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
def test_backfill_creates_message_and_read_receipt_availability_events(self): def test_backfill_creates_message_and_read_receipt_availability_events(self):
base_ts = now_ms() base_ts = now_ms()
@@ -58,7 +62,9 @@ class BackfillContactAvailabilityCommandTests(TestCase):
) )
events = list( events = list(
ContactAvailabilityEvent.objects.filter(user=self.user).order_by("ts", "source_kind") ContactAvailabilityEvent.objects.filter(user=self.user).order_by(
"ts", "source_kind"
)
) )
self.assertEqual(3, len(events)) self.assertEqual(3, len(events))
self.assertTrue(any(row.source_kind == "message_in" for row in events)) self.assertTrue(any(row.source_kind == "message_in" for row in events))

View File

@@ -123,7 +123,9 @@ class BPFallbackTests(TransactionTestCase):
run = CommandRun.objects.get(trigger_message=trigger, profile=self.profile) run = CommandRun.objects.get(trigger_message=trigger, profile=self.profile)
self.assertEqual("failed", run.status) self.assertEqual("failed", run.status)
self.assertIn("bp_ai_failed", str(run.error)) self.assertIn("bp_ai_failed", str(run.error))
self.assertFalse(BusinessPlanDocument.objects.filter(trigger_message=trigger).exists()) self.assertFalse(
BusinessPlanDocument.objects.filter(trigger_message=trigger).exists()
)
def test_bp_uses_same_ai_selection_order_as_compose(self): def test_bp_uses_same_ai_selection_order_as_compose(self):
AI.objects.create( AI.objects.create(

View File

@@ -35,7 +35,9 @@ class BPSubcommandTests(TransactionTestCase):
service="whatsapp", service="whatsapp",
identifier="120363402761690215", identifier="120363402761690215",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
self.profile = CommandProfile.objects.create( self.profile = CommandProfile.objects.create(
user=self.user, user=self.user,
slug="bp", slug="bp",
@@ -96,13 +98,19 @@ class BPSubcommandTests(TransactionTestCase):
source_service="whatsapp", source_service="whatsapp",
source_chat_id="120363402761690215", source_chat_id="120363402761690215",
) )
with patch("core.commands.handlers.bp.ai_runner.run_prompt", new=AsyncMock()) as mocked_ai: with patch(
result = async_to_sync(BPCommandHandler().execute)(self._ctx(trigger, trigger.text)) "core.commands.handlers.bp.ai_runner.run_prompt", new=AsyncMock()
) as mocked_ai:
result = async_to_sync(BPCommandHandler().execute)(
self._ctx(trigger, trigger.text)
)
self.assertTrue(result.ok) self.assertTrue(result.ok)
mocked_ai.assert_not_awaited() mocked_ai.assert_not_awaited()
doc = BusinessPlanDocument.objects.get(trigger_message=trigger) doc = BusinessPlanDocument.objects.get(trigger_message=trigger)
self.assertEqual("direct body", doc.content_markdown) self.assertEqual("direct body", doc.content_markdown)
self.assertEqual("Generated from 1 message.", doc.structured_payload.get("annotation")) self.assertEqual(
"Generated from 1 message.", doc.structured_payload.get("annotation")
)
def test_set_reply_only_uses_anchor(self): def test_set_reply_only_uses_anchor(self):
anchor = Message.objects.create( anchor = Message.objects.create(
@@ -124,11 +132,15 @@ class BPSubcommandTests(TransactionTestCase):
source_chat_id="120363402761690215", source_chat_id="120363402761690215",
reply_to=anchor, reply_to=anchor,
) )
result = async_to_sync(BPCommandHandler().execute)(self._ctx(trigger, trigger.text)) result = async_to_sync(BPCommandHandler().execute)(
self._ctx(trigger, trigger.text)
)
self.assertTrue(result.ok) self.assertTrue(result.ok)
doc = BusinessPlanDocument.objects.get(trigger_message=trigger) doc = BusinessPlanDocument.objects.get(trigger_message=trigger)
self.assertEqual("anchor body", doc.content_markdown) self.assertEqual("anchor body", doc.content_markdown)
self.assertEqual("Generated from 1 message.", doc.structured_payload.get("annotation")) self.assertEqual(
"Generated from 1 message.", doc.structured_payload.get("annotation")
)
def test_set_reply_plus_addendum_uses_divider(self): def test_set_reply_plus_addendum_uses_divider(self):
anchor = Message.objects.create( anchor = Message.objects.create(
@@ -150,7 +162,9 @@ class BPSubcommandTests(TransactionTestCase):
source_chat_id="120363402761690215", source_chat_id="120363402761690215",
reply_to=anchor, reply_to=anchor,
) )
result = async_to_sync(BPCommandHandler().execute)(self._ctx(trigger, trigger.text)) result = async_to_sync(BPCommandHandler().execute)(
self._ctx(trigger, trigger.text)
)
self.assertTrue(result.ok) self.assertTrue(result.ok)
doc = BusinessPlanDocument.objects.get(trigger_message=trigger) doc = BusinessPlanDocument.objects.get(trigger_message=trigger)
self.assertIn("base body", doc.content_markdown) self.assertIn("base body", doc.content_markdown)
@@ -171,7 +185,9 @@ class BPSubcommandTests(TransactionTestCase):
source_service="whatsapp", source_service="whatsapp",
source_chat_id="120363402761690215", source_chat_id="120363402761690215",
) )
result = async_to_sync(BPCommandHandler().execute)(self._ctx(trigger, trigger.text)) result = async_to_sync(BPCommandHandler().execute)(
self._ctx(trigger, trigger.text)
)
self.assertFalse(result.ok) self.assertFalse(result.ok)
self.assertEqual("failed", result.status) self.assertEqual("failed", result.status)
self.assertEqual("bp_set_range_requires_reply_target", result.error) self.assertEqual("bp_set_range_requires_reply_target", result.error)
@@ -205,8 +221,12 @@ class BPSubcommandTests(TransactionTestCase):
source_chat_id="120363402761690215", source_chat_id="120363402761690215",
reply_to=anchor, reply_to=anchor,
) )
result = async_to_sync(BPCommandHandler().execute)(self._ctx(trigger, trigger.text)) result = async_to_sync(BPCommandHandler().execute)(
self._ctx(trigger, trigger.text)
)
self.assertTrue(result.ok) self.assertTrue(result.ok)
doc = BusinessPlanDocument.objects.get(trigger_message=trigger) doc = BusinessPlanDocument.objects.get(trigger_message=trigger)
self.assertEqual("line 1\n(no text)\n#bp set range#", doc.content_markdown) self.assertEqual("line 1\n(no text)\n#bp set range#", doc.content_markdown)
self.assertEqual("Generated from 3 messages.", doc.structured_payload.get("annotation")) self.assertEqual(
"Generated from 3 messages.", doc.structured_payload.get("annotation")
)

View File

@@ -55,7 +55,9 @@ class ClaudeCLITaskProviderTests(SimpleTestCase):
@patch("core.tasks.providers.claude_cli.subprocess.run") @patch("core.tasks.providers.claude_cli.subprocess.run")
def test_timeout_maps_to_failed_result(self, run_mock): def test_timeout_maps_to_failed_result(self, run_mock):
run_mock.side_effect = TimeoutExpired(cmd=["claude"], timeout=10) run_mock.side_effect = TimeoutExpired(cmd=["claude"], timeout=10)
result = self.provider.append_update({"command": "claude", "timeout_seconds": 10}, {"task_id": "t1"}) result = self.provider.append_update(
{"command": "claude", "timeout_seconds": 10}, {"task_id": "t1"}
)
self.assertFalse(result.ok) self.assertFalse(result.ok)
self.assertIn("timeout", result.error) self.assertIn("timeout", result.error)
@@ -70,7 +72,9 @@ class ClaudeCLITaskProviderTests(SimpleTestCase):
result = self.provider.append_update({"command": "claude"}, {"task_id": "t1"}) result = self.provider.append_update({"command": "claude"}, {"task_id": "t1"})
self.assertTrue(result.ok) self.assertTrue(result.ok)
self.assertTrue(bool((result.payload or {}).get("requires_approval"))) self.assertTrue(bool((result.payload or {}).get("requires_approval")))
self.assertEqual("requires_approval", (result.payload or {}).get("parsed_status")) self.assertEqual(
"requires_approval", (result.payload or {}).get("parsed_status")
)
@patch("core.tasks.providers.claude_cli.subprocess.run") @patch("core.tasks.providers.claude_cli.subprocess.run")
def test_retries_with_positional_op_when_flag_unsupported(self, run_mock): def test_retries_with_positional_op_when_flag_unsupported(self, run_mock):
@@ -99,7 +103,9 @@ class ClaudeCLITaskProviderTests(SimpleTestCase):
self.assertEqual(["claude", "task-sync", "create"], second[:3]) self.assertEqual(["claude", "task-sync", "create"], second[:3])
@patch("core.tasks.providers.claude_cli.subprocess.run") @patch("core.tasks.providers.claude_cli.subprocess.run")
def test_falls_back_to_builtin_approval_stub_when_no_task_sync_contract(self, run_mock): def test_falls_back_to_builtin_approval_stub_when_no_task_sync_contract(
self, run_mock
):
run_mock.side_effect = [ run_mock.side_effect = [
CompletedProcess( CompletedProcess(
args=[], args=[],
@@ -124,8 +130,13 @@ class ClaudeCLITaskProviderTests(SimpleTestCase):
) )
self.assertTrue(result.ok) self.assertTrue(result.ok)
self.assertTrue(bool((result.payload or {}).get("requires_approval"))) self.assertTrue(bool((result.payload or {}).get("requires_approval")))
self.assertEqual("requires_approval", str((result.payload or {}).get("status") or "")) self.assertEqual(
self.assertEqual("builtin_task_sync_stub", str((result.payload or {}).get("fallback_mode") or "")) "requires_approval", str((result.payload or {}).get("status") or "")
)
self.assertEqual(
"builtin_task_sync_stub",
str((result.payload or {}).get("fallback_mode") or ""),
)
@patch("core.tasks.providers.claude_cli.subprocess.run") @patch("core.tasks.providers.claude_cli.subprocess.run")
def test_builtin_stub_approval_response_returns_ok(self, run_mock): def test_builtin_stub_approval_response_returns_ok(self, run_mock):

View File

@@ -8,10 +8,10 @@ from core.commands.engine import process_inbound_message
from core.commands.handlers.claude import parse_claude_command from core.commands.handlers.claude import parse_claude_command
from core.models import ( from core.models import (
ChatSession, ChatSession,
CommandChannelBinding,
CommandProfile,
CodexPermissionRequest, CodexPermissionRequest,
CodexRun, CodexRun,
CommandChannelBinding,
CommandProfile,
DerivedTask, DerivedTask,
ExternalSyncEvent, ExternalSyncEvent,
Message, Message,
@@ -45,7 +45,9 @@ class ClaudeCommandParserTests(TestCase):
class ClaudeCommandExecutionTests(TestCase): class ClaudeCommandExecutionTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("claude-cmd-user", "claude-cmd@example.com", "x") self.user = User.objects.create_user(
"claude-cmd-user", "claude-cmd@example.com", "x"
)
self.person = Person.objects.create(user=self.user, name="Claude Cmd") self.person = Person.objects.create(user=self.user, name="Claude Cmd")
self.identifier = PersonIdentifier.objects.create( self.identifier = PersonIdentifier.objects.create(
user=self.user, user=self.user,
@@ -53,7 +55,9 @@ class ClaudeCommandExecutionTests(TestCase):
service="web", service="web",
identifier="web-chan-1", identifier="web-chan-1",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
self.project = TaskProject.objects.create(user=self.user, name="Project A") self.project = TaskProject.objects.create(user=self.user, name="Project A")
self.task = DerivedTask.objects.create( self.task = DerivedTask.objects.create(
user=self.user, user=self.user,
@@ -202,7 +206,9 @@ class ClaudeCommandExecutionTests(TestCase):
channel_identifier="approver-chan", channel_identifier="approver-chan",
enabled=True, enabled=True,
) )
trigger = self._msg("#claude approve cl-ak-123#", source_chat_id="approver-chan") trigger = self._msg(
"#claude approve cl-ak-123#", source_chat_id="approver-chan"
)
results = async_to_sync(process_inbound_message)( results = async_to_sync(process_inbound_message)(
CommandContext( CommandContext(
service="web", service="web",

View File

@@ -55,7 +55,9 @@ class CodexCLITaskProviderTests(SimpleTestCase):
@patch("core.tasks.providers.codex_cli.subprocess.run") @patch("core.tasks.providers.codex_cli.subprocess.run")
def test_timeout_maps_to_failed_result(self, run_mock): def test_timeout_maps_to_failed_result(self, run_mock):
run_mock.side_effect = TimeoutExpired(cmd=["codex"], timeout=10) run_mock.side_effect = TimeoutExpired(cmd=["codex"], timeout=10)
result = self.provider.append_update({"command": "codex", "timeout_seconds": 10}, {"task_id": "t1"}) result = self.provider.append_update(
{"command": "codex", "timeout_seconds": 10}, {"task_id": "t1"}
)
self.assertFalse(result.ok) self.assertFalse(result.ok)
self.assertIn("timeout", result.error) self.assertIn("timeout", result.error)
@@ -70,7 +72,9 @@ class CodexCLITaskProviderTests(SimpleTestCase):
result = self.provider.append_update({"command": "codex"}, {"task_id": "t1"}) result = self.provider.append_update({"command": "codex"}, {"task_id": "t1"})
self.assertTrue(result.ok) self.assertTrue(result.ok)
self.assertTrue(bool((result.payload or {}).get("requires_approval"))) self.assertTrue(bool((result.payload or {}).get("requires_approval")))
self.assertEqual("requires_approval", (result.payload or {}).get("parsed_status")) self.assertEqual(
"requires_approval", (result.payload or {}).get("parsed_status")
)
@patch("core.tasks.providers.codex_cli.subprocess.run") @patch("core.tasks.providers.codex_cli.subprocess.run")
def test_retries_with_positional_op_when_flag_unsupported(self, run_mock): def test_retries_with_positional_op_when_flag_unsupported(self, run_mock):
@@ -99,7 +103,9 @@ class CodexCLITaskProviderTests(SimpleTestCase):
self.assertEqual(["codex", "task-sync", "create"], second[:3]) self.assertEqual(["codex", "task-sync", "create"], second[:3])
@patch("core.tasks.providers.codex_cli.subprocess.run") @patch("core.tasks.providers.codex_cli.subprocess.run")
def test_falls_back_to_builtin_approval_stub_when_no_task_sync_contract(self, run_mock): def test_falls_back_to_builtin_approval_stub_when_no_task_sync_contract(
self, run_mock
):
run_mock.side_effect = [ run_mock.side_effect = [
CompletedProcess( CompletedProcess(
args=[], args=[],
@@ -124,8 +130,13 @@ class CodexCLITaskProviderTests(SimpleTestCase):
) )
self.assertTrue(result.ok) self.assertTrue(result.ok)
self.assertTrue(bool((result.payload or {}).get("requires_approval"))) self.assertTrue(bool((result.payload or {}).get("requires_approval")))
self.assertEqual("requires_approval", str((result.payload or {}).get("status") or "")) self.assertEqual(
self.assertEqual("builtin_task_sync_stub", str((result.payload or {}).get("fallback_mode") or "")) "requires_approval", str((result.payload or {}).get("status") or "")
)
self.assertEqual(
"builtin_task_sync_stub",
str((result.payload or {}).get("fallback_mode") or ""),
)
@patch("core.tasks.providers.codex_cli.subprocess.run") @patch("core.tasks.providers.codex_cli.subprocess.run")
def test_builtin_stub_approval_response_returns_ok(self, run_mock): def test_builtin_stub_approval_response_returns_ok(self, run_mock):

View File

@@ -8,10 +8,10 @@ from core.commands.engine import process_inbound_message
from core.commands.handlers.codex import parse_codex_command from core.commands.handlers.codex import parse_codex_command
from core.models import ( from core.models import (
ChatSession, ChatSession,
CommandChannelBinding,
CommandProfile,
CodexPermissionRequest, CodexPermissionRequest,
CodexRun, CodexRun,
CommandChannelBinding,
CommandProfile,
DerivedTask, DerivedTask,
ExternalSyncEvent, ExternalSyncEvent,
Message, Message,
@@ -41,7 +41,9 @@ class CodexCommandParserTests(TestCase):
class CodexCommandExecutionTests(TestCase): class CodexCommandExecutionTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("codex-cmd-user", "codex-cmd@example.com", "x") self.user = User.objects.create_user(
"codex-cmd-user", "codex-cmd@example.com", "x"
)
self.person = Person.objects.create(user=self.user, name="Codex Cmd") self.person = Person.objects.create(user=self.user, name="Codex Cmd")
self.identifier = PersonIdentifier.objects.create( self.identifier = PersonIdentifier.objects.create(
user=self.user, user=self.user,
@@ -49,7 +51,9 @@ class CodexCommandExecutionTests(TestCase):
service="web", service="web",
identifier="web-chan-1", identifier="web-chan-1",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
self.project = TaskProject.objects.create(user=self.user, name="Project A") self.project = TaskProject.objects.create(user=self.user, name="Project A")
self.task = DerivedTask.objects.create( self.task = DerivedTask.objects.create(
user=self.user, user=self.user,
@@ -126,7 +130,10 @@ class CodexCommandExecutionTests(TestCase):
self.assertEqual("waiting_approval", run.status) self.assertEqual("waiting_approval", run.status)
event = ExternalSyncEvent.objects.order_by("-created_at").first() event = ExternalSyncEvent.objects.order_by("-created_at").first()
self.assertEqual("waiting_approval", event.status) self.assertEqual("waiting_approval", event.status)
self.assertEqual("default", str((event.payload or {}).get("provider_payload", {}).get("mode") or "")) self.assertEqual(
"default",
str((event.payload or {}).get("provider_payload", {}).get("mode") or ""),
)
self.assertTrue( self.assertTrue(
CodexPermissionRequest.objects.filter( CodexPermissionRequest.objects.filter(
user=self.user, user=self.user,
@@ -167,7 +174,10 @@ class CodexCommandExecutionTests(TestCase):
source_service="web", source_service="web",
source_channel="web-chan-1", source_channel="web-chan-1",
status="waiting_approval", status="waiting_approval",
request_payload={"action": "append_update", "provider_payload": {"task_id": str(self.task.id)}}, request_payload={
"action": "append_update",
"provider_payload": {"task_id": str(self.task.id)},
},
result_payload={}, result_payload={},
) )
req = CodexPermissionRequest.objects.create( req = CodexPermissionRequest.objects.create(
@@ -207,7 +217,9 @@ class CodexCommandExecutionTests(TestCase):
self.assertEqual("approved_waiting_resume", run.status) self.assertEqual("approved_waiting_resume", run.status)
self.assertEqual("ok", waiting_event.status) self.assertEqual("ok", waiting_event.status)
self.assertTrue( self.assertTrue(
ExternalSyncEvent.objects.filter(idempotency_key="codex_approval:ak-123:approved", status="pending").exists() ExternalSyncEvent.objects.filter(
idempotency_key="codex_approval:ak-123:approved", status="pending"
).exists()
) )
def test_approve_pre_submit_request_queues_original_action(self): def test_approve_pre_submit_request_queues_original_action(self):
@@ -226,7 +238,10 @@ class CodexCommandExecutionTests(TestCase):
source_service="web", source_service="web",
source_channel="web-chan-1", source_channel="web-chan-1",
status="waiting_approval", status="waiting_approval",
request_payload={"action": "append_update", "provider_payload": {"task_id": str(self.task.id)}}, request_payload={
"action": "append_update",
"provider_payload": {"task_id": str(self.task.id)},
},
result_payload={}, result_payload={},
) )
CodexPermissionRequest.objects.create( CodexPermissionRequest.objects.create(
@@ -264,7 +279,11 @@ class CodexCommandExecutionTests(TestCase):
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
self.assertTrue(results[0].ok) self.assertTrue(results[0].ok)
resume = ExternalSyncEvent.objects.filter(idempotency_key="codex_cmd:resume:1").first() resume = ExternalSyncEvent.objects.filter(
idempotency_key="codex_cmd:resume:1"
).first()
self.assertIsNotNone(resume) self.assertIsNotNone(resume)
self.assertEqual("pending", resume.status) self.assertEqual("pending", resume.status)
self.assertEqual("append_update", str((resume.payload or {}).get("action") or "")) self.assertEqual(
"append_update", str((resume.payload or {}).get("action") or "")
)

View File

@@ -5,13 +5,22 @@ from unittest.mock import patch
from django.test import TestCase from django.test import TestCase
from core.management.commands.codex_worker import Command as CodexWorkerCommand from core.management.commands.codex_worker import Command as CodexWorkerCommand
from core.models import CodexPermissionRequest, CodexRun, ExternalSyncEvent, TaskProject, TaskProviderConfig, User from core.models import (
CodexPermissionRequest,
CodexRun,
ExternalSyncEvent,
TaskProject,
TaskProviderConfig,
User,
)
from core.tasks.providers.base import ProviderResult from core.tasks.providers.base import ProviderResult
class CodexWorkerPhase1Tests(TestCase): class CodexWorkerPhase1Tests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("codex-worker-user", "codex-worker@example.com", "x") self.user = User.objects.create_user(
"codex-worker-user", "codex-worker@example.com", "x"
)
self.project = TaskProject.objects.create(user=self.user, name="Worker Project") self.project = TaskProject.objects.create(user=self.user, name="Worker Project")
self.cfg = TaskProviderConfig.objects.create( self.cfg = TaskProviderConfig.objects.create(
user=self.user, user=self.user,
@@ -57,7 +66,9 @@ class CodexWorkerPhase1Tests(TestCase):
run_in_worker = True run_in_worker = True
def append_update(self, config, payload): def append_update(self, config, payload):
return ProviderResult(ok=True, payload={"status": "ok", "summary": "done"}) return ProviderResult(
ok=True, payload={"status": "ok", "summary": "done"}
)
create_task = mark_complete = link_task = append_update create_task = mark_complete = link_task = append_update
@@ -71,7 +82,9 @@ class CodexWorkerPhase1Tests(TestCase):
self.assertEqual("done", str(run.result_payload.get("summary") or "")) self.assertEqual("done", str(run.result_payload.get("summary") or ""))
@patch("core.management.commands.codex_worker.get_provider") @patch("core.management.commands.codex_worker.get_provider")
def test_requires_approval_moves_to_waiting_and_creates_permission_request(self, get_provider_mock): def test_requires_approval_moves_to_waiting_and_creates_permission_request(
self, get_provider_mock
):
run = CodexRun.objects.create( run = CodexRun.objects.create(
user=self.user, user=self.user,
project=self.project, project=self.project,
@@ -128,7 +141,10 @@ class CodexWorkerPhase1Tests(TestCase):
user=self.user, user=self.user,
provider="codex_cli", provider="codex_cli",
status="waiting_approval", status="waiting_approval",
payload={"action": "append_update", "provider_payload": {"mode": "default"}}, payload={
"action": "append_update",
"provider_payload": {"mode": "default"},
},
error="", error="",
) )
run = CodexRun.objects.create( run = CodexRun.objects.create(
@@ -169,7 +185,9 @@ class CodexWorkerPhase1Tests(TestCase):
run_in_worker = True run_in_worker = True
def append_update(self, config, payload): def append_update(self, config, payload):
return ProviderResult(ok=True, payload={"status": "ok", "summary": "resumed"}) return ProviderResult(
ok=True, payload={"status": "ok", "summary": "resumed"}
)
create_task = mark_complete = link_task = append_update create_task = mark_complete = link_task = append_update

View File

@@ -89,7 +89,9 @@ class CommandSecurityPolicyTests(TestCase):
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
self.assertEqual("skipped", results[0].status) self.assertEqual("skipped", results[0].status)
self.assertTrue(str(results[0].error).startswith("policy_denied:service_not_allowed")) self.assertTrue(
str(results[0].error).startswith("policy_denied:service_not_allowed")
)
def test_gateway_scope_can_require_trusted_omemo_key(self): def test_gateway_scope_can_require_trusted_omemo_key(self):
CommandSecurityPolicy.objects.create( CommandSecurityPolicy.objects.create(
@@ -120,7 +122,9 @@ class CommandSecurityPolicyTests(TestCase):
channel_identifier="policy-user@zm.is", channel_identifier="policy-user@zm.is",
sender_identifier="policy-user@zm.is/phone", sender_identifier="policy-user@zm.is/phone",
message_text=".tasks list", message_text=".tasks list",
message_meta={"xmpp": {"omemo_status": "detected", "omemo_client_key": "sid:abc"}}, message_meta={
"xmpp": {"omemo_status": "detected", "omemo_client_key": "sid:abc"}
},
payload={}, payload={},
), ),
routes=[ routes=[

View File

@@ -9,8 +9,8 @@ from core.commands.base import CommandContext
from core.commands.handlers.bp import BPCommandHandler from core.commands.handlers.bp import BPCommandHandler
from core.commands.policies import ensure_variant_policies_for_profile from core.commands.policies import ensure_variant_policies_for_profile
from core.models import ( from core.models import (
BusinessPlanDocument,
AI, AI,
BusinessPlanDocument,
ChatSession, ChatSession,
CommandAction, CommandAction,
CommandChannelBinding, CommandChannelBinding,
@@ -37,7 +37,9 @@ class CommandVariantPolicyTests(TransactionTestCase):
service="whatsapp", service="whatsapp",
identifier="120363402761690215", identifier="120363402761690215",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
self.profile = CommandProfile.objects.create( self.profile = CommandProfile.objects.create(
user=self.user, user=self.user,
slug="bp", slug="bp",
@@ -109,7 +111,9 @@ class CommandVariantPolicyTests(TransactionTestCase):
def test_bp_primary_can_run_in_verbatim_mode_without_ai(self): def test_bp_primary_can_run_in_verbatim_mode_without_ai(self):
ensure_variant_policies_for_profile(self.profile) ensure_variant_policies_for_profile(self.profile)
policy = CommandVariantPolicy.objects.get(profile=self.profile, variant_key="bp") policy = CommandVariantPolicy.objects.get(
profile=self.profile, variant_key="bp"
)
policy.generation_mode = "verbatim" policy.generation_mode = "verbatim"
policy.send_plan_to_egress = False policy.send_plan_to_egress = False
policy.send_status_to_source = False policy.send_status_to_source = False
@@ -143,7 +147,9 @@ class CommandVariantPolicyTests(TransactionTestCase):
def test_bp_set_ai_mode_ignores_template(self): def test_bp_set_ai_mode_ignores_template(self):
ensure_variant_policies_for_profile(self.profile) ensure_variant_policies_for_profile(self.profile)
policy = CommandVariantPolicy.objects.get(profile=self.profile, variant_key="bp_set") policy = CommandVariantPolicy.objects.get(
profile=self.profile, variant_key="bp_set"
)
policy.generation_mode = "ai" policy.generation_mode = "ai"
policy.send_plan_to_egress = False policy.send_plan_to_egress = False
policy.send_status_to_source = False policy.send_status_to_source = False
@@ -222,4 +228,6 @@ class CommandVariantPolicyTests(TransactionTestCase):
self.assertTrue(result.ok) self.assertTrue(result.ok)
source_status.assert_awaited() source_status.assert_awaited()
self.assertEqual(1, binding_send.await_count) self.assertEqual(1, binding_send.await_count)
self.assertFalse(BusinessPlanDocument.objects.filter(trigger_message=trigger).exists()) self.assertFalse(
BusinessPlanDocument.objects.filter(trigger_message=trigger).exists()
)

View File

@@ -13,7 +13,9 @@ class ComposeReactTests(TestCase):
self.user = User.objects.create_user("compose-react", "react@example.com", "pw") self.user = User.objects.create_user("compose-react", "react@example.com", "pw")
self.client.force_login(self.user) self.client.force_login(self.user)
def _build_message(self, *, service: str, identifier: str, source_message_id: str = ""): def _build_message(
self, *, service: str, identifier: str, source_message_id: str = ""
):
person = Person.objects.create(user=self.user, name=f"{service} person") person = Person.objects.create(user=self.user, name=f"{service} person")
person_identifier = PersonIdentifier.objects.create( person_identifier = PersonIdentifier.objects.create(
user=self.user, user=self.user,

View File

@@ -6,6 +6,7 @@ Signal coverage is in test_signal_reply_send.py. This file fills the gaps
for WhatsApp and XMPP, and verifies the shared reply_sync infrastructure for WhatsApp and XMPP, and verifies the shared reply_sync infrastructure
works correctly for both services. works correctly for both services.
""" """
from __future__ import annotations from __future__ import annotations
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
@@ -25,11 +26,11 @@ from core.messaging import history, reply_sync
from core.models import ChatSession, Message, Person, PersonIdentifier, User from core.models import ChatSession, Message, Person, PersonIdentifier, User
from core.presence.inference import now_ms from core.presence.inference import now_ms
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _fake_stanza(xml_text: str) -> SimpleNamespace: def _fake_stanza(xml_text: str) -> SimpleNamespace:
"""Minimal stanza-like object with an .xml attribute.""" """Minimal stanza-like object with an .xml attribute."""
return SimpleNamespace(xml=ET.fromstring(xml_text)) return SimpleNamespace(xml=ET.fromstring(xml_text))
@@ -39,6 +40,7 @@ def _fake_stanza(xml_text: str) -> SimpleNamespace:
# WhatsApp — reply extraction (pure, no DB) # WhatsApp — reply extraction (pure, no DB)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class WhatsAppReplyExtractionTests(SimpleTestCase): class WhatsAppReplyExtractionTests(SimpleTestCase):
def test_extract_reply_ref_from_contextinfo_stanza_id(self): def test_extract_reply_ref_from_contextinfo_stanza_id(self):
payload = { payload = {
@@ -87,6 +89,7 @@ class WhatsAppReplyExtractionTests(SimpleTestCase):
# WhatsApp — reply resolution (requires DB) # WhatsApp — reply resolution (requires DB)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class WhatsAppReplyResolutionTests(TestCase): class WhatsAppReplyResolutionTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user( self.user = User.objects.create_user(
@@ -178,7 +181,9 @@ class WhatsAppReplyResolutionTests(TestCase):
) )
self.anchor.refresh_from_db() self.anchor.refresh_from_db()
reactions = list((self.anchor.receipt_payload or {}).get("reactions") or []) reactions = list((self.anchor.receipt_payload or {}).get("reactions") or [])
removed = [r for r in reactions if r.get("emoji") == "👍" and not r.get("removed")] removed = [
r for r in reactions if r.get("emoji") == "👍" and not r.get("removed")
]
self.assertEqual(0, len(removed)) self.assertEqual(0, len(removed))
@@ -186,6 +191,7 @@ class WhatsAppReplyResolutionTests(TestCase):
# WhatsApp — outbound reply metadata # WhatsApp — outbound reply metadata
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class WhatsAppOutboundReplyTests(TestCase): class WhatsAppOutboundReplyTests(TestCase):
def test_transport_passes_reply_metadata_to_whatsapp_api(self): def test_transport_passes_reply_metadata_to_whatsapp_api(self):
mock_client = MagicMock() mock_client = MagicMock()
@@ -222,6 +228,7 @@ class WhatsAppOutboundReplyTests(TestCase):
# XMPP — reaction extraction (pure, no DB) # XMPP — reaction extraction (pure, no DB)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class XMPPReactionExtractionTests(SimpleTestCase): class XMPPReactionExtractionTests(SimpleTestCase):
def test_extract_xep_0444_reaction(self): def test_extract_xep_0444_reaction(self):
stanza = _fake_stanza( stanza = _fake_stanza(
@@ -276,6 +283,7 @@ class XMPPReactionExtractionTests(SimpleTestCase):
# XMPP — reply extraction (pure, no DB) # XMPP — reply extraction (pure, no DB)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class XMPPReplyExtractionTests(SimpleTestCase): class XMPPReplyExtractionTests(SimpleTestCase):
def test_extract_reply_target_id_from_xep_0461_stanza(self): def test_extract_reply_target_id_from_xep_0461_stanza(self):
stanza = _fake_stanza( stanza = _fake_stanza(
@@ -304,7 +312,9 @@ class XMPPReplyExtractionTests(SimpleTestCase):
self.assertEqual("user@zm.is/mobile", ref.get("reply_source_chat_id")) self.assertEqual("user@zm.is/mobile", ref.get("reply_source_chat_id"))
def test_extract_reply_ref_returns_empty_for_missing_id(self): def test_extract_reply_ref_returns_empty_for_missing_id(self):
ref = reply_sync.extract_reply_ref("xmpp", {"reply_source_chat_id": "user@zm.is"}) ref = reply_sync.extract_reply_ref(
"xmpp", {"reply_source_chat_id": "user@zm.is"}
)
self.assertEqual({}, ref) self.assertEqual({}, ref)
@@ -312,6 +322,7 @@ class XMPPReplyExtractionTests(SimpleTestCase):
# XMPP — reply resolution (requires DB) # XMPP — reply resolution (requires DB)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class XMPPReplyResolutionTests(TestCase): class XMPPReplyResolutionTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user( self.user = User.objects.create_user(

View File

@@ -1,5 +1,5 @@
from io import StringIO
import time import time
from io import StringIO
from django.core.management import call_command from django.core.management import call_command
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
@@ -24,7 +24,9 @@ class EventProjectionShadowTests(TestCase):
service="signal", service="signal",
identifier="+15555550333", identifier="+15555550333",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
def test_shadow_compare_has_zero_mismatch_when_projection_matches(self): def test_shadow_compare_has_zero_mismatch_when_projection_matches(self):
message = Message.objects.create( message = Message.objects.create(

View File

@@ -7,13 +7,13 @@ from django.test import TestCase, override_settings
from core.mcp.tools import execute_tool, tool_specs from core.mcp.tools import execute_tool, tool_specs
from core.models import ( from core.models import (
AIRequest, AIRequest,
DerivedTask,
DerivedTaskEvent,
MCPToolAuditLog, MCPToolAuditLog,
MemoryItem, MemoryItem,
TaskProject, TaskProject,
User, User,
WorkspaceConversation, WorkspaceConversation,
DerivedTask,
DerivedTaskEvent,
) )
@@ -80,9 +80,13 @@ class MCPToolTests(TestCase):
first_hit = (memory_payload.get("hits") or [{}])[0] first_hit = (memory_payload.get("hits") or [{}])[0]
self.assertEqual(str(self.memory.id), str(first_hit.get("memory_id"))) self.assertEqual(str(self.memory.id), str(first_hit.get("memory_id")))
list_payload = execute_tool("tasks.list", {"user_id": self.user.id, "limit": 10}) list_payload = execute_tool(
"tasks.list", {"user_id": self.user.id, "limit": 10}
)
self.assertEqual(1, int(list_payload.get("count") or 0)) self.assertEqual(1, int(list_payload.get("count") or 0))
self.assertEqual(str(self.task.id), str((list_payload.get("items") or [{}])[0].get("id"))) self.assertEqual(
str(self.task.id), str((list_payload.get("items") or [{}])[0].get("id"))
)
search_payload = execute_tool( search_payload = execute_tool(
"tasks.search", "tasks.search",
@@ -90,9 +94,13 @@ class MCPToolTests(TestCase):
) )
self.assertEqual(1, int(search_payload.get("count") or 0)) self.assertEqual(1, int(search_payload.get("count") or 0))
events_payload = execute_tool("tasks.events", {"task_id": str(self.task.id), "limit": 5}) events_payload = execute_tool(
"tasks.events", {"task_id": str(self.task.id), "limit": 5}
)
self.assertEqual(1, int(events_payload.get("count") or 0)) self.assertEqual(1, int(events_payload.get("count") or 0))
self.assertEqual("created", str((events_payload.get("items") or [{}])[0].get("event_type"))) self.assertEqual(
"created", str((events_payload.get("items") or [{}])[0].get("event_type"))
)
def test_memory_proposal_review_flow(self): def test_memory_proposal_review_flow(self):
propose_payload = execute_tool( propose_payload = execute_tool(
@@ -182,7 +190,9 @@ class MCPToolTests(TestCase):
"note": "Implemented wiki tooling.", "note": "Implemented wiki tooling.",
}, },
) )
self.assertEqual("progress", str((note_payload.get("event") or {}).get("event_type"))) self.assertEqual(
"progress", str((note_payload.get("event") or {}).get("event_type"))
)
artifact_payload = execute_tool( artifact_payload = execute_tool(
"tasks.link_artifact", "tasks.link_artifact",

View File

@@ -7,7 +7,13 @@ from django.core.management import call_command
from django.test import TestCase from django.test import TestCase
from django.utils import timezone from django.utils import timezone
from core.models import MemoryChangeRequest, MemoryItem, MessageEvent, User, WorkspaceConversation from core.models import (
MemoryChangeRequest,
MemoryItem,
MessageEvent,
User,
WorkspaceConversation,
)
class MemoryPipelineCommandTests(TestCase): class MemoryPipelineCommandTests(TestCase):
@@ -46,7 +52,9 @@ class MemoryPipelineCommandTests(TestCase):
self.assertIn("memory-suggest-from-messages", rendered) self.assertIn("memory-suggest-from-messages", rendered)
self.assertGreaterEqual(MemoryItem.objects.filter(user=self.user).count(), 1) self.assertGreaterEqual(MemoryItem.objects.filter(user=self.user).count(), 1)
self.assertGreaterEqual( self.assertGreaterEqual(
MemoryChangeRequest.objects.filter(user=self.user, status="pending").count(), MemoryChangeRequest.objects.filter(
user=self.user, status="pending"
).count(),
1, 1,
) )

View File

@@ -6,10 +6,9 @@ from django.test import TestCase
from core.commands.base import CommandContext from core.commands.base import CommandContext
from core.commands.engine import _matches_trigger, process_inbound_message from core.commands.engine import _matches_trigger, process_inbound_message
from core.messaging.reply_sync import extract_reply_ref, resolve_reply_target from core.messaging.reply_sync import extract_reply_ref, resolve_reply_target
from core.views.compose import _command_options_for_channel
from core.models import ( from core.models import (
ChatTaskSource,
ChatSession, ChatSession,
ChatTaskSource,
CommandAction, CommandAction,
CommandChannelBinding, CommandChannelBinding,
CommandProfile, CommandProfile,
@@ -19,6 +18,7 @@ from core.models import (
PersonIdentifier, PersonIdentifier,
User, User,
) )
from core.views.compose import _command_options_for_channel
class Phase1ReplyResolutionTests(TestCase): class Phase1ReplyResolutionTests(TestCase):
@@ -402,7 +402,9 @@ class Phase1CommandEngineTests(TestCase):
if profile is None: if profile is None:
return return
self.assertEqual(3, CommandAction.objects.filter(profile=profile).count()) self.assertEqual(3, CommandAction.objects.filter(profile=profile).count())
self.assertEqual(3, CommandVariantPolicy.objects.filter(profile=profile).count()) self.assertEqual(
3, CommandVariantPolicy.objects.filter(profile=profile).count()
)
self.assertEqual( self.assertEqual(
2, 2,
CommandChannelBinding.objects.filter( CommandChannelBinding.objects.filter(
@@ -436,7 +438,9 @@ class Phase1CommandEngineTests(TestCase):
self.assertEqual(1, len(second_results)) self.assertEqual(1, len(second_results))
self.assertEqual("reply_required", second_results[0].error) self.assertEqual("reply_required", second_results[0].error)
self.assertEqual(3, CommandAction.objects.filter(profile=profile).count()) self.assertEqual(3, CommandAction.objects.filter(profile=profile).count())
self.assertEqual(3, CommandVariantPolicy.objects.filter(profile=profile).count()) self.assertEqual(
3, CommandVariantPolicy.objects.filter(profile=profile).count()
)
self.assertEqual( self.assertEqual(
2, 2,
CommandChannelBinding.objects.filter( CommandChannelBinding.objects.filter(

View File

@@ -21,7 +21,9 @@ from core.presence.inference import now_ms
class PresenceEngineTests(TestCase): class PresenceEngineTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("presence-user", "presence@example.com", "x") self.user = User.objects.create_user(
"presence-user", "presence@example.com", "x"
)
self.person = Person.objects.create(user=self.user, name="Presence Person") self.person = Person.objects.create(user=self.user, name="Presence Person")
self.identifier = PersonIdentifier.objects.create( self.identifier = PersonIdentifier.objects.create(
user=self.user, user=self.user,
@@ -57,7 +59,9 @@ class PresenceEngineTests(TestCase):
) )
) )
self.assertIsNotNone(event) self.assertIsNotNone(event)
self.assertEqual(1, ContactAvailabilityEvent.objects.filter(user=self.user).count()) self.assertEqual(
1, ContactAvailabilityEvent.objects.filter(user=self.user).count()
)
self.assertEqual("available", event.availability_state) self.assertEqual("available", event.availability_state)
def test_inactivity_transitions_to_fading(self): def test_inactivity_transitions_to_fading(self):
@@ -106,7 +110,9 @@ class PresenceEngineTests(TestCase):
at_ts=base_ts + 60_000, at_ts=base_ts + 60_000,
) )
self.assertIsNone(fade_event) self.assertIsNone(fade_event)
self.assertEqual(1, ContactAvailabilityEvent.objects.filter(user=self.user).count()) self.assertEqual(
1, ContactAvailabilityEvent.objects.filter(user=self.user).count()
)
def test_adjacent_same_state_events_extend_single_span(self): def test_adjacent_same_state_events_extend_single_span(self):
ts0 = now_ms() ts0 = now_ms()
@@ -134,7 +140,9 @@ class PresenceEngineTests(TestCase):
ts=ts0 + 5_000, ts=ts0 + 5_000,
) )
) )
spans = list(ContactAvailabilitySpan.objects.filter(user=self.user).order_by("start_ts")) spans = list(
ContactAvailabilitySpan.objects.filter(user=self.user).order_by("start_ts")
)
self.assertEqual(1, len(spans)) self.assertEqual(1, len(spans))
self.assertEqual(ts0, spans[0].start_ts) self.assertEqual(ts0, spans[0].start_ts)
self.assertEqual(ts0 + 5_000, spans[0].end_ts) self.assertEqual(ts0 + 5_000, spans[0].end_ts)

View File

@@ -62,12 +62,21 @@ class ReactionNormalizationTests(TestCase):
self.assertEqual(str(exact_message.id), str(updated.id)) self.assertEqual(str(exact_message.id), str(updated.id))
exact_message.refresh_from_db() exact_message.refresh_from_db()
near_message.refresh_from_db() near_message.refresh_from_db()
self.assertEqual(1, len((exact_message.receipt_payload or {}).get("reactions") or [])) self.assertEqual(
1, len((exact_message.receipt_payload or {}).get("reactions") or [])
)
self.assertEqual( self.assertEqual(
"exact_source_message_id_ts", "exact_source_message_id_ts",
str((exact_message.receipt_payload or {}).get("reaction_last_match_strategy") or ""), str(
(exact_message.receipt_payload or {}).get(
"reaction_last_match_strategy"
)
or ""
),
)
self.assertEqual(
0, len((near_message.receipt_payload or {}).get("reactions") or [])
) )
self.assertEqual(0, len((near_message.receipt_payload or {}).get("reactions") or []))
def test_remove_without_emoji_is_audited_not_active(self): def test_remove_without_emoji_is_audited_not_active(self):
message = Message.objects.create( message = Message.objects.create(

View File

@@ -28,7 +28,9 @@ class ReconcileWorkspaceMetricHistoryCommandTests(TestCase):
service="whatsapp", service="whatsapp",
identifier="15551230000@s.whatsapp.net", identifier="15551230000@s.whatsapp.net",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
base_ts = 1_700_000_000_000 base_ts = 1_700_000_000_000
for idx in range(10): for idx in range(10):
inbound = idx % 2 == 0 inbound = idx % 2 == 0

View File

@@ -1,7 +1,7 @@
from unittest.mock import patch from unittest.mock import patch
from django.urls import reverse
from django.test import TestCase from django.test import TestCase
from django.urls import reverse
from core.models import User from core.models import User

View File

@@ -1,8 +1,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import Mock
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from django.conf import settings from django.conf import settings
@@ -175,11 +174,15 @@ class SignalInboundReplyLinkTests(TransactionTestCase):
} }
async_to_sync(client._process_raw_inbound_event)(json.dumps(payload)) async_to_sync(client._process_raw_inbound_event)(json.dumps(payload))
created = Message.objects.filter( created = (
Message.objects.filter(
user=self.user, user=self.user,
session=self.session, session=self.session,
text="reply inbound s3", text="reply inbound s3",
).order_by("-ts").first() )
.order_by("-ts")
.first()
)
self.assertIsNotNone(created) self.assertIsNotNone(created)
self.assertEqual(self.anchor.id, created.reply_to_id) self.assertEqual(self.anchor.id, created.reply_to_id)
self.assertEqual("1772545458187", created.reply_source_message_id) self.assertEqual("1772545458187", created.reply_source_message_id)
@@ -222,7 +225,9 @@ class SignalInboundReplyLinkTests(TransactionTestCase):
"Expected Signal heart reaction to be applied to anchor receipt payload.", "Expected Signal heart reaction to be applied to anchor receipt payload.",
) )
def test_process_raw_inbound_event_applies_sync_reaction_using_destination_fallback(self): def test_process_raw_inbound_event_applies_sync_reaction_using_destination_fallback(
self,
):
fake_ur = Mock() fake_ur = Mock()
fake_ur.message_received = AsyncMock(return_value=None) fake_ur.message_received = AsyncMock(return_value=None)
fake_ur.xmpp = Mock() fake_ur.xmpp = Mock()
@@ -253,7 +258,7 @@ class SignalInboundReplyLinkTests(TransactionTestCase):
"emoji": "🔥", "emoji": "🔥",
"targetSentTimestamp": 1772545458187, "targetSentTimestamp": 1772545458187,
} }
} },
} }
}, },
} }
@@ -352,7 +357,9 @@ class SignalRuntimeCommandWritebackTests(TestCase):
service="signal", service="signal",
identifier="+15550003000", identifier="+15550003000",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
self.message = Message.objects.create( self.message = Message.objects.create(
user=self.user, user=self.user,
session=self.session, session=self.session,

View File

@@ -5,7 +5,6 @@ from unittest.mock import AsyncMock, patch
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.utils import timezone
from core.models import ( from core.models import (
ChatSession, ChatSession,
@@ -88,7 +87,9 @@ class TaskEnginePlan09Tests(TestCase):
service="signal", service="signal",
identifier="+15559001234", identifier="+15559001234",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
self.project = TaskProject.objects.create(user=self.user, name="Plan09 Project") self.project = TaskProject.objects.create(user=self.user, name="Plan09 Project")
ChatTaskSource.objects.create( ChatTaskSource.objects.create(
user=self.user, user=self.user,
@@ -133,7 +134,9 @@ class TaskEnginePlan09Tests(TestCase):
async_to_sync(process_inbound_task_intelligence)(seed) async_to_sync(process_inbound_task_intelligence)(seed)
cmd = self._msg(".task list", ts=1002) cmd = self._msg(".task list", ts=1002)
async_to_sync(process_inbound_task_intelligence)(cmd) async_to_sync(process_inbound_task_intelligence)(cmd)
payloads = [str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list] payloads = [
str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list
]
self.assertTrue(any("open tasks" in row.lower() for row in payloads)) self.assertTrue(any("open tasks" in row.lower() for row in payloads))
@patch("core.tasks.engine.send_message_raw", new_callable=AsyncMock) @patch("core.tasks.engine.send_message_raw", new_callable=AsyncMock)
@@ -143,7 +146,9 @@ class TaskEnginePlan09Tests(TestCase):
task = DerivedTask.objects.get(origin_message=seed) task = DerivedTask.objects.get(origin_message=seed)
cmd = self._msg(f".task show #{task.reference_code}", ts=1004) cmd = self._msg(f".task show #{task.reference_code}", ts=1004)
async_to_sync(process_inbound_task_intelligence)(cmd) async_to_sync(process_inbound_task_intelligence)(cmd)
payloads = [str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list] payloads = [
str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list
]
self.assertTrue(any("deploy new version" in row.lower() for row in payloads)) self.assertTrue(any("deploy new version" in row.lower() for row in payloads))
self.assertTrue(any(str(task.reference_code) in row for row in payloads)) self.assertTrue(any(str(task.reference_code) in row for row in payloads))
@@ -157,9 +162,13 @@ class TaskEnginePlan09Tests(TestCase):
task.refresh_from_db() task.refresh_from_db()
self.assertEqual("completed", task.status_snapshot) self.assertEqual("completed", task.status_snapshot)
self.assertTrue( self.assertTrue(
DerivedTaskEvent.objects.filter(task=task, event_type="completion_marked").exists() DerivedTaskEvent.objects.filter(
task=task, event_type="completion_marked"
).exists()
) )
payloads = [str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list] payloads = [
str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list
]
self.assertTrue(any("completed" in row.lower() for row in payloads)) self.assertTrue(any("completed" in row.lower() for row in payloads))
def test_dot_task_complete_creates_audit_event(self): def test_dot_task_complete_creates_audit_event(self):
@@ -169,7 +178,9 @@ class TaskEnginePlan09Tests(TestCase):
with patch("core.tasks.engine.send_message_raw", new_callable=AsyncMock): with patch("core.tasks.engine.send_message_raw", new_callable=AsyncMock):
cmd = self._msg(f".task complete #{task.reference_code}", ts=1008) cmd = self._msg(f".task complete #{task.reference_code}", ts=1008)
async_to_sync(process_inbound_task_intelligence)(cmd) async_to_sync(process_inbound_task_intelligence)(cmd)
event = DerivedTaskEvent.objects.filter(task=task, event_type="completion_marked").first() event = DerivedTaskEvent.objects.filter(
task=task, event_type="completion_marked"
).first()
self.assertIsNotNone(event) self.assertIsNotNone(event)
self.assertIn("command", str(event.payload or {}).lower()) self.assertIn("command", str(event.payload or {}).lower())
@@ -185,7 +196,9 @@ class TaskEngineMemoryContextTests(TestCase):
service="whatsapp", service="whatsapp",
identifier="447700900001@s.whatsapp.net", identifier="447700900001@s.whatsapp.net",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
self.project = TaskProject.objects.create(user=self.user, name="Mem Project") self.project = TaskProject.objects.create(user=self.user, name="Mem Project")
ChatTaskSource.objects.create( ChatTaskSource.objects.create(
user=self.user, user=self.user,
@@ -218,8 +231,16 @@ class TaskEngineMemoryContextTests(TestCase):
from core.models import CodexRun from core.models import CodexRun
m = self._msg("task: fix authentication bug", ts=2001) m = self._msg("task: fix authentication bug", ts=2001)
fake_memory = [{"id": "mem-1", "memory_kind": "fact", "content": {"text": "prefers short summaries"}}] fake_memory = [
with patch("core.tasks.engine.retrieve_memories_for_prompt", return_value=fake_memory): {
"id": "mem-1",
"memory_kind": "fact",
"content": {"text": "prefers short summaries"},
}
]
with patch(
"core.tasks.engine.retrieve_memories_for_prompt", return_value=fake_memory
):
async_to_sync(process_inbound_task_intelligence)(m) async_to_sync(process_inbound_task_intelligence)(m)
task = DerivedTask.objects.filter(origin_message=m).first() task = DerivedTask.objects.filter(origin_message=m).first()
self.assertIsNotNone(task) self.assertIsNotNone(task)
@@ -227,5 +248,7 @@ class TaskEngineMemoryContextTests(TestCase):
self.assertIsNotNone(run, "Expected CodexRun created for task") self.assertIsNotNone(run, "Expected CodexRun created for task")
provider_payload = (run.request_payload or {}).get("provider_payload") or {} provider_payload = (run.request_payload or {}).get("provider_payload") or {}
memory_context = provider_payload.get("memory_context") memory_context = provider_payload.get("memory_context")
self.assertIsNotNone(memory_context, "Expected memory_context in CodexRun provider payload") self.assertIsNotNone(
memory_context, "Expected memory_context in CodexRun provider payload"
)
self.assertEqual(1, len(memory_context)) self.assertEqual(1, len(memory_context))

View File

@@ -48,7 +48,9 @@ class TasksPagesManagementTests(TestCase):
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
project = TaskProject.objects.get(user=self.user, name="Ops") project = TaskProject.objects.get(user=self.user, name="Ops")
self.assertIsNotNone(project) self.assertIsNotNone(project)
self.assertFalse(ChatTaskSource.objects.filter(user=self.user, project=project).exists()) self.assertFalse(
ChatTaskSource.objects.filter(user=self.user, project=project).exists()
)
def test_tasks_hub_can_map_identifier_to_selected_project(self): def test_tasks_hub_can_map_identifier_to_selected_project(self):
project = TaskProject.objects.create(user=self.user, name="Mapped") project = TaskProject.objects.create(user=self.user, name="Mapped")
@@ -108,7 +110,9 @@ class TasksPagesManagementTests(TestCase):
follow=True, follow=True,
) )
self.assertEqual(200, delete_response.status_code) self.assertEqual(200, delete_response.status_code)
self.assertFalse(TaskEpic.objects.filter(project=project, name="Phase 1").exists()) self.assertFalse(
TaskEpic.objects.filter(project=project, name="Phase 1").exists()
)
def test_project_page_can_assign_and_clear_task_epic(self): def test_project_page_can_assign_and_clear_task_epic(self):
project = TaskProject.objects.create(user=self.user, name="Roadmap") project = TaskProject.objects.create(user=self.user, name="Roadmap")
@@ -179,9 +183,13 @@ class TasksPagesManagementTests(TestCase):
follow=True, follow=True,
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertTrue(TaskEpic.objects.filter(project=project, name="Phase 2").exists()) self.assertTrue(
TaskEpic.objects.filter(project=project, name="Phase 2").exists()
)
self.assertTrue(mocked_send.await_count >= 1) self.assertTrue(mocked_send.await_count >= 1)
payloads = [str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list] payloads = [
str(call.kwargs.get("text") or "") for call in mocked_send.await_args_list
]
self.assertTrue(any("whatsapp usage" in row.lower() for row in payloads)) self.assertTrue(any("whatsapp usage" in row.lower() for row in payloads))
self.assertTrue(any("add task to epic" in row.lower() for row in payloads)) self.assertTrue(any("add task to epic" in row.lower() for row in payloads))
@@ -266,7 +274,9 @@ class TasksPagesManagementTests(TestCase):
reference_code="2", reference_code="2",
status_snapshot="open", status_snapshot="open",
) )
response = self.client.get(reverse("tasks_project", kwargs={"project_id": str(project.id)})) response = self.client.get(
reverse("tasks_project", kwargs={"project_id": str(project.id)})
)
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertContains( self.assertContains(
response, response,
@@ -302,7 +312,9 @@ class TasksPagesManagementTests(TestCase):
payload={"source": "signal", "emoji": "❤️", "reason": "heart_reaction"}, payload={"source": "signal", "emoji": "❤️", "reason": "heart_reaction"},
) )
response = self.client.get(reverse("tasks_task", kwargs={"task_id": str(task.id)})) response = self.client.get(
reverse("tasks_task", kwargs={"task_id": str(task.id)})
)
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertContains(response, "View payload JSON") self.assertContains(response, "View payload JSON")
self.assertContains(response, "<strong>source</strong>: signal", html=True) self.assertContains(response, "<strong>source</strong>: signal", html=True)

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from django.urls import reverse
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.urls import reverse
from core.models import ( from core.models import (
ChatSession, ChatSession,
@@ -12,24 +12,32 @@ from core.models import (
CodexPermissionRequest, CodexPermissionRequest,
CodexRun, CodexRun,
DerivedTask, DerivedTask,
ExternalSyncEvent,
ExternalChatLink, ExternalChatLink,
ExternalSyncEvent,
Message, Message,
Person, Person,
PersonIdentifier, PersonIdentifier,
TaskCompletionPattern, TaskCompletionPattern,
TaskProviderConfig,
TaskProject, TaskProject,
TaskProviderConfig,
User, User,
) )
from core.tasks.engine import process_inbound_task_intelligence from core.tasks.engine import process_inbound_task_intelligence
from core.views.compose import _command_options_for_channel, _toggle_task_announce_for_channel from core.views.compose import (
from core.views.tasks import _apply_safe_defaults_for_user, _ensure_default_completion_patterns _command_options_for_channel,
_toggle_task_announce_for_channel,
)
from core.views.tasks import (
_apply_safe_defaults_for_user,
_ensure_default_completion_patterns,
)
class TaskSettingsBackfillTests(TestCase): class TaskSettingsBackfillTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("defaults-user", "defaults@example.com", "x") self.user = User.objects.create_user(
"defaults-user", "defaults@example.com", "x"
)
self.person = Person.objects.create(user=self.user, name="Defaults Person") self.person = Person.objects.create(user=self.user, name="Defaults Person")
self.identifier = PersonIdentifier.objects.create( self.identifier = PersonIdentifier.objects.create(
user=self.user, user=self.user,
@@ -67,7 +75,9 @@ class TaskSettingsBackfillTests(TestCase):
self.source.refresh_from_db() self.source.refresh_from_db()
self.assertEqual("strict", self.project.settings.get("match_mode")) self.assertEqual("strict", self.project.settings.get("match_mode"))
self.assertTrue(bool(self.project.settings.get("require_prefix"))) self.assertTrue(bool(self.project.settings.get("require_prefix")))
self.assertEqual(["task:", "todo:"], self.project.settings.get("allowed_prefixes")) self.assertEqual(
["task:", "todo:"], self.project.settings.get("allowed_prefixes")
)
self.assertFalse(bool(self.project.settings.get("announce_task_id"))) self.assertFalse(bool(self.project.settings.get("announce_task_id")))
self.assertEqual("strict", self.source.settings.get("match_mode")) self.assertEqual("strict", self.source.settings.get("match_mode"))
self.assertTrue(bool(self.source.settings.get("require_prefix"))) self.assertTrue(bool(self.source.settings.get("require_prefix")))
@@ -75,7 +85,9 @@ class TaskSettingsBackfillTests(TestCase):
def test_default_completion_phrases_seeded(self): def test_default_completion_phrases_seeded(self):
_ensure_default_completion_patterns(self.user) _ensure_default_completion_patterns(self.user)
phrases = set( phrases = set(
TaskCompletionPattern.objects.filter(user=self.user).values_list("phrase", flat=True) TaskCompletionPattern.objects.filter(user=self.user).values_list(
"phrase", flat=True
)
) )
self.assertTrue({"done", "completed", "fixed"}.issubset(phrases)) self.assertTrue({"done", "completed", "fixed"}.issubset(phrases))
@@ -136,8 +148,12 @@ class TaskAnnounceRuntimeTests(TestCase):
service="whatsapp", service="whatsapp",
identifier="120363402761690215@g.us", identifier="120363402761690215@g.us",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
self.project = TaskProject.objects.create(user=self.user, name="Runtime Project") user=self.user, identifier=self.identifier
)
self.project = TaskProject.objects.create(
user=self.user, name="Runtime Project"
)
def _seed_source(self, announce_enabled: bool): def _seed_source(self, announce_enabled: bool):
return ChatTaskSource.objects.create( return ChatTaskSource.objects.create(
@@ -167,22 +183,32 @@ class TaskAnnounceRuntimeTests(TestCase):
def test_no_announce_send_when_disabled(self): def test_no_announce_send_when_disabled(self):
self._seed_source(False) self._seed_source(False)
with patch("core.tasks.engine.send_message_raw", new=AsyncMock()) as mocked_send: with patch(
async_to_sync(process_inbound_task_intelligence)(self._msg("task: rotate secrets")) "core.tasks.engine.send_message_raw", new=AsyncMock()
) as mocked_send:
async_to_sync(process_inbound_task_intelligence)(
self._msg("task: rotate secrets")
)
self.assertTrue(DerivedTask.objects.exists()) self.assertTrue(DerivedTask.objects.exists())
mocked_send.assert_not_awaited() mocked_send.assert_not_awaited()
def test_announce_send_when_enabled(self): def test_announce_send_when_enabled(self):
self._seed_source(True) self._seed_source(True)
with patch("core.tasks.engine.send_message_raw", new=AsyncMock(return_value=True)) as mocked_send: with patch(
async_to_sync(process_inbound_task_intelligence)(self._msg("task: rotate secrets")) "core.tasks.engine.send_message_raw", new=AsyncMock(return_value=True)
) as mocked_send:
async_to_sync(process_inbound_task_intelligence)(
self._msg("task: rotate secrets")
)
self.assertTrue(DerivedTask.objects.exists()) self.assertTrue(DerivedTask.objects.exists())
mocked_send.assert_awaited() mocked_send.assert_awaited()
class TaskSettingsViewActionsTests(TestCase): class TaskSettingsViewActionsTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("task-settings-user", "ts@example.com", "x") self.user = User.objects.create_user(
"task-settings-user", "ts@example.com", "x"
)
self.client.force_login(self.user) self.client.force_login(self.user)
self.project = TaskProject.objects.create(user=self.user, name="Project A") self.project = TaskProject.objects.create(user=self.user, name="Project A")
self.source = ChatTaskSource.objects.create( self.source = ChatTaskSource.objects.create(
@@ -214,7 +240,9 @@ class TaskSettingsViewActionsTests(TestCase):
@override_settings(TASK_DERIVATION_USE_AI=False) @override_settings(TASK_DERIVATION_USE_AI=False)
class TaskAutoBootstrapTests(TestCase): class TaskAutoBootstrapTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("task-auto-user", "task-auto@example.com", "x") self.user = User.objects.create_user(
"task-auto-user", "task-auto@example.com", "x"
)
self.person = Person.objects.create(user=self.user, name="Bootstrap Chat") self.person = Person.objects.create(user=self.user, name="Bootstrap Chat")
self.identifier = PersonIdentifier.objects.create( self.identifier = PersonIdentifier.objects.create(
user=self.user, user=self.user,
@@ -222,7 +250,9 @@ class TaskAutoBootstrapTests(TestCase):
service="whatsapp", service="whatsapp",
identifier="120363402761690215@g.us", identifier="120363402761690215@g.us",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
def test_task_message_auto_creates_project_and_source(self): def test_task_message_auto_creates_project_and_source(self):
msg = Message.objects.create( msg = Message.objects.create(
@@ -243,13 +273,17 @@ class TaskAutoBootstrapTests(TestCase):
enabled=True, enabled=True,
).first() ).first()
self.assertIsNotNone(source) self.assertIsNotNone(source)
self.assertTrue(TaskProject.objects.filter(user=self.user, id=source.project_id).exists()) self.assertTrue(
TaskProject.objects.filter(user=self.user, id=source.project_id).exists()
)
self.assertEqual(1, DerivedTask.objects.filter(user=self.user).count()) self.assertEqual(1, DerivedTask.objects.filter(user=self.user).count())
class TaskProjectDeleteGuardTests(TestCase): class TaskProjectDeleteGuardTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("task-delete-user", "task-delete@example.com", "x") self.user = User.objects.create_user(
"task-delete-user", "task-delete@example.com", "x"
)
self.client.force_login(self.user) self.client.force_login(self.user)
self.project = TaskProject.objects.create(user=self.user, name="Delete Me") self.project = TaskProject.objects.create(user=self.user, name="Delete Me")
self.source = ChatTaskSource.objects.create( self.source = ChatTaskSource.objects.create(
@@ -271,7 +305,9 @@ class TaskProjectDeleteGuardTests(TestCase):
follow=True, follow=True,
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertTrue(TaskProject.objects.filter(id=self.project.id, user=self.user).exists()) self.assertTrue(
TaskProject.objects.filter(id=self.project.id, user=self.user).exists()
)
def test_project_delete_reseeds_default_mapping(self): def test_project_delete_reseeds_default_mapping(self):
response = self.client.post( response = self.client.post(
@@ -284,7 +320,9 @@ class TaskProjectDeleteGuardTests(TestCase):
follow=True, follow=True,
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
self.assertFalse(TaskProject.objects.filter(id=self.project.id, user=self.user).exists()) self.assertFalse(
TaskProject.objects.filter(id=self.project.id, user=self.user).exists()
)
self.assertTrue( self.assertTrue(
ChatTaskSource.objects.filter( ChatTaskSource.objects.filter(
user=self.user, user=self.user,
@@ -297,7 +335,9 @@ class TaskProjectDeleteGuardTests(TestCase):
class TaskHubEmptyProjectVisibilityTests(TestCase): class TaskHubEmptyProjectVisibilityTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("task-hub-user", "task-hub@example.com", "x") self.user = User.objects.create_user(
"task-hub-user", "task-hub@example.com", "x"
)
self.client.force_login(self.user) self.client.force_login(self.user)
self.empty = TaskProject.objects.create(user=self.user, name="Empty") self.empty = TaskProject.objects.create(user=self.user, name="Empty")
self.used = TaskProject.objects.create(user=self.user, name="Used") self.used = TaskProject.objects.create(user=self.user, name="Used")
@@ -326,7 +366,9 @@ class TaskHubEmptyProjectVisibilityTests(TestCase):
class TaskSettingsExternalChatLinkScopeTests(TestCase): class TaskSettingsExternalChatLinkScopeTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("task-link-user", "task-link@example.com", "x") self.user = User.objects.create_user(
"task-link-user", "task-link@example.com", "x"
)
self.client.force_login(self.user) self.client.force_login(self.user)
self.group_person = Person.objects.create(user=self.user, name="Scoped Group") self.group_person = Person.objects.create(user=self.user, name="Scoped Group")
self.group_identifier = PersonIdentifier.objects.create( self.group_identifier = PersonIdentifier.objects.create(
@@ -390,7 +432,9 @@ class TaskSettingsExternalChatLinkScopeTests(TestCase):
class CodexSettingsAndSubmitTests(TestCase): class CodexSettingsAndSubmitTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("codex-settings-user", "codex-settings@example.com", "x") self.user = User.objects.create_user(
"codex-settings-user", "codex-settings@example.com", "x"
)
self.client.force_login(self.user) self.client.force_login(self.user)
self.project = TaskProject.objects.create(user=self.user, name="Codex Project") self.project = TaskProject.objects.create(user=self.user, name="Codex Project")
self.task = DerivedTask.objects.create( self.task = DerivedTask.objects.create(
@@ -426,7 +470,9 @@ class CodexSettingsAndSubmitTests(TestCase):
self.assertTrue(cfg.enabled) self.assertTrue(cfg.enabled)
self.assertEqual("team-a", str(cfg.settings.get("instance_label") or "")) self.assertEqual("team-a", str(cfg.settings.get("instance_label") or ""))
self.assertEqual("web", str(cfg.settings.get("approver_service") or "")) self.assertEqual("web", str(cfg.settings.get("approver_service") or ""))
self.assertEqual("approver-chan", str(cfg.settings.get("approver_identifier") or "")) self.assertEqual(
"approver-chan", str(cfg.settings.get("approver_identifier") or "")
)
def test_task_submit_endpoint_creates_codex_run_and_event(self): def test_task_submit_endpoint_creates_codex_run_and_event(self):
TaskProviderConfig.objects.create( TaskProviderConfig.objects.create(
@@ -444,10 +490,20 @@ class CodexSettingsAndSubmitTests(TestCase):
follow=True, follow=True,
) )
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
run = CodexRun.objects.filter(user=self.user, task=self.task).order_by("-created_at").first() run = (
CodexRun.objects.filter(user=self.user, task=self.task)
.order_by("-created_at")
.first()
)
self.assertIsNotNone(run) self.assertIsNotNone(run)
self.assertEqual("waiting_approval", str(getattr(run, "status", ""))) self.assertEqual("waiting_approval", str(getattr(run, "status", "")))
event = ExternalSyncEvent.objects.filter(user=self.user, task=self.task, provider="codex_cli").order_by("-created_at").first() event = (
ExternalSyncEvent.objects.filter(
user=self.user, task=self.task, provider="codex_cli"
)
.order_by("-created_at")
.first()
)
self.assertIsNotNone(event) self.assertIsNotNone(event)
self.assertEqual("waiting_approval", str(getattr(event, "status", ""))) self.assertEqual("waiting_approval", str(getattr(event, "status", "")))
self.assertTrue( self.assertTrue(
@@ -474,7 +530,10 @@ class CodexSettingsAndSubmitTests(TestCase):
source_service="web", source_service="web",
source_channel="web-chan-1", source_channel="web-chan-1",
status="waiting_approval", status="waiting_approval",
request_payload={"action": "append_update", "provider_payload": {"task_id": str(self.task.id)}}, request_payload={
"action": "append_update",
"provider_payload": {"task_id": str(self.task.id)},
},
result_payload={}, result_payload={},
) )
req = CodexPermissionRequest.objects.create( req = CodexPermissionRequest.objects.create(

View File

@@ -2,7 +2,11 @@ from asgiref.sync import async_to_sync
from django.test import SimpleTestCase from django.test import SimpleTestCase
from core.clients import transport from core.clients import transport
from core.transports.capabilities import capability_snapshot, supports, unsupported_reason from core.transports.capabilities import (
capability_snapshot,
supports,
unsupported_reason,
)
class TransportCapabilitiesTests(SimpleTestCase): class TransportCapabilitiesTests(SimpleTestCase):
@@ -11,7 +15,10 @@ class TransportCapabilitiesTests(SimpleTestCase):
def test_instagram_reactions_not_supported(self): def test_instagram_reactions_not_supported(self):
self.assertFalse(supports("instagram", "reactions")) self.assertFalse(supports("instagram", "reactions"))
self.assertIn("instagram does not support reactions", unsupported_reason("instagram", "reactions")) self.assertIn(
"instagram does not support reactions",
unsupported_reason("instagram", "reactions"),
)
def test_snapshot_has_schema_version(self): def test_snapshot_has_schema_version(self):
snapshot = capability_snapshot() snapshot = capability_snapshot()

View File

@@ -49,7 +49,9 @@ class WhatsAppReactionHandlingTests(TestCase):
service="whatsapp", service="whatsapp",
identifier="15551234567@s.whatsapp.net", identifier="15551234567@s.whatsapp.net",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
self.base_ts = now_ms() self.base_ts = now_ms()
self.target = Message.objects.create( self.target = Message.objects.create(
user=self.user, user=self.user,
@@ -84,7 +86,9 @@ class WhatsAppReactionHandlingTests(TestCase):
parsed = self.client._extract_reaction_event(message_obj) parsed = self.client._extract_reaction_event(message_obj)
self.assertIsNotNone(parsed) self.assertIsNotNone(parsed)
self.assertEqual("wa-target-1", str(parsed.get("target_message_id") or "")) self.assertEqual("wa-target-1", str(parsed.get("target_message_id") or ""))
before_count = Message.objects.filter(user=self.user, session=self.session).count() before_count = Message.objects.filter(
user=self.user, session=self.session
).count()
async_to_sync(history.apply_reaction)( async_to_sync(history.apply_reaction)(
self.user, self.user,
self.identifier, self.identifier,
@@ -96,7 +100,9 @@ class WhatsAppReactionHandlingTests(TestCase):
remove=False, remove=False,
payload={"event": "reaction"}, payload={"event": "reaction"},
) )
after_count = Message.objects.filter(user=self.user, session=self.session).count() after_count = Message.objects.filter(
user=self.user, session=self.session
).count()
self.assertEqual(before_count, after_count) self.assertEqual(before_count, after_count)
self.target.refresh_from_db() self.target.refresh_from_db()
@@ -127,7 +133,9 @@ class RecalculateContactAvailabilityTests(TestCase):
service="whatsapp", service="whatsapp",
identifier="15557654321@s.whatsapp.net", identifier="15557654321@s.whatsapp.net",
) )
self.session = ChatSession.objects.create(user=self.user, identifier=self.identifier) self.session = ChatSession.objects.create(
user=self.user, identifier=self.identifier
)
self.base_ts = now_ms() self.base_ts = now_ms()
Message.objects.create( Message.objects.create(
@@ -168,19 +176,25 @@ class RecalculateContactAvailabilityTests(TestCase):
return events, spans return events, spans
def test_recalculate_is_deterministic_and_no_skew_on_rerun(self): def test_recalculate_is_deterministic_and_no_skew_on_rerun(self):
call_command("recalculate_contact_availability", "--days", "36500", "--limit", "500") call_command(
"recalculate_contact_availability", "--days", "36500", "--limit", "500"
)
first_events, first_spans = self._projection() first_events, first_spans = self._projection()
self.assertTrue(first_events) self.assertTrue(first_events)
self.assertTrue(first_spans) self.assertTrue(first_spans)
call_command("recalculate_contact_availability", "--days", "36500", "--limit", "500") call_command(
"recalculate_contact_availability", "--days", "36500", "--limit", "500"
)
second_events, second_spans = self._projection() second_events, second_spans = self._projection()
self.assertEqual(first_events, second_events) self.assertEqual(first_events, second_events)
self.assertEqual(first_spans, second_spans) self.assertEqual(first_spans, second_spans)
def test_recalculate_no_reset_does_not_duplicate(self): def test_recalculate_no_reset_does_not_duplicate(self):
call_command("recalculate_contact_availability", "--days", "36500", "--limit", "500") call_command(
"recalculate_contact_availability", "--days", "36500", "--limit", "500"
)
events_before = ContactAvailabilityEvent.objects.filter(user=self.user).count() events_before = ContactAvailabilityEvent.objects.filter(user=self.user).count()
spans_before = ContactAvailabilitySpan.objects.filter(user=self.user).count() spans_before = ContactAvailabilitySpan.objects.filter(user=self.user).count()

View File

@@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
@@ -32,8 +31,12 @@ class _ApprovalProbe:
class XMPPGatewayApprovalCommandTests(TestCase): class XMPPGatewayApprovalCommandTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("xmpp-approval-user", "xmpp-approval@example.com", "x") self.user = User.objects.create_user(
self.project = TaskProject.objects.create(user=self.user, name="Approval Project") "xmpp-approval-user", "xmpp-approval@example.com", "x"
)
self.project = TaskProject.objects.create(
user=self.user, name="Approval Project"
)
self.task = DerivedTask.objects.create( self.task = DerivedTask.objects.create(
user=self.user, user=self.user,
project=self.project, project=self.project,
@@ -59,7 +62,10 @@ class XMPPGatewayApprovalCommandTests(TestCase):
source_service="xmpp", source_service="xmpp",
source_channel="jews.zm.is", source_channel="jews.zm.is",
status="waiting_approval", status="waiting_approval",
request_payload={"action": "append_update", "provider_payload": {"task_id": str(self.task.id)}}, request_payload={
"action": "append_update",
"provider_payload": {"task_id": str(self.task.id)},
},
result_payload={}, result_payload={},
) )
self.request = CodexPermissionRequest.objects.create( self.request = CodexPermissionRequest.objects.create(
@@ -124,7 +130,9 @@ class XMPPGatewayApprovalCommandTests(TestCase):
class XMPPGatewayTasksCommandTests(TestCase): class XMPPGatewayTasksCommandTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user("xmpp-task-user", "xmpp-task@example.com", "x") self.user = User.objects.create_user(
"xmpp-task-user", "xmpp-task@example.com", "x"
)
self.project = TaskProject.objects.create(user=self.user, name="Task Project") self.project = TaskProject.objects.create(user=self.user, name="Task Project")
self.task = DerivedTask.objects.create( self.task = DerivedTask.objects.create(
user=self.user, user=self.user,

View File

@@ -10,6 +10,7 @@ mirroring exactly the flow a phone XMPP client uses:
Tests are skipped automatically when XMPP settings are absent (e.g. in CI Tests are skipped automatically when XMPP settings are absent (e.g. in CI
environments without a running stack). environments without a running stack).
""" """
from __future__ import annotations from __future__ import annotations
import base64 import base64
@@ -26,11 +27,11 @@ import xml.etree.ElementTree as ET
from django.conf import settings from django.conf import settings
from django.test import SimpleTestCase from django.test import SimpleTestCase
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _xmpp_configured() -> bool: def _xmpp_configured() -> bool:
return bool( return bool(
getattr(settings, "XMPP_JID", None) getattr(settings, "XMPP_JID", None)
@@ -67,10 +68,21 @@ def _xmpp_domain() -> str:
def _prosody_auth_endpoint() -> str: def _prosody_auth_endpoint() -> str:
"""URL of the Django auth bridge that Prosody calls for c2s authentication.""" """URL of the Django auth bridge that Prosody calls for c2s authentication."""
return str(getattr(settings, "PROSODY_AUTH_ENDPOINT", "http://127.0.0.1:8090/internal/prosody/auth/")) return str(
getattr(
settings,
"PROSODY_AUTH_ENDPOINT",
"http://127.0.0.1:8090/internal/prosody/auth/",
)
)
def _recv_until(sock: socket.socket, patterns: list[bytes], timeout: float = 8.0, max_bytes: int = 16384) -> bytes: def _recv_until(
sock: socket.socket,
patterns: list[bytes],
timeout: float = 8.0,
max_bytes: int = 16384,
) -> bytes:
"""Read from sock until one of the byte patterns appears or timeout/max_bytes hit.""" """Read from sock until one of the byte patterns appears or timeout/max_bytes hit."""
buf = b"" buf = b""
deadline = time.monotonic() + timeout deadline = time.monotonic() + timeout
@@ -91,7 +103,9 @@ def _recv_until(sock: socket.socket, patterns: list[bytes], timeout: float = 8.0
return buf return buf
def _component_handshake(address: str, port: int, jid: str, secret: str, timeout: float = 5.0) -> tuple[bool, str]: def _component_handshake(
address: str, port: int, jid: str, secret: str, timeout: float = 5.0
) -> tuple[bool, str]:
""" """
Attempt an XEP-0114 external component handshake. Attempt an XEP-0114 external component handshake.
@@ -123,7 +137,9 @@ def _component_handshake(address: str, port: int, jid: str, secret: str, timeout
token = hashlib.sha1((stream_id + secret).encode()).hexdigest() token = hashlib.sha1((stream_id + secret).encode()).hexdigest()
sock.sendall(f"<handshake>{token}</handshake>".encode()) sock.sendall(f"<handshake>{token}</handshake>".encode())
response = _recv_until(sock, [b"<handshake", b"<stream:error"], timeout=timeout) response = _recv_until(
sock, [b"<handshake", b"<stream:error"], timeout=timeout
)
resp_text = response.decode(errors="replace") resp_text = response.decode(errors="replace")
if "<handshake/>" in resp_text or "<handshake />" in resp_text: if "<handshake/>" in resp_text or "<handshake />" in resp_text:
@@ -146,6 +162,7 @@ def _component_handshake(address: str, port: int, jid: str, secret: str, timeout
class _C2SResult: class _C2SResult:
"""Return value from _c2s_sasl_auth.""" """Return value from _c2s_sasl_auth."""
def __init__(self, success: bool, stage: str, detail: str): def __init__(self, success: bool, stage: str, detail: str):
self.success = success # True = SASL <success/> self.success = success # True = SASL <success/>
self.stage = stage # where we got to: tcp/starttls/tls/features/auth self.stage = stage # where we got to: tcp/starttls/tls/features/auth
@@ -201,19 +218,29 @@ def _c2s_sasl_auth(
raw.sendall(stream_open(domain)) raw.sendall(stream_open(domain))
# --- Receive pre-TLS features (expect <starttls>) --- # --- Receive pre-TLS features (expect <starttls>) ---
buf = _recv_until(raw, [b"</stream:features>", b"<stream:error"], timeout=timeout) buf = _recv_until(
raw, [b"</stream:features>", b"<stream:error"], timeout=timeout
)
text = buf.decode(errors="replace") text = buf.decode(errors="replace")
if "<stream:error" in text: if "<stream:error" in text:
return _C2SResult(False, "starttls", f"Stream error before features: {text[:200]}") return _C2SResult(
False, "starttls", f"Stream error before features: {text[:200]}"
)
if "starttls" not in text.lower(): if "starttls" not in text.lower():
return _C2SResult(False, "starttls", f"No <starttls> in pre-TLS features: {text[:300]}") return _C2SResult(
False, "starttls", f"No <starttls> in pre-TLS features: {text[:300]}"
)
# --- Negotiate STARTTLS --- # --- Negotiate STARTTLS ---
raw.sendall(f"<starttls xmlns='{NS_TLS}'/>".encode()) raw.sendall(f"<starttls xmlns='{NS_TLS}'/>".encode())
buf2 = _recv_until(raw, [b"<proceed", b"<failure"], timeout=timeout) buf2 = _recv_until(raw, [b"<proceed", b"<failure"], timeout=timeout)
text2 = buf2.decode(errors="replace") text2 = buf2.decode(errors="replace")
if "<proceed" not in text2: if "<proceed" not in text2:
return _C2SResult(False, "starttls", f"No <proceed/> after STARTTLS request: {text2[:200]}") return _C2SResult(
False,
"starttls",
f"No <proceed/> after STARTTLS request: {text2[:200]}",
)
# --- Upgrade to TLS --- # --- Upgrade to TLS ---
ctx = ssl.create_default_context() ctx = ssl.create_default_context()
@@ -223,7 +250,9 @@ def _c2s_sasl_auth(
try: try:
tls = ctx.wrap_socket(raw, server_hostname=domain) tls = ctx.wrap_socket(raw, server_hostname=domain)
except ssl.SSLCertVerificationError as exc: except ssl.SSLCertVerificationError as exc:
return _C2SResult(False, "tls", f"TLS cert verification failed for {domain!r}: {exc}") return _C2SResult(
False, "tls", f"TLS cert verification failed for {domain!r}: {exc}"
)
except ssl.SSLError as exc: except ssl.SSLError as exc:
return _C2SResult(False, "tls", f"TLS handshake error: {exc}") return _C2SResult(False, "tls", f"TLS handshake error: {exc}")
@@ -231,23 +260,35 @@ def _c2s_sasl_auth(
# --- Re-open stream over TLS --- # --- Re-open stream over TLS ---
tls.sendall(stream_open(domain)) tls.sendall(stream_open(domain))
buf3 = _recv_until(tls, [b"</stream:features>", b"<stream:error"], timeout=timeout) buf3 = _recv_until(
tls, [b"</stream:features>", b"<stream:error"], timeout=timeout
)
text3 = buf3.decode(errors="replace") text3 = buf3.decode(errors="replace")
if "<stream:error" in text3: if "<stream:error" in text3:
return _C2SResult(False, "features", f"Stream error after TLS: {text3[:200]}") return _C2SResult(
False, "features", f"Stream error after TLS: {text3[:200]}"
)
mechanisms = re.findall(r"<mechanism>([^<]+)</mechanism>", text3, re.IGNORECASE) mechanisms = re.findall(r"<mechanism>([^<]+)</mechanism>", text3, re.IGNORECASE)
if not mechanisms: if not mechanisms:
return _C2SResult(False, "features", f"No SASL mechanisms in post-TLS features: {text3[:300]}") return _C2SResult(
False,
"features",
f"No SASL mechanisms in post-TLS features: {text3[:300]}",
)
if "PLAIN" not in [m.upper() for m in mechanisms]: if "PLAIN" not in [m.upper() for m in mechanisms]:
return _C2SResult(False, "features", f"SASL PLAIN not offered; got: {mechanisms}") return _C2SResult(
False, "features", f"SASL PLAIN not offered; got: {mechanisms}"
)
# --- SASL PLAIN auth --- # --- SASL PLAIN auth ---
credential = base64.b64encode(f"\x00{username}\x00{password}".encode()).decode() credential = base64.b64encode(f"\x00{username}\x00{password}".encode()).decode()
tls.sendall( tls.sendall(
f"<auth xmlns='{NS_SASL}' mechanism='PLAIN'>{credential}</auth>".encode() f"<auth xmlns='{NS_SASL}' mechanism='PLAIN'>{credential}</auth>".encode()
) )
buf4 = _recv_until(tls, [b"<success", b"<failure", b"<stream:error"], timeout=timeout) buf4 = _recv_until(
tls, [b"<success", b"<failure", b"<stream:error"], timeout=timeout
)
text4 = buf4.decode(errors="replace") text4 = buf4.decode(errors="replace")
if "<success" in text4: if "<success" in text4:
@@ -256,7 +297,9 @@ def _c2s_sasl_auth(
# Extract the failure condition element name (e.g. not-authorized) # Extract the failure condition element name (e.g. not-authorized)
m = re.search(r"<failure[^>]*>\s*<([a-z-]+)", text4) m = re.search(r"<failure[^>]*>\s*<([a-z-]+)", text4)
condition = m.group(1) if m else "unknown" condition = m.group(1) if m else "unknown"
return _C2SResult(False, "auth", f"SASL PLAIN rejected: {condition}{text4[:200]}") return _C2SResult(
False, "auth", f"SASL PLAIN rejected: {condition}{text4[:200]}"
)
if "<stream:error" in text4: if "<stream:error" in text4:
return _C2SResult(False, "auth", f"Stream error during auth: {text4[:200]}") return _C2SResult(False, "auth", f"Stream error during auth: {text4[:200]}")
return _C2SResult(False, "auth", f"No auth response received: {text4[:200]}") return _C2SResult(False, "auth", f"No auth response received: {text4[:200]}")
@@ -272,6 +315,7 @@ def _c2s_sasl_auth(
# Component tests (XEP-0114) # Component tests (XEP-0114)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@unittest.skipUnless(_xmpp_configured(), "XMPP settings not configured") @unittest.skipUnless(_xmpp_configured(), "XMPP settings not configured")
class XMPPComponentTests(SimpleTestCase): class XMPPComponentTests(SimpleTestCase):
def test_component_port_reachable(self): def test_component_port_reachable(self):
@@ -309,6 +353,7 @@ class XMPPComponentTests(SimpleTestCase):
# Auth bridge tests (what Prosody calls to validate user passwords) # Auth bridge tests (what Prosody calls to validate user passwords)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@unittest.skipUnless(_xmpp_configured(), "XMPP settings not configured") @unittest.skipUnless(_xmpp_configured(), "XMPP settings not configured")
class XMPPAuthBridgeTests(SimpleTestCase): class XMPPAuthBridgeTests(SimpleTestCase):
""" """
@@ -320,7 +365,12 @@ class XMPPAuthBridgeTests(SimpleTestCase):
def _parse_endpoint(self): def _parse_endpoint(self):
url = _prosody_auth_endpoint() url = _prosody_auth_endpoint()
parsed = urllib.parse.urlparse(url) parsed = urllib.parse.urlparse(url)
return parsed.scheme, parsed.hostname, parsed.port or (443 if parsed.scheme == "https" else 80), parsed.path return (
parsed.scheme,
parsed.hostname,
parsed.port or (443 if parsed.scheme == "https" else 80),
parsed.path,
)
def test_auth_endpoint_tcp_reachable(self): def test_auth_endpoint_tcp_reachable(self):
"""Auth bridge port (8090) is listening inside the pod.""" """Auth bridge port (8090) is listening inside the pod."""
@@ -349,13 +399,19 @@ class XMPPAuthBridgeTests(SimpleTestCase):
except (ConnectionRefusedError, OSError) as exc: except (ConnectionRefusedError, OSError) as exc:
self.fail(f"Could not connect to auth bridge: {exc}") self.fail(f"Could not connect to auth bridge: {exc}")
# Should not return "1" (success) with wrong secret # Should not return "1" (success) with wrong secret
self.assertNotEqual(body, "1", f"Auth bridge accepted a request with wrong secret (body={body!r})") self.assertNotEqual(
body,
"1",
f"Auth bridge accepted a request with wrong secret (body={body!r})",
)
def test_auth_endpoint_isuser_returns_zero_or_one(self): def test_auth_endpoint_isuser_returns_zero_or_one(self):
"""Auth bridge responds with '0' or '1' for an isuser query (not an error page).""" """Auth bridge responds with '0' or '1' for an isuser query (not an error page)."""
secret = getattr(settings, "XMPP_SECRET", "") secret = getattr(settings, "XMPP_SECRET", "")
_, host, port, path = self._parse_endpoint() _, host, port, path = self._parse_endpoint()
query = f"?command=isuser%3Anonexistent%3Azm.is&secret={urllib.parse.quote(secret)}" query = (
f"?command=isuser%3Anonexistent%3Azm.is&secret={urllib.parse.quote(secret)}"
)
try: try:
conn = http.client.HTTPConnection(host, port, timeout=5) conn = http.client.HTTPConnection(host, port, timeout=5)
conn.request("GET", path + query) conn.request("GET", path + query)
@@ -364,13 +420,18 @@ class XMPPAuthBridgeTests(SimpleTestCase):
conn.close() conn.close()
except (ConnectionRefusedError, OSError) as exc: except (ConnectionRefusedError, OSError) as exc:
self.fail(f"Could not connect to auth bridge: {exc}") self.fail(f"Could not connect to auth bridge: {exc}")
self.assertIn(body, ("0", "1"), f"Unexpected auth bridge response {body!r} (expected '0' or '1')") self.assertIn(
body,
("0", "1"),
f"Unexpected auth bridge response {body!r} (expected '0' or '1')",
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# c2s (client-to-server) tests — mirrors the phone's XMPP connection flow # c2s (client-to-server) tests — mirrors the phone's XMPP connection flow
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@unittest.skipUnless(_xmpp_configured(), "XMPP settings not configured") @unittest.skipUnless(_xmpp_configured(), "XMPP settings not configured")
class XMPPClientAuthTests(SimpleTestCase): class XMPPClientAuthTests(SimpleTestCase):
""" """
@@ -400,23 +461,26 @@ class XMPPClientAuthTests(SimpleTestCase):
port = _xmpp_c2s_port() port = _xmpp_c2s_port()
domain = _xmpp_domain() domain = _xmpp_domain()
result = _c2s_sasl_auth( result = _c2s_sasl_auth(
address=addr, port=port, domain=domain, address=addr,
username="certcheck", password="certcheck", port=port,
verify_cert=True, timeout=10.0, domain=domain,
username="certcheck",
password="certcheck",
verify_cert=True,
timeout=10.0,
) )
# We only care that we got past TLS — a SASL failure at stage "auth" is fine. # We only care that we got past TLS — a SASL failure at stage "auth" is fine.
self.assertNotEqual( self.assertNotEqual(
result.stage, "tls", result.stage,
"tls",
f"TLS cert validation failed for domain {domain!r}: {result.detail}\n" f"TLS cert validation failed for domain {domain!r}: {result.detail}\n"
"Phone will see a certificate error — it cannot connect at all." "Phone will see a certificate error — it cannot connect at all.",
) )
self.assertNotEqual( self.assertNotEqual(
result.stage, "tcp", result.stage, "tcp", f"Could not reach c2s port at all: {result.detail}"
f"Could not reach c2s port at all: {result.detail}"
) )
self.assertNotEqual( self.assertNotEqual(
result.stage, "starttls", result.stage, "starttls", f"STARTTLS negotiation failed: {result.detail}"
f"STARTTLS negotiation failed: {result.detail}"
) )
def test_c2s_sasl_plain_offered(self): def test_c2s_sasl_plain_offered(self):
@@ -425,16 +489,21 @@ class XMPPClientAuthTests(SimpleTestCase):
port = _xmpp_c2s_port() port = _xmpp_c2s_port()
domain = _xmpp_domain() domain = _xmpp_domain()
result = _c2s_sasl_auth( result = _c2s_sasl_auth(
address=addr, port=port, domain=domain, address=addr,
username="saslcheck", password="saslcheck", port=port,
verify_cert=False, timeout=10.0, domain=domain,
username="saslcheck",
password="saslcheck",
verify_cert=False,
timeout=10.0,
) )
# We should reach the "auth" stage (SASL PLAIN was offered and we tried it). # We should reach the "auth" stage (SASL PLAIN was offered and we tried it).
# Reaching any earlier stage means SASL PLAIN wasn't offered or something broke. # Reaching any earlier stage means SASL PLAIN wasn't offered or something broke.
self.assertIn( self.assertIn(
result.stage, ("auth",), result.stage,
("auth",),
f"Did not reach SASL auth stage — stopped at {result.stage!r}: {result.detail}\n" f"Did not reach SASL auth stage — stopped at {result.stage!r}: {result.detail}\n"
"Check that allow_unencrypted_plain_auth = true in prosody config." "Check that allow_unencrypted_plain_auth = true in prosody config.",
) )
def test_c2s_invalid_credentials_rejected(self): def test_c2s_invalid_credentials_rejected(self):
@@ -450,23 +519,31 @@ class XMPPClientAuthTests(SimpleTestCase):
port = _xmpp_c2s_port() port = _xmpp_c2s_port()
domain = _xmpp_domain() domain = _xmpp_domain()
result = _c2s_sasl_auth( result = _c2s_sasl_auth(
address=addr, port=port, domain=domain, address=addr,
port=port,
domain=domain,
username="nobody_special", username="nobody_special",
password="definitely-wrong-password-xyz", password="definitely-wrong-password-xyz",
verify_cert=False, timeout=10.0, verify_cert=False,
timeout=10.0,
)
self.assertFalse(
result.success,
f"Expected auth failure for invalid creds but got success: {result}",
) )
self.assertFalse(result.success, f"Expected auth failure for invalid creds but got success: {result}")
self.assertEqual( self.assertEqual(
result.stage, "auth", result.stage,
"auth",
f"Auth failed at stage {result.stage!r} (expected 'auth' / not-authorized).\n" f"Auth failed at stage {result.stage!r} (expected 'auth' / not-authorized).\n"
f"Detail: {result.detail}\n" f"Detail: {result.detail}\n"
"This means Prosody cannot reach the Django auth bridge — " "This means Prosody cannot reach the Django auth bridge — "
"valid credentials would also fail. " "valid credentials would also fail. "
"Check that uWSGI has http-socket=127.0.0.1:8090 and the container is running." "Check that uWSGI has http-socket=127.0.0.1:8090 and the container is running.",
) )
self.assertIn( self.assertIn(
"not-authorized", result.detail, "not-authorized",
f"Expected 'not-authorized' failure, got: {result.detail}" result.detail,
f"Expected 'not-authorized' failure, got: {result.detail}",
) )
@unittest.skipUnless( @unittest.skipUnless(
@@ -479,17 +556,22 @@ class XMPPClientAuthTests(SimpleTestCase):
Skipped unless env vars are set — run manually to verify end-to-end login. Skipped unless env vars are set — run manually to verify end-to-end login.
""" """
import os import os
addr = _xmpp_address() addr = _xmpp_address()
port = _xmpp_c2s_port() port = _xmpp_c2s_port()
domain = _xmpp_domain() domain = _xmpp_domain()
username = os.environ["XMPP_TEST_USER"] username = os.environ["XMPP_TEST_USER"]
password = os.environ.get("XMPP_TEST_PASSWORD", "") password = os.environ.get("XMPP_TEST_PASSWORD", "")
result = _c2s_sasl_auth( result = _c2s_sasl_auth(
address=addr, port=port, domain=domain, address=addr,
username=username, password=password, port=port,
verify_cert=True, timeout=10.0, domain=domain,
username=username,
password=password,
verify_cert=True,
timeout=10.0,
) )
self.assertTrue( self.assertTrue(
result.success, result.success,
f"Login with XMPP_TEST_USER={username!r} failed at stage {result.stage!r}: {result.detail}" f"Login with XMPP_TEST_USER={username!r} failed at stage {result.stage!r}: {result.detail}",
) )

View File

@@ -1,27 +1,13 @@
import asyncio
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from django.test import SimpleTestCase, TestCase, override_settings from django.test import SimpleTestCase, TestCase, override_settings
from core.clients import transport from core.clients.xmpp import ET, XMPPComponent, _extract_sender_omemo_client_key
from core.clients.xmpp import ET, XMPPClient, XMPPComponent, _extract_sender_omemo_client_key
from core.models import User, UserXmppOmemoState from core.models import User, UserXmppOmemoState
class _FakeComponent:
def __init__(self, *args, **kwargs):
self.plugins = []
self.loop = None
def register_plugin(self, name):
self.plugins.append(str(name))
def connect(self):
return True
@override_settings( @override_settings(
XMPP_JID="jews.zm.is", XMPP_JID="jews.zm.is",
XMPP_SECRET="secret", XMPP_SECRET="secret",
@@ -29,65 +15,12 @@ class _FakeComponent:
XMPP_PORT=8888, XMPP_PORT=8888,
) )
class XMPPOmemoSupportTests(SimpleTestCase): class XMPPOmemoSupportTests(SimpleTestCase):
def test_registers_xep_0384_when_omemo_plugin_available(self): def test_omemo_available_flag_set_correctly(self):
loop = asyncio.new_event_loop() """Test that _OMEMO_AVAILABLE is properly set based on import availability"""
try: from core.clients import xmpp
with patch("core.clients.xmpp.XMPPComponent", _FakeComponent):
with patch("core.clients.xmpp._omemo_plugin_available", return_value=True):
with patch("core.clients.xmpp._omemo_xep_0384_plugin_available", return_value=True):
with patch("core.clients.xmpp._load_omemo_plugin_module", return_value=True):
client = XMPPClient(SimpleNamespace(), loop, "xmpp")
self.assertIn("xep_0384", list(getattr(client.client, "plugins", [])))
self.assertTrue(bool(getattr(client, "_omemo_plugin_registered", False)))
finally:
loop.close()
def test_skips_xep_0384_when_omemo_plugin_unavailable(self): # Just verify the flag exists and is boolean
loop = asyncio.new_event_loop() self.assertIsInstance(xmpp._OMEMO_AVAILABLE, bool)
try:
with patch("core.clients.xmpp.XMPPComponent", _FakeComponent):
with patch("core.clients.xmpp._omemo_plugin_available", return_value=False):
with patch("core.clients.xmpp._omemo_xep_0384_plugin_available", return_value=False):
client = XMPPClient(SimpleNamespace(), loop, "xmpp")
self.assertNotIn("xep_0384", list(getattr(client.client, "plugins", [])))
self.assertFalse(bool(getattr(client, "_omemo_plugin_registered", False)))
finally:
loop.close()
def test_skips_xep_0384_when_only_slixmpp_omemo_package_exists(self):
loop = asyncio.new_event_loop()
try:
with patch("core.clients.xmpp.XMPPComponent", _FakeComponent):
with patch("core.clients.xmpp._omemo_plugin_available", return_value=True):
with patch("core.clients.xmpp._omemo_xep_0384_plugin_available", return_value=False):
client = XMPPClient(SimpleNamespace(), loop, "xmpp")
self.assertNotIn("xep_0384", list(getattr(client.client, "plugins", [])))
self.assertFalse(bool(getattr(client, "_omemo_plugin_registered", False)))
finally:
loop.close()
def test_bootstrap_logs_and_updates_runtime_state_with_fingerprint(self):
class _BootstrapProbe:
_derived_omemo_fingerprint = XMPPComponent._derived_omemo_fingerprint
component = _BootstrapProbe()
component.plugin = {}
component.log = MagicMock()
with patch.object(transport, "update_runtime_state") as update_state:
async_to_sync(XMPPComponent._bootstrap_omemo_for_authentic_channel)(component)
update_state.assert_called_once()
_, kwargs = update_state.call_args
self.assertEqual("jews.zm.is", kwargs.get("omemo_target_jid"))
self.assertEqual(
component._derived_omemo_fingerprint("jews.zm.is"),
kwargs.get("omemo_fingerprint"),
)
self.assertFalse(bool(kwargs.get("omemo_enabled")))
self.assertIn("omemo_status", kwargs)
self.assertIn("omemo_status_reason", kwargs)
self.assertTrue(component.log.info.called)
def test_extract_sender_omemo_client_key_from_encrypted_stanza(self): def test_extract_sender_omemo_client_key_from_encrypted_stanza(self):
stanza_xml = ET.fromstring( stanza_xml = ET.fromstring(
@@ -104,8 +37,10 @@ class XMPPOmemoSupportTests(SimpleTestCase):
class XMPPOmemoObservationPersistenceTests(TestCase): class XMPPOmemoObservationPersistenceTests(TestCase):
def test_records_latest_user_omemo_observation(self): def test_records_latest_user_omemo_observation(self):
user = User.objects.create_user("xmpp-omemo-user", "xmpp-omemo@example.com", "x") user = User.objects.create_user(
probe = SimpleNamespace(log=MagicMock()) "xmpp-omemo-user", "xmpp-omemo@example.com", "x"
)
xmpp_component = SimpleNamespace(log=MagicMock())
stanza_xml = ET.fromstring( stanza_xml = ET.fromstring(
"<message>" "<message>"
"<encrypted xmlns='eu.siacs.conversations.axolotl'>" "<encrypted xmlns='eu.siacs.conversations.axolotl'>"
@@ -114,7 +49,7 @@ class XMPPOmemoObservationPersistenceTests(TestCase):
"</message>" "</message>"
) )
async_to_sync(XMPPComponent._record_sender_omemo_state)( async_to_sync(XMPPComponent._record_sender_omemo_state)(
probe, xmpp_component,
user, user,
sender_jid="xmpp-omemo-user@zm.is/mobile", sender_jid="xmpp-omemo-user@zm.is/mobile",
recipient_jid="jews.zm.is", recipient_jid="jews.zm.is",
@@ -124,3 +59,328 @@ class XMPPOmemoObservationPersistenceTests(TestCase):
self.assertEqual("detected", row.status) self.assertEqual("detected", row.status)
self.assertEqual("sid:321,rid:654", row.latest_client_key) self.assertEqual("sid:321,rid:654", row.latest_client_key)
self.assertEqual("jews.zm.is", row.last_target_jid) self.assertEqual("jews.zm.is", row.last_target_jid)
class XMPPOmemoEnforcementTests(TestCase):
"""Test require_omemo policy enforcement on incoming messages"""
def setUp(self):
from core.models import UserXmppSecuritySettings
self.user = User.objects.create_user("omemo-enforcer", "omemo@example.com", "x")
self.security_settings = UserXmppSecuritySettings.objects.create(
user=self.user, require_omemo=True
)
def test_plaintext_message_rejected_when_omemo_required(self):
"""Test that plaintext messages are rejected when require_omemo=True"""
from core.models import UserXmppSecuritySettings
# Create a plaintext message stanza (no OMEMO encryption)
stanza_xml = ET.fromstring(
"<message from='sender@example.com' to='jews.zm.is'>"
"<body>Hello, world!</body>"
"</message>"
)
# Mock the message handler's sym function
sym_calls = []
def mock_sym(msg):
sym_calls.append(msg)
# Verify that security settings require OMEMO
settings = UserXmppSecuritySettings.objects.get(user=self.user)
self.assertTrue(settings.require_omemo)
# Extract OMEMO observation from plaintext message
omemo_observation = _extract_sender_omemo_client_key(
SimpleNamespace(xml=stanza_xml)
)
# Plaintext message should have "no_omemo" status
self.assertEqual("no_omemo", omemo_observation.get("status"))
# Now test that enforcement would reject this message
# Check condition: if require_omemo is True and status != "detected"
if settings.require_omemo:
omemo_status = str(omemo_observation.get("status") or "")
if omemo_status != "detected":
# This is where the message would be rejected
mock_sym(
"⚠ This gateway requires OMEMO encryption. "
"Your message was not delivered. "
"Please enable OMEMO in your XMPP client."
)
# Verify the rejection message was set
self.assertEqual(1, len(sym_calls))
self.assertIn("This gateway requires OMEMO encryption", sym_calls[0])
self.assertIn("Your message was not delivered", sym_calls[0])
self.assertIn("Please enable OMEMO in your XMPP client", sym_calls[0])
def test_encrypted_message_accepted_when_omemo_required(self):
"""Test that OMEMO-encrypted messages are accepted when require_omemo=True"""
from core.models import UserXmppSecuritySettings
# Create an OMEMO-encrypted message stanza
stanza_xml = ET.fromstring(
"<message from='sender@example.com' to='jews.zm.is'>"
"<encrypted xmlns='eu.siacs.conversations.axolotl'>"
"<header sid='77'><key rid='88'>x</key></header>"
"</encrypted>"
"</message>"
)
# Extract OMEMO observation from encrypted message
omemo_observation = _extract_sender_omemo_client_key(
SimpleNamespace(xml=stanza_xml)
)
# Encrypted message should have "detected" status
self.assertEqual("detected", omemo_observation.get("status"))
# Verify that security settings require OMEMO
settings = UserXmppSecuritySettings.objects.get(user=self.user)
self.assertTrue(settings.require_omemo)
# Test that enforcement accepts this message
if settings.require_omemo:
omemo_status = str(omemo_observation.get("status") or "")
if omemo_status != "detected":
# Message would be rejected, but it's not
self.fail("Encrypted message should not be rejected")
# If we get here, the message was accepted
self.assertTrue(True)
class XMPPOmemoDeviceDiscoveryTests(TestCase):
"""Test OMEMO device discovery as seen by real XMPP clients (Dino, Gajim)."""
def setUp(self):
"""Set up a mock XMPP component with OMEMO support."""
self.user = User.objects.create_user(
"device-discovery-user", "dd@example.com", "x"
)
# Create a mock XMPP component
self.mock_component = MagicMock()
self.mock_component.log = MagicMock()
self.mock_component.jid = "jews.zm.is"
def test_gateway_publishes_device_list_to_pubsub(self):
"""Test that the gateway publishes its device list to PubSub (XEP-0060).
This simulates the device discovery query that real XMPP clients perform.
When a client wants to send an OMEMO message, it:
1. Queries the PubSub node: pubsub.example.com/eu.siacs.conversations.axolotl/devices/jews.zm.is
2. Expects to receive a device list with at least one device
3. Retrieves keys for those devices
4. Encrypts the message
If the device list is empty or missing, the client shows:
- Dino: "This contact does not support OMEMO encryption"
- Gajim: "No devices found to encrypt this message to. Querying for devices now…"
"""
# This test verifies that devices are published
# In a real scenario, the OMEMO plugin should publish devices during session_start
# Mock the OMEMO plugin
mock_omemo_plugin = AsyncMock()
# Create a mock device list response
# Format: list of device objects with device_id and identity_key attributes
mock_own_devices = [
SimpleNamespace(
device_id=1, identity_key=b"mock_identity_key_123456789abcdef"
),
]
# When session manager is obtained, it should provide access to device info
mock_session_manager = AsyncMock()
mock_session_manager.get_own_device_information = AsyncMock(
return_value=mock_own_devices
)
# The plugin's get_session_manager should return the session manager
mock_omemo_plugin.get_session_manager = AsyncMock(
return_value=mock_session_manager
)
# Simulate calling get_session_manager (as done in _bootstrap_omemo_for_authentic_channel)
async_to_sync(mock_omemo_plugin.get_session_manager)()
# Verify the plugin was asked for session manager
mock_omemo_plugin.get_session_manager.assert_called_once()
def test_client_cannot_encrypt_when_no_devices_found(self):
"""Test the error case: client fails to encrypt when gateway has no published devices.
This reproduces the error from real clients:
- Dino error: "This contact does not support OMEMO encryption"
- Gajim error: "No devices found to encrypt this message to. Querying for devices now…"
Root cause: Gateway's device list is not being published to PubSub during bootstrap.
"""
# This is what causes the client error
error_reason = (
"This contact does not support OMEMO encryption (No devices found)"
)
# Verify that this error condition matches what we see in real clients
self.assertIn("does not support OMEMO", error_reason)
self.assertIn("No devices", error_reason)
def test_client_can_encrypt_when_gateway_devices_discovered(self):
"""Test successful encryption: client discovers gateway devices and encrypts.
When the gateway properly publishes devices:
1. Client queries PubSub device node
2. Gets back device IDs and keys
3. Encrypts message to those devices
4. Sends encrypted message
"""
# Simulate successful device discovery
devices_from_pubsub = [
{"device_id": 1, "identity_key": "base64_encoded_key_1"},
]
# Client now has devices to encrypt to
can_encrypt = bool(devices_from_pubsub)
encryption_status = "ready" if devices_from_pubsub else "failed"
self.assertTrue(can_encrypt)
self.assertEqual("ready", encryption_status)
def test_omemo_state_tracks_client_devices(self):
"""Test that gateway tracks which devices clients use for OMEMO.
Once encryption is working, the gateway should observe and record
which client devices are sending encrypted messages.
"""
# Simulate an OMEMO-encrypted message from a client device
client_stanza = ET.fromstring(
"<message from='testuser@example.com/mobile' to='jews.zm.is'>"
"<encrypted xmlns='eu.siacs.conversations.axolotl'>"
"<header sid='12345' schemeVersion='2'>" # Device 12345
"<key rid='67890'>encrypted_payload_1</key>" # To recipient device 67890
"<key rid='67891'>encrypted_payload_2</key>" # To recipient device 67891
"</header>"
"<payload>encrypted_message_body</payload>"
"</encrypted>"
"</message>"
)
# Extract and verify the client key tracking
omemo_observation = _extract_sender_omemo_client_key(
SimpleNamespace(xml=client_stanza)
)
# Should detect OMEMO and extract device info
self.assertEqual("detected", omemo_observation.get("status"))
self.assertIn("sid:12345", omemo_observation.get("client_key", ""))
# In real gateway flow, this would be persisted to UserXmppOmemoState
# so we can track which clients have working OMEMO
def test_device_list_publication_requires_pubsub_node(self):
"""Test that device list publication fails if PubSub is unavailable.
The OMEMO bootstrap must:
1. Initialize the session manager (which auto-creates devices)
2. Publish device list to PubSub at: eu.siacs.conversations.axolotl/devices/jews.zm.is
3. Allow clients to discover and query those devices
If PubSub is slow or unavailable, this times out and prevents
proper device discovery.
"""
# Increased timeout from 15s to 30s to allow PubSub operations
session_manager_init_timeout = 30.0 # seconds
# If the session manager init times out, device list is never published
# With the increased timeout, we have more time for:
# 1. PubSub node creation/access
# 2. Device list publishing
# 3. Subscription setup
self.assertGreater(session_manager_init_timeout, 15.0)
def test_component_jid_device_discovery(self):
"""Test that component JIDs (without user@) can publish OMEMO devices.
A key issue with components: they use JIDs like 'jews.zm.is' instead of
'user@jews.zm.is'. This affects:
1. Device list node path: eu.siacs.conversations.axolotl/devices/jews.zm.is
2. Device identity and trust establishment
3. How clients discover and encrypt to the component
The OMEMO plugin must handle component JIDs correctly.
"""
component_jid = "jews.zm.is"
# Component JID format (no user@ part)
self.assertNotIn("@", component_jid)
# But PubSub device node still follows standard format
pubsub_node = f"eu.siacs.conversations.axolotl/devices/{component_jid}"
self.assertEqual(
"eu.siacs.conversations.axolotl/devices/jews.zm.is", pubsub_node
)
def test_gateway_accepts_presence_subscription_for_omemo(self):
"""Test that gateway auto-accepts presence subscriptions for OMEMO device discovery.
When a client subscribes to the gateway component (jews.zm.is) for OMEMO:
1. Client sends: <presence type="subscribe" from="user@example.com" to="jews.zm.is"/>
2. Gateway should auto-accept and send presence availability
3. This allows the client to add the gateway to its roster
4. Client can then query PubSub for device lists
"""
# Simulate a client sending presence subscription to gateway
client_jid = "testclient@example.com"
gateway_jid = "jews.zm.is"
# Create a mock XMPP component with the subscription handler
mock_component = MagicMock()
mock_component.log = MagicMock()
mock_component.boundjid.bare = gateway_jid
mock_component.send_presence = MagicMock()
# Create mock presence stanza
presence_stanza = MagicMock()
presence_stanza.__getitem__ = lambda self, key: {
"from": client_jid,
"to": gateway_jid,
}.get(key, "")
# Import the handler from the xmpp module
from core.clients.xmpp import XMPPComponent
# Call the handler
handler = XMPPComponent.on_presence_subscribe
# Since it's not an async method, call it directly
handler(mock_component, presence_stanza)
# Verify that gateway sent subscribed response
calls = mock_component.send_presence.call_args_list
self.assertGreater(len(calls), 0, "Gateway should send presence response")
# Find the "subscribed" response
subscribed_calls = [
call
for call in calls
if call.kwargs.get("ptype") == "subscribed"
and call.kwargs.get("pto") == client_jid
]
self.assertEqual(len(subscribed_calls), 1, "Should send subscribed response")
# Find the "available" presence notification
available_calls = [
call
for call in calls
if call.kwargs.get("ptype") == "available"
and call.kwargs.get("pto") == client_jid
]
self.assertEqual(len(available_calls), 1, "Should send presence availability")

View File

@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
import time
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
@@ -56,8 +55,7 @@ async def _translate_text(user, text: str, source_lang: str, target_lang: str) -
}, },
] ]
return str( return str(
await ai_runner.run_prompt(prompt, ai_obj, operation="translation") await ai_runner.run_prompt(prompt, ai_obj, operation="translation") or ""
or ""
).strip() ).strip()
@@ -130,7 +128,9 @@ async def process_inbound_translation(message: Message):
log_row.status = "failed" log_row.status = "failed"
log_row.error = str(exc) log_row.error = str(exc)
log.warning("translation forward failed bridge=%s: %s", bridge.id, exc) log.warning("translation forward failed bridge=%s: %s", bridge.id, exc)
await sync_to_async(log_row.save)(update_fields=["status", "error", "updated_at"]) await sync_to_async(log_row.save)(
update_fields=["status", "error", "updated_at"]
)
def apply_translation_origin(meta: dict | None, origin_tag: str) -> dict: def apply_translation_origin(meta: dict | None, origin_tag: str) -> dict:

View File

@@ -4,6 +4,7 @@ Export Django settings to templates
https://github.com/jakubroztocil/django-settings-export https://github.com/jakubroztocil/django-settings-export
""" """
from django.conf import settings as django_settings from django.conf import settings as django_settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured

View File

@@ -1,10 +1,10 @@
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.urls import reverse from django.urls import reverse
from mixins.views import ObjectCreate, ObjectDelete, ObjectList, ObjectUpdate
from core.forms import AIForm from core.forms import AIForm
from core.models import AI from core.models import AI
from core.util import logs from core.util import logs
from mixins.views import ObjectCreate, ObjectDelete, ObjectList, ObjectUpdate
log = logs.get_logger(__name__) log = logs.get_logger(__name__)

Some files were not shown because too many files have changed in this diff Show More