Continue AI features and improve protocol support
This commit is contained in:
627
core/clients/whatsapp.py
Normal file
627
core/clients/whatsapp.py
Normal file
@@ -0,0 +1,627 @@
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.conf import settings
|
||||
|
||||
from core.clients import ClientBase, transport
|
||||
from core.messaging import history, media_bridge
|
||||
from core.models import PersonIdentifier
|
||||
|
||||
|
||||
class WhatsAppClient(ClientBase):
|
||||
"""
|
||||
Async WhatsApp transport backed by Neonize.
|
||||
|
||||
Design notes:
|
||||
- Runs in UR process.
|
||||
- Publishes runtime state to shared cache via transport.
|
||||
- Degrades gracefully when Neonize/session is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(self, ur, loop, service="whatsapp"):
|
||||
super().__init__(ur, loop, service)
|
||||
self._task = None
|
||||
self._stopping = False
|
||||
self._client = None
|
||||
self._build_jid = None
|
||||
self._connected = False
|
||||
self._last_qr_payload = ""
|
||||
self._accounts = []
|
||||
|
||||
self.enabled = bool(
|
||||
str(getattr(settings, "WHATSAPP_ENABLED", "false")).lower()
|
||||
in {"1", "true", "yes", "on"}
|
||||
)
|
||||
self.client_name = str(
|
||||
getattr(settings, "WHATSAPP_CLIENT_NAME", "gia_whatsapp")
|
||||
).strip() or "gia_whatsapp"
|
||||
self.database_url = str(
|
||||
getattr(settings, "WHATSAPP_DATABASE_URL", "")
|
||||
).strip()
|
||||
|
||||
transport.register_runtime_client(self.service, self)
|
||||
self._publish_state(
|
||||
connected=False,
|
||||
warning=(
|
||||
"WhatsApp runtime is disabled by settings."
|
||||
if not self.enabled
|
||||
else ""
|
||||
),
|
||||
accounts=[],
|
||||
)
|
||||
|
||||
def _publish_state(self, **updates):
|
||||
state = transport.update_runtime_state(self.service, **updates)
|
||||
accounts = state.get("accounts")
|
||||
if isinstance(accounts, list):
|
||||
self._accounts = accounts
|
||||
|
||||
def start(self):
|
||||
if not self.enabled:
|
||||
self.log.info("whatsapp client disabled by settings")
|
||||
return
|
||||
if self._task is None:
|
||||
self.log.info("whatsapp neonize client starting")
|
||||
self._task = self.loop.create_task(self._run())
|
||||
|
||||
async def _run(self):
|
||||
try:
|
||||
from neonize.aioze.client import NewAClient
|
||||
from neonize.aioze import events as wa_events
|
||||
try:
|
||||
from neonize.utils import build_jid as wa_build_jid
|
||||
except Exception:
|
||||
wa_build_jid = None
|
||||
except Exception as exc:
|
||||
self._publish_state(
|
||||
connected=False,
|
||||
warning=f"Neonize not available: {exc}",
|
||||
accounts=[],
|
||||
)
|
||||
self.log.warning("whatsapp neonize import failed: %s", exc)
|
||||
return
|
||||
|
||||
self._build_jid = wa_build_jid
|
||||
self._client = self._build_client(NewAClient)
|
||||
if self._client is None:
|
||||
self._publish_state(
|
||||
connected=False,
|
||||
warning="Failed to initialize Neonize client.",
|
||||
accounts=[],
|
||||
)
|
||||
return
|
||||
|
||||
self._register_event_handlers(wa_events)
|
||||
|
||||
try:
|
||||
await self._maybe_await(self._client.connect())
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._publish_state(
|
||||
connected=False,
|
||||
warning=f"WhatsApp connect failed: {exc}",
|
||||
accounts=[],
|
||||
)
|
||||
self.log.warning("whatsapp connect failed: %s", exc)
|
||||
return
|
||||
|
||||
# Keep task alive so state/callbacks remain active.
|
||||
while not self._stopping:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def _build_client(self, cls):
|
||||
candidates = []
|
||||
if self.database_url:
|
||||
candidates.append((self.client_name, self.database_url))
|
||||
candidates.append((self.client_name,))
|
||||
for args in candidates:
|
||||
try:
|
||||
return cls(*args)
|
||||
except TypeError:
|
||||
continue
|
||||
except Exception as exc:
|
||||
self.log.warning("whatsapp client init failed for args %s: %s", args, exc)
|
||||
try:
|
||||
if self.database_url:
|
||||
return cls(name=self.client_name, database=self.database_url)
|
||||
return cls(name=self.client_name)
|
||||
except Exception as exc:
|
||||
self.log.warning("whatsapp client init failed: %s", exc)
|
||||
return None
|
||||
|
||||
def _register_event_handlers(self, wa_events):
|
||||
connected_ev = getattr(wa_events, "ConnectedEv", None)
|
||||
message_ev = getattr(wa_events, "MessageEv", None)
|
||||
receipt_ev = getattr(wa_events, "ReceiptEv", None)
|
||||
presence_ev = getattr(wa_events, "PresenceEv", None)
|
||||
pair_ev = getattr(wa_events, "PairStatusEv", None)
|
||||
|
||||
if connected_ev is not None:
|
||||
|
||||
async def on_connected(client, event: connected_ev):
|
||||
self._connected = True
|
||||
account = await self._resolve_account_identifier()
|
||||
self._publish_state(
|
||||
connected=True,
|
||||
warning="",
|
||||
accounts=[account] if account else [self.client_name],
|
||||
)
|
||||
|
||||
self._client.event(on_connected)
|
||||
|
||||
if message_ev is not None:
|
||||
|
||||
async def on_message(client, event: message_ev):
|
||||
await self._handle_message_event(event)
|
||||
|
||||
self._client.event(on_message)
|
||||
|
||||
if receipt_ev is not None:
|
||||
|
||||
async def on_receipt(client, event: receipt_ev):
|
||||
await self._handle_receipt_event(event)
|
||||
|
||||
self._client.event(on_receipt)
|
||||
|
||||
if presence_ev is not None:
|
||||
|
||||
async def on_presence(client, event: presence_ev):
|
||||
await self._handle_presence_event(event)
|
||||
|
||||
self._client.event(on_presence)
|
||||
|
||||
if pair_ev is not None:
|
||||
|
||||
async def on_pair_status(client, event: pair_ev):
|
||||
qr_payload = self._extract_pair_qr(event)
|
||||
if qr_payload:
|
||||
self._last_qr_payload = qr_payload
|
||||
self._publish_state(
|
||||
pair_qr=qr_payload,
|
||||
warning="Scan QR to pair WhatsApp account.",
|
||||
)
|
||||
|
||||
self._client.event(on_pair_status)
|
||||
|
||||
async def _maybe_await(self, value):
|
||||
if asyncio.iscoroutine(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
async def _resolve_account_identifier(self):
|
||||
if self._client is None:
|
||||
return ""
|
||||
if not hasattr(self._client, "get_me"):
|
||||
return self.client_name
|
||||
try:
|
||||
me = await self._maybe_await(self._client.get_me())
|
||||
except Exception:
|
||||
return self.client_name
|
||||
# Support both dict-like and object-like payloads.
|
||||
for path in (
|
||||
("JID", "User"),
|
||||
("jid",),
|
||||
("user",),
|
||||
("ID",),
|
||||
):
|
||||
value = self._pluck(me, *path)
|
||||
if value:
|
||||
return str(value)
|
||||
return self.client_name
|
||||
|
||||
def _pluck(self, obj, *path):
|
||||
current = obj
|
||||
for key in path:
|
||||
if current is None:
|
||||
return None
|
||||
if isinstance(current, dict):
|
||||
current = current.get(key)
|
||||
continue
|
||||
if hasattr(current, key):
|
||||
current = getattr(current, key)
|
||||
continue
|
||||
return None
|
||||
return current
|
||||
|
||||
def _normalize_timestamp(self, raw_value):
|
||||
if raw_value is None:
|
||||
return int(time.time() * 1000)
|
||||
try:
|
||||
value = int(raw_value)
|
||||
except Exception:
|
||||
return int(time.time() * 1000)
|
||||
# WhatsApp libs often emit seconds. Promote to ms.
|
||||
if value < 10**12:
|
||||
return value * 1000
|
||||
return value
|
||||
|
||||
def _normalize_identifier_candidates(self, *values):
|
||||
out = set()
|
||||
for value in values:
|
||||
raw = str(value or "").strip()
|
||||
if not raw:
|
||||
continue
|
||||
out.add(raw)
|
||||
if "@" in raw:
|
||||
out.add(raw.split("@", 1)[0])
|
||||
digits = re.sub(r"[^0-9]", "", raw)
|
||||
if digits:
|
||||
out.add(digits)
|
||||
if not digits.startswith("+"):
|
||||
out.add(f"+{digits}")
|
||||
return out
|
||||
|
||||
def _is_media_message(self, message_obj):
|
||||
media_fields = (
|
||||
"imageMessage",
|
||||
"videoMessage",
|
||||
"audioMessage",
|
||||
"documentMessage",
|
||||
"stickerMessage",
|
||||
"image_message",
|
||||
"video_message",
|
||||
"audio_message",
|
||||
"document_message",
|
||||
"sticker_message",
|
||||
)
|
||||
for field in media_fields:
|
||||
value = self._pluck(message_obj, field)
|
||||
if value:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _download_event_media(self, event):
|
||||
if not self._client:
|
||||
return []
|
||||
msg_obj = self._pluck(event, "message")
|
||||
if msg_obj is None or not self._is_media_message(msg_obj):
|
||||
return []
|
||||
if not hasattr(self._client, "download_any"):
|
||||
return []
|
||||
|
||||
try:
|
||||
payload = await self._maybe_await(self._client.download_any(msg_obj))
|
||||
except Exception as exc:
|
||||
self.log.warning("whatsapp media download failed: %s", exc)
|
||||
return []
|
||||
|
||||
if isinstance(payload, memoryview):
|
||||
payload = payload.tobytes()
|
||||
if not isinstance(payload, (bytes, bytearray)):
|
||||
return []
|
||||
|
||||
filename = (
|
||||
self._pluck(msg_obj, "documentMessage", "fileName")
|
||||
or self._pluck(msg_obj, "document_message", "file_name")
|
||||
or f"wa-{int(time.time())}.bin"
|
||||
)
|
||||
content_type = (
|
||||
self._pluck(msg_obj, "documentMessage", "mimetype")
|
||||
or self._pluck(msg_obj, "document_message", "mimetype")
|
||||
or self._pluck(msg_obj, "imageMessage", "mimetype")
|
||||
or self._pluck(msg_obj, "image_message", "mimetype")
|
||||
or "application/octet-stream"
|
||||
)
|
||||
blob_key = media_bridge.put_blob(
|
||||
service="whatsapp",
|
||||
content=bytes(payload),
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
)
|
||||
if not blob_key:
|
||||
return []
|
||||
return [
|
||||
{
|
||||
"blob_key": blob_key,
|
||||
"filename": filename,
|
||||
"content_type": content_type,
|
||||
"size": len(payload),
|
||||
}
|
||||
]
|
||||
|
||||
async def _handle_message_event(self, event):
|
||||
msg_obj = self._pluck(event, "message")
|
||||
text = (
|
||||
self._pluck(msg_obj, "conversation")
|
||||
or self._pluck(msg_obj, "extendedTextMessage", "text")
|
||||
or self._pluck(msg_obj, "extended_text_message", "text")
|
||||
or ""
|
||||
)
|
||||
|
||||
sender = (
|
||||
self._pluck(event, "info", "message_source", "sender")
|
||||
or self._pluck(event, "info", "messageSource", "sender")
|
||||
or ""
|
||||
)
|
||||
chat = (
|
||||
self._pluck(event, "info", "message_source", "chat")
|
||||
or self._pluck(event, "info", "messageSource", "chat")
|
||||
or ""
|
||||
)
|
||||
raw_ts = (
|
||||
self._pluck(event, "info", "timestamp")
|
||||
or self._pluck(event, "info", "message_timestamp")
|
||||
or self._pluck(event, "timestamp")
|
||||
)
|
||||
ts = self._normalize_timestamp(raw_ts)
|
||||
|
||||
identifier_values = self._normalize_identifier_candidates(sender, chat)
|
||||
if not identifier_values:
|
||||
return
|
||||
|
||||
identifiers = await sync_to_async(list)(
|
||||
PersonIdentifier.objects.filter(
|
||||
service="whatsapp",
|
||||
identifier__in=list(identifier_values),
|
||||
)
|
||||
)
|
||||
if not identifiers:
|
||||
return
|
||||
|
||||
attachments = await self._download_event_media(event)
|
||||
xmpp_attachments = []
|
||||
if attachments:
|
||||
fetched = await asyncio.gather(
|
||||
*[transport.fetch_attachment(self.service, att) for att in attachments]
|
||||
)
|
||||
xmpp_attachments = [row for row in fetched if row]
|
||||
|
||||
payload = {
|
||||
"sender": str(sender or ""),
|
||||
"chat": str(chat or ""),
|
||||
"raw_event": str(type(event).__name__),
|
||||
}
|
||||
|
||||
for identifier in identifiers:
|
||||
session = await history.get_chat_session(identifier.user, identifier)
|
||||
await history.store_message(
|
||||
session=session,
|
||||
sender=str(sender or chat or ""),
|
||||
text=text,
|
||||
ts=ts,
|
||||
outgoing=False,
|
||||
)
|
||||
await self.ur.xmpp.client.send_from_external(
|
||||
identifier.user,
|
||||
identifier,
|
||||
text,
|
||||
is_outgoing_message=False,
|
||||
attachments=xmpp_attachments,
|
||||
)
|
||||
await self.ur.message_received(
|
||||
self.service,
|
||||
identifier=identifier,
|
||||
text=text,
|
||||
ts=ts,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
async def _handle_receipt_event(self, event):
|
||||
sender = (
|
||||
self._pluck(event, "info", "message_source", "sender")
|
||||
or self._pluck(event, "info", "messageSource", "sender")
|
||||
or ""
|
||||
)
|
||||
chat = (
|
||||
self._pluck(event, "info", "message_source", "chat")
|
||||
or self._pluck(event, "info", "messageSource", "chat")
|
||||
or ""
|
||||
)
|
||||
timestamps = []
|
||||
raw_ids = self._pluck(event, "message_ids") or []
|
||||
if isinstance(raw_ids, list):
|
||||
for item in raw_ids:
|
||||
try:
|
||||
value = int(item)
|
||||
timestamps.append(value * 1000 if value < 10**12 else value)
|
||||
except Exception:
|
||||
continue
|
||||
read_ts = self._normalize_timestamp(self._pluck(event, "timestamp") or int(time.time() * 1000))
|
||||
|
||||
for candidate in self._normalize_identifier_candidates(sender, chat):
|
||||
await self.ur.message_read(
|
||||
self.service,
|
||||
identifier=candidate,
|
||||
message_timestamps=timestamps,
|
||||
read_ts=read_ts,
|
||||
read_by=sender or chat,
|
||||
payload={"event": "receipt", "sender": str(sender), "chat": str(chat)},
|
||||
)
|
||||
|
||||
async def _handle_presence_event(self, event):
|
||||
sender = (
|
||||
self._pluck(event, "message_source", "sender")
|
||||
or self._pluck(event, "info", "message_source", "sender")
|
||||
or ""
|
||||
)
|
||||
chat = (
|
||||
self._pluck(event, "message_source", "chat")
|
||||
or self._pluck(event, "info", "message_source", "chat")
|
||||
or ""
|
||||
)
|
||||
presence = str(self._pluck(event, "presence") or "").strip().lower()
|
||||
|
||||
for candidate in self._normalize_identifier_candidates(sender, chat):
|
||||
if presence in {"composing", "typing", "recording"}:
|
||||
await self.ur.started_typing(
|
||||
self.service,
|
||||
identifier=candidate,
|
||||
payload={"presence": presence, "sender": str(sender), "chat": str(chat)},
|
||||
)
|
||||
elif presence:
|
||||
await self.ur.stopped_typing(
|
||||
self.service,
|
||||
identifier=candidate,
|
||||
payload={"presence": presence, "sender": str(sender), "chat": str(chat)},
|
||||
)
|
||||
|
||||
def _extract_pair_qr(self, event):
|
||||
for path in (
|
||||
("qr",),
|
||||
("qr_code",),
|
||||
("code",),
|
||||
("pair_code",),
|
||||
("pairCode",),
|
||||
("url",),
|
||||
):
|
||||
value = self._pluck(event, *path)
|
||||
if value:
|
||||
return str(value)
|
||||
return ""
|
||||
|
||||
def _to_jid(self, recipient):
|
||||
raw = str(recipient or "").strip()
|
||||
if not raw:
|
||||
return ""
|
||||
if self._build_jid is not None:
|
||||
try:
|
||||
return self._build_jid(raw)
|
||||
except Exception:
|
||||
pass
|
||||
if "@" in raw:
|
||||
return raw
|
||||
digits = re.sub(r"[^0-9]", "", raw)
|
||||
if digits:
|
||||
return f"{digits}@s.whatsapp.net"
|
||||
return raw
|
||||
|
||||
async def _fetch_attachment_payload(self, attachment):
|
||||
blob_key = (attachment or {}).get("blob_key")
|
||||
if blob_key:
|
||||
row = media_bridge.get_blob(blob_key)
|
||||
if row:
|
||||
return row
|
||||
|
||||
content = (attachment or {}).get("content")
|
||||
if isinstance(content, memoryview):
|
||||
content = content.tobytes()
|
||||
if isinstance(content, bytes):
|
||||
return {
|
||||
"content": content,
|
||||
"filename": (attachment or {}).get("filename") or "attachment.bin",
|
||||
"content_type": (attachment or {}).get("content_type")
|
||||
or "application/octet-stream",
|
||||
"size": len(content),
|
||||
}
|
||||
|
||||
url = (attachment or {}).get("url")
|
||||
if url:
|
||||
timeout = aiohttp.ClientTimeout(total=20)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
payload = await response.read()
|
||||
return {
|
||||
"content": payload,
|
||||
"filename": (attachment or {}).get("filename")
|
||||
or url.rstrip("/").split("/")[-1]
|
||||
or "attachment.bin",
|
||||
"content_type": (attachment or {}).get("content_type")
|
||||
or response.headers.get(
|
||||
"Content-Type", "application/octet-stream"
|
||||
),
|
||||
"size": len(payload),
|
||||
}
|
||||
return None
|
||||
|
||||
async def send_message_raw(self, recipient, text=None, attachments=None):
|
||||
if not self._client:
|
||||
return False
|
||||
jid = self._to_jid(recipient)
|
||||
if not jid:
|
||||
return False
|
||||
|
||||
sent_any = False
|
||||
for attachment in attachments or []:
|
||||
payload = await self._fetch_attachment_payload(attachment)
|
||||
if not payload:
|
||||
continue
|
||||
mime = str(payload.get("content_type") or "application/octet-stream").lower()
|
||||
data = payload.get("content") or b""
|
||||
filename = payload.get("filename") or "attachment.bin"
|
||||
|
||||
try:
|
||||
if mime.startswith("image/") and hasattr(self._client, "send_image"):
|
||||
await self._maybe_await(self._client.send_image(jid, data, caption=""))
|
||||
elif mime.startswith("video/") and hasattr(self._client, "send_video"):
|
||||
await self._maybe_await(self._client.send_video(jid, data, caption=""))
|
||||
elif mime.startswith("audio/") and hasattr(self._client, "send_audio"):
|
||||
await self._maybe_await(self._client.send_audio(jid, data))
|
||||
elif hasattr(self._client, "send_document"):
|
||||
await self._maybe_await(
|
||||
self._client.send_document(
|
||||
jid,
|
||||
data,
|
||||
filename=filename,
|
||||
mimetype=mime,
|
||||
caption="",
|
||||
)
|
||||
)
|
||||
sent_any = True
|
||||
except Exception as exc:
|
||||
self.log.warning("whatsapp attachment send failed: %s", exc)
|
||||
|
||||
if text:
|
||||
try:
|
||||
await self._maybe_await(self._client.send_message(jid, text))
|
||||
sent_any = True
|
||||
except TypeError:
|
||||
await self._maybe_await(self._client.send_message(jid, message=text))
|
||||
sent_any = True
|
||||
except Exception as exc:
|
||||
self.log.warning("whatsapp text send failed: %s", exc)
|
||||
return False
|
||||
|
||||
return int(time.time() * 1000) if sent_any else False
|
||||
|
||||
async def start_typing(self, identifier):
|
||||
if not self._client:
|
||||
return False
|
||||
jid = self._to_jid(identifier)
|
||||
if not jid:
|
||||
return False
|
||||
for method_name in ("send_chat_presence", "set_chat_presence"):
|
||||
if hasattr(self._client, method_name):
|
||||
method = getattr(self._client, method_name)
|
||||
try:
|
||||
await self._maybe_await(method(jid, "composing"))
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
async def stop_typing(self, identifier):
|
||||
if not self._client:
|
||||
return False
|
||||
jid = self._to_jid(identifier)
|
||||
if not jid:
|
||||
return False
|
||||
for method_name in ("send_chat_presence", "set_chat_presence"):
|
||||
if hasattr(self._client, method_name):
|
||||
method = getattr(self._client, method_name)
|
||||
try:
|
||||
await self._maybe_await(method(jid, "paused"))
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
async def fetch_attachment(self, attachment_ref):
|
||||
blob_key = (attachment_ref or {}).get("blob_key")
|
||||
if blob_key:
|
||||
return media_bridge.get_blob(blob_key)
|
||||
return None
|
||||
|
||||
def get_link_qr_png(self, device_name):
|
||||
_ = (device_name or "").strip()
|
||||
if not self._last_qr_payload:
|
||||
return None
|
||||
try:
|
||||
return transport._as_qr_png(self._last_qr_payload)
|
||||
except Exception:
|
||||
return None
|
||||
Reference in New Issue
Block a user