from __future__ import annotations from dataclasses import dataclass from core.models import ChatSession, ConversationEvent, Message @dataclass class _ProjectedMessage: message_id: str ts: int = 0 text: str = "" delivered_ts: int | None = None read_ts: int | None = None reactions: dict[tuple[str, str, str], dict] | None = None def __post_init__(self): if self.reactions is None: self.reactions = {} def _safe_int(value, default=0) -> int: try: return int(value) except Exception: return int(default) def _reaction_key(row: dict) -> tuple[str, str, str]: item = dict(row or {}) return ( str(item.get("source_service") or "").strip().lower(), str(item.get("actor") or "").strip(), str(item.get("emoji") or "").strip(), ) def _normalize_reactions(rows: list[dict] | None) -> list[dict]: merged = {} for row in list(rows or []): item = dict(row or {}) key = _reaction_key(item) if not any(key): continue merged[key] = { "source_service": key[0], "actor": key[1], "emoji": key[2], "removed": bool(item.get("removed")), } return sorted( merged.values(), key=lambda entry: ( str(entry.get("source_service") or ""), str(entry.get("actor") or ""), str(entry.get("emoji") or ""), bool(entry.get("removed")), ), ) def project_session_from_events(session: ChatSession) -> list[dict]: rows = list( ConversationEvent.objects.filter( user=session.user, session=session, ).order_by("ts", "created_at") ) projected: dict[str, _ProjectedMessage] = {} order: list[str] = [] for event in rows: payload = dict(event.payload or {}) event_type = str(event.event_type or "").strip().lower() message_id = str( payload.get("message_id") or payload.get("target_message_id") or "" ).strip() if event_type == "message_created": message_id = str( payload.get("message_id") or event.origin_message_id or "" ).strip() if not message_id: continue state = projected.get(message_id) if state is None: state = _ProjectedMessage(message_id=message_id) projected[message_id] = state order.append(message_id) state.ts = _safe_int(payload.get("message_ts"), _safe_int(event.ts)) state.text = str(payload.get("text") or state.text or "") delivered_default = _safe_int( payload.get("delivered_ts"), _safe_int(event.ts) ) if state.delivered_ts is None: state.delivered_ts = delivered_default or None continue if not message_id or message_id not in projected: continue state = projected[message_id] if event_type == "read_receipt": read_ts = _safe_int(payload.get("read_ts"), _safe_int(event.ts)) if read_ts > 0: if state.read_ts is None: state.read_ts = read_ts else: state.read_ts = max(int(state.read_ts or 0), read_ts) if state.delivered_ts is None and read_ts > 0: state.delivered_ts = read_ts continue if event_type in {"reaction_added", "reaction_removed"}: source_service = ( str(payload.get("source_service") or event.origin_transport or "") .strip() .lower() ) actor = str(payload.get("actor") or event.actor_identifier or "").strip() emoji = str(payload.get("emoji") or "").strip() if not source_service and not actor and not emoji: continue key = (source_service, actor, emoji) state.reactions[key] = { "source_service": source_service, "actor": actor, "emoji": emoji, "removed": bool( event_type == "reaction_removed" or payload.get("remove") ), } output = [] for message_id in order: state = projected.get(message_id) if state is None: continue output.append( { "message_id": str(state.message_id), "ts": int(state.ts or 0), "text": str(state.text or ""), "delivered_ts": ( int(state.delivered_ts) if state.delivered_ts is not None else None ), "read_ts": int(state.read_ts) if state.read_ts is not None else None, "reactions": _normalize_reactions( list((state.reactions or {}).values()) ), } ) return output def shadow_compare_session(session: ChatSession, detail_limit: int = 50) -> dict: projected_rows = project_session_from_events(session) projected_by_id = {str(row.get("message_id") or ""): row for row in projected_rows} db_rows = list( Message.objects.filter(user=session.user, session=session) .order_by("ts", "id") .values( "id", "ts", "text", "delivered_ts", "read_ts", "receipt_payload", ) ) db_by_id = {str(row.get("id")): dict(row) for row in db_rows} counters = { "missing_in_projection": 0, "missing_in_db": 0, "text_mismatch": 0, "ts_mismatch": 0, "delivered_ts_mismatch": 0, "read_ts_mismatch": 0, "reactions_mismatch": 0, } details = [] cause_counts = { "missing_event_write": 0, "ambiguous_reaction_target": 0, "payload_normalization_gap": 0, } cause_samples = {key: [] for key in cause_counts.keys()} cause_sample_limit = min(5, max(0, int(detail_limit))) def _record_detail( message_id: str, issue: str, cause: str, extra: dict | None = None ): if cause in cause_counts: cause_counts[cause] += 1 row = {"message_id": message_id, "issue": issue, "cause": cause} if extra: row.update(dict(extra)) if len(details) < max(0, int(detail_limit)): details.append(row) if cause in cause_samples and len(cause_samples[cause]) < cause_sample_limit: cause_samples[cause].append(row) for message_id, db_row in db_by_id.items(): projected = projected_by_id.get(message_id) if projected is None: counters["missing_in_projection"] += 1 _record_detail(message_id, "missing_in_projection", "missing_event_write") continue db_text = str(db_row.get("text") or "") projected_text = str(projected.get("text") or "") if db_text != projected_text: counters["text_mismatch"] += 1 _record_detail( message_id, "text_mismatch", "payload_normalization_gap", {"db": db_text, "projected": projected_text}, ) db_ts = _safe_int(db_row.get("ts"), 0) projected_ts = _safe_int(projected.get("ts"), 0) if db_ts != projected_ts: counters["ts_mismatch"] += 1 _record_detail( message_id, "ts_mismatch", "payload_normalization_gap", {"db": db_ts, "projected": projected_ts}, ) db_delivered_ts = db_row.get("delivered_ts") projected_delivered_ts = projected.get("delivered_ts") if (db_delivered_ts is None) != (projected_delivered_ts is None) or ( db_delivered_ts is not None and projected_delivered_ts is not None and int(db_delivered_ts) != int(projected_delivered_ts) ): counters["delivered_ts_mismatch"] += 1 _record_detail( message_id, "delivered_ts_mismatch", "payload_normalization_gap", { "db": db_delivered_ts, "projected": projected_delivered_ts, }, ) db_read_ts = db_row.get("read_ts") projected_read_ts = projected.get("read_ts") if (db_read_ts is None) != (projected_read_ts is None) or ( db_read_ts is not None and projected_read_ts is not None and int(db_read_ts) != int(projected_read_ts) ): counters["read_ts_mismatch"] += 1 _record_detail( message_id, "read_ts_mismatch", "payload_normalization_gap", {"db": db_read_ts, "projected": projected_read_ts}, ) db_reactions = _normalize_reactions( list((db_row.get("receipt_payload") or {}).get("reactions") or []) ) projected_reactions = _normalize_reactions( list(projected.get("reactions") or []) ) if db_reactions != projected_reactions: counters["reactions_mismatch"] += 1 cause = "payload_normalization_gap" strategy = str( ( (db_row.get("receipt_payload") or {}).get( "reaction_last_match_strategy" ) or "" ) ).strip() if strategy == "nearest_ts_window": cause = "ambiguous_reaction_target" _record_detail( message_id, "reactions_mismatch", cause, {"db": db_reactions, "projected": projected_reactions}, ) for message_id in projected_by_id.keys(): if message_id not in db_by_id: counters["missing_in_db"] += 1 _record_detail(message_id, "missing_in_db", "payload_normalization_gap") mismatch_total = int(sum(int(value or 0) for value in counters.values())) return { "session_id": str(session.id), "db_message_count": len(db_rows), "projected_message_count": len(projected_rows), "mismatch_total": mismatch_total, "counters": counters, "cause_counts": cause_counts, "cause_samples": cause_samples, "details": details, }