Files
GIA/core/messaging/reply_sync.py

463 lines
17 KiB
Python

from __future__ import annotations
import re
from typing import Any
from asgiref.sync import sync_to_async
from core.messaging import history
from core.models import Message
def _as_dict(value: Any) -> dict[str, Any]:
return dict(value) if isinstance(value, dict) else {}
def _pluck(data: Any, *path: str):
cur = data
for key in path:
if isinstance(cur, dict):
cur = cur.get(key)
continue
if hasattr(cur, key):
cur = getattr(cur, key)
continue
return None
return cur
def _clean(value: Any) -> str:
return str(value or "").strip()
def _find_origin_tag(value: Any, depth: int = 0) -> str:
if depth > 4:
return ""
if isinstance(value, dict):
direct = _clean(value.get("origin_tag"))
if direct:
return direct
for key in ("metadata", "meta", "message_meta", "contextInfo", "context_info"):
nested = _find_origin_tag(value.get(key), depth + 1)
if nested:
return nested
for nested_value in value.values():
nested = _find_origin_tag(nested_value, depth + 1)
if nested:
return nested
return ""
if isinstance(value, list):
for item in value:
nested = _find_origin_tag(item, depth + 1)
if nested:
return nested
return ""
def _extract_signal_reply(raw_payload: dict[str, Any]) -> dict[str, str]:
envelope = _as_dict((raw_payload or {}).get("envelope"))
sync_message = _as_dict(envelope.get("syncMessage"))
sent_message = _as_dict(sync_message.get("sentMessage"))
data_candidates = [
_as_dict(envelope.get("dataMessage")),
_as_dict(sent_message.get("message")),
_as_dict(sent_message),
_as_dict((raw_payload or {}).get("dataMessage")),
_as_dict(raw_payload),
]
quote_key_candidates = (
"id",
"targetSentTimestamp",
"targetTimestamp",
"quotedMessageId",
"quoted_message_id",
"quotedMessageID",
"messageId",
"message_id",
"timestamp",
)
quote_author_candidates = (
"author",
"authorUuid",
"authorAci",
"authorNumber",
"source",
"sourceNumber",
"sourceUuid",
)
quote_candidates: list[dict[str, Any]] = []
for data_message in data_candidates:
if not data_message:
continue
direct_quote = _as_dict(data_message.get("quote") or data_message.get("Quote"))
if direct_quote:
quote_candidates.append(direct_quote)
stack = [data_message]
while stack:
current = stack.pop()
if not isinstance(current, dict):
continue
for key, value in current.items():
if isinstance(value, dict):
key_text = str(key or "").strip().lower()
if "quote" in key_text or "reply" in key_text:
quote_candidates.append(_as_dict(value))
stack.append(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
stack.append(item)
for quote in quote_candidates:
quote_id = ""
for key in quote_key_candidates:
quote_id = _clean(quote.get(key))
if quote_id:
break
if not quote_id:
nested = _as_dict(quote.get("id"))
if nested:
for key in quote_key_candidates:
quote_id = _clean(nested.get(key))
if quote_id:
break
if quote_id:
reply_chat_id = ""
for key in quote_author_candidates:
reply_chat_id = _clean(quote.get(key))
if reply_chat_id:
break
return {
"reply_source_message_id": quote_id,
"reply_source_service": "signal",
"reply_source_chat_id": reply_chat_id,
}
return {}
def _extract_whatsapp_reply(raw_payload: dict[str, Any]) -> dict[str, str]:
# Handles common and nested contextInfo/messageContextInfo shapes for
# WhatsApp payloads (extended text, media, ephemeral, view-once wrappers).
candidate_paths = (
("contextInfo",),
("ContextInfo",),
("messageContextInfo",),
("MessageContextInfo",),
("extendedTextMessage", "contextInfo"),
("ExtendedTextMessage", "ContextInfo"),
("imageMessage", "contextInfo"),
("ImageMessage", "ContextInfo"),
("videoMessage", "contextInfo"),
("VideoMessage", "ContextInfo"),
("documentMessage", "contextInfo"),
("DocumentMessage", "ContextInfo"),
("ephemeralMessage", "message", "contextInfo"),
("ephemeralMessage", "message", "extendedTextMessage", "contextInfo"),
("viewOnceMessage", "message", "contextInfo"),
("viewOnceMessage", "message", "extendedTextMessage", "contextInfo"),
("viewOnceMessageV2", "message", "contextInfo"),
("viewOnceMessageV2", "message", "extendedTextMessage", "contextInfo"),
("viewOnceMessageV2Extension", "message", "contextInfo"),
("viewOnceMessageV2Extension", "message", "extendedTextMessage", "contextInfo"),
# snake_case protobuf dict variants
("context_info",),
("message_context_info",),
("extended_text_message", "context_info"),
("image_message", "context_info"),
("video_message", "context_info"),
("document_message", "context_info"),
("ephemeral_message", "message", "context_info"),
("ephemeral_message", "message", "extended_text_message", "context_info"),
("view_once_message", "message", "context_info"),
("view_once_message", "message", "extended_text_message", "context_info"),
("view_once_message_v2", "message", "context_info"),
("view_once_message_v2", "message", "extended_text_message", "context_info"),
("view_once_message_v2_extension", "message", "context_info"),
(
"view_once_message_v2_extension",
"message",
"extended_text_message",
"context_info",
),
)
contexts = []
for path in candidate_paths:
row = _as_dict(_pluck(raw_payload, *path))
if row:
contexts.append(row)
# Recursive fallback for unknown wrapper shapes.
stack = [_as_dict(raw_payload)]
while stack:
current = stack.pop()
if not isinstance(current, dict):
continue
if isinstance(current.get("contextInfo"), dict):
contexts.append(_as_dict(current.get("contextInfo")))
if isinstance(current.get("ContextInfo"), dict):
contexts.append(_as_dict(current.get("ContextInfo")))
if isinstance(current.get("messageContextInfo"), dict):
contexts.append(_as_dict(current.get("messageContextInfo")))
if isinstance(current.get("MessageContextInfo"), dict):
contexts.append(_as_dict(current.get("MessageContextInfo")))
if isinstance(current.get("context_info"), dict):
contexts.append(_as_dict(current.get("context_info")))
if isinstance(current.get("message_context_info"), dict):
contexts.append(_as_dict(current.get("message_context_info")))
for value in current.values():
if isinstance(value, dict):
stack.append(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
stack.append(item)
for context in contexts:
stanza_id = _clean(
context.get("stanzaId")
or context.get("stanzaID")
or context.get("stanza_id")
or context.get("StanzaId")
or context.get("StanzaID")
or context.get("quotedMessageID")
or context.get("quotedMessageId")
or context.get("QuotedMessageID")
or context.get("QuotedMessageId")
or _pluck(context, "quotedMessageKey", "id")
or _pluck(context, "quoted_message_key", "id")
or _pluck(context, "quotedMessage", "key", "id")
or _pluck(context, "quoted_message", "key", "id")
)
if not stanza_id:
continue
participant = _clean(
context.get("participant")
or context.get("remoteJid")
or context.get("chat")
or context.get("Participant")
or context.get("RemoteJid")
or context.get("RemoteJID")
or context.get("Chat")
)
return {
"reply_source_message_id": stanza_id,
"reply_source_service": "whatsapp",
"reply_source_chat_id": participant,
}
return {}
def extract_whatsapp_reply_debug(raw_payload: dict[str, Any]) -> dict[str, Any]:
payload = _as_dict(raw_payload)
candidate_paths = (
("contextInfo",),
("ContextInfo",),
("messageContextInfo",),
("MessageContextInfo",),
("extendedTextMessage", "contextInfo"),
("ExtendedTextMessage", "ContextInfo"),
("imageMessage", "contextInfo"),
("ImageMessage", "ContextInfo"),
("videoMessage", "contextInfo"),
("VideoMessage", "ContextInfo"),
("documentMessage", "contextInfo"),
("DocumentMessage", "ContextInfo"),
("ephemeralMessage", "message", "contextInfo"),
("ephemeralMessage", "message", "extendedTextMessage", "contextInfo"),
("viewOnceMessage", "message", "contextInfo"),
("viewOnceMessage", "message", "extendedTextMessage", "contextInfo"),
("viewOnceMessageV2", "message", "contextInfo"),
("viewOnceMessageV2", "message", "extendedTextMessage", "contextInfo"),
("viewOnceMessageV2Extension", "message", "contextInfo"),
("viewOnceMessageV2Extension", "message", "extendedTextMessage", "contextInfo"),
("context_info",),
("message_context_info",),
("extended_text_message", "context_info"),
("image_message", "context_info"),
("video_message", "context_info"),
("document_message", "context_info"),
("ephemeral_message", "message", "context_info"),
("ephemeral_message", "message", "extended_text_message", "context_info"),
("view_once_message", "message", "context_info"),
("view_once_message", "message", "extended_text_message", "context_info"),
("view_once_message_v2", "message", "context_info"),
("view_once_message_v2", "message", "extended_text_message", "context_info"),
("view_once_message_v2_extension", "message", "context_info"),
(
"view_once_message_v2_extension",
"message",
"extended_text_message",
"context_info",
),
)
rows = []
for path in candidate_paths:
context = _as_dict(_pluck(payload, *path))
if not context:
continue
rows.append(
{
"path": ".".join(path),
"keys": sorted([str(key) for key in context.keys()])[:40],
"stanzaId": _clean(
context.get("stanzaId")
or context.get("stanzaID")
or context.get("stanza_id")
or context.get("StanzaId")
or context.get("StanzaID")
or context.get("quotedMessageID")
or context.get("quotedMessageId")
or context.get("QuotedMessageID")
or context.get("QuotedMessageId")
or _pluck(context, "quotedMessageKey", "id")
or _pluck(context, "quoted_message_key", "id")
or _pluck(context, "quotedMessage", "key", "id")
or _pluck(context, "quoted_message", "key", "id")
),
"participant": _clean(
context.get("participant")
or context.get("remoteJid")
or context.get("chat")
or context.get("Participant")
or context.get("RemoteJid")
or context.get("RemoteJID")
or context.get("Chat")
),
}
)
return {
"candidate_count": len(rows),
"candidates": rows[:20],
}
def extract_reply_ref(service: str, raw_payload: dict[str, Any]) -> dict[str, str]:
svc = _clean(service).lower()
payload = _as_dict(raw_payload)
if svc == "xmpp":
reply_id = _clean(
payload.get("reply_source_message_id") or payload.get("reply_id")
)
reply_chat = _clean(
payload.get("reply_source_chat_id") or payload.get("reply_chat_id")
)
if reply_id:
return {
"reply_source_message_id": reply_id,
"reply_source_service": "xmpp",
"reply_source_chat_id": reply_chat,
}
return {}
if svc == "signal":
return _extract_signal_reply(payload)
if svc == "whatsapp":
return _extract_whatsapp_reply(payload)
if svc == "web":
reply_id = _clean(payload.get("reply_to_message_id"))
if reply_id:
return {
"reply_source_message_id": reply_id,
"reply_source_service": "web",
"reply_source_chat_id": _clean(payload.get("reply_source_chat_id")),
}
return {}
def extract_origin_tag(raw_payload: dict[str, Any] | None) -> str:
return _find_origin_tag(_as_dict(raw_payload))
async def resolve_reply_target(
user, session, reply_ref: dict[str, str]
) -> Message | None:
if not reply_ref or session is None:
return None
reply_source_message_id = _clean(reply_ref.get("reply_source_message_id"))
if not reply_source_message_id:
return None
# Direct local UUID fallback (web compose references local Message IDs).
if re.fullmatch(
r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}",
reply_source_message_id,
):
direct = await sync_to_async(
lambda: Message.objects.filter(
user=user,
session=session,
id=reply_source_message_id,
).first()
)()
if direct is not None:
return direct
source_service = _clean(reply_ref.get("reply_source_service"))
by_source = await sync_to_async(
lambda: Message.objects.filter(
user=user,
session=session,
source_service=source_service or None,
source_message_id=reply_source_message_id,
)
.order_by("-ts")
.first()
)()
if by_source is not None:
return by_source
# Bridge ref fallback: resolve replies against bridge mappings persisted in
# message receipt payloads.
identifier = getattr(session, "identifier", None)
if identifier is not None:
service_candidates = []
if source_service:
service_candidates.append(source_service)
# XMPP replies can target bridged messages from any external service.
if source_service == "xmpp":
service_candidates.extend(["signal", "whatsapp", "instagram"])
for candidate in service_candidates:
bridge = await history.resolve_bridge_ref(
user=user,
identifier=identifier,
source_service=candidate,
xmpp_message_id=reply_source_message_id,
upstream_message_id=reply_source_message_id,
)
local_message_id = _clean((bridge or {}).get("local_message_id"))
if not local_message_id:
continue
bridged = await sync_to_async(
lambda: Message.objects.filter(
user=user,
session=session,
id=local_message_id,
).first()
)()
if bridged is not None:
return bridged
fallback = await sync_to_async(
lambda: Message.objects.filter(
user=user,
session=session,
reply_source_message_id=reply_source_message_id,
)
.order_by("-ts")
.first()
)()
return fallback
def apply_sync_origin(message_meta: dict | None, origin_tag: str) -> dict:
payload = dict(message_meta or {})
tag = _clean(origin_tag)
if not tag:
return payload
payload["origin_tag"] = tag
return payload
def is_mirrored_origin(message_meta: dict | None) -> bool:
payload = dict(message_meta or {})
return bool(_clean(payload.get("origin_tag")))