Begin adding AI memory
This commit is contained in:
3
core/memory/__init__.py
Normal file
3
core/memory/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .search_backend import get_memory_search_backend
|
||||
|
||||
__all__ = ["get_memory_search_backend"]
|
||||
283
core/memory/search_backend.py
Normal file
283
core/memory/search_backend.py
Normal file
@@ -0,0 +1,283 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from django.conf import settings
|
||||
|
||||
from core.models import MemoryItem
|
||||
from core.util import logs
|
||||
|
||||
log = logs.get_logger("memory-search")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemorySearchHit:
|
||||
memory_id: str
|
||||
score: float
|
||||
summary: str
|
||||
payload: dict[str, Any]
|
||||
|
||||
|
||||
def _flatten_to_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, dict):
|
||||
parts = []
|
||||
for key, item in value.items():
|
||||
parts.append(str(key))
|
||||
parts.append(_flatten_to_text(item))
|
||||
return " ".join(part for part in parts if part).strip()
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return " ".join(_flatten_to_text(item) for item in value if item).strip()
|
||||
return str(value).strip()
|
||||
|
||||
|
||||
class BaseMemorySearchBackend:
|
||||
def upsert(self, item: MemoryItem) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self, memory_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
query: str,
|
||||
conversation_id: str = "",
|
||||
limit: int = 20,
|
||||
include_statuses: tuple[str, ...] = ("active",),
|
||||
) -> list[MemorySearchHit]:
|
||||
raise NotImplementedError
|
||||
|
||||
def reindex(
|
||||
self,
|
||||
*,
|
||||
user_id: int | None = None,
|
||||
include_statuses: tuple[str, ...] = ("active",),
|
||||
limit: int = 2000,
|
||||
) -> dict[str, int]:
|
||||
queryset = MemoryItem.objects.all().order_by("-updated_at")
|
||||
if user_id is not None:
|
||||
queryset = queryset.filter(user_id=int(user_id))
|
||||
if include_statuses:
|
||||
queryset = queryset.filter(status__in=list(include_statuses))
|
||||
|
||||
scanned = 0
|
||||
indexed = 0
|
||||
for item in queryset[: max(1, int(limit))]:
|
||||
scanned += 1
|
||||
try:
|
||||
self.upsert(item)
|
||||
indexed += 1
|
||||
except Exception as exc:
|
||||
log.warning("memory-search upsert failed id=%s err=%s", item.id, exc)
|
||||
return {"scanned": scanned, "indexed": indexed}
|
||||
|
||||
|
||||
class DjangoMemorySearchBackend(BaseMemorySearchBackend):
|
||||
name = "django"
|
||||
|
||||
def upsert(self, item: MemoryItem) -> None:
|
||||
# No-op because Django backend queries source-of-truth rows directly.
|
||||
_ = item
|
||||
|
||||
def delete(self, memory_id: str) -> None:
|
||||
_ = memory_id
|
||||
|
||||
def search(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
query: str,
|
||||
conversation_id: str = "",
|
||||
limit: int = 20,
|
||||
include_statuses: tuple[str, ...] = ("active",),
|
||||
) -> list[MemorySearchHit]:
|
||||
needle = str(query or "").strip().lower()
|
||||
if not needle:
|
||||
return []
|
||||
|
||||
queryset = MemoryItem.objects.filter(user_id=int(user_id))
|
||||
if conversation_id:
|
||||
queryset = queryset.filter(conversation_id=conversation_id)
|
||||
if include_statuses:
|
||||
queryset = queryset.filter(status__in=list(include_statuses))
|
||||
|
||||
hits: list[MemorySearchHit] = []
|
||||
for item in queryset.order_by("-updated_at")[:500]:
|
||||
text_blob = _flatten_to_text(item.content).lower()
|
||||
if needle not in text_blob:
|
||||
continue
|
||||
raw_summary = _flatten_to_text(item.content)
|
||||
summary = raw_summary[:280]
|
||||
score = 1.0 + min(1.0, len(needle) / max(1.0, len(text_blob)))
|
||||
hits.append(
|
||||
MemorySearchHit(
|
||||
memory_id=str(item.id),
|
||||
score=float(score),
|
||||
summary=summary,
|
||||
payload={
|
||||
"memory_kind": str(item.memory_kind or ""),
|
||||
"status": str(item.status or ""),
|
||||
"conversation_id": str(item.conversation_id or ""),
|
||||
"updated_at": item.updated_at.isoformat(),
|
||||
},
|
||||
)
|
||||
)
|
||||
if len(hits) >= max(1, int(limit)):
|
||||
break
|
||||
return hits
|
||||
|
||||
|
||||
class ManticoreMemorySearchBackend(BaseMemorySearchBackend):
|
||||
name = "manticore"
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = str(
|
||||
getattr(settings, "MANTICORE_HTTP_URL", "http://127.0.0.1:9308")
|
||||
).rstrip("/")
|
||||
self.table = str(
|
||||
getattr(settings, "MANTICORE_MEMORY_TABLE", "gia_memory_items")
|
||||
).strip() or "gia_memory_items"
|
||||
self.timeout_seconds = int(getattr(settings, "MANTICORE_HTTP_TIMEOUT", 5) or 5)
|
||||
|
||||
def _sql(self, query: str) -> dict[str, Any]:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/sql",
|
||||
data={"mode": "raw", "query": query},
|
||||
timeout=self.timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
if isinstance(payload, list):
|
||||
return payload[0] if payload else {}
|
||||
return dict(payload or {})
|
||||
|
||||
def ensure_table(self) -> None:
|
||||
self._sql(
|
||||
(
|
||||
f"CREATE TABLE IF NOT EXISTS {self.table} ("
|
||||
"id BIGINT,"
|
||||
"memory_uuid STRING,"
|
||||
"user_id BIGINT,"
|
||||
"conversation_id STRING,"
|
||||
"memory_kind STRING,"
|
||||
"status STRING,"
|
||||
"updated_ts BIGINT,"
|
||||
"summary TEXT,"
|
||||
"body TEXT"
|
||||
")"
|
||||
)
|
||||
)
|
||||
|
||||
def _doc_id(self, memory_id: str) -> int:
|
||||
digest = hashlib.blake2b(
|
||||
str(memory_id or "").encode("utf-8"),
|
||||
digest_size=8,
|
||||
).digest()
|
||||
value = int.from_bytes(digest, byteorder="big", signed=False)
|
||||
return max(1, int(value))
|
||||
|
||||
def _escape(self, value: Any) -> str:
|
||||
text = str(value or "")
|
||||
text = text.replace("\\", "\\\\").replace("'", "\\'")
|
||||
return text
|
||||
|
||||
def upsert(self, item: MemoryItem) -> None:
|
||||
self.ensure_table()
|
||||
memory_id = str(item.id)
|
||||
doc_id = self._doc_id(memory_id)
|
||||
summary = _flatten_to_text(item.content)[:280]
|
||||
body = _flatten_to_text(item.content)
|
||||
updated_ts = int(item.updated_at.timestamp() * 1000)
|
||||
query = (
|
||||
f"REPLACE INTO {self.table} "
|
||||
"(id,memory_uuid,user_id,conversation_id,memory_kind,status,updated_ts,summary,body) "
|
||||
f"VALUES ({doc_id},'{self._escape(memory_id)}',{int(item.user_id)},"
|
||||
f"'{self._escape(item.conversation_id)}','{self._escape(item.memory_kind)}',"
|
||||
f"'{self._escape(item.status)}',{updated_ts},"
|
||||
f"'{self._escape(summary)}','{self._escape(body)}')"
|
||||
)
|
||||
self._sql(query)
|
||||
|
||||
def delete(self, memory_id: str) -> None:
|
||||
self.ensure_table()
|
||||
doc_id = self._doc_id(memory_id)
|
||||
self._sql(f"DELETE FROM {self.table} WHERE id={doc_id}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
query: str,
|
||||
conversation_id: str = "",
|
||||
limit: int = 20,
|
||||
include_statuses: tuple[str, ...] = ("active",),
|
||||
) -> list[MemorySearchHit]:
|
||||
self.ensure_table()
|
||||
needle = str(query or "").strip()
|
||||
if not needle:
|
||||
return []
|
||||
|
||||
where_parts = [f"user_id={int(user_id)}", f"MATCH('{self._escape(needle)}')"]
|
||||
if conversation_id:
|
||||
where_parts.append(f"conversation_id='{self._escape(conversation_id)}'")
|
||||
statuses = [str(item or "").strip() for item in include_statuses if str(item or "").strip()]
|
||||
if statuses:
|
||||
in_clause = ",".join(f"'{self._escape(item)}'" for item in statuses)
|
||||
where_parts.append(f"status IN ({in_clause})")
|
||||
where_sql = " AND ".join(where_parts)
|
||||
query_sql = (
|
||||
f"SELECT memory_uuid,memory_kind,status,conversation_id,updated_ts,summary,WEIGHT() AS score "
|
||||
f"FROM {self.table} WHERE {where_sql} ORDER BY score DESC LIMIT {max(1, int(limit))}"
|
||||
)
|
||||
payload = self._sql(query_sql)
|
||||
rows = list(payload.get("data") or [])
|
||||
hits = []
|
||||
for row in rows:
|
||||
item = dict(row or {})
|
||||
hits.append(
|
||||
MemorySearchHit(
|
||||
memory_id=str(item.get("memory_uuid") or ""),
|
||||
score=float(item.get("score") or 0.0),
|
||||
summary=str(item.get("summary") or ""),
|
||||
payload={
|
||||
"memory_kind": str(item.get("memory_kind") or ""),
|
||||
"status": str(item.get("status") or ""),
|
||||
"conversation_id": str(item.get("conversation_id") or ""),
|
||||
"updated_ts": int(item.get("updated_ts") or 0),
|
||||
},
|
||||
)
|
||||
)
|
||||
return hits
|
||||
|
||||
|
||||
def get_memory_search_backend() -> BaseMemorySearchBackend:
|
||||
backend = str(getattr(settings, "MEMORY_SEARCH_BACKEND", "django")).strip().lower()
|
||||
if backend == "manticore":
|
||||
return ManticoreMemorySearchBackend()
|
||||
return DjangoMemorySearchBackend()
|
||||
|
||||
|
||||
def backend_status() -> dict[str, Any]:
|
||||
backend = get_memory_search_backend()
|
||||
status = {
|
||||
"backend": getattr(backend, "name", "unknown"),
|
||||
"ok": True,
|
||||
"ts": int(time.time() * 1000),
|
||||
}
|
||||
if isinstance(backend, ManticoreMemorySearchBackend):
|
||||
try:
|
||||
backend.ensure_table()
|
||||
status["manticore_http_url"] = backend.base_url
|
||||
status["manticore_table"] = backend.table
|
||||
except Exception as exc:
|
||||
status["ok"] = False
|
||||
status["error"] = str(exc)
|
||||
return status
|
||||
Reference in New Issue
Block a user