Files
GIA/core/memory/search_backend.py

354 lines
12 KiB
Python

from __future__ import annotations
import hashlib
import time
from dataclasses import dataclass
from typing import Any
import requests
from django.conf import settings
from core.models import MemoryItem
from core.util import logs
log = logs.get_logger("memory-search")
@dataclass
class MemorySearchHit:
memory_id: str
score: float
summary: str
payload: dict[str, Any]
def _flatten_to_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, dict):
parts = []
for key, item in value.items():
parts.append(str(key))
parts.append(_flatten_to_text(item))
return " ".join(part for part in parts if part).strip()
if isinstance(value, (list, tuple, set)):
return " ".join(_flatten_to_text(item) for item in value if item).strip()
return str(value).strip()
class BaseMemorySearchBackend:
def upsert(self, item: MemoryItem) -> None:
raise NotImplementedError
def delete(self, memory_id: str) -> None:
raise NotImplementedError
def search(
self,
*,
user_id: int,
query: str,
conversation_id: str = "",
limit: int = 20,
include_statuses: tuple[str, ...] = ("active",),
) -> list[MemorySearchHit]:
raise NotImplementedError
def reindex(
self,
*,
user_id: int | None = None,
include_statuses: tuple[str, ...] = ("active",),
limit: int = 2000,
) -> dict[str, int]:
queryset = MemoryItem.objects.all().order_by("-updated_at")
if user_id is not None:
queryset = queryset.filter(user_id=int(user_id))
if include_statuses:
queryset = queryset.filter(status__in=list(include_statuses))
scanned = 0
indexed = 0
for item in queryset[: max(1, int(limit))]:
scanned += 1
try:
self.upsert(item)
indexed += 1
except Exception as exc:
log.warning("memory-search upsert failed id=%s err=%s", item.id, exc)
return {"scanned": scanned, "indexed": indexed}
class DjangoMemorySearchBackend(BaseMemorySearchBackend):
name = "django"
def upsert(self, item: MemoryItem) -> None:
# No-op because Django backend queries source-of-truth rows directly.
_ = item
def delete(self, memory_id: str) -> None:
_ = memory_id
def search(
self,
*,
user_id: int,
query: str,
conversation_id: str = "",
limit: int = 20,
include_statuses: tuple[str, ...] = ("active",),
) -> list[MemorySearchHit]:
needle = str(query or "").strip().lower()
if not needle:
return []
queryset = MemoryItem.objects.filter(user_id=int(user_id))
if conversation_id:
queryset = queryset.filter(conversation_id=conversation_id)
if include_statuses:
queryset = queryset.filter(status__in=list(include_statuses))
hits: list[MemorySearchHit] = []
for item in queryset.order_by("-updated_at")[:500]:
text_blob = _flatten_to_text(item.content).lower()
if needle not in text_blob:
continue
raw_summary = _flatten_to_text(item.content)
summary = raw_summary[:280]
score = 1.0 + min(1.0, len(needle) / max(1.0, len(text_blob)))
hits.append(
MemorySearchHit(
memory_id=str(item.id),
score=float(score),
summary=summary,
payload={
"memory_kind": str(item.memory_kind or ""),
"status": str(item.status or ""),
"conversation_id": str(item.conversation_id or ""),
"updated_at": item.updated_at.isoformat(),
},
)
)
if len(hits) >= max(1, int(limit)):
break
return hits
class ManticoreMemorySearchBackend(BaseMemorySearchBackend):
name = "manticore"
_table_ready_cache: dict[str, float] = {}
_table_ready_ttl_seconds = 30.0
def __init__(self):
self.base_url = str(
getattr(settings, "MANTICORE_HTTP_URL", "http://localhost:9308")
).rstrip("/")
self.table = (
str(getattr(settings, "MANTICORE_MEMORY_TABLE", "gia_memory_items")).strip()
or "gia_memory_items"
)
self.timeout_seconds = int(getattr(settings, "MANTICORE_HTTP_TIMEOUT", 5) or 5)
self._table_cache_key = f"{self.base_url}|{self.table}"
def _sql(self, query: str) -> dict[str, Any]:
response = requests.post(
f"{self.base_url}/sql",
data={"mode": "raw", "query": query},
timeout=self.timeout_seconds,
)
response.raise_for_status()
payload = response.json()
if isinstance(payload, list):
return payload[0] if payload else {}
return dict(payload or {})
def ensure_table(self) -> None:
last_ready = float(
self._table_ready_cache.get(self._table_cache_key, 0.0) or 0.0
)
if (time.time() - last_ready) <= float(self._table_ready_ttl_seconds):
return
self._sql(
(
f"CREATE TABLE IF NOT EXISTS {self.table} ("
"id BIGINT,"
"memory_uuid STRING,"
"user_id BIGINT,"
"conversation_id STRING,"
"memory_kind STRING,"
"status STRING,"
"updated_ts BIGINT,"
"summary TEXT,"
"body TEXT"
")"
)
)
self._table_ready_cache[self._table_cache_key] = time.time()
def _doc_id(self, memory_id: str) -> int:
digest = hashlib.blake2b(
str(memory_id or "").encode("utf-8"),
digest_size=8,
).digest()
value = int.from_bytes(digest, byteorder="big", signed=False)
return max(1, int(value))
def _escape(self, value: Any) -> str:
text = str(value or "")
text = text.replace("\\", "\\\\").replace("'", "\\'")
return text
def upsert(self, item: MemoryItem) -> None:
self.ensure_table()
memory_id = str(item.id)
doc_id = self._doc_id(memory_id)
summary = _flatten_to_text(item.content)[:280]
body = _flatten_to_text(item.content)
updated_ts = int(item.updated_at.timestamp() * 1000)
query = (
f"REPLACE INTO {self.table} "
"(id,memory_uuid,user_id,conversation_id,memory_kind,status,updated_ts,summary,body) "
f"VALUES ({doc_id},'{self._escape(memory_id)}',{int(item.user_id)},"
f"'{self._escape(item.conversation_id)}','{self._escape(item.memory_kind)}',"
f"'{self._escape(item.status)}',{updated_ts},"
f"'{self._escape(summary)}','{self._escape(body)}')"
)
self._sql(query)
def _build_upsert_values_clause(self, item: MemoryItem) -> str:
memory_id = str(item.id)
doc_id = self._doc_id(memory_id)
summary = _flatten_to_text(item.content)[:280]
body = _flatten_to_text(item.content)
updated_ts = int(item.updated_at.timestamp() * 1000)
return (
f"({doc_id},'{self._escape(memory_id)}',{int(item.user_id)},"
f"'{self._escape(item.conversation_id)}','{self._escape(item.memory_kind)}',"
f"'{self._escape(item.status)}',{updated_ts},"
f"'{self._escape(summary)}','{self._escape(body)}')"
)
def delete(self, memory_id: str) -> None:
self.ensure_table()
doc_id = self._doc_id(memory_id)
self._sql(f"DELETE FROM {self.table} WHERE id={doc_id}")
def reindex(
self,
*,
user_id: int | None = None,
include_statuses: tuple[str, ...] = ("active",),
limit: int = 2000,
) -> dict[str, int]:
self.ensure_table()
queryset = MemoryItem.objects.all().order_by("-updated_at")
if user_id is not None:
queryset = queryset.filter(user_id=int(user_id))
if include_statuses:
queryset = queryset.filter(status__in=list(include_statuses))
scanned = 0
indexed = 0
batch_size = 100
values: list[str] = []
for item in queryset[: max(1, int(limit))]:
scanned += 1
try:
values.append(self._build_upsert_values_clause(item))
except Exception as exc:
log.warning(
"memory-search upsert build failed id=%s err=%s", item.id, exc
)
continue
if len(values) >= batch_size:
self._sql(
f"REPLACE INTO {self.table} "
"(id,memory_uuid,user_id,conversation_id,memory_kind,status,updated_ts,summary,body) "
f"VALUES {','.join(values)}"
)
indexed += len(values)
values = []
if values:
self._sql(
f"REPLACE INTO {self.table} "
"(id,memory_uuid,user_id,conversation_id,memory_kind,status,updated_ts,summary,body) "
f"VALUES {','.join(values)}"
)
indexed += len(values)
return {"scanned": scanned, "indexed": indexed}
def search(
self,
*,
user_id: int,
query: str,
conversation_id: str = "",
limit: int = 20,
include_statuses: tuple[str, ...] = ("active",),
) -> list[MemorySearchHit]:
self.ensure_table()
needle = str(query or "").strip()
if not needle:
return []
where_parts = [f"user_id={int(user_id)}", f"MATCH('{self._escape(needle)}')"]
if conversation_id:
where_parts.append(f"conversation_id='{self._escape(conversation_id)}'")
statuses = [
str(item or "").strip()
for item in include_statuses
if str(item or "").strip()
]
if statuses:
in_clause = ",".join(f"'{self._escape(item)}'" for item in statuses)
where_parts.append(f"status IN ({in_clause})")
where_sql = " AND ".join(where_parts)
query_sql = (
f"SELECT memory_uuid,memory_kind,status,conversation_id,updated_ts,summary,WEIGHT() AS score "
f"FROM {self.table} WHERE {where_sql} ORDER BY score DESC LIMIT {max(1, int(limit))}"
)
payload = self._sql(query_sql)
rows = list(payload.get("data") or [])
hits = []
for row in rows:
item = dict(row or {})
hits.append(
MemorySearchHit(
memory_id=str(item.get("memory_uuid") or ""),
score=float(item.get("score") or 0.0),
summary=str(item.get("summary") or ""),
payload={
"memory_kind": str(item.get("memory_kind") or ""),
"status": str(item.get("status") or ""),
"conversation_id": str(item.get("conversation_id") or ""),
"updated_ts": int(item.get("updated_ts") or 0),
},
)
)
return hits
def get_memory_search_backend() -> BaseMemorySearchBackend:
backend = str(getattr(settings, "MEMORY_SEARCH_BACKEND", "django")).strip().lower()
if backend == "manticore":
return ManticoreMemorySearchBackend()
return DjangoMemorySearchBackend()
def backend_status() -> dict[str, Any]:
backend = get_memory_search_backend()
status = {
"backend": getattr(backend, "name", "unknown"),
"ok": True,
"ts": int(time.time() * 1000),
}
if isinstance(backend, ManticoreMemorySearchBackend):
try:
backend.ensure_table()
status["manticore_http_url"] = backend.base_url
status["manticore_table"] = backend.table
except Exception as exc:
status["ok"] = False
status["error"] = str(exc)
return status