552 lines
18 KiB
Python
552 lines
18 KiB
Python
from asgiref.sync import sync_to_async
|
|
from django.conf import settings
|
|
|
|
from core.messaging.utils import messages_to_string
|
|
from core.models import ChatSession, Message, QueuedMessage
|
|
from core.util import logs
|
|
|
|
log = logs.get_logger("history")
|
|
|
|
# Prompt-window controls:
|
|
# - Full message history is always persisted in the database.
|
|
# - Only the prompt input window is reduced.
|
|
# - Max values are hard safety rails; runtime chooses a smaller adaptive subset.
|
|
# - Min value prevents overly aggressive clipping on very long average messages.
|
|
DEFAULT_PROMPT_HISTORY_MAX_MESSAGES = getattr(
|
|
settings, "PROMPT_HISTORY_MAX_MESSAGES", 120
|
|
)
|
|
DEFAULT_PROMPT_HISTORY_MAX_CHARS = getattr(
|
|
settings,
|
|
"PROMPT_HISTORY_MAX_CHARS",
|
|
24000,
|
|
)
|
|
DEFAULT_PROMPT_HISTORY_MIN_MESSAGES = getattr(
|
|
settings, "PROMPT_HISTORY_MIN_MESSAGES", 24
|
|
)
|
|
|
|
|
|
def _build_recent_history(messages, max_chars):
|
|
"""
|
|
Build the final prompt transcript under a strict character budget.
|
|
|
|
Method:
|
|
1. Iterate messages from newest to oldest so recency is prioritized.
|
|
2. For each message, estimate the rendered line length exactly as it will
|
|
appear in the prompt transcript.
|
|
3. Stop once adding another line would exceed `max_chars`, while still
|
|
guaranteeing at least one message can be included.
|
|
4. Reverse back to chronological order for readability in prompts.
|
|
"""
|
|
if not messages:
|
|
return ""
|
|
|
|
selected = []
|
|
total_chars = 0
|
|
# Recency-first packing, then reorder to chronological output later.
|
|
for msg in reversed(messages):
|
|
author = msg.custom_author or msg.session.identifier.person.name
|
|
line = f"[{msg.ts}] <{author}> {msg.text}"
|
|
line_len = len(line) + 1
|
|
# Keep at least one line even if it alone exceeds max_chars.
|
|
if selected and (total_chars + line_len) > max_chars:
|
|
break
|
|
selected.append(msg)
|
|
total_chars += line_len
|
|
|
|
selected.reverse()
|
|
return messages_to_string(selected)
|
|
|
|
|
|
def _compute_adaptive_message_limit(messages, max_messages, max_chars):
|
|
"""
|
|
Derive how many messages to include before final char-budget packing.
|
|
|
|
This function intentionally avoids hand-picked threshold buckets.
|
|
Instead, it computes a budget-derived estimate:
|
|
- Build a recent sample (up to 80 messages) representing current chat style.
|
|
- Measure *rendered* line lengths (timestamp + author + text), not raw text.
|
|
- Estimate average line length from that sample.
|
|
- Convert char budget into message budget: floor(max_chars / avg_line_len).
|
|
- Clamp to configured min/max rails.
|
|
|
|
Why two stages:
|
|
- Stage A (this function): estimate count from current message density.
|
|
- Stage B (`_build_recent_history`): enforce exact char ceiling.
|
|
This keeps behavior stable while guaranteeing hard prompt budget compliance.
|
|
"""
|
|
if not messages:
|
|
return DEFAULT_PROMPT_HISTORY_MIN_MESSAGES
|
|
|
|
sample = messages[-min(len(messages), 80) :]
|
|
rendered_lengths = []
|
|
for msg in sample:
|
|
author = (
|
|
msg.custom_author
|
|
if msg.custom_author
|
|
else msg.session.identifier.person.name
|
|
)
|
|
text = msg.text or ""
|
|
# Match the line shape used in _build_recent_history/messages_to_string.
|
|
rendered_lengths.append(len(f"[{msg.ts}] <{author}> {text}") + 1)
|
|
|
|
# Defensive denominator: never divide by zero.
|
|
avg_line_len = (
|
|
(sum(rendered_lengths) / len(rendered_lengths)) if rendered_lengths else 1.0
|
|
)
|
|
avg_line_len = max(avg_line_len, 1.0)
|
|
|
|
budget_based = int(max_chars / avg_line_len)
|
|
adaptive = max(DEFAULT_PROMPT_HISTORY_MIN_MESSAGES, budget_based)
|
|
adaptive = min(max_messages, adaptive)
|
|
return max(1, adaptive)
|
|
|
|
|
|
async def get_chat_history(
|
|
session,
|
|
max_messages=DEFAULT_PROMPT_HISTORY_MAX_MESSAGES,
|
|
max_chars=DEFAULT_PROMPT_HISTORY_MAX_CHARS,
|
|
):
|
|
"""
|
|
Return prompt-ready chat history with adaptive windowing and hard budget limits.
|
|
|
|
Pipeline:
|
|
1. Fetch a bounded recent slice from DB (performance guard).
|
|
2. Estimate adaptive message count from observed rendered message density.
|
|
3. Keep only the newest `adaptive_limit` messages.
|
|
4. Pack those lines under `max_chars` exactly.
|
|
"""
|
|
# Storage remains complete; only prompt context is reduced.
|
|
fetch_limit = max(max_messages * 3, 200)
|
|
fetch_limit = min(fetch_limit, 1000)
|
|
stored_messages = await sync_to_async(list)(
|
|
Message.objects.filter(session=session, user=session.user).order_by("-ts")[
|
|
:fetch_limit
|
|
]
|
|
)
|
|
stored_messages.reverse()
|
|
adaptive_limit = _compute_adaptive_message_limit(
|
|
stored_messages,
|
|
max_messages=max_messages,
|
|
max_chars=max_chars,
|
|
)
|
|
selected_messages = stored_messages[-adaptive_limit:]
|
|
recent_chat_history = _build_recent_history(selected_messages, max_chars=max_chars)
|
|
chat_history = f"Recent Messages:\n{recent_chat_history}"
|
|
|
|
return chat_history
|
|
|
|
|
|
async def get_chat_session(user, identifier):
|
|
chat_session, _ = await sync_to_async(ChatSession.objects.get_or_create)(
|
|
identifier=identifier,
|
|
user=user,
|
|
)
|
|
return chat_session
|
|
|
|
|
|
async def store_message(session, sender, text, ts, outgoing=False):
|
|
log.debug("Storing message for session=%s outgoing=%s", session.id, outgoing)
|
|
msg = await sync_to_async(Message.objects.create)(
|
|
user=session.user,
|
|
session=session,
|
|
sender_uuid=sender,
|
|
text=text,
|
|
ts=ts,
|
|
delivered_ts=ts,
|
|
custom_author="USER" if outgoing else None,
|
|
)
|
|
|
|
return msg
|
|
|
|
|
|
async def store_own_message(session, text, ts, manip=None, queue=False):
|
|
log.debug("Storing own message for session=%s queue=%s", session.id, queue)
|
|
cast = {
|
|
"user": session.user,
|
|
"session": session,
|
|
"custom_author": "BOT",
|
|
"text": text,
|
|
"ts": ts,
|
|
"delivered_ts": ts,
|
|
}
|
|
if queue:
|
|
msg_object = QueuedMessage
|
|
cast["manipulation"] = manip
|
|
else:
|
|
msg_object = Message
|
|
|
|
msg = await sync_to_async(msg_object.objects.create)(
|
|
**cast,
|
|
)
|
|
|
|
return msg
|
|
|
|
|
|
async def delete_queryset(queryset):
|
|
await sync_to_async(queryset.delete, thread_sensitive=True)()
|
|
|
|
|
|
async def apply_read_receipts(
|
|
user,
|
|
identifier,
|
|
message_timestamps,
|
|
read_ts=None,
|
|
source_service="signal",
|
|
read_by_identifier="",
|
|
payload=None,
|
|
):
|
|
"""
|
|
Persist delivery/read metadata for one identifier's messages.
|
|
"""
|
|
ts_values = []
|
|
for item in message_timestamps or []:
|
|
try:
|
|
ts_values.append(int(item))
|
|
except Exception:
|
|
continue
|
|
if not ts_values:
|
|
return 0
|
|
|
|
try:
|
|
read_at = int(read_ts) if read_ts else None
|
|
except Exception:
|
|
read_at = None
|
|
|
|
rows = await sync_to_async(list)(
|
|
Message.objects.filter(
|
|
user=user,
|
|
session__identifier=identifier,
|
|
ts__in=ts_values,
|
|
).select_related("session")
|
|
)
|
|
updated = 0
|
|
for message in rows:
|
|
dirty = []
|
|
if message.delivered_ts is None:
|
|
message.delivered_ts = read_at or message.ts
|
|
dirty.append("delivered_ts")
|
|
if read_at and (message.read_ts is None or read_at > message.read_ts):
|
|
message.read_ts = read_at
|
|
dirty.append("read_ts")
|
|
if source_service and message.read_source_service != source_service:
|
|
message.read_source_service = source_service
|
|
dirty.append("read_source_service")
|
|
if read_by_identifier and message.read_by_identifier != read_by_identifier:
|
|
message.read_by_identifier = read_by_identifier
|
|
dirty.append("read_by_identifier")
|
|
if payload:
|
|
receipt_data = dict(message.receipt_payload or {})
|
|
receipt_data[str(source_service)] = payload
|
|
message.receipt_payload = receipt_data
|
|
dirty.append("receipt_payload")
|
|
if dirty:
|
|
await sync_to_async(message.save)(update_fields=dirty)
|
|
updated += 1
|
|
return updated
|
|
|
|
|
|
async def apply_reaction(
|
|
user,
|
|
identifier,
|
|
*,
|
|
target_message_id="",
|
|
target_ts=0,
|
|
emoji="",
|
|
source_service="",
|
|
actor="",
|
|
remove=False,
|
|
payload=None,
|
|
):
|
|
log.debug(
|
|
"reaction-bridge history-apply start user=%s person_identifier=%s target_message_id=%s target_ts=%s source=%s actor=%s remove=%s emoji=%s",
|
|
getattr(user, "id", "-"),
|
|
getattr(identifier, "id", "-"),
|
|
str(target_message_id or "") or "-",
|
|
int(target_ts or 0),
|
|
str(source_service or "") or "-",
|
|
str(actor or "") or "-",
|
|
bool(remove),
|
|
str(emoji or "") or "-",
|
|
)
|
|
queryset = Message.objects.filter(
|
|
user=user,
|
|
session__identifier=identifier,
|
|
).select_related("session")
|
|
|
|
target = None
|
|
target_uuid = str(target_message_id or "").strip()
|
|
if target_uuid:
|
|
target = await sync_to_async(
|
|
lambda: queryset.filter(id=target_uuid).order_by("-ts").first()
|
|
)()
|
|
|
|
if target is None:
|
|
try:
|
|
ts_value = int(target_ts or 0)
|
|
except Exception:
|
|
ts_value = 0
|
|
if ts_value > 0:
|
|
lower = ts_value - 10_000
|
|
upper = ts_value + 10_000
|
|
window_rows = await sync_to_async(list)(
|
|
queryset.filter(ts__gte=lower, ts__lte=upper).order_by("ts")[:200]
|
|
)
|
|
if window_rows:
|
|
target = min(
|
|
window_rows,
|
|
key=lambda row: (
|
|
abs(int(row.ts or 0) - ts_value),
|
|
-int(row.ts or 0),
|
|
),
|
|
)
|
|
log.debug(
|
|
"reaction-bridge history-apply ts-match target_ts=%s picked_message_id=%s picked_ts=%s candidates=%s",
|
|
ts_value,
|
|
str(target.id),
|
|
int(target.ts or 0),
|
|
len(window_rows),
|
|
)
|
|
|
|
if target is None:
|
|
log.warning(
|
|
"reaction-bridge history-apply miss user=%s person_identifier=%s target_message_id=%s target_ts=%s",
|
|
getattr(user, "id", "-"),
|
|
getattr(identifier, "id", "-"),
|
|
str(target_message_id or "") or "-",
|
|
int(target_ts or 0),
|
|
)
|
|
return None
|
|
|
|
reactions = list((target.receipt_payload or {}).get("reactions") or [])
|
|
reaction_key = (
|
|
str(source_service or "").strip().lower(),
|
|
str(actor or "").strip(),
|
|
str(emoji or "").strip(),
|
|
)
|
|
|
|
merged = []
|
|
replaced = False
|
|
for item in reactions:
|
|
row = dict(item or {})
|
|
row_key = (
|
|
str(row.get("source_service") or "").strip().lower(),
|
|
str(row.get("actor") or "").strip(),
|
|
str(row.get("emoji") or "").strip(),
|
|
)
|
|
if row_key == reaction_key:
|
|
row["removed"] = bool(remove)
|
|
row["updated_at"] = int(target_ts or target.ts or 0)
|
|
row["payload"] = dict(payload or {})
|
|
merged.append(row)
|
|
replaced = True
|
|
continue
|
|
merged.append(row)
|
|
|
|
if not replaced:
|
|
merged.append(
|
|
{
|
|
"emoji": str(emoji or ""),
|
|
"source_service": str(source_service or ""),
|
|
"actor": str(actor or ""),
|
|
"removed": bool(remove),
|
|
"updated_at": int(target_ts or target.ts or 0),
|
|
"payload": dict(payload or {}),
|
|
}
|
|
)
|
|
|
|
receipt_payload = dict(target.receipt_payload or {})
|
|
receipt_payload["reactions"] = merged
|
|
target.receipt_payload = receipt_payload
|
|
await sync_to_async(target.save)(update_fields=["receipt_payload"])
|
|
log.debug(
|
|
"reaction-bridge history-apply ok message_id=%s reactions=%s",
|
|
str(target.id),
|
|
len(merged),
|
|
)
|
|
return target
|
|
|
|
|
|
def _iter_bridge_refs(receipt_payload, source_service):
|
|
payload = dict(receipt_payload or {})
|
|
refs = payload.get("bridge_refs") or {}
|
|
rows = refs.get(str(source_service or "").strip().lower()) or []
|
|
return [dict(row or {}) for row in rows if isinstance(row, dict)]
|
|
|
|
|
|
def _set_bridge_refs(receipt_payload, source_service, rows):
|
|
payload = dict(receipt_payload or {})
|
|
refs = dict(payload.get("bridge_refs") or {})
|
|
refs[str(source_service or "").strip().lower()] = list(rows or [])
|
|
payload["bridge_refs"] = refs
|
|
return payload
|
|
|
|
|
|
async def save_bridge_ref(
|
|
user,
|
|
identifier,
|
|
*,
|
|
source_service,
|
|
local_message_id="",
|
|
local_ts=0,
|
|
xmpp_message_id="",
|
|
upstream_message_id="",
|
|
upstream_author="",
|
|
upstream_ts=0,
|
|
):
|
|
# TODO(edit-sync): persist upstream edit identifiers/version vectors here so
|
|
# edit operations can target exact upstream message revisions.
|
|
# TODO(delete-sync): persist upstream deletion tombstone metadata here and
|
|
# keep bridge refs resolvable even after local message redaction.
|
|
source_key = str(source_service or "").strip().lower()
|
|
if not source_key:
|
|
return None
|
|
|
|
queryset = Message.objects.filter(
|
|
user=user,
|
|
session__identifier=identifier,
|
|
).select_related("session")
|
|
|
|
target = None
|
|
message_id = str(local_message_id or "").strip()
|
|
if message_id:
|
|
target = await sync_to_async(
|
|
lambda: queryset.filter(id=message_id).order_by("-ts").first()
|
|
)()
|
|
|
|
if target is None:
|
|
try:
|
|
ts_value = int(local_ts or 0)
|
|
except Exception:
|
|
ts_value = 0
|
|
if ts_value > 0:
|
|
lower = ts_value - 10_000
|
|
upper = ts_value + 10_000
|
|
rows = await sync_to_async(list)(
|
|
queryset.filter(ts__gte=lower, ts__lte=upper).order_by("-ts")[:200]
|
|
)
|
|
if rows:
|
|
target = min(
|
|
rows,
|
|
key=lambda row: (
|
|
abs(int(row.ts or 0) - ts_value),
|
|
-int(row.ts or 0),
|
|
),
|
|
)
|
|
|
|
if target is None:
|
|
return None
|
|
|
|
row = {
|
|
"xmpp_message_id": str(xmpp_message_id or "").strip(),
|
|
"upstream_message_id": str(upstream_message_id or "").strip(),
|
|
"upstream_author": str(upstream_author or "").strip(),
|
|
"upstream_ts": int(upstream_ts or 0),
|
|
"updated_at": int(local_ts or target.ts or 0),
|
|
}
|
|
|
|
existing = _iter_bridge_refs(target.receipt_payload or {}, source_key)
|
|
merged = []
|
|
for item in existing:
|
|
same_xmpp = row["xmpp_message_id"] and (
|
|
str(item.get("xmpp_message_id") or "").strip() == row["xmpp_message_id"]
|
|
)
|
|
same_upstream = row["upstream_message_id"] and (
|
|
str(item.get("upstream_message_id") or "").strip()
|
|
== row["upstream_message_id"]
|
|
)
|
|
if same_xmpp or same_upstream:
|
|
continue
|
|
merged.append(item)
|
|
merged.append(row)
|
|
if len(merged) > 100:
|
|
merged = merged[-100:]
|
|
|
|
target.receipt_payload = _set_bridge_refs(
|
|
target.receipt_payload or {},
|
|
source_key,
|
|
merged,
|
|
)
|
|
await sync_to_async(target.save)(update_fields=["receipt_payload"])
|
|
return {
|
|
"local_message_id": str(target.id),
|
|
"local_ts": int(target.ts or 0),
|
|
**row,
|
|
}
|
|
|
|
|
|
async def resolve_bridge_ref(
|
|
user,
|
|
identifier,
|
|
*,
|
|
source_service,
|
|
xmpp_message_id="",
|
|
upstream_message_id="",
|
|
upstream_author="",
|
|
upstream_ts=0,
|
|
):
|
|
source_key = str(source_service or "").strip().lower()
|
|
if not source_key:
|
|
return None
|
|
|
|
rows = await sync_to_async(list)(
|
|
Message.objects.filter(
|
|
user=user,
|
|
session__identifier=identifier,
|
|
)
|
|
.order_by("-ts")
|
|
.only("id", "ts", "receipt_payload")[:500]
|
|
)
|
|
|
|
xmpp_id = str(xmpp_message_id or "").strip()
|
|
upstream_id = str(upstream_message_id or "").strip()
|
|
author = str(upstream_author or "").strip()
|
|
try:
|
|
target_ts = int(upstream_ts or 0)
|
|
except Exception:
|
|
target_ts = 0
|
|
|
|
# 1) exact IDs first
|
|
for message in rows:
|
|
refs = _iter_bridge_refs(message.receipt_payload or {}, source_key)
|
|
for ref in refs:
|
|
if xmpp_id and str(ref.get("xmpp_message_id") or "").strip() == xmpp_id:
|
|
return {
|
|
"local_message_id": str(message.id),
|
|
"local_ts": int(message.ts or 0),
|
|
**dict(ref or {}),
|
|
}
|
|
if upstream_id and (
|
|
str(ref.get("upstream_message_id") or "").strip() == upstream_id
|
|
):
|
|
return {
|
|
"local_message_id": str(message.id),
|
|
"local_ts": int(message.ts or 0),
|
|
**dict(ref or {}),
|
|
}
|
|
|
|
# 2) timestamp proximity with optional author tie-break
|
|
best = None
|
|
best_key = None
|
|
if target_ts > 0:
|
|
for message in rows:
|
|
refs = _iter_bridge_refs(message.receipt_payload or {}, source_key)
|
|
for ref in refs:
|
|
row_ts = int(ref.get("upstream_ts") or 0)
|
|
if row_ts <= 0:
|
|
continue
|
|
gap = abs(row_ts - target_ts)
|
|
if gap > 15_000:
|
|
continue
|
|
row_author = str(ref.get("upstream_author") or "").strip()
|
|
author_penalty = 0 if (not author or author == row_author) else 1
|
|
freshness = int(ref.get("updated_at") or 0)
|
|
key = (gap, author_penalty, -freshness)
|
|
if best is None or key < best_key:
|
|
best = {
|
|
"local_message_id": str(message.id),
|
|
"local_ts": int(message.ts or 0),
|
|
**dict(ref or {}),
|
|
}
|
|
best_key = key
|
|
return best
|