284 lines
9.4 KiB
Python
284 lines
9.4 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import requests
|
|
from django.conf import settings
|
|
|
|
from core.models import MemoryItem
|
|
from core.util import logs
|
|
|
|
log = logs.get_logger("memory-search")
|
|
|
|
|
|
@dataclass
|
|
class MemorySearchHit:
|
|
memory_id: str
|
|
score: float
|
|
summary: str
|
|
payload: dict[str, Any]
|
|
|
|
|
|
def _flatten_to_text(value: Any) -> str:
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, dict):
|
|
parts = []
|
|
for key, item in value.items():
|
|
parts.append(str(key))
|
|
parts.append(_flatten_to_text(item))
|
|
return " ".join(part for part in parts if part).strip()
|
|
if isinstance(value, (list, tuple, set)):
|
|
return " ".join(_flatten_to_text(item) for item in value if item).strip()
|
|
return str(value).strip()
|
|
|
|
|
|
class BaseMemorySearchBackend:
|
|
def upsert(self, item: MemoryItem) -> None:
|
|
raise NotImplementedError
|
|
|
|
def delete(self, memory_id: str) -> None:
|
|
raise NotImplementedError
|
|
|
|
def search(
|
|
self,
|
|
*,
|
|
user_id: int,
|
|
query: str,
|
|
conversation_id: str = "",
|
|
limit: int = 20,
|
|
include_statuses: tuple[str, ...] = ("active",),
|
|
) -> list[MemorySearchHit]:
|
|
raise NotImplementedError
|
|
|
|
def reindex(
|
|
self,
|
|
*,
|
|
user_id: int | None = None,
|
|
include_statuses: tuple[str, ...] = ("active",),
|
|
limit: int = 2000,
|
|
) -> dict[str, int]:
|
|
queryset = MemoryItem.objects.all().order_by("-updated_at")
|
|
if user_id is not None:
|
|
queryset = queryset.filter(user_id=int(user_id))
|
|
if include_statuses:
|
|
queryset = queryset.filter(status__in=list(include_statuses))
|
|
|
|
scanned = 0
|
|
indexed = 0
|
|
for item in queryset[: max(1, int(limit))]:
|
|
scanned += 1
|
|
try:
|
|
self.upsert(item)
|
|
indexed += 1
|
|
except Exception as exc:
|
|
log.warning("memory-search upsert failed id=%s err=%s", item.id, exc)
|
|
return {"scanned": scanned, "indexed": indexed}
|
|
|
|
|
|
class DjangoMemorySearchBackend(BaseMemorySearchBackend):
|
|
name = "django"
|
|
|
|
def upsert(self, item: MemoryItem) -> None:
|
|
# No-op because Django backend queries source-of-truth rows directly.
|
|
_ = item
|
|
|
|
def delete(self, memory_id: str) -> None:
|
|
_ = memory_id
|
|
|
|
def search(
|
|
self,
|
|
*,
|
|
user_id: int,
|
|
query: str,
|
|
conversation_id: str = "",
|
|
limit: int = 20,
|
|
include_statuses: tuple[str, ...] = ("active",),
|
|
) -> list[MemorySearchHit]:
|
|
needle = str(query or "").strip().lower()
|
|
if not needle:
|
|
return []
|
|
|
|
queryset = MemoryItem.objects.filter(user_id=int(user_id))
|
|
if conversation_id:
|
|
queryset = queryset.filter(conversation_id=conversation_id)
|
|
if include_statuses:
|
|
queryset = queryset.filter(status__in=list(include_statuses))
|
|
|
|
hits: list[MemorySearchHit] = []
|
|
for item in queryset.order_by("-updated_at")[:500]:
|
|
text_blob = _flatten_to_text(item.content).lower()
|
|
if needle not in text_blob:
|
|
continue
|
|
raw_summary = _flatten_to_text(item.content)
|
|
summary = raw_summary[:280]
|
|
score = 1.0 + min(1.0, len(needle) / max(1.0, len(text_blob)))
|
|
hits.append(
|
|
MemorySearchHit(
|
|
memory_id=str(item.id),
|
|
score=float(score),
|
|
summary=summary,
|
|
payload={
|
|
"memory_kind": str(item.memory_kind or ""),
|
|
"status": str(item.status or ""),
|
|
"conversation_id": str(item.conversation_id or ""),
|
|
"updated_at": item.updated_at.isoformat(),
|
|
},
|
|
)
|
|
)
|
|
if len(hits) >= max(1, int(limit)):
|
|
break
|
|
return hits
|
|
|
|
|
|
class ManticoreMemorySearchBackend(BaseMemorySearchBackend):
|
|
name = "manticore"
|
|
|
|
def __init__(self):
|
|
self.base_url = str(
|
|
getattr(settings, "MANTICORE_HTTP_URL", "http://localhost:9308")
|
|
).rstrip("/")
|
|
self.table = str(
|
|
getattr(settings, "MANTICORE_MEMORY_TABLE", "gia_memory_items")
|
|
).strip() or "gia_memory_items"
|
|
self.timeout_seconds = int(getattr(settings, "MANTICORE_HTTP_TIMEOUT", 5) or 5)
|
|
|
|
def _sql(self, query: str) -> dict[str, Any]:
|
|
response = requests.post(
|
|
f"{self.base_url}/sql",
|
|
data={"mode": "raw", "query": query},
|
|
timeout=self.timeout_seconds,
|
|
)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
if isinstance(payload, list):
|
|
return payload[0] if payload else {}
|
|
return dict(payload or {})
|
|
|
|
def ensure_table(self) -> None:
|
|
self._sql(
|
|
(
|
|
f"CREATE TABLE IF NOT EXISTS {self.table} ("
|
|
"id BIGINT,"
|
|
"memory_uuid STRING,"
|
|
"user_id BIGINT,"
|
|
"conversation_id STRING,"
|
|
"memory_kind STRING,"
|
|
"status STRING,"
|
|
"updated_ts BIGINT,"
|
|
"summary TEXT,"
|
|
"body TEXT"
|
|
")"
|
|
)
|
|
)
|
|
|
|
def _doc_id(self, memory_id: str) -> int:
|
|
digest = hashlib.blake2b(
|
|
str(memory_id or "").encode("utf-8"),
|
|
digest_size=8,
|
|
).digest()
|
|
value = int.from_bytes(digest, byteorder="big", signed=False)
|
|
return max(1, int(value))
|
|
|
|
def _escape(self, value: Any) -> str:
|
|
text = str(value or "")
|
|
text = text.replace("\\", "\\\\").replace("'", "\\'")
|
|
return text
|
|
|
|
def upsert(self, item: MemoryItem) -> None:
|
|
self.ensure_table()
|
|
memory_id = str(item.id)
|
|
doc_id = self._doc_id(memory_id)
|
|
summary = _flatten_to_text(item.content)[:280]
|
|
body = _flatten_to_text(item.content)
|
|
updated_ts = int(item.updated_at.timestamp() * 1000)
|
|
query = (
|
|
f"REPLACE INTO {self.table} "
|
|
"(id,memory_uuid,user_id,conversation_id,memory_kind,status,updated_ts,summary,body) "
|
|
f"VALUES ({doc_id},'{self._escape(memory_id)}',{int(item.user_id)},"
|
|
f"'{self._escape(item.conversation_id)}','{self._escape(item.memory_kind)}',"
|
|
f"'{self._escape(item.status)}',{updated_ts},"
|
|
f"'{self._escape(summary)}','{self._escape(body)}')"
|
|
)
|
|
self._sql(query)
|
|
|
|
def delete(self, memory_id: str) -> None:
|
|
self.ensure_table()
|
|
doc_id = self._doc_id(memory_id)
|
|
self._sql(f"DELETE FROM {self.table} WHERE id={doc_id}")
|
|
|
|
def search(
|
|
self,
|
|
*,
|
|
user_id: int,
|
|
query: str,
|
|
conversation_id: str = "",
|
|
limit: int = 20,
|
|
include_statuses: tuple[str, ...] = ("active",),
|
|
) -> list[MemorySearchHit]:
|
|
self.ensure_table()
|
|
needle = str(query or "").strip()
|
|
if not needle:
|
|
return []
|
|
|
|
where_parts = [f"user_id={int(user_id)}", f"MATCH('{self._escape(needle)}')"]
|
|
if conversation_id:
|
|
where_parts.append(f"conversation_id='{self._escape(conversation_id)}'")
|
|
statuses = [str(item or "").strip() for item in include_statuses if str(item or "").strip()]
|
|
if statuses:
|
|
in_clause = ",".join(f"'{self._escape(item)}'" for item in statuses)
|
|
where_parts.append(f"status IN ({in_clause})")
|
|
where_sql = " AND ".join(where_parts)
|
|
query_sql = (
|
|
f"SELECT memory_uuid,memory_kind,status,conversation_id,updated_ts,summary,WEIGHT() AS score "
|
|
f"FROM {self.table} WHERE {where_sql} ORDER BY score DESC LIMIT {max(1, int(limit))}"
|
|
)
|
|
payload = self._sql(query_sql)
|
|
rows = list(payload.get("data") or [])
|
|
hits = []
|
|
for row in rows:
|
|
item = dict(row or {})
|
|
hits.append(
|
|
MemorySearchHit(
|
|
memory_id=str(item.get("memory_uuid") or ""),
|
|
score=float(item.get("score") or 0.0),
|
|
summary=str(item.get("summary") or ""),
|
|
payload={
|
|
"memory_kind": str(item.get("memory_kind") or ""),
|
|
"status": str(item.get("status") or ""),
|
|
"conversation_id": str(item.get("conversation_id") or ""),
|
|
"updated_ts": int(item.get("updated_ts") or 0),
|
|
},
|
|
)
|
|
)
|
|
return hits
|
|
|
|
|
|
def get_memory_search_backend() -> BaseMemorySearchBackend:
|
|
backend = str(getattr(settings, "MEMORY_SEARCH_BACKEND", "django")).strip().lower()
|
|
if backend == "manticore":
|
|
return ManticoreMemorySearchBackend()
|
|
return DjangoMemorySearchBackend()
|
|
|
|
|
|
def backend_status() -> dict[str, Any]:
|
|
backend = get_memory_search_backend()
|
|
status = {
|
|
"backend": getattr(backend, "name", "unknown"),
|
|
"ok": True,
|
|
"ts": int(time.time() * 1000),
|
|
}
|
|
if isinstance(backend, ManticoreMemorySearchBackend):
|
|
try:
|
|
backend.ensure_table()
|
|
status["manticore_http_url"] = backend.base_url
|
|
status["manticore_table"] = backend.table
|
|
except Exception as exc:
|
|
status["ok"] = False
|
|
status["error"] = str(exc)
|
|
return status
|