diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 50d36c3af..3fd23c780 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -695,10 +695,7 @@ class SignalChannel(BaseChannel): self.logger.debug(f"Ignoring DM from {sender_id} (DMs disabled)") return if self.config.dm.policy == "allowlist": - allow_list = self.config.dm.allow_from - sender_str = str(sender_id) - parts = [sender_str] + (sender_str.split("|") if "|" in sender_str else []) - if not any(p for p in parts if p in allow_list): + if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from): self.logger.debug(f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})") return @@ -864,6 +861,28 @@ class SignalChannel(BaseChannel): normalized.append(f"+{raw}") return list(dict.fromkeys(normalized)) + @classmethod + def _sender_matches_allowlist(cls, sender_id: str, allow_list: list[str]) -> bool: + """Return True if any normalized variant of sender_id is on allow_list. + + sender_id is the pipe-joined identifier string built by + _collect_sender_id_parts. Each part and each allow_list entry is run + through _normalize_signal_id so an allowlist entry like ``1234567890`` + matches a sender ``+1234567890`` (and vice versa), and case-only + differences in UUIDs/ACIs match too. + """ + if not allow_list: + return False + sender_variants: set[str] = set() + for part in str(sender_id).split("|"): + sender_variants.update(cls._normalize_signal_id(part)) + if not sender_variants: + return False + allow_variants: set[str] = set() + for entry in allow_list: + allow_variants.update(cls._normalize_signal_id(entry)) + return bool(sender_variants & allow_variants) + def _remember_account_id_alias(self, value: str | None) -> None: """Remember known bot identifiers for mention matching.""" if not value: diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index 27b8b1e91..d6308b803 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -510,6 +510,36 @@ class TestHandleDataMessageDM: await ch._handle_receive_notification(params) assert handled == [] + @pytest.mark.asyncio + async def test_dm_allowlist_matches_without_plus_prefix(self): + """An allowlist entry without '+' must match a sender that carries '+'.""" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["19995550001"]) + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_matches_with_plus_prefix(self): + """An allowlist entry with '+' must match a sender without '+'.""" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"]) + params = _dm_envelope(source_number="+19995550001", source_uuid=None) + # Replace envelope's sourceNumber with the non-prefixed form by editing + # the constructed dict directly so _collect_sender_id_parts sees it. + params["envelope"]["sourceNumber"] = "19995550001" + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_matches_uuid_case_insensitive(self): + """UUID matching must be case-insensitive.""" + uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890" + ch, handled = self._make_dm_channel( + policy="allowlist", allow_from=[uuid.lower()] + ) + params = _dm_envelope(source_number="+19995550001", source_uuid=uuid) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + @pytest.mark.asyncio async def test_dm_disabled_rejected(self): ch = _make_channel(dm_enabled=False)