244 lines
7.7 KiB
Python
244 lines
7.7 KiB
Python
import asyncio
|
|
import json
|
|
import time
|
|
from urllib.parse import parse_qs
|
|
|
|
from asgiref.sync import sync_to_async
|
|
from django.core import signing
|
|
|
|
from core.models import Message, Person, PersonIdentifier, WorkspaceConversation
|
|
from core.realtime.typing_state import get_person_typing_state
|
|
from core.views.compose import (
|
|
COMPOSE_WS_TOKEN_SALT,
|
|
ComposeHistorySync,
|
|
_serialize_messages_with_artifacts,
|
|
)
|
|
|
|
|
|
def _safe_int(value, default=0):
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
|
|
def _load_since(user_id, service, identifier, person_id, after_ts, limit):
|
|
person = None
|
|
person_identifier = None
|
|
resolved_person_id = _safe_int(person_id)
|
|
|
|
if resolved_person_id > 0:
|
|
person = Person.objects.filter(id=resolved_person_id, user_id=user_id).first()
|
|
|
|
if person is not None:
|
|
person_identifier = (
|
|
PersonIdentifier.objects.filter(
|
|
user_id=user_id,
|
|
person_id=person.id,
|
|
service=service,
|
|
).first()
|
|
or PersonIdentifier.objects.filter(
|
|
user_id=user_id,
|
|
person_id=person.id,
|
|
).first()
|
|
)
|
|
elif identifier:
|
|
person_identifier = PersonIdentifier.objects.filter(
|
|
user_id=user_id,
|
|
service=service,
|
|
identifier=identifier,
|
|
).first()
|
|
|
|
session_ids = ComposeHistorySync._session_ids_for_scope(
|
|
user=user_id,
|
|
person=person,
|
|
service=service,
|
|
person_identifier=person_identifier,
|
|
explicit_identifier=identifier,
|
|
)
|
|
|
|
if not session_ids:
|
|
return {
|
|
"messages": [],
|
|
"last_ts": int(after_ts or 0),
|
|
"person_id": int(person.id) if person is not None else 0,
|
|
}
|
|
|
|
base_queryset = Message.objects.filter(
|
|
user_id=user_id,
|
|
session_id__in=session_ids,
|
|
)
|
|
qs = base_queryset.order_by("ts")
|
|
seed_previous = None
|
|
if after_ts > 0:
|
|
seed_previous = base_queryset.filter(ts__lte=after_ts).order_by("-ts").first()
|
|
# Use a small rolling window to capture late/out-of-order timestamps.
|
|
# Frontend dedupes by message id, so repeated rows are ignored.
|
|
window_start = max(0, int(after_ts) - 5 * 60 * 1000)
|
|
qs = qs.filter(ts__gte=window_start)
|
|
|
|
rows_desc = list(
|
|
qs.select_related(
|
|
"session",
|
|
"session__identifier",
|
|
"session__identifier__person",
|
|
).order_by("-ts")[: max(10, min(limit, 200))]
|
|
)
|
|
rows_desc.reverse()
|
|
rows = rows_desc
|
|
newest = (
|
|
Message.objects.filter(
|
|
user_id=user_id,
|
|
session_id__in=session_ids,
|
|
)
|
|
.order_by("-ts")
|
|
.values_list("ts", flat=True)
|
|
.first()
|
|
)
|
|
|
|
effective_person_id = (
|
|
int(person.id)
|
|
if person is not None
|
|
else (int(person_identifier.person_id) if person_identifier is not None else 0)
|
|
)
|
|
|
|
conversation = None
|
|
counterpart_identifiers = set()
|
|
if effective_person_id > 0:
|
|
conversation = (
|
|
WorkspaceConversation.objects.filter(
|
|
user_id=user_id,
|
|
participants__id=effective_person_id,
|
|
)
|
|
.order_by("-last_event_ts", "-created_at")
|
|
.first()
|
|
)
|
|
counterpart_identifiers = {
|
|
str(value or "").strip()
|
|
for value in PersonIdentifier.objects.filter(
|
|
user_id=user_id,
|
|
person_id=effective_person_id,
|
|
).values_list("identifier", flat=True)
|
|
if str(value or "").strip()
|
|
}
|
|
|
|
return {
|
|
"messages": _serialize_messages_with_artifacts(
|
|
rows,
|
|
counterpart_identifiers=counterpart_identifiers,
|
|
conversation=conversation,
|
|
seed_previous=seed_previous,
|
|
),
|
|
"last_ts": int(newest or after_ts or 0),
|
|
"person_id": int(effective_person_id),
|
|
}
|
|
|
|
|
|
async def compose_ws_application(scope, receive, send):
|
|
if scope.get("type") != "websocket":
|
|
return
|
|
|
|
query_string = (scope.get("query_string") or b"").decode("utf-8", errors="ignore")
|
|
params = parse_qs(query_string)
|
|
token = (params.get("token") or [""])[0]
|
|
try:
|
|
payload = signing.loads(token, salt=COMPOSE_WS_TOKEN_SALT)
|
|
except Exception:
|
|
await send({"type": "websocket.close", "code": 4401})
|
|
return
|
|
|
|
if _safe_int(payload.get("exp")) < int(time.time()):
|
|
await send({"type": "websocket.close", "code": 4401})
|
|
return
|
|
|
|
user_id = _safe_int(payload.get("u"))
|
|
service = str(payload.get("s") or "").strip()
|
|
identifier = str(payload.get("i") or "").strip()
|
|
person_id = str(payload.get("p") or "").strip()
|
|
resolved_person_id = _safe_int(person_id)
|
|
|
|
if user_id <= 0 or (not identifier and not person_id):
|
|
await send({"type": "websocket.close", "code": 4401})
|
|
return
|
|
|
|
await send({"type": "websocket.accept"})
|
|
|
|
# TODO(reactions): stream incremental reaction add/remove events over WS
|
|
# instead of relying on message row refresh polling windows.
|
|
# TODO(edits): add edit event envelopes so compose can update message text
|
|
# in place when upstream supports edits.
|
|
# TODO(retractions): add retract/delete event envelopes and tombstone UI.
|
|
# TODO(capability): surface per-service capability notices (e.g. "edited
|
|
# locally but upstream protocol does not support edits").
|
|
|
|
last_ts = 0
|
|
limit = 100
|
|
last_typing_key = ""
|
|
sent_message_ids = set()
|
|
|
|
while True:
|
|
event = None
|
|
try:
|
|
event = await asyncio.wait_for(receive(), timeout=1.2)
|
|
except asyncio.TimeoutError:
|
|
event = None
|
|
|
|
if event and event.get("type") == "websocket.disconnect":
|
|
break
|
|
if event and event.get("type") == "websocket.receive":
|
|
try:
|
|
body = json.loads(event.get("text") or "{}")
|
|
except Exception:
|
|
body = {}
|
|
if body.get("kind") == "sync":
|
|
last_ts = max(last_ts, _safe_int(body.get("last_ts"), 0))
|
|
|
|
payload = await sync_to_async(_load_since)(
|
|
user_id=user_id,
|
|
service=service,
|
|
identifier=identifier,
|
|
person_id=person_id,
|
|
after_ts=last_ts,
|
|
limit=limit,
|
|
)
|
|
raw_messages = payload.get("messages") or []
|
|
messages = []
|
|
for msg in raw_messages:
|
|
message_id = str((msg or {}).get("id") or "").strip()
|
|
if message_id and message_id in sent_message_ids:
|
|
continue
|
|
if message_id:
|
|
sent_message_ids.add(message_id)
|
|
messages.append(msg)
|
|
latest = _safe_int(payload.get("last_ts"), last_ts)
|
|
if resolved_person_id <= 0:
|
|
resolved_person_id = _safe_int(payload.get("person_id"), 0)
|
|
typing_state = get_person_typing_state(
|
|
user_id=user_id,
|
|
person_id=resolved_person_id,
|
|
)
|
|
typing_key = json.dumps(typing_state, sort_keys=True)
|
|
typing_changed = typing_key != last_typing_key
|
|
if typing_changed:
|
|
last_typing_key = typing_key
|
|
|
|
outgoing_payload = {}
|
|
if messages:
|
|
last_ts = max(last_ts, latest)
|
|
outgoing_payload["messages"] = messages
|
|
outgoing_payload["last_ts"] = last_ts
|
|
else:
|
|
last_ts = max(last_ts, latest)
|
|
outgoing_payload["last_ts"] = last_ts
|
|
|
|
if typing_changed:
|
|
outgoing_payload["typing"] = typing_state
|
|
|
|
if messages or typing_changed:
|
|
await send(
|
|
{
|
|
"type": "websocket.send",
|
|
"text": json.dumps(outgoing_payload),
|
|
}
|
|
)
|