import asyncio import json import time from urllib.parse import parse_qs from asgiref.sync import sync_to_async from django.core import signing from core.models import Message, Person, PersonIdentifier, WorkspaceConversation from core.realtime.typing_state import get_person_typing_state from core.views.compose import ( COMPOSE_WS_TOKEN_SALT, ComposeHistorySync, _serialize_messages_with_artifacts, ) def _safe_int(value, default=0): try: return int(value) except (TypeError, ValueError): return default def _load_since(user_id, service, identifier, person_id, after_ts, limit): person = None person_identifier = None resolved_person_id = _safe_int(person_id) if resolved_person_id > 0: person = Person.objects.filter(id=resolved_person_id, user_id=user_id).first() if person is not None: person_identifier = ( PersonIdentifier.objects.filter( user_id=user_id, person_id=person.id, service=service, ).first() or PersonIdentifier.objects.filter( user_id=user_id, person_id=person.id, ).first() ) elif identifier: person_identifier = PersonIdentifier.objects.filter( user_id=user_id, service=service, identifier=identifier, ).first() session_ids = ComposeHistorySync._session_ids_for_scope( user=user_id, person=person, service=service, person_identifier=person_identifier, explicit_identifier=identifier, ) if not session_ids: return { "messages": [], "last_ts": int(after_ts or 0), "person_id": int(person.id) if person is not None else 0, } base_queryset = Message.objects.filter( user_id=user_id, session_id__in=session_ids, ) qs = base_queryset.order_by("ts") seed_previous = None if after_ts > 0: seed_previous = base_queryset.filter(ts__lte=after_ts).order_by("-ts").first() # Use a small rolling window to capture late/out-of-order timestamps. # Frontend dedupes by message id, so repeated rows are ignored. window_start = max(0, int(after_ts) - 5 * 60 * 1000) qs = qs.filter(ts__gte=window_start) rows_desc = list( qs.select_related( "session", "session__identifier", "session__identifier__person", ).order_by("-ts")[: max(10, min(limit, 200))] ) rows_desc.reverse() rows = rows_desc newest = ( Message.objects.filter( user_id=user_id, session_id__in=session_ids, ) .order_by("-ts") .values_list("ts", flat=True) .first() ) effective_person_id = ( int(person.id) if person is not None else (int(person_identifier.person_id) if person_identifier is not None else 0) ) conversation = None counterpart_identifiers = set() if effective_person_id > 0: conversation = ( WorkspaceConversation.objects.filter( user_id=user_id, participants__id=effective_person_id, ) .order_by("-last_event_ts", "-created_at") .first() ) counterpart_identifiers = { str(value or "").strip() for value in PersonIdentifier.objects.filter( user_id=user_id, person_id=effective_person_id, ).values_list("identifier", flat=True) if str(value or "").strip() } return { "messages": _serialize_messages_with_artifacts( rows, counterpart_identifiers=counterpart_identifiers, conversation=conversation, seed_previous=seed_previous, ), "last_ts": int(newest or after_ts or 0), "person_id": int(effective_person_id), } async def compose_ws_application(scope, receive, send): if scope.get("type") != "websocket": return query_string = (scope.get("query_string") or b"").decode("utf-8", errors="ignore") params = parse_qs(query_string) token = (params.get("token") or [""])[0] try: payload = signing.loads(token, salt=COMPOSE_WS_TOKEN_SALT) except Exception: await send({"type": "websocket.close", "code": 4401}) return if _safe_int(payload.get("exp")) < int(time.time()): await send({"type": "websocket.close", "code": 4401}) return user_id = _safe_int(payload.get("u")) service = str(payload.get("s") or "").strip() identifier = str(payload.get("i") or "").strip() person_id = str(payload.get("p") or "").strip() resolved_person_id = _safe_int(person_id) if user_id <= 0 or (not identifier and not person_id): await send({"type": "websocket.close", "code": 4401}) return await send({"type": "websocket.accept"}) last_ts = 0 limit = 100 last_typing_key = "" sent_message_ids = set() while True: event = None try: event = await asyncio.wait_for(receive(), timeout=1.2) except asyncio.TimeoutError: event = None if event and event.get("type") == "websocket.disconnect": break if event and event.get("type") == "websocket.receive": try: body = json.loads(event.get("text") or "{}") except Exception: body = {} if body.get("kind") == "sync": last_ts = max(last_ts, _safe_int(body.get("last_ts"), 0)) payload = await sync_to_async(_load_since)( user_id=user_id, service=service, identifier=identifier, person_id=person_id, after_ts=last_ts, limit=limit, ) raw_messages = payload.get("messages") or [] messages = [] for msg in raw_messages: message_id = str((msg or {}).get("id") or "").strip() if message_id and message_id in sent_message_ids: continue if message_id: sent_message_ids.add(message_id) messages.append(msg) latest = _safe_int(payload.get("last_ts"), last_ts) if resolved_person_id <= 0: resolved_person_id = _safe_int(payload.get("person_id"), 0) typing_state = get_person_typing_state( user_id=user_id, person_id=resolved_person_id, ) typing_key = json.dumps(typing_state, sort_keys=True) typing_changed = typing_key != last_typing_key if typing_changed: last_typing_key = typing_key outgoing_payload = {} if messages: last_ts = max(last_ts, latest) outgoing_payload["messages"] = messages outgoing_payload["last_ts"] = last_ts else: last_ts = max(last_ts, latest) outgoing_payload["last_ts"] = last_ts if typing_changed: outgoing_payload["typing"] = typing_state if messages or typing_changed: await send( { "type": "websocket.send", "text": json.dumps(outgoing_payload), } )