Files
GIA/core/memory/retrieval.py

124 lines
4.4 KiB
Python

from __future__ import annotations
from typing import Any
from django.db.models import Q
from django.utils import timezone
from core.memory.search_backend import get_memory_search_backend
from core.models import MemoryItem
def _coerce_statuses(value: Any, default: tuple[str, ...]) -> tuple[str, ...]:
if isinstance(value, (list, tuple, set)):
items = [str(item or "").strip().lower() for item in value]
else:
items = [item.strip().lower() for item in str(value or "").split(",")]
cleaned = tuple(item for item in items if item)
return cleaned or default
def _base_queryset(
*,
user_id: int,
person_id: str = "",
conversation_id: str = "",
statuses: tuple[str, ...] = ("active",),
):
now = timezone.now()
queryset = MemoryItem.objects.filter(user_id=int(user_id))
if statuses:
queryset = queryset.filter(status__in=list(statuses))
queryset = queryset.filter(Q(expires_at__isnull=True) | Q(expires_at__gt=now))
if person_id:
queryset = queryset.filter(person_id=person_id)
if conversation_id:
queryset = queryset.filter(conversation_id=conversation_id)
return queryset
def retrieve_memories_for_prompt(
*,
user_id: int,
query: str = "",
person_id: str = "",
conversation_id: str = "",
statuses: tuple[str, ...] = ("active",),
limit: int = 20,
) -> list[dict[str, Any]]:
statuses = _coerce_statuses(statuses, ("active",))
safe_limit = max(1, min(200, int(limit or 20)))
search_text = str(query or "").strip()
if search_text:
backend = get_memory_search_backend()
hits = backend.search(
user_id=int(user_id),
query=search_text,
conversation_id=conversation_id,
limit=safe_limit,
include_statuses=statuses,
)
ids = [str(hit.memory_id or "").strip() for hit in hits if str(hit.memory_id or "").strip()]
scoped = _base_queryset(
user_id=int(user_id),
person_id=person_id,
conversation_id=conversation_id,
statuses=statuses,
).filter(id__in=ids)
by_id = {str(item.id): item for item in scoped}
rows = []
for hit in hits:
item = by_id.get(str(hit.memory_id))
if not item:
continue
rows.append(
{
"id": str(item.id),
"memory_kind": str(item.memory_kind or ""),
"status": str(item.status or ""),
"person_id": str(item.person_id or ""),
"conversation_id": str(item.conversation_id or ""),
"content": item.content or {},
"provenance": item.provenance or {},
"confidence_score": float(item.confidence_score or 0.0),
"expires_at": item.expires_at.isoformat() if item.expires_at else "",
"last_verified_at": (
item.last_verified_at.isoformat() if item.last_verified_at else ""
),
"updated_at": item.updated_at.isoformat() if item.updated_at else "",
"search_score": float(hit.score or 0.0),
"search_summary": str(hit.summary or ""),
}
)
return rows
queryset = _base_queryset(
user_id=int(user_id),
person_id=person_id,
conversation_id=conversation_id,
statuses=statuses,
).order_by("-last_verified_at", "-updated_at")
rows = []
for item in queryset[:safe_limit]:
rows.append(
{
"id": str(item.id),
"memory_kind": str(item.memory_kind or ""),
"status": str(item.status or ""),
"person_id": str(item.person_id or ""),
"conversation_id": str(item.conversation_id or ""),
"content": item.content or {},
"provenance": item.provenance or {},
"confidence_score": float(item.confidence_score or 0.0),
"expires_at": item.expires_at.isoformat() if item.expires_at else "",
"last_verified_at": (
item.last_verified_at.isoformat() if item.last_verified_at else ""
),
"updated_at": item.updated_at.isoformat() if item.updated_at else "",
"search_score": 0.0,
"search_summary": "",
}
)
return rows