Implement AI workspace and mitigation workflow

This commit is contained in:
2026-02-15 04:27:28 +00:00
parent de2b9a9bbb
commit 2d3b8fdac6
64 changed files with 7669 additions and 769 deletions

View File

@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from core.util import logs

View File

@@ -8,7 +8,6 @@ from django.urls import reverse
from signalbot import Command, Context, SignalBot
from core.clients import ClientBase, signalapi
from core.lib.prompts.functions import delete_messages, truncate_and_summarize
from core.messaging import ai, history, natural, replies, utils
from core.models import Chat, Manipulation, PersonIdentifier, QueuedMessage
from core.util import logs
@@ -25,11 +24,90 @@ SIGNAL_PORT = 8080
SIGNAL_URL = f"{SIGNAL_HOST}:{SIGNAL_PORT}"
def _get_nested(payload, path):
current = payload
for key in path:
if not isinstance(current, dict):
return None
current = current.get(key)
return current
def _looks_like_signal_attachment(entry):
return isinstance(entry, dict) and (
"id" in entry or "attachmentId" in entry or "contentType" in entry
)
def _normalize_attachment(entry):
attachment_id = entry.get("id") or entry.get("attachmentId")
if attachment_id is None:
return None
return {
"id": attachment_id,
"content_type": entry.get("contentType", "application/octet-stream"),
"filename": entry.get("filename") or str(attachment_id),
"size": entry.get("size") or 0,
"width": entry.get("width"),
"height": entry.get("height"),
}
def _extract_attachments(raw_payload):
envelope = raw_payload.get("envelope", {})
candidate_paths = [
("dataMessage", "attachments"),
("syncMessage", "sentMessage", "attachments"),
("syncMessage", "editMessage", "dataMessage", "attachments"),
]
results = []
seen = set()
for path in candidate_paths:
found = _get_nested(envelope, path)
if not isinstance(found, list):
continue
for entry in found:
normalized = _normalize_attachment(entry)
if not normalized:
continue
key = str(normalized["id"])
if key in seen:
continue
seen.add(key)
results.append(normalized)
# Fallback: scan for attachment-shaped lists under envelope.
if not results:
stack = [envelope]
while stack:
node = stack.pop()
if isinstance(node, dict):
for value in node.values():
stack.append(value)
elif isinstance(node, list):
if node and all(_looks_like_signal_attachment(item) for item in node):
for entry in node:
normalized = _normalize_attachment(entry)
if not normalized:
continue
key = str(normalized["id"])
if key in seen:
continue
seen.add(key)
results.append(normalized)
else:
for value in node:
stack.append(value)
return results
class NewSignalBot(SignalBot):
def __init__(self, ur, service, config):
self.ur = ur
self.service = service
self.signal_rest = config["signal_service"] # keep your own copy
self.signal_rest = config["signal_service"] # keep your own copy
self.phone_number = config["phone_number"]
super().__init__(config)
self.log = logs.get_logger("signalI")
@@ -46,7 +124,9 @@ class NewSignalBot(SignalBot):
try:
resp = await session.get(uri)
if resp.status != 200:
self.log.error(f"contacts lookup failed: {resp.status} {await resp.text()}")
self.log.error(
f"contacts lookup failed: {resp.status} {await resp.text()}"
)
return None
contacts_data = await resp.json()
@@ -95,6 +175,7 @@ class HandleMessage(Command):
self.ur = ur
self.service = service
return super().__init__(*args, **kwargs)
async def handle(self, c: Context):
msg = {
"source": c.message.source,
@@ -106,10 +187,15 @@ class HandleMessage(Command):
"group": c.message.group,
"reaction": c.message.reaction,
"mentions": c.message.mentions,
"raw_message": c.message.raw_message
"raw_message": c.message.raw_message,
}
raw = json.loads(c.message.raw_message)
dest = raw.get("envelope", {}).get("syncMessage", {}).get("sentMessage", {}).get("destinationUuid")
dest = (
raw.get("envelope", {})
.get("syncMessage", {})
.get("sentMessage", {})
.get("destinationUuid")
)
account = raw.get("account", "")
source_name = raw.get("envelope", {}).get("sourceName", "")
@@ -125,9 +211,9 @@ class HandleMessage(Command):
is_from_bot = source_uuid == c.bot.bot_uuid
is_to_bot = dest == c.bot.bot_uuid or dest is None
reply_to_self = same_recipient and is_from_bot # Reply
reply_to_others = is_to_bot and not same_recipient # Reply
is_outgoing_message = is_from_bot and not is_to_bot # Do not reply
reply_to_self = same_recipient and is_from_bot # Reply
reply_to_others = is_to_bot and not same_recipient # Reply
is_outgoing_message = is_from_bot and not is_to_bot # Do not reply
# Determine the identifier to use
identifier_uuid = dest if is_from_bot else source_uuid
@@ -135,20 +221,8 @@ class HandleMessage(Command):
log.warning("No Signal identifier available for message routing.")
return
# Handle attachments
attachments = raw.get("envelope", {}).get("syncMessage", {}).get("sentMessage", {}).get("attachments", [])
if not attachments:
attachments = raw.get("envelope", {}).get("dataMessage", {}).get("attachments", [])
attachment_list = []
for attachment in attachments:
attachment_list.append({
"id": attachment["id"],
"content_type": attachment["contentType"],
"filename": attachment["filename"],
"size": attachment["size"],
"width": attachment.get("width"),
"height": attachment.get("height"),
})
# Handle attachments across multiple Signal payload variants.
attachment_list = _extract_attachments(raw)
# Get users/person identifiers for this Signal sender/recipient.
identifiers = await sync_to_async(list)(
@@ -160,9 +234,16 @@ class HandleMessage(Command):
xmpp_attachments = []
# Asynchronously fetch all attachments
tasks = [signalapi.fetch_signal_attachment(att["id"]) for att in attachment_list]
fetched_attachments = await asyncio.gather(*tasks)
log.info(f"ATTACHMENT LIST {attachment_list}")
if attachment_list:
tasks = [
signalapi.fetch_signal_attachment(att["id"]) for att in attachment_list
]
fetched_attachments = await asyncio.gather(*tasks)
else:
envelope = raw.get("envelope", {})
log.info(f"No attachments found. Envelope keys: {list(envelope.keys())}")
fetched_attachments = []
for fetched, att in zip(fetched_attachments, attachment_list):
if not fetched:
@@ -170,12 +251,14 @@ class HandleMessage(Command):
continue
# Attach fetched file to XMPP
xmpp_attachments.append({
"content": fetched["content"],
"content_type": fetched["content_type"],
"filename": fetched["filename"],
"size": fetched["size"],
})
xmpp_attachments.append(
{
"content": fetched["content"],
"content_type": fetched["content_type"],
"filename": fetched["filename"],
"size": fetched["size"],
}
)
# Forward incoming Signal messages to XMPP and apply mutate rules.
for identifier in identifiers:
@@ -200,7 +283,9 @@ class HandleMessage(Command):
)
log.info("Running Signal mutate prompt")
result = await ai.run_prompt(prompt, manip.ai)
log.info(f"Sending {len(xmpp_attachments)} attachments from Signal to XMPP.")
log.info(
f"Sending {len(xmpp_attachments)} attachments from Signal to XMPP."
)
await self.ur.xmpp.client.send_from_external(
user,
identifier,
@@ -209,7 +294,9 @@ class HandleMessage(Command):
attachments=xmpp_attachments,
)
else:
log.info(f"Sending {len(xmpp_attachments)} attachments from Signal to XMPP.")
log.info(
f"Sending {len(xmpp_attachments)} attachments from Signal to XMPP."
)
await self.ur.xmpp.client.send_from_external(
user,
identifier,
@@ -219,9 +306,7 @@ class HandleMessage(Command):
)
# TODO: Permission checks
manips = await sync_to_async(list)(
Manipulation.objects.filter(enabled=True)
)
manips = await sync_to_async(list)(Manipulation.objects.filter(enabled=True))
session_cache = {}
stored_messages = set()
for manip in manips:
@@ -233,7 +318,9 @@ class HandleMessage(Command):
person__in=manip.group.people.all(),
)
except PersonIdentifier.DoesNotExist:
log.warning(f"{manip.name}: Message from unknown identifier {identifier_uuid}.")
log.warning(
f"{manip.name}: Message from unknown identifier {identifier_uuid}."
)
continue
# Find/create ChatSession once per user/person.
@@ -241,7 +328,9 @@ class HandleMessage(Command):
if session_key in session_cache:
chat_session = session_cache[session_key]
else:
chat_session = await history.get_chat_session(manip.user, person_identifier)
chat_session = await history.get_chat_session(
manip.user, person_identifier
)
session_cache[session_key] = chat_session
# Store each incoming/outgoing event once per session.
@@ -270,10 +359,7 @@ class HandleMessage(Command):
elif manip.mode in ["active", "notify", "instant"]:
await utils.update_last_interaction(chat_session)
prompt = replies.generate_reply_prompt(
msg,
person_identifier.person,
manip,
chat_history
msg, person_identifier.person, manip, chat_history
)
log.info("Running context prompt")
@@ -307,14 +393,13 @@ class HandleMessage(Command):
custom_author="BOT",
)
await delete_messages(existing_queue)
await history.delete_queryset(existing_queue)
qm = await history.store_own_message(
session=chat_session,
text=result,
ts=ts + 1,
manip=manip,
queue=True,
)
accept = reverse(
"message_accept_api", kwargs={"message_id": qm.id}
@@ -333,9 +418,6 @@ class HandleMessage(Command):
else:
log.error(f"Mode {manip.mode} is not implemented")
# Manage truncation & summarization
await truncate_and_summarize(chat_session, manip.ai)
await sync_to_async(Chat.objects.update_or_create)(
source_uuid=source_uuid,
defaults={
@@ -353,9 +435,10 @@ class SignalClient(ClientBase):
ur,
self.service,
{
"signal_service": SIGNAL_URL,
"phone_number": "+447490296227",
})
"signal_service": SIGNAL_URL,
"phone_number": "+447490296227",
},
)
self.client.register(HandleMessage(self.ur, self.service))

View File

@@ -1,12 +1,12 @@
from rest_framework import status
import requests
from requests.exceptions import RequestException
import orjson
from django.conf import settings
import aiohttp
import base64
import asyncio
import base64
import aiohttp
import orjson
import requests
from django.conf import settings
from requests.exceptions import RequestException
from rest_framework import status
async def start_typing(uuid):
@@ -18,6 +18,7 @@ async def start_typing(uuid):
async with session.put(url, json=data) as response:
return await response.text() # Optional: Return response content
async def stop_typing(uuid):
base = getattr(settings, "SIGNAL_HTTP_URL", "http://signal:8080").rstrip("/")
url = f"{base}/v1/typing_indicator/{settings.SIGNAL_NUMBER}"
@@ -27,6 +28,7 @@ async def stop_typing(uuid):
async with session.delete(url, json=data) as response:
return await response.text() # Optional: Return response content
async def download_and_encode_base64(file_url, filename, content_type):
"""
Downloads a file from a given URL asynchronously, converts it to Base64,
@@ -51,12 +53,15 @@ async def download_and_encode_base64(file_url, filename, content_type):
base64_encoded = base64.b64encode(file_data).decode("utf-8")
# Format according to Signal's expected structure
return f"data:{content_type};filename={filename};base64,{base64_encoded}"
return (
f"data:{content_type};filename={filename};base64,{base64_encoded}"
)
except aiohttp.ClientError as e:
# log.error(f"Failed to download file: {file_url}, error: {e}")
return None
async def send_message_raw(recipient_uuid, text=None, attachments=[]):
"""
Sends a message using the Signal REST API, ensuring attachment links are not included in the text body.
@@ -75,11 +80,14 @@ async def send_message_raw(recipient_uuid, text=None, attachments=[]):
data = {
"recipients": [recipient_uuid],
"number": settings.SIGNAL_NUMBER,
"base64_attachments": []
"base64_attachments": [],
}
# Asynchronously download and encode all attachments
tasks = [download_and_encode_base64(att["url"], att["filename"], att["content_type"]) for att in attachments]
tasks = [
download_and_encode_base64(att["url"], att["filename"], att["content_type"])
for att in attachments
]
encoded_attachments = await asyncio.gather(*tasks)
# Filter out failed downloads (None values)
@@ -87,7 +95,7 @@ async def send_message_raw(recipient_uuid, text=None, attachments=[]):
# Remove the message body if it only contains an attachment link
if text and (text.strip() in [att["url"] for att in attachments]):
#log.info("Removing message body since it only contains an attachment link.")
# log.info("Removing message body since it only contains an attachment link.")
text = None # Don't send the link as text
if text:
@@ -103,6 +111,7 @@ async def send_message_raw(recipient_uuid, text=None, attachments=[]):
return ts if ts else False
return False
async def fetch_signal_attachment(attachment_id):
"""
Asynchronously fetches an attachment from Signal.
@@ -111,7 +120,7 @@ async def fetch_signal_attachment(attachment_id):
attachment_id (str): The Signal attachment ID.
Returns:
dict | None:
dict | None:
{
"content": <binary file data>,
"content_type": <MIME type>,
@@ -128,7 +137,9 @@ async def fetch_signal_attachment(attachment_id):
if response.status != 200:
return None # Failed request
content_type = response.headers.get("Content-Type", "application/octet-stream")
content_type = response.headers.get(
"Content-Type", "application/octet-stream"
)
content = await response.read()
size = int(response.headers.get("Content-Length", len(content)))
@@ -150,7 +161,6 @@ async def fetch_signal_attachment(attachment_id):
return None # Network error
def download_and_encode_base64_sync(file_url, filename, content_type):
"""
Downloads a file from a given URL, converts it to Base64, and returns it in Signal's expected format.
@@ -173,7 +183,7 @@ def download_and_encode_base64_sync(file_url, filename, content_type):
# Format according to Signal's expected structure
return f"data:{content_type};filename={filename};base64,{base64_encoded}"
except requests.RequestException as e:
#log.error(f"Failed to download file: {file_url}, error: {e}")
# log.error(f"Failed to download file: {file_url}, error: {e}")
return None
@@ -193,18 +203,20 @@ def send_message_raw_sync(recipient_uuid, text=None, attachments=[]):
data = {
"recipients": [recipient_uuid],
"number": settings.SIGNAL_NUMBER,
"base64_attachments": []
"base64_attachments": [],
}
# Convert attachments to Base64
for att in attachments:
base64_data = download_and_encode_base64_sync(att["url"], att["filename"], att["content_type"])
base64_data = download_and_encode_base64_sync(
att["url"], att["filename"], att["content_type"]
)
if base64_data:
data["base64_attachments"].append(base64_data)
# Remove the message body if it only contains an attachment link
if text and (text.strip() in [att["url"] for att in attachments]):
#log.info("Removing message body since it only contains an attachment link.")
# log.info("Removing message body since it only contains an attachment link.")
text = None # Don't send the link as text
if text:
@@ -214,10 +226,12 @@ def send_message_raw_sync(recipient_uuid, text=None, attachments=[]):
response = requests.post(url, json=data, timeout=10)
response.raise_for_status()
except requests.RequestException as e:
#log.error(f"Failed to send Signal message: {e}")
# log.error(f"Failed to send Signal message: {e}")
return False
if response.status_code == status.HTTP_201_CREATED: # Signal server returns 201 on success
if (
response.status_code == status.HTTP_201_CREATED
): # Signal server returns 201 on success
try:
ts = orjson.loads(response.text).get("timestamp", None)
return ts if ts else False

View File

@@ -1,20 +1,30 @@
from core.clients import ClientBase
from django.conf import settings
from slixmpp.componentxmpp import ComponentXMPP
from django.conf import settings
from core.models import User, Person, PersonIdentifier, ChatSession, Manipulation
from asgiref.sync import sync_to_async
from django.utils.timezone import now
import asyncio
from core.clients import signalapi
from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.plugins.xep_0085.stanza import Active, Composing, Paused, Inactive, Gone
from slixmpp.stanza import Message
from slixmpp.xmlstream.stanzabase import ET
import aiohttp
from core.messaging import history
from asgiref.sync import sync_to_async
from django.conf import settings
from django.utils.timezone import now
from slixmpp.componentxmpp import ComponentXMPP
from slixmpp.plugins.xep_0085.stanza import Active, Composing, Gone, Inactive, Paused
from slixmpp.stanza import Message
from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.xmlstream.stanzabase import ET
from core.clients import ClientBase, signalapi
from core.messaging import ai, history, replies, utils
from core.models import (
ChatSession,
Manipulation,
PatternMitigationAutoSettings,
PatternMitigationGame,
PatternMitigationPlan,
PatternMitigationRule,
Person,
PersonIdentifier,
User,
WorkspaceConversation,
)
from core.util import logs
from core.messaging import replies, utils, ai
class XMPPComponent(ComponentXMPP):
@@ -51,7 +61,9 @@ class XMPPComponent(ComponentXMPP):
self.add_event_handler("presence_subscribed", self.on_presence_subscribed)
self.add_event_handler("presence_unsubscribe", self.on_presence_unsubscribe)
self.add_event_handler("presence_unsubscribed", self.on_presence_unsubscribed)
self.add_event_handler("roster_subscription_request", self.on_roster_subscription_request)
self.add_event_handler(
"roster_subscription_request", self.on_roster_subscription_request
)
# Chat state handlers
self.add_event_handler("chatstate_active", self.on_chatstate_active)
@@ -73,13 +85,15 @@ class XMPPComponent(ComponentXMPP):
def get_identifier(self, msg):
# Extract sender JID (full format: user@domain/resource)
sender_jid = str(msg["from"])
# Split into username@domain and optional resource
sender_parts = sender_jid.split("/", 1)
sender_bare_jid = sender_parts[0] # Always present: user@domain
sender_username, sender_domain = sender_bare_jid.split("@", 1)
sender_resource = sender_parts[1] if len(sender_parts) > 1 else None # Extract resource if present
sender_resource = (
sender_parts[1] if len(sender_parts) > 1 else None
) # Extract resource if present
# Extract recipient JID (should match component JID format)
recipient_jid = str(msg["to"])
@@ -100,7 +114,6 @@ class XMPPComponent(ComponentXMPP):
person_name = recipient_username.title()
service = None
try:
# Lookup user in Django
self.log.info(f"User {sender_username}")
@@ -112,22 +125,255 @@ class XMPPComponent(ComponentXMPP):
# Ensure a PersonIdentifier exists for this user, person, and service
self.log.info(f"Identifier {service}")
identifier = PersonIdentifier.objects.get(user=user, person=person, service=service)
identifier = PersonIdentifier.objects.get(
user=user, person=person, service=service
)
return identifier
except (User.DoesNotExist, Person.DoesNotExist, PersonIdentifier.DoesNotExist):
# If any lookup fails, reject the subscription
except Exception as e:
self.log.error(f"Failed to resolve identifier from XMPP message: {e}")
return None
def _get_workspace_conversation(self, user, person):
conversation, _ = WorkspaceConversation.objects.get_or_create(
user=user,
platform_type="signal",
title=f"{person.name} Workspace",
defaults={"platform_thread_id": str(person.id)},
)
conversation.participants.add(person)
return conversation
def _get_or_create_plan(self, user, person):
conversation = self._get_workspace_conversation(user, person)
plan = conversation.mitigation_plans.order_by("-updated_at").first()
if plan is None:
plan = PatternMitigationPlan.objects.create(
user=user,
conversation=conversation,
title=f"{person.name} Pattern Mitigation",
objective="Mitigate repeated friction loops.",
fundamental_items=[],
creation_mode="guided",
status="draft",
)
PatternMitigationRule.objects.create(
user=user,
plan=plan,
title="Safety Before Analysis",
content="Prioritize de-escalation before analysis.",
enabled=True,
)
PatternMitigationGame.objects.create(
user=user,
plan=plan,
title="Two-Turn Pause",
instructions="Use two short turns then pause with a return time.",
enabled=True,
)
return plan
async def _handle_mitigation_command(self, sender_user, body, sym):
def parse_parts(raw):
return [part.strip() for part in raw.split("|")]
command = body.strip()
if command == ".mitigation help":
sym(
"Mitigation commands: "
".mitigation list | "
".mitigation show <person> | "
".mitigation rule-add <person>|<title>|<content> | "
".mitigation rule-del <person>|<title> | "
".mitigation game-add <person>|<title>|<instructions> | "
".mitigation game-del <person>|<title> | "
".mitigation auto <person>|on|off | "
".mitigation auto-status <person>"
)
return True
if command == ".mitigation list":
plans = await sync_to_async(list)(
PatternMitigationPlan.objects.filter(user=sender_user)
.select_related("conversation")
.order_by("-updated_at")[:15]
)
if not plans:
sym("No mitigation plans found.")
return True
rows = []
for plan in plans:
person_name = (
plan.conversation.participants.order_by("name").first().name
if plan.conversation.participants.exists()
else "Unknown"
)
rows.append(f"{person_name}: {plan.title}")
sym("Plans: " + " | ".join(rows))
return True
if command.startswith(".mitigation show "):
person_name = command.replace(".mitigation show ", "", 1).strip().title()
person = await sync_to_async(
lambda: Person.objects.filter(user=sender_user, name__iexact=person_name).first()
)()
if not person:
sym("Unknown person.")
return True
plan = await sync_to_async(self._get_or_create_plan)(sender_user, person)
rule_count = await sync_to_async(plan.rules.count)()
game_count = await sync_to_async(plan.games.count)()
sym(f"{person.name}: {plan.title} | rules={rule_count} games={game_count}")
return True
if command.startswith(".mitigation rule-add "):
payload = command.replace(".mitigation rule-add ", "", 1)
parts = parse_parts(payload)
if len(parts) < 3:
sym("Usage: .mitigation rule-add <person>|<title>|<content>")
return True
person_name, title, content = parts[0].title(), parts[1], "|".join(parts[2:])
person = await sync_to_async(
lambda: Person.objects.filter(user=sender_user, name__iexact=person_name).first()
)()
if not person:
sym("Unknown person.")
return True
plan = await sync_to_async(self._get_or_create_plan)(sender_user, person)
await sync_to_async(PatternMitigationRule.objects.create)(
user=sender_user,
plan=plan,
title=title[:255],
content=content,
enabled=True,
)
sym("Rule added.")
return True
if command.startswith(".mitigation rule-del "):
payload = command.replace(".mitigation rule-del ", "", 1)
parts = parse_parts(payload)
if len(parts) < 2:
sym("Usage: .mitigation rule-del <person>|<title>")
return True
person_name, title = parts[0].title(), "|".join(parts[1:])
person = await sync_to_async(
lambda: Person.objects.filter(user=sender_user, name__iexact=person_name).first()
)()
if not person:
sym("Unknown person.")
return True
plan = await sync_to_async(self._get_or_create_plan)(sender_user, person)
deleted, _ = await sync_to_async(
lambda: PatternMitigationRule.objects.filter(
user=sender_user,
plan=plan,
title__iexact=title,
).delete()
)()
sym("Rule deleted." if deleted else "Rule not found.")
return True
if command.startswith(".mitigation game-add "):
payload = command.replace(".mitigation game-add ", "", 1)
parts = parse_parts(payload)
if len(parts) < 3:
sym("Usage: .mitigation game-add <person>|<title>|<instructions>")
return True
person_name, title, content = parts[0].title(), parts[1], "|".join(parts[2:])
person = await sync_to_async(
lambda: Person.objects.filter(user=sender_user, name__iexact=person_name).first()
)()
if not person:
sym("Unknown person.")
return True
plan = await sync_to_async(self._get_or_create_plan)(sender_user, person)
await sync_to_async(PatternMitigationGame.objects.create)(
user=sender_user,
plan=plan,
title=title[:255],
instructions=content,
enabled=True,
)
sym("Game added.")
return True
if command.startswith(".mitigation game-del "):
payload = command.replace(".mitigation game-del ", "", 1)
parts = parse_parts(payload)
if len(parts) < 2:
sym("Usage: .mitigation game-del <person>|<title>")
return True
person_name, title = parts[0].title(), "|".join(parts[1:])
person = await sync_to_async(
lambda: Person.objects.filter(user=sender_user, name__iexact=person_name).first()
)()
if not person:
sym("Unknown person.")
return True
plan = await sync_to_async(self._get_or_create_plan)(sender_user, person)
deleted, _ = await sync_to_async(
lambda: PatternMitigationGame.objects.filter(
user=sender_user,
plan=plan,
title__iexact=title,
).delete()
)()
sym("Game deleted." if deleted else "Game not found.")
return True
if command.startswith(".mitigation auto "):
payload = command.replace(".mitigation auto ", "", 1)
parts = parse_parts(payload)
if len(parts) < 2:
sym("Usage: .mitigation auto <person>|on|off")
return True
person_name, state = parts[0].title(), parts[1].lower()
person = await sync_to_async(
lambda: Person.objects.filter(user=sender_user, name__iexact=person_name).first()
)()
if not person:
sym("Unknown person.")
return True
conversation = await sync_to_async(self._get_workspace_conversation)(sender_user, person)
auto_obj, _ = await sync_to_async(PatternMitigationAutoSettings.objects.get_or_create)(
user=sender_user,
conversation=conversation,
)
auto_obj.enabled = state in {"on", "true", "1", "yes"}
await sync_to_async(auto_obj.save)(update_fields=["enabled", "updated_at"])
sym(f"Automation {'enabled' if auto_obj.enabled else 'disabled'} for {person.name}.")
return True
if command.startswith(".mitigation auto-status "):
person_name = command.replace(".mitigation auto-status ", "", 1).strip().title()
person = await sync_to_async(
lambda: Person.objects.filter(user=sender_user, name__iexact=person_name).first()
)()
if not person:
sym("Unknown person.")
return True
conversation = await sync_to_async(self._get_workspace_conversation)(sender_user, person)
auto_obj, _ = await sync_to_async(PatternMitigationAutoSettings.objects.get_or_create)(
user=sender_user,
conversation=conversation,
)
sym(
f"{person.name}: auto={'on' if auto_obj.enabled else 'off'}, "
f"pattern={'on' if auto_obj.auto_pattern_recognition else 'off'}, "
f"corrections={'on' if auto_obj.auto_create_corrections else 'off'}"
)
return True
return False
def update_roster(self, jid, name=None):
"""
Adds or updates a user in the roster.
"""
iq = self.Iq()
iq['type'] = 'set'
iq['roster']['items'] = {jid: {'name': name or jid}}
iq["type"] = "set"
iq["roster"]["items"] = {jid: {"name": name or jid}}
iq.send()
self.log.info(f"Updated roster: Added {jid} ({name})")
@@ -171,7 +417,6 @@ class XMPPComponent(ComponentXMPP):
identifier = self.get_identifier(msg)
def on_presence_available(self, pres):
"""
Handle when a user becomes available.
@@ -214,10 +459,12 @@ class XMPPComponent(ComponentXMPP):
Accept only if the recipient has a contact matching the sender.
"""
sender_jid = str(pres['from']).split('/')[0] # Bare JID (user@domain)
recipient_jid = str(pres['to']).split('/')[0]
sender_jid = str(pres["from"]).split("/")[0] # Bare JID (user@domain)
recipient_jid = str(pres["to"]).split("/")[0]
self.log.info(f"Received subscription request from {sender_jid} to {recipient_jid}")
self.log.info(
f"Received subscription request from {sender_jid} to {recipient_jid}"
)
try:
# Extract sender and recipient usernames
@@ -248,7 +495,9 @@ class XMPPComponent(ComponentXMPP):
# Accept the subscription
self.send_presence(ptype="subscribed", pto=sender_jid, pfrom=component_jid)
self.log.info(f"Accepted subscription from {sender_jid}, sent from {component_jid}")
self.log.info(
f"Accepted subscription from {sender_jid}, sent from {component_jid}"
)
# Send a presence request **from the recipient to the sender** (ASKS THEM TO ACCEPT BACK)
# self.send_presence(ptype="subscribe", pto=sender_jid, pfrom=component_jid)
@@ -262,16 +511,16 @@ class XMPPComponent(ComponentXMPP):
self.send_presence(ptype="available", pto=sender_jid, pfrom=component_jid)
self.log.info(f"Sent presence update from {component_jid} to {sender_jid}")
except (User.DoesNotExist, Person.DoesNotExist, PersonIdentifier.DoesNotExist):
# If any lookup fails, reject the subscription
self.log.warning(f"Subscription request from {sender_jid} rejected (recipient does not have this contact).")
self.log.warning(
f"Subscription request from {sender_jid} rejected (recipient does not have this contact)."
)
self.send_presence(ptype="unsubscribed", pto=sender_jid)
except ValueError:
return
def on_presence_subscribed(self, pres):
"""
Handle successful subscription confirmations.
@@ -325,16 +574,16 @@ class XMPPComponent(ComponentXMPP):
# self.log.error("No XEP-0363 upload service found.")
# return None
#self.log.info(f"Upload service: {upload_service}")
# self.log.info(f"Upload service: {upload_service}")
upload_service_jid = "share.zm.is"
try:
slot = await self['xep_0363'].request_slot(
slot = await self["xep_0363"].request_slot(
jid=upload_service_jid,
filename=filename,
content_type=content_type,
size=size
size=size,
)
if slot is None:
@@ -350,8 +599,12 @@ class XMPPComponent(ComponentXMPP):
put_url = put_element.attrib.get("url")
# Extract the Authorization header correctly
header_element = put_element.find(f"./{namespace}header[@name='Authorization']")
auth_header = header_element.text.strip() if header_element is not None else None
header_element = put_element.find(
f"./{namespace}header[@name='Authorization']"
)
auth_header = (
header_element.text.strip() if header_element is not None else None
)
if not get_url or not put_url:
self.log.error(f"Missing URLs in upload slot: {slot}")
@@ -363,7 +616,6 @@ class XMPPComponent(ComponentXMPP):
self.log.error(f"Exception while requesting upload slot: {e}")
return None
async def message(self, msg):
"""
Process incoming XMPP messages.
@@ -374,13 +626,15 @@ class XMPPComponent(ComponentXMPP):
# Extract sender JID (full format: user@domain/resource)
sender_jid = str(msg["from"])
# Split into username@domain and optional resource
sender_parts = sender_jid.split("/", 1)
sender_bare_jid = sender_parts[0] # Always present: user@domain
sender_username, sender_domain = sender_bare_jid.split("@", 1)
sender_resource = sender_parts[1] if len(sender_parts) > 1 else None # Extract resource if present
sender_resource = (
sender_parts[1] if len(sender_parts) > 1 else None
) # Extract resource if present
# Extract recipient JID (should match component JID format)
recipient_jid = str(msg["to"])
@@ -399,19 +653,23 @@ class XMPPComponent(ComponentXMPP):
# Extract attachments from standard XMPP <attachments> (if present)
for att in msg.xml.findall(".//{urn:xmpp:attachments}attachment"):
attachments.append({
"url": att.attrib.get("url"),
"filename": att.attrib.get("filename"),
"content_type": att.attrib.get("content_type"),
})
attachments.append(
{
"url": att.attrib.get("url"),
"filename": att.attrib.get("filename"),
"content_type": att.attrib.get("content_type"),
}
)
# Extract attachments from XEP-0066 <x><url> format (Out of Band Data)
for oob in msg.xml.findall(".//{jabber:x:oob}x/{jabber:x:oob}url"):
attachments.append({
"url": oob.text,
"filename": oob.text.split("/")[-1], # Extract filename from URL
"content_type": "application/octet-stream", # Generic content-type
})
attachments.append(
{
"url": oob.text,
"filename": oob.text.split("/")[-1], # Extract filename from URL
"content_type": "application/octet-stream", # Generic content-type
}
)
self.log.info(f"Extracted {len(attachments)} attachments from XMPP message.")
# Log extracted information with variable name annotations
@@ -426,7 +684,9 @@ class XMPPComponent(ComponentXMPP):
# Ensure recipient domain matches our configured component
expected_domain = settings.XMPP_JID # 'jews.zm.is' in your config
if recipient_domain != expected_domain:
self.log.warning(f"Invalid recipient domain: {recipient_domain}, expected {expected_domain}")
self.log.warning(
f"Invalid recipient domain: {recipient_domain}, expected {expected_domain}"
)
return
# Lookup sender in Django's User model
@@ -452,6 +712,16 @@ class XMPPComponent(ComponentXMPP):
contact_names = [person.name for person in persons]
response_text = f"Contacts: " + ", ".join(contact_names)
sym(response_text)
elif body == ".help":
sym("Commands: .contacts, .whoami, .mitigation help")
elif body.startswith(".mitigation"):
handled = await self._handle_mitigation_command(
sender_user,
body,
sym,
)
if not handled:
sym("Unknown mitigation command. Try .mitigation help")
elif body == ".whoami":
sym(str(sender_user.__dict__))
else:
@@ -468,7 +738,7 @@ class XMPPComponent(ComponentXMPP):
recipient_service = None
recipient_name = recipient_name.title()
try:
person = Person.objects.get(user=sender_user, name=recipient_name)
except Person.DoesNotExist:
@@ -476,21 +746,22 @@ class XMPPComponent(ComponentXMPP):
if recipient_service:
try:
identifier = PersonIdentifier.objects.get(user=sender_user,
person=person,
service=recipient_service)
identifier = PersonIdentifier.objects.get(
user=sender_user, person=person, service=recipient_service
)
except PersonIdentifier.DoesNotExist:
sym("This service identifier does not exist.")
else:
# Get a random identifier
identifier = PersonIdentifier.objects.filter(user=sender_user,
person=person).first()
identifier = PersonIdentifier.objects.filter(
user=sender_user, person=person
).first()
recipient_service = identifier.service
# sym(str(person.__dict__))
# sym(f"Service: {recipient_service}")
#tss = await identifier.send(body, attachments=attachments)
# tss = await identifier.send(body, attachments=attachments)
# AM FIXING https://git.zm.is/XF/GIA/issues/5
session, _ = await sync_to_async(ChatSession.objects.get_or_create)(
identifier=identifier,
@@ -502,7 +773,7 @@ class XMPPComponent(ComponentXMPP):
sender="XMPP",
text=body,
ts=int(now().timestamp() * 1000),
#outgoing=detail.is_outgoing_message, ????????? TODO:
# outgoing=detail.is_outgoing_message, ????????? TODO:
)
self.log.info("Stored a message sent from XMPP in the history.")
@@ -526,11 +797,11 @@ class XMPPComponent(ComponentXMPP):
chat_history = await history.get_chat_history(session)
await utils.update_last_interaction(session)
prompt = replies.generate_mutate_reply_prompt(
body,
identifier.person,
manip,
chat_history,
)
body,
identifier.person,
manip,
chat_history,
)
self.log.info("Running XMPP context prompt")
result = await ai.run_prompt(prompt, manip.ai)
self.log.info(f"RESULT {result}")
@@ -546,17 +817,21 @@ class XMPPComponent(ComponentXMPP):
)
self.log.info(f"Message sent with modifications")
async def request_upload_slots(self, recipient_jid, attachments):
"""Requests upload slots for multiple attachments concurrently."""
upload_tasks = [
self.request_upload_slot(recipient_jid, att["filename"], att["content_type"], att["size"])
self.request_upload_slot(
recipient_jid, att["filename"], att["content_type"], att["size"]
)
for att in attachments
]
upload_slots = await asyncio.gather(*upload_tasks)
return [(att, slot) for att, slot in zip(attachments, upload_slots) if slot is not None]
return [
(att, slot)
for att, slot in zip(attachments, upload_slots)
if slot is not None
]
async def upload_and_send(self, att, upload_slot, recipient_jid, sender_jid):
"""Uploads a file and immediately sends the corresponding XMPP message."""
@@ -567,19 +842,29 @@ class XMPPComponent(ComponentXMPP):
async with aiohttp.ClientSession() as session:
try:
async with session.put(put_url, data=att["content"], headers=headers) as response:
async with session.put(
put_url, data=att["content"], headers=headers
) as response:
if response.status not in (200, 201):
self.log.error(f"Upload failed: {response.status} {await response.text()}")
self.log.error(
f"Upload failed: {response.status} {await response.text()}"
)
return
self.log.info(f"Successfully uploaded {att['filename']} to {upload_url}")
self.log.info(
f"Successfully uploaded {att['filename']} to {upload_url}"
)
# Send XMPP message immediately after successful upload
await self.send_xmpp_message(recipient_jid, sender_jid, upload_url, attachment_url=upload_url)
await self.send_xmpp_message(
recipient_jid, sender_jid, upload_url, attachment_url=upload_url
)
except Exception as e:
self.log.error(f"Error uploading {att['filename']} to XMPP: {e}")
async def send_xmpp_message(self, recipient_jid, sender_jid, body_text, attachment_url=None):
async def send_xmpp_message(
self, recipient_jid, sender_jid, body_text, attachment_url=None
):
"""Sends an XMPP message with either text or an attachment URL."""
msg = self.make_message(mto=recipient_jid, mfrom=sender_jid, mtype="chat")
msg["body"] = body_text # Body must contain only text or the URL
@@ -594,7 +879,9 @@ class XMPPComponent(ComponentXMPP):
self.log.info(f"Sending XMPP message: {msg.xml}")
msg.send()
async def send_from_external(self, user, person_identifier, text, is_outgoing_message, attachments=[]):
async def send_from_external(
self, user, person_identifier, text, is_outgoing_message, attachments=[]
):
"""Handles sending XMPP messages with text and attachments."""
sender_jid = f"{person_identifier.person.name.lower()}|{person_identifier.service}@{settings.XMPP_JID}"
@@ -614,11 +901,12 @@ class XMPPComponent(ComponentXMPP):
self.log.info(f"Got upload slots")
if not valid_uploads:
self.log.warning("No valid upload slots obtained.")
#return
# return
# Step 3: Upload each file and send its message immediately after upload
upload_tasks = [
self.upload_and_send(att, slot, recipient_jid, sender_jid) for att, slot in valid_uploads
self.upload_and_send(att, slot, recipient_jid, sender_jid)
for att, slot in valid_uploads
]
await asyncio.gather(*upload_tasks) # Upload files concurrently
@@ -634,12 +922,12 @@ class XMPPClient(ClientBase):
port=settings.XMPP_PORT,
)
self.client.register_plugin('xep_0030') # Service Discovery
self.client.register_plugin('xep_0004') # Data Forms
self.client.register_plugin('xep_0060') # PubSub
self.client.register_plugin('xep_0199') # XMPP Ping
self.client.register_plugin("xep_0030") # Service Discovery
self.client.register_plugin("xep_0004") # Data Forms
self.client.register_plugin("xep_0060") # PubSub
self.client.register_plugin("xep_0199") # XMPP Ping
self.client.register_plugin("xep_0085") # Chat State Notifications
self.client.register_plugin('xep_0363') # HTTP File Upload
self.client.register_plugin("xep_0363") # HTTP File Upload
def start(self):
self.log.info("XMPP client starting...")
@@ -648,4 +936,4 @@ class XMPPClient(ClientBase):
self.client.loop = self.loop
self.client.connect()
#self.client.process()
# self.client.process()