diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index 53c8a2aa6..9444e2218 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -5,7 +5,7 @@ from __future__ import annotations import asyncio from contextlib import asynccontextmanager from pathlib import Path -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest @@ -62,6 +62,24 @@ class _FakeHTTPClient: # --------------------------------------------------------------------------- +def _make_channel_with_capture(**overrides) -> tuple[SignalChannel, list[dict]]: + """Build a SignalChannel with _handle_message captured into a list and a + no-op _start_typing, used by every receive-flow test class. + """ + ch = _make_channel(**overrides) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + async def noop_typing(chat_id): + pass + + ch._handle_message = capture # type: ignore[method-assign] + ch._start_typing = noop_typing # type: ignore[method-assign] + return ch, handled + + def _make_channel( *, phone_number: str = "+10000000000", @@ -590,19 +608,9 @@ class TestAttachmentsDir: class TestHandleDataMessageDM: def _make_dm_channel(self, policy="open", allow_from=None) -> tuple[SignalChannel, list]: - ch = _make_channel(dm_enabled=True, dm_policy=policy, dm_allow_from=allow_from or []) - handled: list[dict] = [] - - async def capture(**kwargs): - handled.append(kwargs) - - ch._handle_message = capture # type: ignore[method-assign] - - async def noop_typing(chat_id): - pass - - ch._start_typing = noop_typing # type: ignore[method-assign] - return ch, handled + return _make_channel_with_capture( + dm_enabled=True, dm_policy=policy, dm_allow_from=allow_from or [] + ) @pytest.mark.asyncio async def test_dm_open_policy_accepted(self): @@ -777,24 +785,12 @@ class TestHandleDataMessageGroup: allow_from=None, require_mention=True, ) -> tuple[SignalChannel, list]: - ch = _make_channel( + return _make_channel_with_capture( group_enabled=True, group_policy=policy, group_allow_from=allow_from or [], require_mention=require_mention, ) - handled: list[dict] = [] - - async def capture(**kwargs): - handled.append(kwargs) - - ch._handle_message = capture # type: ignore[method-assign] - - async def noop_typing(chat_id): - pass - - ch._start_typing = noop_typing # type: ignore[method-assign] - return ch, handled @pytest.mark.asyncio async def test_group_disabled_rejected(self): @@ -1007,55 +1003,31 @@ class TestCommandHandling: @pytest.mark.asyncio async def test_dm_command_forwarded_to_bus(self): """Slash commands in DMs are forwarded to the bus for AgentLoop to handle.""" - ch = _make_channel(dm_enabled=True, dm_policy="open") - forwarded: list[dict] = [] - - async def capture(**kw): - forwarded.append(kw) - - ch._handle_message = capture # type: ignore[method-assign] - ch._start_typing = AsyncMock() - + ch, forwarded = _make_channel_with_capture(dm_enabled=True, dm_policy="open") params = _dm_envelope(source_number="+19995550001", message="/reset") await ch._handle_receive_notification(params) - assert len(forwarded) == 1 assert forwarded[0]["content"].strip() == "/reset" @pytest.mark.asyncio async def test_group_command_bypasses_mention_requirement(self): """Slash commands in groups bypass the mention requirement and reach the bus.""" - ch = _make_channel( + ch, forwarded = _make_channel_with_capture( group_enabled=True, group_policy="open", require_mention=True ) - forwarded: list[dict] = [] - - async def capture(**kw): - forwarded.append(kw) - - ch._handle_message = capture # type: ignore[method-assign] - ch._start_typing = AsyncMock() - - params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset") + params = _group_envelope( + source_number="+19995550001", group_id="grp==", message="/reset" + ) await ch._handle_receive_notification(params) - assert len(forwarded) == 1 assert "/reset" in forwarded[0]["content"] @pytest.mark.asyncio async def test_command_denied_for_disallowed_dm_sender(self): """Commands from senders not on the DM allowlist are dropped.""" - ch = _make_channel(dm_enabled=False) - forwarded: list[dict] = [] - - async def capture(**kw): - forwarded.append(kw) - - ch._handle_message = capture # type: ignore[method-assign] - + ch, forwarded = _make_channel_with_capture(dm_enabled=False) params = _dm_envelope(source_number="+19995550001", message="/reset") await ch._handle_receive_notification(params) - assert forwarded == []