Files
GIA/core/realtime/compose_ws.py

194 lines
6.0 KiB
Python

import asyncio
import json
import time
from urllib.parse import parse_qs
from asgiref.sync import sync_to_async
from django.core import signing
from core.models import ChatSession, Message, PersonIdentifier, WorkspaceConversation
from core.realtime.typing_state import get_person_typing_state
from core.views.compose import COMPOSE_WS_TOKEN_SALT, _serialize_messages_with_artifacts
def _safe_int(value, default=0):
try:
return int(value)
except (TypeError, ValueError):
return default
def _load_since(user_id, service, identifier, person_id, after_ts, limit):
person_identifier = None
if person_id:
person_identifier = (
PersonIdentifier.objects.filter(
user_id=user_id,
person_id=person_id,
service=service,
).first()
or PersonIdentifier.objects.filter(
user_id=user_id,
person_id=person_id,
).first()
)
if person_identifier is None and identifier:
person_identifier = PersonIdentifier.objects.filter(
user_id=user_id,
service=service,
identifier=identifier,
).first()
if person_identifier is None:
return {"messages": [], "last_ts": after_ts, "person_id": 0}
session = ChatSession.objects.filter(
user_id=user_id,
identifier=person_identifier,
).first()
if session is None:
return {
"messages": [],
"last_ts": after_ts,
"person_id": int(person_identifier.person_id),
}
qs = Message.objects.filter(user_id=user_id, session=session).order_by("ts")
seed_previous = None
if after_ts > 0:
seed_previous = (
Message.objects.filter(
user_id=user_id,
session=session,
ts__lte=after_ts,
)
.order_by("-ts")
.first()
)
qs = qs.filter(ts__gt=after_ts)
rows = list(qs[: max(10, min(limit, 200))])
newest = (
Message.objects.filter(user_id=user_id, session=session)
.order_by("-ts")
.values_list("ts", flat=True)
.first()
)
conversation = (
WorkspaceConversation.objects.filter(
user_id=user_id,
participants__id=person_identifier.person_id,
)
.order_by("-last_event_ts", "-created_at")
.first()
)
counterpart_identifiers = {
str(value or "").strip()
for value in PersonIdentifier.objects.filter(
user_id=user_id,
person_id=person_identifier.person_id,
).values_list("identifier", flat=True)
if str(value or "").strip()
}
return {
"messages": _serialize_messages_with_artifacts(
rows,
counterpart_identifiers=counterpart_identifiers,
conversation=conversation,
seed_previous=seed_previous,
),
"last_ts": int(newest or after_ts or 0),
"person_id": int(person_identifier.person_id),
}
async def compose_ws_application(scope, receive, send):
if scope.get("type") != "websocket":
return
query_string = (scope.get("query_string") or b"").decode("utf-8", errors="ignore")
params = parse_qs(query_string)
token = (params.get("token") or [""])[0]
try:
payload = signing.loads(token, salt=COMPOSE_WS_TOKEN_SALT)
except Exception:
await send({"type": "websocket.close", "code": 4401})
return
if _safe_int(payload.get("exp")) < int(time.time()):
await send({"type": "websocket.close", "code": 4401})
return
user_id = _safe_int(payload.get("u"))
service = str(payload.get("s") or "").strip()
identifier = str(payload.get("i") or "").strip()
person_id = str(payload.get("p") or "").strip()
resolved_person_id = _safe_int(person_id)
if user_id <= 0 or (not identifier and not person_id):
await send({"type": "websocket.close", "code": 4401})
return
await send({"type": "websocket.accept"})
last_ts = 0
limit = 100
last_typing_key = ""
while True:
event = None
try:
event = await asyncio.wait_for(receive(), timeout=1.2)
except asyncio.TimeoutError:
event = None
if event and event.get("type") == "websocket.disconnect":
break
if event and event.get("type") == "websocket.receive":
try:
body = json.loads(event.get("text") or "{}")
except Exception:
body = {}
if body.get("kind") == "sync":
last_ts = max(last_ts, _safe_int(body.get("last_ts"), 0))
payload = await sync_to_async(_load_since)(
user_id=user_id,
service=service,
identifier=identifier,
person_id=person_id,
after_ts=last_ts,
limit=limit,
)
messages = payload.get("messages") or []
latest = _safe_int(payload.get("last_ts"), last_ts)
if resolved_person_id <= 0:
resolved_person_id = _safe_int(payload.get("person_id"), 0)
typing_state = get_person_typing_state(
user_id=user_id,
person_id=resolved_person_id,
)
typing_key = json.dumps(typing_state, sort_keys=True)
typing_changed = typing_key != last_typing_key
if typing_changed:
last_typing_key = typing_key
outgoing_payload = {}
if messages:
last_ts = max(last_ts, latest)
outgoing_payload["messages"] = messages
outgoing_payload["last_ts"] = last_ts
else:
last_ts = max(last_ts, latest)
outgoing_payload["last_ts"] = last_ts
if typing_changed:
outgoing_payload["typing"] = typing_state
if messages or typing_changed:
await send(
{
"type": "websocket.send",
"text": json.dumps(outgoing_payload),
}
)