154 lines
4.2 KiB
Python
154 lines
4.2 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import re
|
|
from dataclasses import dataclass
|
|
|
|
from asgiref.sync import sync_to_async
|
|
from django.utils import timezone
|
|
|
|
from core.models import AnswerMemory, AnswerSuggestionEvent, Message
|
|
|
|
_WORD_RE = re.compile(r"[^a-z0-9\s]+", re.IGNORECASE)
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class RepeatAnswerSuggestion:
|
|
answer_memory_id: str
|
|
answer_text: str
|
|
score: float
|
|
|
|
|
|
def _normalize_question(text: str) -> str:
|
|
body = str(text or "").strip().lower()
|
|
body = _WORD_RE.sub(" ", body)
|
|
body = re.sub(r"\s+", " ", body).strip()
|
|
return body
|
|
|
|
|
|
def _fingerprint(text: str) -> str:
|
|
norm = _normalize_question(text)
|
|
if not norm:
|
|
return ""
|
|
return hashlib.sha1(norm.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def _is_question(text: str) -> bool:
|
|
body = str(text or "").strip()
|
|
if not body:
|
|
return False
|
|
low = body.lower()
|
|
return body.endswith("?") or low.startswith(
|
|
(
|
|
"what",
|
|
"why",
|
|
"how",
|
|
"when",
|
|
"where",
|
|
"who",
|
|
"can ",
|
|
"do ",
|
|
"did ",
|
|
"is ",
|
|
"are ",
|
|
)
|
|
)
|
|
|
|
|
|
def _is_group_channel(message: Message) -> bool:
|
|
channel = str(getattr(message, "source_chat_id", "") or "").strip().lower()
|
|
if channel.endswith("@g.us"):
|
|
return True
|
|
return (
|
|
str(getattr(message, "source_service", "") or "").strip().lower() == "xmpp"
|
|
and "conference." in channel
|
|
)
|
|
|
|
|
|
async def learn_from_message(message: Message) -> None:
|
|
if message is None:
|
|
return
|
|
text = str(message.text or "").strip()
|
|
if not text:
|
|
return
|
|
if dict(message.message_meta or {}).get("origin_tag"):
|
|
return
|
|
|
|
# Build memory by linking obvious reply answers to prior questions.
|
|
if message.reply_to_id and message.reply_to:
|
|
q_text = str(message.reply_to.text or "").strip()
|
|
if _is_question(q_text):
|
|
fp = _fingerprint(q_text)
|
|
if fp:
|
|
await sync_to_async(AnswerMemory.objects.create)(
|
|
user=message.user,
|
|
service=message.source_service or "web",
|
|
channel_identifier=message.source_chat_id or "",
|
|
question_fingerprint=fp,
|
|
question_text=q_text,
|
|
answer_message=message,
|
|
answer_text=text,
|
|
confidence_meta={"source": "reply_pair"},
|
|
)
|
|
|
|
|
|
async def find_repeat_answer(user, message: Message) -> RepeatAnswerSuggestion | None:
|
|
if message is None:
|
|
return None
|
|
if not _is_group_channel(message):
|
|
return None
|
|
if dict(message.message_meta or {}).get("origin_tag"):
|
|
return None
|
|
text = str(message.text or "").strip()
|
|
if not _is_question(text):
|
|
return None
|
|
|
|
fp = _fingerprint(text)
|
|
if not fp:
|
|
return None
|
|
|
|
# channel cooldown for repeated suggestions in short windows
|
|
cooldown_cutoff = timezone.now() - timezone.timedelta(minutes=3)
|
|
cooldown_exists = await sync_to_async(
|
|
lambda: AnswerSuggestionEvent.objects.filter(
|
|
user=user,
|
|
message__source_service=message.source_service,
|
|
message__source_chat_id=message.source_chat_id,
|
|
status="suggested",
|
|
created_at__gte=cooldown_cutoff,
|
|
).exists()
|
|
)()
|
|
if cooldown_exists:
|
|
return None
|
|
|
|
memory = await sync_to_async(
|
|
lambda: AnswerMemory.objects.filter(
|
|
user=user,
|
|
service=message.source_service or "web",
|
|
channel_identifier=message.source_chat_id or "",
|
|
question_fingerprint=fp,
|
|
)
|
|
.order_by("-created_at")
|
|
.first()
|
|
)()
|
|
if not memory:
|
|
return None
|
|
|
|
answer = str(memory.answer_text or "").strip()
|
|
if not answer:
|
|
return None
|
|
|
|
score = 0.99
|
|
await sync_to_async(AnswerSuggestionEvent.objects.create)(
|
|
user=user,
|
|
message=message,
|
|
status="suggested",
|
|
candidate_answer=memory,
|
|
score=score,
|
|
)
|
|
return RepeatAnswerSuggestion(
|
|
answer_memory_id=str(memory.id),
|
|
answer_text=answer,
|
|
score=score,
|
|
)
|