Continue AI features and improve protocol support
This commit is contained in:
417
core/clients/transport.py
Normal file
417
core/clients/transport.py
Normal file
@@ -0,0 +1,417 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
import qrcode
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
|
||||
from core.clients import signalapi
|
||||
from core.messaging import media_bridge
|
||||
from core.util import logs
|
||||
|
||||
log = logs.get_logger("transport")
|
||||
|
||||
_RUNTIME_STATE_TTL = 60 * 60 * 24
|
||||
_RUNTIME_CLIENTS: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _service_key(service: str) -> str:
|
||||
return str(service or "").strip().lower()
|
||||
|
||||
|
||||
def _runtime_key(service: str) -> str:
|
||||
return f"gia:service:runtime:{_service_key(service)}"
|
||||
|
||||
|
||||
def _gateway_base(service: str) -> str:
|
||||
key = f"{service.upper()}_HTTP_URL"
|
||||
default = f"http://{service}:8080"
|
||||
return str(getattr(settings, key, default)).rstrip("/")
|
||||
|
||||
|
||||
def _as_qr_png(data: str) -> bytes:
|
||||
image = qrcode.make(data)
|
||||
stream = io.BytesIO()
|
||||
image.save(stream, format="PNG")
|
||||
return stream.getvalue()
|
||||
|
||||
|
||||
def _parse_timestamp(data: Any):
|
||||
if isinstance(data, dict):
|
||||
ts = data.get("timestamp")
|
||||
if ts:
|
||||
return ts
|
||||
return None
|
||||
|
||||
|
||||
def register_runtime_client(service: str, client: Any):
|
||||
"""
|
||||
Register an in-process runtime client (UR process).
|
||||
"""
|
||||
_RUNTIME_CLIENTS[_service_key(service)] = client
|
||||
|
||||
|
||||
def get_runtime_client(service: str):
|
||||
return _RUNTIME_CLIENTS.get(_service_key(service))
|
||||
|
||||
|
||||
def get_runtime_state(service: str) -> dict[str, Any]:
|
||||
return dict(cache.get(_runtime_key(service)) or {})
|
||||
|
||||
|
||||
def update_runtime_state(service: str, **updates):
|
||||
"""
|
||||
Persist runtime state to shared cache so web/UI process can read it.
|
||||
"""
|
||||
key = _runtime_key(service)
|
||||
state = dict(cache.get(key) or {})
|
||||
state.update(updates)
|
||||
state["updated_at"] = int(time.time())
|
||||
cache.set(key, state, timeout=_RUNTIME_STATE_TTL)
|
||||
return state
|
||||
|
||||
|
||||
def list_accounts(service: str):
|
||||
"""
|
||||
Return account identifiers for service UI list.
|
||||
"""
|
||||
service_key = _service_key(service)
|
||||
if service_key == "signal":
|
||||
import requests
|
||||
|
||||
base = str(getattr(settings, "SIGNAL_HTTP_URL", "http://signal:8080")).rstrip("/")
|
||||
try:
|
||||
response = requests.get(f"{base}/v1/accounts", timeout=20)
|
||||
if not response.ok:
|
||||
return []
|
||||
payload = orjson.loads(response.text or "[]")
|
||||
if isinstance(payload, list):
|
||||
return payload
|
||||
except Exception:
|
||||
return []
|
||||
return []
|
||||
|
||||
state = get_runtime_state(service_key)
|
||||
accounts = state.get("accounts") or []
|
||||
if isinstance(accounts, list):
|
||||
return accounts
|
||||
return []
|
||||
|
||||
|
||||
def get_service_warning(service: str) -> str:
|
||||
service_key = _service_key(service)
|
||||
if service_key == "signal":
|
||||
return ""
|
||||
|
||||
state = get_runtime_state(service_key)
|
||||
warning = str(state.get("warning") or "").strip()
|
||||
if warning:
|
||||
return warning
|
||||
|
||||
if not state.get("connected"):
|
||||
return (
|
||||
f"{service_key.title()} runtime is not connected yet. "
|
||||
"Start UR with the service enabled, open Services -> "
|
||||
f"{service_key.title()} -> Add Account, then scan the QR from "
|
||||
"WhatsApp Linked Devices."
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
async def _gateway_json(method: str, url: str, payload=None):
|
||||
timeout = aiohttp.ClientTimeout(total=20)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
request = getattr(session, method.lower())
|
||||
async with request(url, json=payload) as response:
|
||||
body = await response.read()
|
||||
if not body:
|
||||
return response.status, None
|
||||
try:
|
||||
return response.status, orjson.loads(body)
|
||||
except Exception:
|
||||
return response.status, None
|
||||
|
||||
|
||||
async def _normalize_gateway_attachment(service: str, row: dict, session):
|
||||
normalized = dict(row or {})
|
||||
content = normalized.get("content")
|
||||
if isinstance(content, memoryview):
|
||||
content = content.tobytes()
|
||||
if isinstance(content, bytes):
|
||||
blob_key = media_bridge.put_blob(
|
||||
service=service,
|
||||
content=content,
|
||||
filename=normalized.get("filename") or "attachment.bin",
|
||||
content_type=normalized.get("content_type") or "application/octet-stream",
|
||||
)
|
||||
return {
|
||||
"blob_key": blob_key,
|
||||
"filename": normalized.get("filename") or "attachment.bin",
|
||||
"content_type": normalized.get("content_type")
|
||||
or "application/octet-stream",
|
||||
"size": normalized.get("size") or len(content),
|
||||
}
|
||||
|
||||
if normalized.get("blob_key"):
|
||||
return normalized
|
||||
|
||||
source_url = normalized.get("url")
|
||||
if source_url:
|
||||
try:
|
||||
async with session.get(source_url) as response:
|
||||
if response.status == 200:
|
||||
payload = await response.read()
|
||||
blob_key = media_bridge.put_blob(
|
||||
service=service,
|
||||
content=payload,
|
||||
filename=normalized.get("filename")
|
||||
or source_url.rstrip("/").split("/")[-1]
|
||||
or "attachment.bin",
|
||||
content_type=normalized.get("content_type")
|
||||
or response.headers.get(
|
||||
"Content-Type", "application/octet-stream"
|
||||
),
|
||||
)
|
||||
return {
|
||||
"blob_key": blob_key,
|
||||
"filename": normalized.get("filename")
|
||||
or source_url.rstrip("/").split("/")[-1]
|
||||
or "attachment.bin",
|
||||
"content_type": normalized.get("content_type")
|
||||
or response.headers.get(
|
||||
"Content-Type", "application/octet-stream"
|
||||
),
|
||||
"size": normalized.get("size") or len(payload),
|
||||
}
|
||||
except Exception:
|
||||
log.warning("%s attachment fetch failed for %s", service, source_url)
|
||||
return normalized
|
||||
|
||||
|
||||
async def _gateway_send(service: str, recipient: str, text=None, attachments=None):
|
||||
base = _gateway_base(service)
|
||||
url = f"{base}/v1/send"
|
||||
timeout = aiohttp.ClientTimeout(total=20)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as media_session:
|
||||
normalized_attachments = await asyncio.gather(
|
||||
*[
|
||||
_normalize_gateway_attachment(service, dict(att or {}), media_session)
|
||||
for att in (attachments or [])
|
||||
]
|
||||
)
|
||||
|
||||
data = {
|
||||
"recipient": recipient,
|
||||
"text": text or "",
|
||||
"attachments": normalized_attachments,
|
||||
}
|
||||
status, payload = await _gateway_json("post", url, data)
|
||||
if 200 <= status < 300:
|
||||
ts = _parse_timestamp(payload)
|
||||
return ts if ts else True
|
||||
log.warning("%s gateway send failed (%s): %s", service, status, payload)
|
||||
return False
|
||||
|
||||
|
||||
async def _gateway_typing(service: str, recipient: str, started: bool):
|
||||
base = _gateway_base(service)
|
||||
action = "start" if started else "stop"
|
||||
url = f"{base}/v1/typing/{action}"
|
||||
payload = {"recipient": recipient}
|
||||
status, _ = await _gateway_json("post", url, payload)
|
||||
if 200 <= status < 300:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def send_message_raw(service: str, recipient: str, text=None, attachments=None):
|
||||
"""
|
||||
Unified outbound send path used by models/views/UR.
|
||||
"""
|
||||
service_key = _service_key(service)
|
||||
if service_key == "signal":
|
||||
return await signalapi.send_message_raw(recipient, text, attachments or [])
|
||||
|
||||
if service_key in {"whatsapp", "instagram"}:
|
||||
runtime_client = get_runtime_client(service_key)
|
||||
if runtime_client and hasattr(runtime_client, "send_message_raw"):
|
||||
try:
|
||||
runtime_result = await runtime_client.send_message_raw(
|
||||
recipient,
|
||||
text=text,
|
||||
attachments=attachments or [],
|
||||
)
|
||||
if runtime_result is not False and runtime_result is not None:
|
||||
return runtime_result
|
||||
except Exception as exc:
|
||||
log.warning("%s runtime send failed: %s", service_key, exc)
|
||||
return await _gateway_send(
|
||||
service_key,
|
||||
recipient,
|
||||
text=text,
|
||||
attachments=attachments or [],
|
||||
)
|
||||
|
||||
if service_key == "xmpp":
|
||||
raise NotImplementedError("Direct XMPP send is handled by the XMPP client.")
|
||||
raise NotImplementedError(f"Unsupported service: {service}")
|
||||
|
||||
|
||||
async def start_typing(service: str, recipient: str):
|
||||
service_key = _service_key(service)
|
||||
if service_key == "signal":
|
||||
await signalapi.start_typing(recipient)
|
||||
return True
|
||||
|
||||
if service_key in {"whatsapp", "instagram"}:
|
||||
runtime_client = get_runtime_client(service_key)
|
||||
if runtime_client and hasattr(runtime_client, "start_typing"):
|
||||
try:
|
||||
result = await runtime_client.start_typing(recipient)
|
||||
if result:
|
||||
return True
|
||||
except Exception as exc:
|
||||
log.warning("%s runtime start_typing failed: %s", service_key, exc)
|
||||
return await _gateway_typing(service_key, recipient, started=True)
|
||||
return False
|
||||
|
||||
|
||||
async def stop_typing(service: str, recipient: str):
|
||||
service_key = _service_key(service)
|
||||
if service_key == "signal":
|
||||
await signalapi.stop_typing(recipient)
|
||||
return True
|
||||
|
||||
if service_key in {"whatsapp", "instagram"}:
|
||||
runtime_client = get_runtime_client(service_key)
|
||||
if runtime_client and hasattr(runtime_client, "stop_typing"):
|
||||
try:
|
||||
result = await runtime_client.stop_typing(recipient)
|
||||
if result:
|
||||
return True
|
||||
except Exception as exc:
|
||||
log.warning("%s runtime stop_typing failed: %s", service_key, exc)
|
||||
return await _gateway_typing(service_key, recipient, started=False)
|
||||
return False
|
||||
|
||||
|
||||
async def fetch_attachment(service: str, attachment_ref: dict):
|
||||
"""
|
||||
Fetch attachment bytes from a source service or URL.
|
||||
"""
|
||||
service_key = _service_key(service)
|
||||
if service_key == "signal":
|
||||
attachment_id = attachment_ref.get("id") or attachment_ref.get("attachment_id")
|
||||
if not attachment_id:
|
||||
return None
|
||||
return await signalapi.fetch_signal_attachment(attachment_id)
|
||||
|
||||
runtime_client = get_runtime_client(service_key)
|
||||
if runtime_client and hasattr(runtime_client, "fetch_attachment"):
|
||||
try:
|
||||
from_runtime = await runtime_client.fetch_attachment(attachment_ref)
|
||||
if from_runtime:
|
||||
return from_runtime
|
||||
except Exception as exc:
|
||||
log.warning("%s runtime attachment fetch failed: %s", service_key, exc)
|
||||
|
||||
direct_url = attachment_ref.get("url")
|
||||
blob_key = attachment_ref.get("blob_key")
|
||||
if blob_key:
|
||||
return media_bridge.get_blob(blob_key)
|
||||
if direct_url:
|
||||
timeout = aiohttp.ClientTimeout(total=20)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(direct_url) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
content = await response.read()
|
||||
return {
|
||||
"content": content,
|
||||
"content_type": response.headers.get(
|
||||
"Content-Type",
|
||||
attachment_ref.get("content_type", "application/octet-stream"),
|
||||
),
|
||||
"filename": attachment_ref.get("filename")
|
||||
or direct_url.rstrip("/").split("/")[-1]
|
||||
or "attachment.bin",
|
||||
"size": len(content),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _qr_from_runtime_state(service: str) -> bytes | None:
|
||||
state = get_runtime_state(service)
|
||||
qr_payload = str(state.get("pair_qr") or "").strip()
|
||||
if not qr_payload:
|
||||
return None
|
||||
if qr_payload.startswith("data:image/") and "," in qr_payload:
|
||||
_, b64_data = qr_payload.split(",", 1)
|
||||
try:
|
||||
return base64.b64decode(b64_data)
|
||||
except Exception:
|
||||
return None
|
||||
return _as_qr_png(qr_payload)
|
||||
|
||||
|
||||
def get_link_qr(service: str, device_name: str):
|
||||
"""
|
||||
Returns PNG bytes for account-linking QR.
|
||||
|
||||
- Signal: uses signal-cli REST endpoint.
|
||||
- WhatsApp/Instagram: runtime QR from shared state when available.
|
||||
Falls back to local pairing token QR in development.
|
||||
"""
|
||||
service_key = _service_key(service)
|
||||
device = (device_name or "GIA Device").strip()
|
||||
|
||||
if service_key == "signal":
|
||||
import requests
|
||||
|
||||
base = str(getattr(settings, "SIGNAL_HTTP_URL", "http://signal:8080")).rstrip("/")
|
||||
response = requests.get(
|
||||
f"{base}/v1/qrcodelink",
|
||||
params={"device_name": device},
|
||||
timeout=20,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
if service_key in {"whatsapp", "instagram"}:
|
||||
runtime_client = get_runtime_client(service_key)
|
||||
if runtime_client and hasattr(runtime_client, "get_link_qr_png"):
|
||||
try:
|
||||
image_bytes = runtime_client.get_link_qr_png(device)
|
||||
if image_bytes:
|
||||
return image_bytes
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cached = _qr_from_runtime_state(service_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
token = secrets.token_urlsafe(24)
|
||||
uri = f"gia://{service_key}/link?device={device}&token={token}"
|
||||
update_runtime_state(
|
||||
service_key,
|
||||
pair_device=device,
|
||||
pair_requested_at=int(time.time()),
|
||||
warning=(
|
||||
"Waiting for runtime pairing QR. "
|
||||
"If this persists, check UR logs and Neonize session state."
|
||||
),
|
||||
)
|
||||
return _as_qr_png(uri)
|
||||
|
||||
raise NotImplementedError(f"Unsupported service for QR linking: {service}")
|
||||
|
||||
|
||||
def image_bytes_to_base64(image_bytes: bytes) -> str:
|
||||
return base64.b64encode(image_bytes).decode("utf-8")
|
||||
Reference in New Issue
Block a user