Increase platform abstraction cohesion
This commit is contained in:
123
core/memory/retrieval.py
Normal file
123
core/memory/retrieval.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.db.models import Q
|
||||
from django.utils import timezone
|
||||
|
||||
from core.memory.search_backend import get_memory_search_backend
|
||||
from core.models import MemoryItem
|
||||
|
||||
|
||||
def _coerce_statuses(value: Any, default: tuple[str, ...]) -> tuple[str, ...]:
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
items = [str(item or "").strip().lower() for item in value]
|
||||
else:
|
||||
items = [item.strip().lower() for item in str(value or "").split(",")]
|
||||
cleaned = tuple(item for item in items if item)
|
||||
return cleaned or default
|
||||
|
||||
|
||||
def _base_queryset(
|
||||
*,
|
||||
user_id: int,
|
||||
person_id: str = "",
|
||||
conversation_id: str = "",
|
||||
statuses: tuple[str, ...] = ("active",),
|
||||
):
|
||||
now = timezone.now()
|
||||
queryset = MemoryItem.objects.filter(user_id=int(user_id))
|
||||
if statuses:
|
||||
queryset = queryset.filter(status__in=list(statuses))
|
||||
queryset = queryset.filter(Q(expires_at__isnull=True) | Q(expires_at__gt=now))
|
||||
if person_id:
|
||||
queryset = queryset.filter(person_id=person_id)
|
||||
if conversation_id:
|
||||
queryset = queryset.filter(conversation_id=conversation_id)
|
||||
return queryset
|
||||
|
||||
|
||||
def retrieve_memories_for_prompt(
|
||||
*,
|
||||
user_id: int,
|
||||
query: str = "",
|
||||
person_id: str = "",
|
||||
conversation_id: str = "",
|
||||
statuses: tuple[str, ...] = ("active",),
|
||||
limit: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
statuses = _coerce_statuses(statuses, ("active",))
|
||||
safe_limit = max(1, min(200, int(limit or 20)))
|
||||
search_text = str(query or "").strip()
|
||||
|
||||
if search_text:
|
||||
backend = get_memory_search_backend()
|
||||
hits = backend.search(
|
||||
user_id=int(user_id),
|
||||
query=search_text,
|
||||
conversation_id=conversation_id,
|
||||
limit=safe_limit,
|
||||
include_statuses=statuses,
|
||||
)
|
||||
ids = [str(hit.memory_id or "").strip() for hit in hits if str(hit.memory_id or "").strip()]
|
||||
scoped = _base_queryset(
|
||||
user_id=int(user_id),
|
||||
person_id=person_id,
|
||||
conversation_id=conversation_id,
|
||||
statuses=statuses,
|
||||
).filter(id__in=ids)
|
||||
by_id = {str(item.id): item for item in scoped}
|
||||
rows = []
|
||||
for hit in hits:
|
||||
item = by_id.get(str(hit.memory_id))
|
||||
if not item:
|
||||
continue
|
||||
rows.append(
|
||||
{
|
||||
"id": str(item.id),
|
||||
"memory_kind": str(item.memory_kind or ""),
|
||||
"status": str(item.status or ""),
|
||||
"person_id": str(item.person_id or ""),
|
||||
"conversation_id": str(item.conversation_id or ""),
|
||||
"content": item.content or {},
|
||||
"provenance": item.provenance or {},
|
||||
"confidence_score": float(item.confidence_score or 0.0),
|
||||
"expires_at": item.expires_at.isoformat() if item.expires_at else "",
|
||||
"last_verified_at": (
|
||||
item.last_verified_at.isoformat() if item.last_verified_at else ""
|
||||
),
|
||||
"updated_at": item.updated_at.isoformat() if item.updated_at else "",
|
||||
"search_score": float(hit.score or 0.0),
|
||||
"search_summary": str(hit.summary or ""),
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
queryset = _base_queryset(
|
||||
user_id=int(user_id),
|
||||
person_id=person_id,
|
||||
conversation_id=conversation_id,
|
||||
statuses=statuses,
|
||||
).order_by("-last_verified_at", "-updated_at")
|
||||
rows = []
|
||||
for item in queryset[:safe_limit]:
|
||||
rows.append(
|
||||
{
|
||||
"id": str(item.id),
|
||||
"memory_kind": str(item.memory_kind or ""),
|
||||
"status": str(item.status or ""),
|
||||
"person_id": str(item.person_id or ""),
|
||||
"conversation_id": str(item.conversation_id or ""),
|
||||
"content": item.content or {},
|
||||
"provenance": item.provenance or {},
|
||||
"confidence_score": float(item.confidence_score or 0.0),
|
||||
"expires_at": item.expires_at.isoformat() if item.expires_at else "",
|
||||
"last_verified_at": (
|
||||
item.last_verified_at.isoformat() if item.last_verified_at else ""
|
||||
),
|
||||
"updated_at": item.updated_at.isoformat() if item.updated_at else "",
|
||||
"search_score": 0.0,
|
||||
"search_summary": "",
|
||||
}
|
||||
)
|
||||
return rows
|
||||
Reference in New Issue
Block a user