"""Tests for the Signal channel implementation.""" from __future__ import annotations import asyncio from contextlib import asynccontextmanager from pathlib import Path from unittest.mock import MagicMock import pytest from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.signal import ( SignalChannel, SignalConfig, SignalDMConfig, SignalGroupConfig, ) # --------------------------------------------------------------------------- # Fake HTTP client # --------------------------------------------------------------------------- class _FakeResponse: def __init__(self, status_code: int = 200, body: dict | None = None) -> None: self.status_code = status_code self._body = body or {} def raise_for_status(self) -> None: if self.status_code >= 400: raise RuntimeError(f"HTTP {self.status_code}") def json(self) -> dict: return self._body class _FakeHTTPClient: """Minimal httpx.AsyncClient stand-in that records requests.""" def __init__(self, *, default_response: dict | None = None) -> None: self.posts: list[dict] = [] self.gets: list[str] = [] self._response = _FakeResponse(body=default_response or {"result": {"timestamp": 123}}) self.closed = False async def get(self, path: str) -> _FakeResponse: self.gets.append(path) return self._response async def post(self, path: str, *, json: dict) -> _FakeResponse: self.posts.append({"path": path, "json": json}) return self._response async def aclose(self) -> None: self.closed = True # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_channel_with_capture(**overrides) -> tuple[SignalChannel, list[dict]]: """Build a SignalChannel with _handle_message captured into a list and a no-op _start_typing, used by every receive-flow test class. """ ch = _make_channel(**overrides) handled: list[dict] = [] async def capture(**kwargs): handled.append(kwargs) async def noop_typing(chat_id): pass ch._handle_message = capture # type: ignore[method-assign] ch._start_typing = noop_typing # type: ignore[method-assign] return ch, handled def _make_channel( *, phone_number: str = "+10000000000", dm_enabled: bool = True, dm_policy: str = "open", dm_allow_from: list[str] | None = None, group_enabled: bool = False, group_policy: str = "open", group_allow_from: list[str] | None = None, require_mention: bool = True, group_buffer_size: int = 20, attachments_dir: str | None = None, ) -> SignalChannel: config = SignalConfig( enabled=True, phone_number=phone_number, dm=SignalDMConfig( enabled=dm_enabled, policy=dm_policy, allow_from=dm_allow_from or [], ), group=SignalGroupConfig( enabled=group_enabled, policy=group_policy, allow_from=group_allow_from or [], require_mention=require_mention, ), group_message_buffer_size=group_buffer_size, attachments_dir=attachments_dir, ) return SignalChannel(config, MessageBus()) def _dm_envelope( *, source_number: str = "+19995550001", source_uuid: str | None = None, source_name: str | None = "Alice", message: str = "hello", attachments: list | None = None, reaction: dict | None = None, timestamp: int = 1000, ) -> dict: data_message: dict = {"message": message, "timestamp": timestamp} if attachments is not None: data_message["attachments"] = attachments if reaction is not None: data_message["reaction"] = reaction envelope: dict = { "sourceNumber": source_number, "sourceName": source_name, "dataMessage": data_message, } if source_uuid: envelope["sourceUuid"] = source_uuid return {"envelope": envelope} def _group_envelope( *, source_number: str = "+19995550001", source_name: str = "Bob", group_id: str = "group123==", message: str = "hey group", mentions: list | None = None, timestamp: int = 2000, use_v2: bool = False, ) -> dict: group_obj = {"groupId": group_id} key = "groupV2" if use_v2 else "groupInfo" data_message: dict = { "message": message, "timestamp": timestamp, key: group_obj, "mentions": mentions or [], } return { "envelope": { "sourceNumber": source_number, "sourceName": source_name, "dataMessage": data_message, } } # --------------------------------------------------------------------------- # Static utility tests # --------------------------------------------------------------------------- class TestNormalizeSignalId: def test_phone_number_kept_and_stripped(self): result = SignalChannel._normalize_signal_id("+12345678901") assert "+12345678901" in result assert "12345678901" in result def test_digits_only_gets_plus_prefix(self): result = SignalChannel._normalize_signal_id("12345678901") assert "+12345678901" in result def test_lowercase_variant_added(self): result = SignalChannel._normalize_signal_id("SOME-UUID") assert "some-uuid" in result def test_empty_string_returns_empty(self): assert SignalChannel._normalize_signal_id("") == [] def test_whitespace_stripped(self): result = SignalChannel._normalize_signal_id(" +1234 ") assert "+1234" in result class TestCollectSenderIdParts: def test_collects_source_number(self): env = {"sourceNumber": "+15551234567"} parts = SignalChannel._collect_sender_id_parts(env) assert "+15551234567" in parts def test_collects_multiple_keys(self): env = {"sourceNumber": "+15551234567", "sourceUuid": "uuid-abc"} parts = SignalChannel._collect_sender_id_parts(env) assert "+15551234567" in parts assert "uuid-abc" in parts def test_deduplicates(self): env = {"sourceNumber": "+15551234567", "source": "+15551234567"} parts = SignalChannel._collect_sender_id_parts(env) assert parts.count("+15551234567") == 1 def test_ignores_non_string_values(self): env = {"sourceNumber": 12345, "sourceUuid": None} parts = SignalChannel._collect_sender_id_parts(env) assert parts == [] def test_empty_envelope_returns_empty(self): assert SignalChannel._collect_sender_id_parts({}) == [] class TestPrimarySenderId: def test_prefers_phone_number(self): assert SignalChannel._primary_sender_id(["+1234", "uuid-abc"]) == "+1234" def test_accepts_digit_only(self): assert SignalChannel._primary_sender_id(["1234567890", "uuid-abc"]) == "1234567890" def test_falls_back_to_first_part(self): assert SignalChannel._primary_sender_id(["uuid-abc", "other"]) == "uuid-abc" def test_empty_list_returns_empty(self): assert SignalChannel._primary_sender_id([]) == "" class TestExtractGroupId: def test_extracts_from_group_info(self): gid = SignalChannel._extract_group_id({"groupId": "abc=="}, None) assert gid == "abc==" def test_extracts_from_group_v2(self): gid = SignalChannel._extract_group_id(None, {"id": "xyz=="}) assert gid == "xyz==" def test_prefers_group_info_over_v2(self): gid = SignalChannel._extract_group_id({"groupId": "first"}, {"groupId": "second"}) assert gid == "first" def test_returns_none_when_both_none(self): assert SignalChannel._extract_group_id(None, None) is None def test_returns_none_when_not_dicts(self): assert SignalChannel._extract_group_id("bad", 123) is None class TestIsGroupChatId: def test_base64_with_equals_is_group(self): assert SignalChannel._is_group_chat_id("abc==") is True def test_long_id_without_dash_is_group(self): long_id = "a" * 41 assert SignalChannel._is_group_chat_id(long_id) is True def test_phone_number_is_not_group(self): assert SignalChannel._is_group_chat_id("+12345678901") is False def test_uuid_with_dashes_is_not_group(self): assert SignalChannel._is_group_chat_id("550e8400-e29b-41d4-a716-446655440000") is False class TestRecipientParams: def test_group_chat_uses_group_id(self): ch = _make_channel() params = ch._recipient_params("abc==") assert params == {"groupId": "abc=="} def test_dm_uses_recipient_list(self): ch = _make_channel() params = ch._recipient_params("+12345678901") assert params == {"recipient": ["+12345678901"]} class TestMentionHelpers: def test_mention_id_candidates_extracts_number(self): mention = {"number": "+1234567890"} ids = SignalChannel._mention_id_candidates(mention) assert "+1234567890" in ids def test_mention_id_candidates_extracts_uuid(self): mention = {"uuid": "some-uuid"} ids = SignalChannel._mention_id_candidates(mention) assert "some-uuid" in ids def test_mention_span_valid(self): assert SignalChannel._mention_span({"start": 0, "length": 5}) == (0, 5) def test_mention_span_negative_start(self): assert SignalChannel._mention_span({"start": -1, "length": 5}) is None def test_mention_span_zero_length(self): assert SignalChannel._mention_span({"start": 0, "length": 0}) is None def test_mention_span_missing_keys(self): assert SignalChannel._mention_span({}) is None def test_leading_placeholder_ufffc(self): span = SignalChannel._leading_placeholder_span(" hello") assert span == (0, 1) def test_leading_placeholder_not_at_start(self): assert SignalChannel._leading_placeholder_span("hello ") is None def test_leading_placeholder_empty_string(self): assert SignalChannel._leading_placeholder_span("") is None def test_leading_placeholder_plain_text(self): assert SignalChannel._leading_placeholder_span("hello") is None # --------------------------------------------------------------------------- # Account ID alias / mention matching # --------------------------------------------------------------------------- class TestAccountIdAliases: def test_phone_number_alias_registered_on_init(self): ch = _make_channel(phone_number="+10000000000") assert ch._id_matches_account("+10000000000") def test_digit_only_variant_matches(self): ch = _make_channel(phone_number="+10000000000") assert ch._id_matches_account("10000000000") def test_remember_alias_adds_uuid(self): ch = _make_channel() ch._remember_account_id_alias("some-uuid-abc") assert ch._id_matches_account("some-uuid-abc") def test_non_matching_id_returns_false(self): ch = _make_channel(phone_number="+10000000000") assert not ch._id_matches_account("+19999999999") def test_none_and_non_string_return_false(self): ch = _make_channel() assert not ch._id_matches_account(None) # --------------------------------------------------------------------------- # _should_respond_in_group # --------------------------------------------------------------------------- class TestShouldRespondInGroup: def _make_group_channel(self, require_mention: bool = True) -> SignalChannel: return _make_channel( phone_number="+10000000000", group_enabled=True, require_mention=require_mention, ) def test_no_require_mention_always_responds(self): ch = self._make_group_channel(require_mention=False) assert ch._should_respond_in_group("anything", []) is True def test_require_mention_with_no_mentions_returns_false(self): ch = self._make_group_channel(require_mention=True) assert ch._should_respond_in_group("hello", []) is False def test_require_mention_with_bot_number_mention(self): ch = self._make_group_channel(require_mention=True) mentions = [{"number": "+10000000000", "start": 0, "length": 12}] assert ch._should_respond_in_group(" hello", mentions) is True def test_require_mention_with_uuid_mention(self): ch = self._make_group_channel(require_mention=True) ch._remember_account_id_alias("bot-uuid-123") mentions = [{"uuid": "bot-uuid-123", "start": 0, "length": 8}] assert ch._should_respond_in_group(" hello", mentions) is True def test_identifier_less_leading_mention_accepted(self): ch = self._make_group_channel(require_mention=True) # Mention with no IDs but leading span — treated as bot mention mentions = [{"start": 0, "length": 1}] assert ch._should_respond_in_group(" hello", mentions) is True def test_identifier_less_non_leading_mention_rejected(self): ch = self._make_group_channel(require_mention=True) mentions = [{"start": 5, "length": 1}] assert ch._should_respond_in_group("hello ", mentions) is False def test_leading_placeholder_without_mentions_metadata(self): ch = self._make_group_channel(require_mention=True) assert ch._should_respond_in_group(" hello", []) is True def test_phone_number_in_text_triggers_response(self): ch = self._make_group_channel(require_mention=True) assert ch._should_respond_in_group("hey +10000000000 help", []) is True # --------------------------------------------------------------------------- # _strip_bot_mention # --------------------------------------------------------------------------- class TestStripBotMention: def _make_channel_with_number(self) -> SignalChannel: return _make_channel(phone_number="+10000000000") def test_strips_mention_by_phone(self): ch = self._make_channel_with_number() text = " hello" mentions = [{"number": "+10000000000", "start": 0, "length": 1}] result = ch._strip_bot_mention(text, mentions) assert result == "hello" def test_strips_identifier_less_leading_mention(self): ch = self._make_channel_with_number() text = " hello" mentions = [{"start": 0, "length": 1}] result = ch._strip_bot_mention(text, mentions) assert result == "hello" def test_strips_leading_placeholder_without_mention_metadata(self): ch = self._make_channel_with_number() text = " hello" result = ch._strip_bot_mention(text, []) assert result == "hello" def test_non_bot_mention_mid_text_not_stripped(self): # A non-bot mention that is NOT a leading placeholder leaves the text alone. ch = self._make_channel_with_number() text = "hello  world" mentions = [{"number": "+19999999999", "start": 6, "length": 1}] result = ch._strip_bot_mention(text, mentions) # Mid-text placeholder from a non-bot mention should be untouched assert "" in result def test_empty_text_returned_unchanged(self): ch = self._make_channel_with_number() assert ch._strip_bot_mention("", []) == "" # --------------------------------------------------------------------------- # Group message buffer # --------------------------------------------------------------------------- class TestGroupBuffer: def test_add_and_get_context(self): ch = _make_channel(group_buffer_size=5) ch._add_to_group_buffer("g1", "Alice", "+1111", "first msg", 1000) ch._add_to_group_buffer("g1", "Bob", "+2222", "second msg", 2000) # Only messages before the latest are returned as context ctx = ch._get_group_buffer_context("g1") assert "first msg" in ctx # The last message is not included (it's the "current" one) assert "second msg" not in ctx def test_empty_context_when_only_one_message(self): ch = _make_channel(group_buffer_size=5) ch._add_to_group_buffer("g1", "Alice", "+1111", "only msg", 1000) assert ch._get_group_buffer_context("g1") == "" def test_empty_context_when_group_unknown(self): ch = _make_channel() assert ch._get_group_buffer_context("unknown") == "" def test_buffer_respects_max_size(self): ch = _make_channel(group_buffer_size=3) for i in range(10): ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i) assert len(ch._group_buffers["g1"]) == 3 def test_zero_buffer_size_rejected_by_validator(self): with pytest.raises(ValueError, match="group_message_buffer_size"): _make_channel(group_buffer_size=0) def test_negative_buffer_size_rejected_by_validator(self): with pytest.raises(ValueError, match="group_message_buffer_size"): _make_channel(group_buffer_size=-1) def test_context_limits_message_length(self): ch = _make_channel(group_buffer_size=5) long_msg = "x" * 500 ch._add_to_group_buffer("g1", "Alice", "+1111", long_msg, 1000) ch._add_to_group_buffer("g1", "Bob", "+2222", "short", 2000) ctx = ch._get_group_buffer_context("g1") # Context is capped at 200 chars per message assert len(ctx.split("Alice: ", 1)[1]) <= 200 # --------------------------------------------------------------------------- # _handle_data_message — DM routing # --------------------------------------------------------------------------- class TestIsAllowed: """The base-channel allowlist gate is overridden to understand Signal's pipe-joined composite sender_ids and the +/no-+ phone variants. """ def test_denies_when_allowlist_empty(self): ch = _make_channel(dm_enabled=True, dm_policy="open") # open -> no entries assert ch.is_allowed("+19995550001") is False def test_allows_wildcard(self): ch = _make_channel(dm_policy="allowlist", dm_allow_from=["*"]) assert ch.is_allowed("+19995550001|some-uuid") is True def test_allows_composite_sender_against_split_allowlist(self): """Composite sender_id, single-id allow_from — must match either part.""" ch = _make_channel( dm_policy="allowlist", dm_allow_from=["+19995550001"], ) assert ch.is_allowed("+19995550001|1872ba20-uuid") is True def test_allows_composite_sender_against_composite_allowlist_entry(self): """Backward compat: pipe-joined composite allowlist entries still match.""" composite = "+19995550001|1872ba20-uuid" ch = _make_channel(dm_policy="allowlist", dm_allow_from=[composite]) assert ch.is_allowed(composite) is True def test_allows_when_only_uuid_part_is_listed(self): ch = _make_channel(dm_policy="allowlist", dm_allow_from=["1872ba20-uuid"]) assert ch.is_allowed("+19995550001|1872ba20-uuid") is True def test_denies_when_no_part_matches(self): ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+12223334444"]) assert ch.is_allowed("+19995550001|1872ba20-uuid") is False def test_allowlist_union_includes_group_ids(self): """allow_from is the union of dm.allow_from and group.allow_from.""" ch = _make_channel( group_enabled=True, group_policy="allowlist", group_allow_from=["group-id-base64=="], ) assert "group-id-base64==" in ch.config.allow_from class TestCheckInboundPolicy: """Direct tests for the policy gate that _handle_data_message now delegates to.""" def _call( self, ch: SignalChannel, *, sender_id: str = "+19995550001", sender_number: str = "+19995550001", group_id: str | None = None, is_group_message: bool = False, message_text: str = "hi", mentions: list | None = None, sender_name: str | None = "Alice", timestamp: int | None = 1000, ) -> tuple[bool, str]: return ch._check_inbound_policy( sender_id=sender_id, sender_number=sender_number, group_id=group_id, is_group_message=is_group_message, message_text=message_text, mentions=mentions or [], sender_name=sender_name, timestamp=timestamp, ) def test_dm_open_allows(self): ch = _make_channel(dm_enabled=True, dm_policy="open") allowed, chat_id = self._call(ch) assert allowed is True assert chat_id == "+19995550001" def test_dm_disabled_blocks(self): ch = _make_channel(dm_enabled=False) allowed, _ = self._call(ch) assert allowed is False def test_dm_allowlist_blocks_unknown_sender(self): ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+12223334444"]) allowed, _ = self._call(ch, sender_id="+19995550001") assert allowed is False def test_dm_allowlist_allows_known_sender(self): ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+19995550001"]) allowed, _ = self._call(ch, sender_id="+19995550001") assert allowed is True def test_group_disabled_blocks(self): ch = _make_channel(group_enabled=False) allowed, _ = self._call(ch, is_group_message=True, group_id="g1") assert allowed is False def test_group_open_with_mention_allows(self): ch = _make_channel( group_enabled=True, group_policy="open", phone_number="+10000000000", require_mention=True, ) allowed, chat_id = self._call( ch, is_group_message=True, group_id="g1", message_text="hello @bot", mentions=[{"number": "+10000000000", "start": 6, "length": 4}], ) assert allowed is True assert chat_id == "g1" def test_group_open_without_mention_blocks(self): ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True) allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="plain talk") assert allowed is False def test_group_command_bypasses_mention_requirement(self): ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True) allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="/help") assert allowed is True def test_allowed_group_appends_to_buffer(self): """Side effect: when a group message is allowed, it lands in the buffer.""" ch = _make_channel(group_enabled=True, group_policy="open", require_mention=False) self._call(ch, is_group_message=True, group_id="g1", message_text="first") self._call(ch, is_group_message=True, group_id="g1", message_text="second") assert len(ch._group_buffers["g1"]) == 2 def test_blocked_group_does_not_append_to_buffer(self): """Side effect: when a group is disabled, the buffer must not change.""" ch = _make_channel(group_enabled=False) self._call(ch, is_group_message=True, group_id="g1", message_text="hi") assert "g1" not in ch._group_buffers class TestAttachmentsDir: def test_default_attachments_dir(self): ch = _make_channel() expected = Path.home() / ".local/share/signal-cli/attachments" assert ch._signal_attachments_dir() == expected def test_configured_attachments_dir(self, tmp_path): ch = _make_channel(attachments_dir=str(tmp_path / "custom")) assert ch._signal_attachments_dir() == tmp_path / "custom" def test_attachments_dir_expands_user(self): ch = _make_channel(attachments_dir="~/signal-attachments") assert ch._signal_attachments_dir() == Path.home() / "signal-attachments" class TestHandleDataMessageDM: def _make_dm_channel(self, policy="open", allow_from=None) -> tuple[SignalChannel, list]: return _make_channel_with_capture( dm_enabled=True, dm_policy=policy, dm_allow_from=allow_from or [] ) @pytest.mark.asyncio async def test_dm_open_policy_accepted(self): ch, handled = self._make_dm_channel(policy="open") params = _dm_envelope(source_number="+19995550001", message="hi") await ch._handle_receive_notification(params) assert len(handled) == 1 assert handled[0]["chat_id"] == "+19995550001" assert handled[0]["content"] == "hi" @pytest.mark.asyncio async def test_dm_allowlist_accepted(self): ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"]) params = _dm_envelope(source_number="+19995550001") await ch._handle_receive_notification(params) assert len(handled) == 1 @pytest.mark.asyncio async def test_dm_allowlist_rejected(self): ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+10000000001"]) params = _dm_envelope(source_number="+19995550002") await ch._handle_receive_notification(params) assert handled == [] @pytest.mark.asyncio async def test_dm_allowlist_matches_without_plus_prefix(self): """An allowlist entry without '+' must match a sender that carries '+'.""" ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["19995550001"]) params = _dm_envelope(source_number="+19995550001") await ch._handle_receive_notification(params) assert len(handled) == 1 @pytest.mark.asyncio async def test_dm_allowlist_matches_with_plus_prefix(self): """An allowlist entry with '+' must match a sender without '+'.""" ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"]) params = _dm_envelope(source_number="+19995550001", source_uuid=None) # Replace envelope's sourceNumber with the non-prefixed form by editing # the constructed dict directly so _collect_sender_id_parts sees it. params["envelope"]["sourceNumber"] = "19995550001" await ch._handle_receive_notification(params) assert len(handled) == 1 @pytest.mark.asyncio async def test_dm_allowlist_matches_uuid_case_insensitive(self): """UUID matching must be case-insensitive.""" uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890" ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[uuid.lower()]) params = _dm_envelope(source_number="+19995550001", source_uuid=uuid) await ch._handle_receive_notification(params) assert len(handled) == 1 @pytest.mark.asyncio async def test_dm_allowlist_matches_pipe_joined_composite_entry(self): """Allowlist entries written as ``phone|uuid`` composites still work. Some configs pre-date the per-part splitting and store the full sender_id composite as a single allow_from entry. Keep matching it. """ composite = "+19995550001|1872ba20-f52a-4bad-b434-bf7f808c8b22" ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[composite]) params = _dm_envelope( source_number="+19995550001", source_uuid="1872ba20-f52a-4bad-b434-bf7f808c8b22", ) await ch._handle_receive_notification(params) assert len(handled) == 1 @pytest.mark.asyncio async def test_dm_disabled_rejected(self): ch = _make_channel(dm_enabled=False) handled: list[dict] = [] async def capture(**kwargs): handled.append(kwargs) ch._handle_message = capture # type: ignore[method-assign] async def noop_typing(chat_id): pass ch._start_typing = noop_typing # type: ignore[method-assign] params = _dm_envelope(source_number="+19995550001") await ch._handle_receive_notification(params) assert handled == [] @pytest.mark.asyncio async def test_reaction_message_ignored(self): ch, handled = self._make_dm_channel() params = _dm_envelope(reaction={"emoji": "👍", "targetTimestamp": 999}) await ch._handle_receive_notification(params) assert handled == [] @pytest.mark.asyncio async def test_empty_message_ignored(self): ch, handled = self._make_dm_channel() params = _dm_envelope(message="") await ch._handle_receive_notification(params) assert handled == [] @pytest.mark.asyncio async def test_receipt_message_ignored(self): ch, handled = self._make_dm_channel() notification = { "envelope": { "sourceNumber": "+19995550001", "receiptMessage": {"when": 1234}, } } await ch._handle_receive_notification(notification) assert handled == [] @pytest.mark.asyncio async def test_typing_indicator_ignored(self): ch, handled = self._make_dm_channel() notification = { "envelope": { "sourceNumber": "+19995550001", "typingMessage": {"action": "STARTED"}, } } await ch._handle_receive_notification(notification) assert handled == [] @pytest.mark.asyncio async def test_missing_envelope_ignored(self): ch, handled = self._make_dm_channel() await ch._handle_receive_notification({}) assert handled == [] @pytest.mark.asyncio async def test_metadata_passed_to_handle(self): ch, handled = self._make_dm_channel() params = _dm_envelope(source_number="+19995550001", source_name="Alice", timestamp=9999) await ch._handle_receive_notification(params) meta = handled[0]["metadata"] assert meta["sender_name"] == "Alice" assert meta["timestamp"] == 9999 assert meta["is_group"] is False @pytest.mark.asyncio async def test_sender_id_with_uuid_variant(self): ch, handled = self._make_dm_channel() params = _dm_envelope(source_number="+19995550001", source_uuid="uuid-abc") await ch._handle_receive_notification(params) assert len(handled) == 1 # sender_id combines both parts assert "+19995550001" in handled[0]["sender_id"] assert "uuid-abc" in handled[0]["sender_id"] @pytest.mark.asyncio async def test_stop_typing_called_on_handle_error(self): ch = _make_channel(dm_enabled=True, dm_policy="open") typing_stopped: list[str] = [] async def fail_handle(**kwargs): raise RuntimeError("boom") async def noop_typing(chat_id): pass async def record_stop(chat_id, **kwargs): typing_stopped.append(chat_id) ch._handle_message = fail_handle # type: ignore[method-assign] ch._start_typing = noop_typing # type: ignore[method-assign] ch._stop_typing = record_stop # type: ignore[method-assign] # _handle_receive_notification swallows exceptions; the typing stop # still fires from _handle_data_message's except clause. params = _dm_envelope(source_number="+19995550001") await ch._handle_receive_notification(params) assert "+19995550001" in typing_stopped # --------------------------------------------------------------------------- # _handle_data_message — group routing # --------------------------------------------------------------------------- class TestHandleDataMessageGroup: def _make_group_channel( self, policy="open", allow_from=None, require_mention=True, ) -> tuple[SignalChannel, list]: return _make_channel_with_capture( group_enabled=True, group_policy=policy, group_allow_from=allow_from or [], require_mention=require_mention, ) @pytest.mark.asyncio async def test_group_disabled_rejected(self): ch = _make_channel(group_enabled=False) handled: list[dict] = [] ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] params = _group_envelope(group_id="grp==", message="hi") await ch._handle_receive_notification(params) assert handled == [] @pytest.mark.asyncio async def test_group_open_policy_no_mention_blocked_when_required(self): ch, handled = self._make_group_channel(require_mention=True) params = _group_envelope(group_id="grp==", message="hey everyone") await ch._handle_receive_notification(params) assert handled == [] @pytest.mark.asyncio async def test_group_open_policy_no_mention_required(self): ch, handled = self._make_group_channel(require_mention=False) params = _group_envelope(group_id="grp==", message="hey everyone") await ch._handle_receive_notification(params) assert len(handled) == 1 assert handled[0]["chat_id"] == "grp==" @pytest.mark.asyncio async def test_group_allowlist_accepted(self): ch, handled = self._make_group_channel( policy="allowlist", allow_from=["grp=="], require_mention=False ) params = _group_envelope(group_id="grp==", message="hi") await ch._handle_receive_notification(params) assert len(handled) == 1 @pytest.mark.asyncio async def test_group_allowlist_rejected(self): ch, handled = self._make_group_channel(policy="allowlist", allow_from=["other=="]) params = _group_envelope(group_id="grp==", message="hi") await ch._handle_receive_notification(params) assert handled == [] @pytest.mark.asyncio async def test_group_mention_triggers_response(self): ch, handled = self._make_group_channel(require_mention=True) ch._remember_account_id_alias("+10000000000") mentions = [{"number": "+10000000000", "start": 0, "length": 1}] params = _group_envelope(group_id="grp==", message=" hello", mentions=mentions) await ch._handle_receive_notification(params) assert len(handled) == 1 @pytest.mark.asyncio async def test_group_v2_id_extracted(self): ch, handled = self._make_group_channel(require_mention=False) params = _group_envelope(group_id="grpV2==", message="hi", use_v2=True) await ch._handle_receive_notification(params) assert len(handled) == 1 assert handled[0]["chat_id"] == "grpV2==" @pytest.mark.asyncio async def test_group_message_includes_sender_prefix(self): ch, handled = self._make_group_channel(require_mention=False) params = _group_envelope(group_id="grp==", source_name="Bob", message="hello") await ch._handle_receive_notification(params) assert "[Bob]:" in handled[0]["content"] @pytest.mark.asyncio async def test_group_message_context_prepended(self): ch, handled = self._make_group_channel(require_mention=False) # First message — adds to buffer but no context yet params1 = _group_envelope(group_id="grp==", source_name="Alice", message="msg1") await ch._handle_receive_notification(params1) # Second message — should include context from first params2 = _group_envelope(group_id="grp==", source_name="Bob", message="msg2") await ch._handle_receive_notification(params2) assert "[Recent group messages for context:]" in handled[1]["content"] assert "msg1" in handled[1]["content"] @pytest.mark.asyncio async def test_group_metadata_marks_is_group(self): ch, handled = self._make_group_channel(require_mention=False) params = _group_envelope(group_id="grp==", message="hi") await ch._handle_receive_notification(params) assert handled[0]["metadata"]["is_group"] is True assert handled[0]["metadata"]["group_id"] == "grp==" @pytest.mark.asyncio async def test_bot_account_alias_learned_from_incoming(self): ch, handled = self._make_group_channel(require_mention=False) # If the bot's own UUID appears in an envelope we learn it params = _dm_envelope(source_number="+10000000000", source_uuid="new-bot-uuid") # DMs from self are processed (learning alias), but DM policy is open ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] ch._start_typing = lambda chat_id: None # type: ignore[method-assign] await ch._handle_receive_notification(params) assert ch._id_matches_account("new-bot-uuid") # --------------------------------------------------------------------------- # Lifecycle / SSE # --------------------------------------------------------------------------- class _FakeSSEResponse: """Minimal stand-in for httpx Response under stream().""" def __init__(self, lines: list[str], status_code: int = 200) -> None: self.status_code = status_code self._lines = lines async def aiter_lines(self): for line in self._lines: yield line def _fake_streaming_client(lines: list[str], *, status_code: int = 200) -> MagicMock: """Return an httpx.AsyncClient stand-in whose .stream() yields a FakeSSEResponse.""" response = _FakeSSEResponse(lines, status_code=status_code) @asynccontextmanager async def _ctx(*_args, **_kwargs): yield response http = MagicMock() http.stream = lambda *a, **kw: _ctx(*a, **kw) return http class TestLifecycle: @pytest.mark.asyncio async def test_start_returns_early_when_phone_missing(self): """start() with an empty phone number must not enter the HTTP loop.""" ch = _make_channel(phone_number="") await ch.start() assert ch._running is False assert ch._http is None assert ch._sse_task is None class TestSSEReceiveLoop: @pytest.mark.asyncio async def test_dispatches_valid_envelope(self): ch = _make_channel() ch._running = True captured: list[dict] = [] async def capture(params): captured.append(params) ch._handle_receive_notification = capture # type: ignore[method-assign] ch._http = _fake_streaming_client( ['data: {"envelope":{"sourceNumber":"+19995550001"}}', ""] ) # Loop ends when lines exhaust; the surrounding _start_http_mode would # treat that as a disconnect, but the loop itself raises ConnectionError # when the stream closes while still running. with pytest.raises(ConnectionError): await ch._sse_receive_loop() assert captured == [{"envelope": {"sourceNumber": "+19995550001"}}] @pytest.mark.asyncio async def test_handles_invalid_json_frame(self): """An unparseable SSE frame is logged and skipped without crashing.""" ch = _make_channel() ch._running = True captured: list[dict] = [] async def capture(params): captured.append(params) ch._handle_receive_notification = capture # type: ignore[method-assign] ch._http = _fake_streaming_client( [ "data: this-is-not-json", "", # event boundary triggers parse attempt 'data: {"envelope":{"sourceNumber":"+1"}}', "", ] ) with pytest.raises(ConnectionError): await ch._sse_receive_loop() # Bad frame skipped; good frame still dispatched. assert captured == [{"envelope": {"sourceNumber": "+1"}}] @pytest.mark.asyncio async def test_non_200_status_raises(self): ch = _make_channel() ch._running = True ch._http = _fake_streaming_client([], status_code=503) with pytest.raises(ConnectionError, match="status 503"): await ch._sse_receive_loop() @pytest.mark.asyncio async def test_no_http_client_raises(self): ch = _make_channel() ch._http = None with pytest.raises(RuntimeError, match="HTTP client not initialized"): await ch._sse_receive_loop() # --------------------------------------------------------------------------- # Command handling # --------------------------------------------------------------------------- class TestCommandHandling: @pytest.mark.asyncio async def test_dm_command_forwarded_to_bus(self): """Slash commands in DMs are forwarded to the bus for AgentLoop to handle.""" ch, forwarded = _make_channel_with_capture(dm_enabled=True, dm_policy="open") params = _dm_envelope(source_number="+19995550001", message="/reset") await ch._handle_receive_notification(params) assert len(forwarded) == 1 assert forwarded[0]["content"].strip() == "/reset" @pytest.mark.asyncio async def test_group_command_bypasses_mention_requirement(self): """Slash commands in groups bypass the mention requirement and reach the bus.""" ch, forwarded = _make_channel_with_capture( group_enabled=True, group_policy="open", require_mention=True ) params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset") await ch._handle_receive_notification(params) assert len(forwarded) == 1 assert "/reset" in forwarded[0]["content"] @pytest.mark.asyncio async def test_command_denied_for_disallowed_dm_sender(self): """Commands from senders not on the DM allowlist are dropped.""" ch, forwarded = _make_channel_with_capture(dm_enabled=False) params = _dm_envelope(source_number="+19995550001", message="/reset") await ch._handle_receive_notification(params) assert forwarded == [] # --------------------------------------------------------------------------- # send() — outbound messages # --------------------------------------------------------------------------- class TestSend: def _make_send_channel(self) -> tuple[SignalChannel, _FakeHTTPClient]: ch = _make_channel() client = _FakeHTTPClient() ch._http = client # type: ignore[assignment] return ch, client @pytest.mark.asyncio async def test_send_plain_text_posts_rpc(self): ch, client = self._make_send_channel() msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") await ch.send(msg) assert len(client.posts) == 1 payload = client.posts[0]["json"] assert payload["method"] == "send" assert payload["params"]["message"] == "hello" @pytest.mark.asyncio async def test_send_with_markdown_includes_text_styles(self): ch, client = self._make_send_channel() msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="**bold**") await ch.send(msg) params = client.posts[0]["json"]["params"] assert "textStyle" in params assert any("BOLD" in s for s in params["textStyle"]) @pytest.mark.asyncio async def test_send_split_message_redistributes_text_styles(self): """Long message split across chunks: each chunk gets its own textStyle with offsets rebased to that chunk.""" ch, client = self._make_send_channel() ch._MAX_MESSAGE_LEN = 12 # type: ignore[attr-defined] msg = OutboundMessage( channel="signal", chat_id="+19995550001", content="**head** middle and **tail**", ) await ch.send(msg) assert len(client.posts) >= 2 # Chunk 0 has BOLD for "head"; chunk 1+ must also carry BOLD for "tail". bold_chunks = [ p["json"]["params"] for p in client.posts if any("BOLD" in s for s in p["json"]["params"].get("textStyle", [])) ] assert len(bold_chunks) >= 2, ( "expected BOLD ranges in more than one chunk; got " f"{[p['json']['params'] for p in client.posts]}" ) # Each emitted range must point inside its own chunk's text. for params in bold_chunks: chunk_text = params["message"] for entry in params["textStyle"]: s, ln, _ = entry.split(":", 2) start, length = int(s), int(ln) end_units = start + length assert end_units <= len(chunk_text.encode("utf-16-le")) // 2 @pytest.mark.asyncio async def test_send_empty_content_skips_rpc(self): ch, client = self._make_send_channel() msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="") await ch.send(msg) assert client.posts == [] @pytest.mark.asyncio async def test_send_to_group_uses_group_id(self): ch, client = self._make_send_channel() msg = OutboundMessage(channel="signal", chat_id="grp==", content="hi group") await ch.send(msg) params = client.posts[0]["json"]["params"] assert "groupId" in params assert "recipient" not in params @pytest.mark.asyncio async def test_send_to_dm_uses_recipient(self): ch, client = self._make_send_channel() msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hi") await ch.send(msg) params = client.posts[0]["json"]["params"] assert "recipient" in params @pytest.mark.asyncio async def test_send_with_media_includes_attachments(self): ch, client = self._make_send_channel() msg = OutboundMessage( channel="signal", chat_id="+19995550001", content="see attachment", media=["/tmp/file.jpg"], ) await ch.send(msg) params = client.posts[0]["json"]["params"] assert params.get("attachments") == ["/tmp/file.jpg"] @pytest.mark.asyncio async def test_send_progress_message_does_not_stop_typing(self): ch, client = self._make_send_channel() stopped: list[str] = [] async def record_stop(chat_id, **kwargs): stopped.append(chat_id) ch._stop_typing = record_stop # type: ignore[method-assign] msg = OutboundMessage( channel="signal", chat_id="+19995550001", content="working...", metadata={"_progress": True}, ) await ch.send(msg) # Progress messages should NOT stop the typing indicator assert stopped == [] @pytest.mark.asyncio async def test_send_final_message_stops_typing(self): ch, client = self._make_send_channel() stopped: list[str] = [] async def record_stop(chat_id, send_stop=True): stopped.append(chat_id) ch._stop_typing = record_stop # type: ignore[method-assign] msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="done") await ch.send(msg) assert "+19995550001" in stopped @pytest.mark.asyncio async def test_send_logs_daemon_error_without_raising(self): ch = _make_channel() # The daemon returns {"error": {...}} in the JSON body — this is not a Python # exception; send() logs it but does not raise (only HTTP-level exceptions raise). ch._http = _FakeHTTPClient(default_response={"error": {"message": "fail"}}) msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") await ch.send(msg) # must not raise # --------------------------------------------------------------------------- # stop() # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_stop_cancels_sse_task() -> None: ch = _make_channel() cancelled = False async def long_running(): nonlocal cancelled try: await asyncio.sleep(9999) except asyncio.CancelledError: cancelled = True raise ch._sse_task = asyncio.create_task(long_running()) # Yield so the task can enter its body (reach the first await) before cancel. await asyncio.sleep(0) ch._running = True await ch.stop() assert cancelled assert ch._running is False @pytest.mark.asyncio async def test_stop_closes_http_client() -> None: ch = _make_channel() client = _FakeHTTPClient() ch._http = client # type: ignore[assignment] ch._running = True await ch.stop() assert client.closed @pytest.mark.asyncio async def test_stop_safe_when_no_sse_task() -> None: ch = _make_channel() ch._running = True # Should not raise even with no _sse_task await ch.stop() assert ch._running is False # --------------------------------------------------------------------------- # _send_request / _send_http_request # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_send_request_increments_id() -> None: ch = _make_channel() client = _FakeHTTPClient() ch._http = client # type: ignore[assignment] await ch._send_request("testMethod", {"key": "val"}) await ch._send_request("testMethod", {"key": "val"}) ids = [p["json"]["id"] for p in client.posts] assert ids == [1, 2] @pytest.mark.asyncio async def test_send_request_raises_when_not_connected() -> None: ch = _make_channel() # _http is None by default with pytest.raises(RuntimeError, match="Not connected"): await ch._send_request("testMethod") # --------------------------------------------------------------------------- # _handle_receive_notification — envelope shapes # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_handle_notification_sync_message_does_not_forward() -> None: ch = _make_channel(dm_enabled=True, dm_policy="open") handled: list[dict] = [] ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] notification = { "envelope": { "sourceNumber": "+19995550001", "syncMessage": { "sentMessage": { "destination": "+19990000000", "message": "sent from other device", } }, } } await ch._handle_receive_notification(notification) assert handled == [] @pytest.mark.asyncio async def test_handle_notification_no_source_skipped() -> None: ch = _make_channel(dm_enabled=True, dm_policy="open") handled: list[dict] = [] ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] notification = {"envelope": {"dataMessage": {"message": "ghost"}}} await ch._handle_receive_notification(notification) assert handled == [] # --------------------------------------------------------------------------- # Config: allow_from property aggregation # --------------------------------------------------------------------------- def test_config_allow_from_aggregates_dm_and_group() -> None: config = SignalConfig( enabled=True, phone_number="+10000000000", dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]), group=SignalGroupConfig(enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]), ) combined = config.allow_from assert "+1111" in combined assert "+2222" in combined assert "+3333" in combined # Duplicates removed assert combined.count("+1111") == 1 def test_config_allow_from_wildcard_propagates() -> None: config = SignalConfig( enabled=True, phone_number="+10000000000", dm=SignalDMConfig(enabled=True, policy="open", allow_from=["*"]), group=SignalGroupConfig(enabled=True, policy="open", allow_from=[]), ) assert "*" in config.allow_from