Files
GIA/core/memory/search_backend.py

346 lines
12 KiB
Python

from __future__ import annotations
import hashlib
import json
import time
from dataclasses import dataclass
from typing import Any
import requests
from django.conf import settings
from core.models import MemoryItem
from core.util import logs
log = logs.get_logger("memory-search")
@dataclass
class MemorySearchHit:
memory_id: str
score: float
summary: str
payload: dict[str, Any]
def _flatten_to_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, dict):
parts = []
for key, item in value.items():
parts.append(str(key))
parts.append(_flatten_to_text(item))
return " ".join(part for part in parts if part).strip()
if isinstance(value, (list, tuple, set)):
return " ".join(_flatten_to_text(item) for item in value if item).strip()
return str(value).strip()
class BaseMemorySearchBackend:
def upsert(self, item: MemoryItem) -> None:
raise NotImplementedError
def delete(self, memory_id: str) -> None:
raise NotImplementedError
def search(
self,
*,
user_id: int,
query: str,
conversation_id: str = "",
limit: int = 20,
include_statuses: tuple[str, ...] = ("active",),
) -> list[MemorySearchHit]:
raise NotImplementedError
def reindex(
self,
*,
user_id: int | None = None,
include_statuses: tuple[str, ...] = ("active",),
limit: int = 2000,
) -> dict[str, int]:
queryset = MemoryItem.objects.all().order_by("-updated_at")
if user_id is not None:
queryset = queryset.filter(user_id=int(user_id))
if include_statuses:
queryset = queryset.filter(status__in=list(include_statuses))
scanned = 0
indexed = 0
for item in queryset[: max(1, int(limit))]:
scanned += 1
try:
self.upsert(item)
indexed += 1
except Exception as exc:
log.warning("memory-search upsert failed id=%s err=%s", item.id, exc)
return {"scanned": scanned, "indexed": indexed}
class DjangoMemorySearchBackend(BaseMemorySearchBackend):
name = "django"
def upsert(self, item: MemoryItem) -> None:
# No-op because Django backend queries source-of-truth rows directly.
_ = item
def delete(self, memory_id: str) -> None:
_ = memory_id
def search(
self,
*,
user_id: int,
query: str,
conversation_id: str = "",
limit: int = 20,
include_statuses: tuple[str, ...] = ("active",),
) -> list[MemorySearchHit]:
needle = str(query or "").strip().lower()
if not needle:
return []
queryset = MemoryItem.objects.filter(user_id=int(user_id))
if conversation_id:
queryset = queryset.filter(conversation_id=conversation_id)
if include_statuses:
queryset = queryset.filter(status__in=list(include_statuses))
hits: list[MemorySearchHit] = []
for item in queryset.order_by("-updated_at")[:500]:
text_blob = _flatten_to_text(item.content).lower()
if needle not in text_blob:
continue
raw_summary = _flatten_to_text(item.content)
summary = raw_summary[:280]
score = 1.0 + min(1.0, len(needle) / max(1.0, len(text_blob)))
hits.append(
MemorySearchHit(
memory_id=str(item.id),
score=float(score),
summary=summary,
payload={
"memory_kind": str(item.memory_kind or ""),
"status": str(item.status or ""),
"conversation_id": str(item.conversation_id or ""),
"updated_at": item.updated_at.isoformat(),
},
)
)
if len(hits) >= max(1, int(limit)):
break
return hits
class ManticoreMemorySearchBackend(BaseMemorySearchBackend):
name = "manticore"
_table_ready_cache: dict[str, float] = {}
_table_ready_ttl_seconds = 30.0
def __init__(self):
self.base_url = str(
getattr(settings, "MANTICORE_HTTP_URL", "http://localhost:9308")
).rstrip("/")
self.table = str(
getattr(settings, "MANTICORE_MEMORY_TABLE", "gia_memory_items")
).strip() or "gia_memory_items"
self.timeout_seconds = int(getattr(settings, "MANTICORE_HTTP_TIMEOUT", 5) or 5)
self._table_cache_key = f"{self.base_url}|{self.table}"
def _sql(self, query: str) -> dict[str, Any]:
response = requests.post(
f"{self.base_url}/sql",
data={"mode": "raw", "query": query},
timeout=self.timeout_seconds,
)
response.raise_for_status()
payload = response.json()
if isinstance(payload, list):
return payload[0] if payload else {}
return dict(payload or {})
def ensure_table(self) -> None:
last_ready = float(self._table_ready_cache.get(self._table_cache_key, 0.0) or 0.0)
if (time.time() - last_ready) <= float(self._table_ready_ttl_seconds):
return
self._sql(
(
f"CREATE TABLE IF NOT EXISTS {self.table} ("
"id BIGINT,"
"memory_uuid STRING,"
"user_id BIGINT,"
"conversation_id STRING,"
"memory_kind STRING,"
"status STRING,"
"updated_ts BIGINT,"
"summary TEXT,"
"body TEXT"
")"
)
)
self._table_ready_cache[self._table_cache_key] = time.time()
def _doc_id(self, memory_id: str) -> int:
digest = hashlib.blake2b(
str(memory_id or "").encode("utf-8"),
digest_size=8,
).digest()
value = int.from_bytes(digest, byteorder="big", signed=False)
return max(1, int(value))
def _escape(self, value: Any) -> str:
text = str(value or "")
text = text.replace("\\", "\\\\").replace("'", "\\'")
return text
def upsert(self, item: MemoryItem) -> None:
self.ensure_table()
memory_id = str(item.id)
doc_id = self._doc_id(memory_id)
summary = _flatten_to_text(item.content)[:280]
body = _flatten_to_text(item.content)
updated_ts = int(item.updated_at.timestamp() * 1000)
query = (
f"REPLACE INTO {self.table} "
"(id,memory_uuid,user_id,conversation_id,memory_kind,status,updated_ts,summary,body) "
f"VALUES ({doc_id},'{self._escape(memory_id)}',{int(item.user_id)},"
f"'{self._escape(item.conversation_id)}','{self._escape(item.memory_kind)}',"
f"'{self._escape(item.status)}',{updated_ts},"
f"'{self._escape(summary)}','{self._escape(body)}')"
)
self._sql(query)
def _build_upsert_values_clause(self, item: MemoryItem) -> str:
memory_id = str(item.id)
doc_id = self._doc_id(memory_id)
summary = _flatten_to_text(item.content)[:280]
body = _flatten_to_text(item.content)
updated_ts = int(item.updated_at.timestamp() * 1000)
return (
f"({doc_id},'{self._escape(memory_id)}',{int(item.user_id)},"
f"'{self._escape(item.conversation_id)}','{self._escape(item.memory_kind)}',"
f"'{self._escape(item.status)}',{updated_ts},"
f"'{self._escape(summary)}','{self._escape(body)}')"
)
def delete(self, memory_id: str) -> None:
self.ensure_table()
doc_id = self._doc_id(memory_id)
self._sql(f"DELETE FROM {self.table} WHERE id={doc_id}")
def reindex(
self,
*,
user_id: int | None = None,
include_statuses: tuple[str, ...] = ("active",),
limit: int = 2000,
) -> dict[str, int]:
self.ensure_table()
queryset = MemoryItem.objects.all().order_by("-updated_at")
if user_id is not None:
queryset = queryset.filter(user_id=int(user_id))
if include_statuses:
queryset = queryset.filter(status__in=list(include_statuses))
scanned = 0
indexed = 0
batch_size = 100
values: list[str] = []
for item in queryset[: max(1, int(limit))]:
scanned += 1
try:
values.append(self._build_upsert_values_clause(item))
except Exception as exc:
log.warning("memory-search upsert build failed id=%s err=%s", item.id, exc)
continue
if len(values) >= batch_size:
self._sql(
f"REPLACE INTO {self.table} "
"(id,memory_uuid,user_id,conversation_id,memory_kind,status,updated_ts,summary,body) "
f"VALUES {','.join(values)}"
)
indexed += len(values)
values = []
if values:
self._sql(
f"REPLACE INTO {self.table} "
"(id,memory_uuid,user_id,conversation_id,memory_kind,status,updated_ts,summary,body) "
f"VALUES {','.join(values)}"
)
indexed += len(values)
return {"scanned": scanned, "indexed": indexed}
def search(
self,
*,
user_id: int,
query: str,
conversation_id: str = "",
limit: int = 20,
include_statuses: tuple[str, ...] = ("active",),
) -> list[MemorySearchHit]:
self.ensure_table()
needle = str(query or "").strip()
if not needle:
return []
where_parts = [f"user_id={int(user_id)}", f"MATCH('{self._escape(needle)}')"]
if conversation_id:
where_parts.append(f"conversation_id='{self._escape(conversation_id)}'")
statuses = [str(item or "").strip() for item in include_statuses if str(item or "").strip()]
if statuses:
in_clause = ",".join(f"'{self._escape(item)}'" for item in statuses)
where_parts.append(f"status IN ({in_clause})")
where_sql = " AND ".join(where_parts)
query_sql = (
f"SELECT memory_uuid,memory_kind,status,conversation_id,updated_ts,summary,WEIGHT() AS score "
f"FROM {self.table} WHERE {where_sql} ORDER BY score DESC LIMIT {max(1, int(limit))}"
)
payload = self._sql(query_sql)
rows = list(payload.get("data") or [])
hits = []
for row in rows:
item = dict(row or {})
hits.append(
MemorySearchHit(
memory_id=str(item.get("memory_uuid") or ""),
score=float(item.get("score") or 0.0),
summary=str(item.get("summary") or ""),
payload={
"memory_kind": str(item.get("memory_kind") or ""),
"status": str(item.get("status") or ""),
"conversation_id": str(item.get("conversation_id") or ""),
"updated_ts": int(item.get("updated_ts") or 0),
},
)
)
return hits
def get_memory_search_backend() -> BaseMemorySearchBackend:
backend = str(getattr(settings, "MEMORY_SEARCH_BACKEND", "django")).strip().lower()
if backend == "manticore":
return ManticoreMemorySearchBackend()
return DjangoMemorySearchBackend()
def backend_status() -> dict[str, Any]:
backend = get_memory_search_backend()
status = {
"backend": getattr(backend, "name", "unknown"),
"ok": True,
"ts": int(time.time() * 1000),
}
if isinstance(backend, ManticoreMemorySearchBackend):
try:
backend.ensure_table()
status["manticore_http_url"] = backend.base_url
status["manticore_table"] = backend.table
except Exception as exc:
status["ok"] = False
status["error"] = str(exc)
return status