diff --git a/core/commands/delivery.py b/core/commands/delivery.py index a42002a..8ea7162 100644 --- a/core/commands/delivery.py +++ b/core/commands/delivery.py @@ -7,7 +7,7 @@ from asgiref.sync import sync_to_async from core.clients import transport from core.models import ChatSession, Message -STATUS_VISIBLE_SOURCE_SERVICES = {"web", "xmpp"} +STATUS_VISIBLE_SOURCE_SERVICES = {"web", "xmpp", "signal", "whatsapp"} def chunk_for_transport(text: str, limit: int = 3000) -> list[str]: diff --git a/core/commands/engine.py b/core/commands/engine.py index 8643808..7e7d0dc 100644 --- a/core/commands/engine.py +++ b/core/commands/engine.py @@ -15,7 +15,13 @@ from core.commands.policies import ensure_variant_policies_for_profile from core.commands.registry import get as get_handler from core.commands.registry import register from core.messaging.reply_sync import is_mirrored_origin -from core.models import CommandAction, CommandChannelBinding, CommandProfile, Message +from core.models import ( + CommandAction, + CommandChannelBinding, + CommandProfile, + Message, + PersonIdentifier, +) from core.security.command_policy import CommandSecurityContext, evaluate_command_policy from core.tasks.chat_defaults import ensure_default_source_for_chat from core.util import logs @@ -50,6 +56,65 @@ def _canonical_channel_identifier(service: str, channel_identifier: str) -> str: return value +def _signal_identifier_rank(identifier_value: str) -> int: + identifier_text = str(identifier_value or "").strip() + if not identifier_text: + return 99 + if identifier_text.startswith("group."): + return 0 + if identifier_text.startswith("+"): + return 1 + return 2 + + +def _expand_service_channel_variants( + user_id: int, + service: str, + identifiers: list[str], +) -> list[str]: + variants: list[str] = [] + for identifier in identifiers: + for value in _channel_variants(service, identifier): + if value and value not in variants: + variants.append(value) + if str(service or "").strip().lower() != "signal" or not variants: + return variants + person_ids = list( + PersonIdentifier.objects.filter( + user_id=user_id, + service="signal", + identifier__in=variants, + ) + .values_list("person_id", flat=True) + .distinct() + ) + if not person_ids: + return variants + alias_rows = list( + PersonIdentifier.objects.filter( + user_id=user_id, + service="signal", + person_id__in=person_ids, + ).values_list("identifier", flat=True) + ) + for value in alias_rows: + cleaned = str(value or "").strip() + if cleaned and cleaned not in variants: + variants.append(cleaned) + variants.sort(key=lambda value: (_signal_identifier_rank(value), value)) + return variants + + +def _preferred_channel_identifier(service: str, identifiers: list[str]) -> str: + cleaned = [str(value or "").strip() for value in identifiers if str(value or "").strip()] + if not cleaned: + return "" + if str(service or "").strip().lower() == "signal": + cleaned.sort(key=lambda value: (_signal_identifier_rank(value), value)) + return cleaned[0] + return cleaned[0] + + def _effective_bootstrap_scope( ctx: CommandContext, trigger_message: Message, @@ -192,7 +257,8 @@ def _auto_setup_profile_bindings_for_first_command( service, identifier = _effective_bootstrap_scope(ctx, trigger_message) service = str(service or "").strip().lower() canonical = _canonical_channel_identifier(service, identifier) - variants = _channel_variants(service, canonical) + variants = _expand_service_channel_variants(ctx.user_id, service, [canonical]) + canonical = _preferred_channel_identifier(service, variants) or canonical if not service or not variants: return for slug in slugs: @@ -252,9 +318,17 @@ async def _eligible_profiles(ctx: CommandContext) -> list[CommandProfile]: .filter(id=ctx.message_id, user_id=ctx.user_id) .first() ) - direct_variants = _channel_variants(ctx.service, ctx.channel_identifier) + direct_variants = _expand_service_channel_variants( + ctx.user_id, + ctx.service, + [ctx.channel_identifier], + ) source_channel = str(getattr(trigger, "source_chat_id", "") or "").strip() - for expanded in _channel_variants(ctx.service, source_channel): + for expanded in _expand_service_channel_variants( + ctx.user_id, + ctx.service, + [source_channel], + ): if expanded and expanded not in direct_variants: direct_variants.append(expanded) if not direct_variants: @@ -278,8 +352,16 @@ async def _eligible_profiles(ctx: CommandContext) -> list[CommandProfile]: identifier = getattr(getattr(trigger, "session", None), "identifier", None) fallback_service = str(getattr(identifier, "service", "") or "").strip().lower() fallback_identifier = str(getattr(identifier, "identifier", "") or "").strip() - fallback_variants = _channel_variants(fallback_service, fallback_identifier) - for expanded in _channel_variants(fallback_service, source_channel): + fallback_variants = _expand_service_channel_variants( + ctx.user_id, + fallback_service, + [fallback_identifier], + ) + for expanded in _expand_service_channel_variants( + ctx.user_id, + fallback_service, + [source_channel], + ): if expanded and expanded not in fallback_variants: fallback_variants.append(expanded) if not fallback_service or not fallback_variants: diff --git a/core/tasks/chat_defaults.py b/core/tasks/chat_defaults.py index f84440d..691661a 100644 --- a/core/tasks/chat_defaults.py +++ b/core/tasks/chat_defaults.py @@ -16,6 +16,50 @@ SAFE_TASK_FLAGS_DEFAULTS = { "min_chars": 3, } +WHATSAPP_GROUP_ID_RE = re.compile(r"^\d+@g\.us$") +WHATSAPP_DIRECT_ID_RE = re.compile(r"^\d+@s\.whatsapp\.net$") +WHATSAPP_BARE_ID_RE = re.compile(r"^\d+$") +SIGNAL_GROUP_ID_RE = re.compile(r"^group\.[A-Za-z0-9+/=]+$") +SIGNAL_UUID_RE = re.compile( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", + re.IGNORECASE, +) +SIGNAL_PHONE_RE = re.compile(r"^\+\d+$") +SIGNAL_INTERNAL_ID_RE = re.compile(r"^[A-Za-z0-9+/=]+$") + + +def _normalize_whatsapp_identifier(identifier: str) -> str: + value = str(identifier or "").strip() + if not value: + return "" + if "/" in value or "?" in value or "#" in value: + return "" + if WHATSAPP_GROUP_ID_RE.fullmatch(value): + return value + if WHATSAPP_DIRECT_ID_RE.fullmatch(value): + return value + bare = value.split("@", 1)[0].strip() + if not WHATSAPP_BARE_ID_RE.fullmatch(bare): + return "" + if value.endswith("@s.whatsapp.net"): + return f"{bare}@s.whatsapp.net" + return f"{bare}@g.us" + + +def _normalize_signal_identifier(identifier: str) -> str: + value = str(identifier or "").strip() + if not value: + return "" + if SIGNAL_GROUP_ID_RE.fullmatch(value): + return value + if SIGNAL_UUID_RE.fullmatch(value): + return value.lower() + if SIGNAL_PHONE_RE.fullmatch(value): + return value + if SIGNAL_INTERNAL_ID_RE.fullmatch(value): + return value + return "" + def normalize_channel_identifier(service: str, identifier: str) -> str: service_key = str(service or "").strip().lower() @@ -23,12 +67,9 @@ def normalize_channel_identifier(service: str, identifier: str) -> str: if not value: return "" if service_key == "whatsapp": - bare = value.split("@", 1)[0].strip() - if not bare: - return value - if value.endswith("@s.whatsapp.net"): - return f"{bare}@s.whatsapp.net" - return f"{bare}@g.us" + return _normalize_whatsapp_identifier(value) + if service_key == "signal": + return _normalize_signal_identifier(value) return value diff --git a/core/tests/test_phase1_command_reply.py b/core/tests/test_phase1_command_reply.py index 6bfb70b..ded9c55 100644 --- a/core/tests/test_phase1_command_reply.py +++ b/core/tests/test_phase1_command_reply.py @@ -298,6 +298,53 @@ class Phase1CommandEngineTests(TestCase): self.assertEqual("skipped", results[0].status) self.assertEqual("reply_required", results[0].error) + def test_eligible_profile_matches_signal_group_alias_variants(self): + self.profile.channel_bindings.all().delete() + canonical = PersonIdentifier.objects.create( + user=self.user, + person=self.person, + service="signal", + identifier="group.canonical-signal-group", + ) + self.session.identifier = canonical + self.session.save(update_fields=["identifier"]) + CommandChannelBinding.objects.create( + profile=self.profile, + direction="ingress", + service="signal", + channel_identifier="group.canonical-signal-group", + enabled=True, + ) + msg = Message.objects.create( + user=self.user, + session=self.session, + sender_uuid="", + text="#bp#", + ts=5500, + source_service="signal", + source_chat_id="signal-internal-group-id", + message_meta={}, + ) + PersonIdentifier.objects.create( + user=self.user, + person=self.person, + service="signal", + identifier="signal-internal-group-id", + ) + results = async_to_sync(process_inbound_message)( + CommandContext( + service="signal", + channel_identifier="signal-internal-group-id", + message_id=str(msg.id), + user_id=self.user.id, + message_text="#bp#", + payload={}, + ) + ) + self.assertEqual(1, len(results)) + self.assertEqual("skipped", results[0].status) + self.assertEqual("reply_required", results[0].error) + def test_compose_command_options_show_bp_subcommands(self): self.profile.channel_bindings.all().delete() CommandChannelBinding.objects.create( diff --git a/core/tests/test_tasks_pages_management.py b/core/tests/test_tasks_pages_management.py index 37f8c77..087f4fb 100644 --- a/core/tests/test_tasks_pages_management.py +++ b/core/tests/test_tasks_pages_management.py @@ -223,6 +223,44 @@ class TasksPagesManagementTests(TestCase): ).exists() ) + def test_group_page_does_not_seed_source_for_malformed_whatsapp_asset_path(self): + response = self.client.get( + reverse( + "tasks_group", + kwargs={ + "service": "whatsapp", + "identifier": "447777695114/static/js/template_profiler.js", + }, + ) + ) + self.assertEqual(200, response.status_code) + self.assertFalse( + ChatTaskSource.objects.filter( + user=self.user, + service="whatsapp", + channel_identifier__icontains="template_profiler.js", + ).exists() + ) + + def test_group_page_does_not_seed_source_for_malformed_signal_asset_path(self): + response = self.client.get( + reverse( + "tasks_group", + kwargs={ + "service": "signal", + "identifier": "group.c0VHQTlGMEhRL2V5/static/js/template_profiler.js", + }, + ) + ) + self.assertEqual(200, response.status_code) + self.assertFalse( + ChatTaskSource.objects.filter( + user=self.user, + service="signal", + channel_identifier__icontains="template_profiler.js", + ).exists() + ) + def test_tasks_hub_shows_human_creator_label(self): project = TaskProject.objects.create(user=self.user, name="Creator Test") session = ChatSession.objects.create(user=self.user, identifier=self.pid_signal) diff --git a/core/tests/test_tasks_settings_and_toggle.py b/core/tests/test_tasks_settings_and_toggle.py index e9774df..f94646b 100644 --- a/core/tests/test_tasks_settings_and_toggle.py +++ b/core/tests/test_tasks_settings_and_toggle.py @@ -236,6 +236,44 @@ class TaskSettingsViewActionsTests(TestCase): ChatTaskSource.objects.filter(id=self.source.id, user=self.user).exists() ) + def test_source_create_updates_existing_signal_mapping_instead_of_duplicating(self): + signal_project = TaskProject.objects.create(user=self.user, name="Signal A") + signal_source = ChatTaskSource.objects.create( + user=self.user, + service="signal", + channel_identifier="group.c0VHQTlGMEhRL2V5TGdtdkt4MjNoaGE5VnA3bURSaHBxMjMvcm9WU1piST0=", + project=signal_project, + settings={"match_mode": "strict"}, + enabled=True, + ) + target_project = TaskProject.objects.create(user=self.user, name="Signal B") + response = self.client.post( + reverse("tasks_settings"), + { + "action": "source_create", + "service": "signal", + "channel_identifier": "group.c0VHQTlGMEhRL2V5TGdtdkt4MjNoaGE5VnA3bURSaHBxMjMvcm9WU1piST0=", + "project_id": str(target_project.id), + "source_match_mode": "strict", + "source_require_prefix": "1", + "source_derive_enabled": "1", + "source_completion_enabled": "1", + "source_ai_title_enabled": "1", + }, + follow=True, + ) + self.assertEqual(200, response.status_code) + self.assertEqual( + 1, + ChatTaskSource.objects.filter( + user=self.user, + service="signal", + channel_identifier="group.c0VHQTlGMEhRL2V5TGdtdkt4MjNoaGE5VnA3bURSaHBxMjMvcm9WU1piST0=", + ).count(), + ) + signal_source.refresh_from_db() + self.assertEqual(target_project.id, signal_source.project_id) + @override_settings(TASK_DERIVATION_USE_AI=False) class TaskAutoBootstrapTests(TestCase): diff --git a/core/views/tasks.py b/core/views/tasks.py index 4fc8dce..5c31980 100644 --- a/core/views/tasks.py +++ b/core/views/tasks.py @@ -45,6 +45,50 @@ from core.tasks.engine import create_task_record_and_sync from core.tasks.providers import get_provider +def _upsert_task_source( + *, + user, + service: str, + channel_identifier: str, + project, + epic=None, + enabled: bool = True, + settings: dict | None = None, +): + service_key = str(service or "").strip().lower() + normalized_identifier = normalize_channel_identifier(service_key, channel_identifier) + if not service_key or not normalized_identifier: + return None, False + source, created = ChatTaskSource.objects.get_or_create( + user=user, + service=service_key, + channel_identifier=normalized_identifier, + defaults={ + "project": project, + "epic": epic, + "enabled": bool(enabled), + "settings": dict(settings or {}), + }, + ) + changed_fields = [] + if source.project_id != getattr(project, "id", None): + source.project = project + changed_fields.append("project") + if source.epic_id != getattr(epic, "id", None): + source.epic = epic + changed_fields.append("epic") + if bool(source.enabled) != bool(enabled): + source.enabled = bool(enabled) + changed_fields.append("enabled") + settings_payload = dict(settings or {}) + if dict(source.settings or {}) != settings_payload: + source.settings = settings_payload + changed_fields.append("settings") + if changed_fields: + source.save(update_fields=changed_fields + ["updated_at"]) + return source, created + + def _to_bool(raw, default=False) -> bool: if raw is None: return bool(default) @@ -1562,7 +1606,7 @@ class TaskSettings(LoginRequiredMixin, View): epic = get_object_or_404( TaskEpic, id=epic_id, project__user=request.user ) - ChatTaskSource.objects.create( + source, _ = _upsert_task_source( user=request.user, service=str(request.POST.get("service") or "web").strip(), channel_identifier=str( @@ -1573,6 +1617,8 @@ class TaskSettings(LoginRequiredMixin, View): enabled=bool(request.POST.get("enabled") or "1"), settings=_flags_from_post(request, prefix="source_"), ) + if source is None: + messages.error(request, "Invalid channel identifier.") return _settings_redirect(request) if action == "quick_setup": @@ -1598,31 +1644,17 @@ class TaskSettings(LoginRequiredMixin, View): project=project, name=epic_name ) if channel_identifier: - source, created = ChatTaskSource.objects.get_or_create( + source, created = _upsert_task_source( user=request.user, service=service, channel_identifier=channel_identifier, project=project, - defaults={ - "epic": epic, - "enabled": True, - "settings": _flags_from_post(request, prefix="source_"), - }, + epic=epic, + enabled=True, + settings=_flags_from_post(request, prefix="source_"), ) - if not created: - source.project = project - source.epic = epic - source.enabled = True - source.settings = _flags_from_post(request, prefix="source_") - source.save( - update_fields=[ - "project", - "epic", - "enabled", - "settings", - "updated_at", - ] - ) + if source is None: + messages.error(request, "Invalid channel identifier.") return _settings_redirect(request) if action == "project_flags_update":