from __future__ import annotations import asyncio from pathlib import Path from types import SimpleNamespace import pytest discord = pytest.importorskip("discord") from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig from nanobot.command.builtin import build_help_text # Minimal Discord client test double used to control startup/readiness behavior. class _FakeDiscordClient: instances: list["_FakeDiscordClient"] = [] start_error: Exception | None = None def __init__(self, owner, *, intents) -> None: self.owner = owner self.intents = intents self.closed = False self.ready = True self.channels: dict[int, object] = {} self.user = SimpleNamespace(id=999) self.__class__.instances.append(self) async def start(self, token: str) -> None: self.token = token if self.__class__.start_error is not None: raise self.__class__.start_error async def close(self) -> None: self.closed = True def is_closed(self) -> bool: return self.closed def is_ready(self) -> bool: return self.ready def get_channel(self, channel_id: int): return self.channels.get(channel_id) async def send_outbound(self, msg: OutboundMessage) -> None: channel = self.get_channel(int(msg.chat_id)) if channel is None: return await channel.send(content=msg.content) class _FakeAttachment: # Attachment double that can simulate successful or failing save() calls. def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None: self.id = attachment_id self.filename = filename self.size = size self._fail = fail async def save(self, path: str | Path) -> None: if self._fail: raise RuntimeError("save failed") Path(path).write_bytes(b"attachment") class _FakePartialMessage: # Lightweight stand-in for Discord partial message references used in replies. def __init__(self, message_id: int) -> None: self.id = message_id class _FakeChannel: # Channel double that records outbound payloads and typing activity. def __init__(self, channel_id: int = 123) -> None: self.id = channel_id self.sent_payloads: list[dict] = [] self.trigger_typing_calls = 0 self.typing_enter_hook = None async def send(self, **kwargs) -> None: payload = dict(kwargs) if "file" in payload: payload["file_name"] = payload["file"].filename del payload["file"] self.sent_payloads.append(payload) def get_partial_message(self, message_id: int) -> _FakePartialMessage: return _FakePartialMessage(message_id) def typing(self): channel = self class _TypingContext: async def __aenter__(self): channel.trigger_typing_calls += 1 if channel.typing_enter_hook is not None: await channel.typing_enter_hook() async def __aexit__(self, exc_type, exc, tb): return False return _TypingContext() class _FakeInteractionResponse: def __init__(self) -> None: self.messages: list[dict] = [] self._done = False async def send_message(self, content: str, *, ephemeral: bool = False) -> None: self.messages.append({"content": content, "ephemeral": ephemeral}) self._done = True def is_done(self) -> bool: return self._done def _make_interaction( *, user_id: int = 123, channel_id: int | None = 456, guild_id: int | None = None, interaction_id: int = 999, ): return SimpleNamespace( user=SimpleNamespace(id=user_id), channel_id=channel_id, guild_id=guild_id, id=interaction_id, command=SimpleNamespace(qualified_name="new"), response=_FakeInteractionResponse(), ) def _make_message( *, author_id: int = 123, author_bot: bool = False, channel_id: int = 456, message_id: int = 789, content: str = "hello", guild_id: int | None = None, mentions: list[object] | None = None, attachments: list[object] | None = None, reply_to: int | None = None, ): # Factory for incoming Discord message objects with optional guild/reply/attachments. guild = SimpleNamespace(id=guild_id) if guild_id is not None else None reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None return SimpleNamespace( author=SimpleNamespace(id=author_id, bot=author_bot), channel=_FakeChannel(channel_id), content=content, guild=guild, mentions=mentions or [], attachments=attachments or [], reference=reference, id=message_id, ) @pytest.mark.asyncio async def test_start_returns_when_token_missing() -> None: # If no token is configured, startup should no-op and leave channel stopped. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) await channel.start() assert channel.is_running is False assert channel._client is None @pytest.mark.asyncio async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None: channel = DiscordChannel( DiscordConfig(enabled=True, token="token", allow_from=["*"]), MessageBus(), ) monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False) await channel.start() assert channel.is_running is False assert channel._client is None @pytest.mark.asyncio async def test_start_handles_client_construction_failure(monkeypatch) -> None: # Construction errors from the Discord client should be swallowed and keep state clean. channel = DiscordChannel( DiscordConfig(enabled=True, token="token", allow_from=["*"]), MessageBus(), ) def _boom(owner, *, intents): raise RuntimeError("bad client") monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom) await channel.start() assert channel.is_running is False assert channel._client is None @pytest.mark.asyncio async def test_start_handles_client_start_failure(monkeypatch) -> None: # If client.start fails, the partially created client should be closed and detached. channel = DiscordChannel( DiscordConfig(enabled=True, token="token", allow_from=["*"]), MessageBus(), ) _FakeDiscordClient.instances.clear() _FakeDiscordClient.start_error = RuntimeError("connect failed") monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient) await channel.start() assert channel.is_running is False assert channel._client is None assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents assert _FakeDiscordClient.instances[0].closed is True _FakeDiscordClient.start_error = None @pytest.mark.asyncio async def test_stop_is_safe_after_partial_start(monkeypatch) -> None: # stop() should close/discard the client even when startup was only partially completed. channel = DiscordChannel( DiscordConfig(enabled=True, token="token", allow_from=["*"]), MessageBus(), ) client = _FakeDiscordClient(channel, intents=None) channel._client = client channel._running = True await channel.stop() assert channel.is_running is False assert client.closed is True assert channel._client is None @pytest.mark.asyncio async def test_on_message_ignores_bot_messages() -> None: # Incoming bot-authored messages must be ignored to prevent feedback loops. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) handled: list[dict] = [] channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign] await channel._on_message(_make_message(author_bot=True)) assert handled == [] # If inbound handling raises, typing should be stopped for that channel. async def fail_handle(**kwargs) -> None: raise RuntimeError("boom") channel._handle_message = fail_handle # type: ignore[method-assign] with pytest.raises(RuntimeError, match="boom"): await channel._on_message(_make_message(author_id=123, channel_id=456)) assert channel._typing_tasks == {} @pytest.mark.asyncio async def test_on_message_accepts_allowlisted_dm() -> None: # Allowed direct messages should be forwarded with normalized metadata. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789)) assert len(handled) == 1 assert handled[0]["chat_id"] == "456" assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None} @pytest.mark.asyncio async def test_on_message_ignores_unmentioned_guild_message() -> None: # With mention-only group policy, guild messages without a bot mention are dropped. channel = DiscordChannel( DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"), MessageBus(), ) channel._bot_user_id = "999" handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] await channel._on_message(_make_message(guild_id=1, content="hello everyone")) assert handled == [] @pytest.mark.asyncio async def test_on_message_accepts_mentioned_guild_message() -> None: # Mentioned guild messages should be accepted and preserve reply threading metadata. channel = DiscordChannel( DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"), MessageBus(), ) channel._bot_user_id = "999" handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] await channel._on_message( _make_message( guild_id=1, content="<@999> hello", mentions=[SimpleNamespace(id=999)], reply_to=321, ) ) assert len(handled) == 1 assert handled[0]["metadata"]["reply_to"] == "321" @pytest.mark.asyncio async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None: # Attachment downloads should be saved and referenced in forwarded content/media. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path) await channel._on_message( _make_message( attachments=[_FakeAttachment(12, "photo.png")], content="see file", ) ) assert len(handled) == 1 assert handled[0]["media"] == [str(tmp_path / "12_photo.png")] assert "[attachment:" in handled[0]["content"] @pytest.mark.asyncio async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None: # Failed attachment downloads should emit a readable placeholder and no media path. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path) await channel._on_message( _make_message( attachments=[_FakeAttachment(12, "photo.png", fail=True)], content="", ) ) assert len(handled) == 1 assert handled[0]["media"] == [] assert handled[0]["content"] == "[attachment: photo.png - download failed]" @pytest.mark.asyncio async def test_send_warns_when_client_not_ready() -> None: # Sending without a running/ready client should be a safe no-op. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) assert channel._typing_tasks == {} @pytest.mark.asyncio async def test_send_skips_when_channel_not_cached() -> None: # Outbound sends should be skipped when the destination channel is not resolvable. owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) client = DiscordBotClient(owner, intents=discord.Intents.none()) fetch_calls: list[int] = [] async def fetch_channel(channel_id: int): fetch_calls.append(channel_id) raise RuntimeError("not found") client.fetch_channel = fetch_channel # type: ignore[method-assign] await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello")) assert client.get_channel(123) is None assert fetch_calls == [123] @pytest.mark.asyncio async def test_send_fetches_channel_when_not_cached() -> None: owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) client = DiscordBotClient(owner, intents=discord.Intents.none()) target = _FakeChannel(channel_id=123) async def fetch_channel(channel_id: int): return target if channel_id == 123 else None client.fetch_channel = fetch_channel # type: ignore[method-assign] await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello")) assert target.sent_payloads == [{"content": "hello"}] @pytest.mark.asyncio async def test_slash_new_forwards_when_user_is_allowlisted() -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] client = DiscordBotClient(channel, intents=discord.Intents.none()) interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321) new_cmd = client.tree.get_command("new") assert new_cmd is not None await new_cmd.callback(interaction) assert interaction.response.messages == [ {"content": "Processing /new...", "ephemeral": True} ] assert len(handled) == 1 assert handled[0]["content"] == "/new" assert handled[0]["sender_id"] == "123" assert handled[0]["chat_id"] == "456" assert handled[0]["metadata"]["interaction_id"] == "321" assert handled[0]["metadata"]["is_slash_command"] is True @pytest.mark.asyncio async def test_slash_new_is_blocked_for_disallowed_user() -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] client = DiscordBotClient(channel, intents=discord.Intents.none()) interaction = _make_interaction(user_id=123, channel_id=456) new_cmd = client.tree.get_command("new") assert new_cmd is not None await new_cmd.callback(interaction) assert interaction.response.messages == [ {"content": "You are not allowed to use this bot.", "ephemeral": True} ] assert handled == [] @pytest.mark.parametrize("slash_name", ["stop", "restart", "status"]) @pytest.mark.asyncio async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] client = DiscordBotClient(channel, intents=discord.Intents.none()) interaction = _make_interaction() interaction.command.qualified_name = slash_name cmd = client.tree.get_command(slash_name) assert cmd is not None await cmd.callback(interaction) assert interaction.response.messages == [ {"content": f"Processing /{slash_name}...", "ephemeral": True} ] assert len(handled) == 1 assert handled[0]["content"] == f"/{slash_name}" assert handled[0]["metadata"]["is_slash_command"] is True @pytest.mark.asyncio async def test_slash_help_returns_ephemeral_help_text() -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] client = DiscordBotClient(channel, intents=discord.Intents.none()) interaction = _make_interaction() interaction.command.qualified_name = "help" help_cmd = client.tree.get_command("help") assert help_cmd is not None await help_cmd.callback(interaction) assert interaction.response.messages == [ {"content": build_help_text(), "ephemeral": True} ] assert handled == [] @pytest.mark.asyncio async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None: # Outbound payloads should upload files, attach reply references, and chunk long text. owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) client = DiscordBotClient(owner, intents=discord.Intents.none()) target = _FakeChannel(channel_id=123) client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign] file_path = tmp_path / "demo.txt" file_path.write_text("hi") await client.send_outbound( OutboundMessage( channel="discord", chat_id="123", content="a" * 2100, reply_to="55", media=[str(file_path)], ) ) assert len(target.sent_payloads) == 3 assert target.sent_payloads[0]["file_name"] == "demo.txt" assert target.sent_payloads[0]["reference"].id == 55 assert target.sent_payloads[1]["content"] == "a" * 2000 assert target.sent_payloads[2]["content"] == "a" * 100 @pytest.mark.asyncio async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None: # If all attachment sends fail and no text exists, emit a failure placeholder message. owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) client = DiscordBotClient(owner, intents=discord.Intents.none()) target = _FakeChannel(channel_id=123) client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign] missing_file = tmp_path / "missing.txt" await client.send_outbound( OutboundMessage( channel="discord", chat_id="123", content="", media=[str(missing_file)], ) ) assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}] @pytest.mark.asyncio async def test_send_stops_typing_after_send() -> None: # Active typing indicators should be cancelled/cleared after a successful send. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) client = _FakeDiscordClient(channel, intents=None) channel._client = client channel._running = True start = asyncio.Event() release = asyncio.Event() async def slow_typing() -> None: start.set() await release.wait() typing_channel = _FakeChannel(channel_id=123) typing_channel.typing_enter_hook = slow_typing await channel._start_typing(typing_channel) await asyncio.wait_for(start.wait(), timeout=1.0) await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) release.set() await asyncio.sleep(0) assert channel._typing_tasks == {} # Progress messages should keep typing active until a final (non-progress) send. start = asyncio.Event() release = asyncio.Event() async def slow_typing_progress() -> None: start.set() await release.wait() typing_channel = _FakeChannel(channel_id=123) typing_channel.typing_enter_hook = slow_typing_progress await channel._start_typing(typing_channel) await asyncio.wait_for(start.wait(), timeout=1.0) await channel.send( OutboundMessage( channel="discord", chat_id="123", content="progress", metadata={"_progress": True}, ) ) assert "123" in channel._typing_tasks await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final")) release.set() await asyncio.sleep(0) assert channel._typing_tasks == {} @pytest.mark.asyncio async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) channel._running = True entered = asyncio.Event() release = asyncio.Event() class _TypingCtx: async def __aenter__(self): entered.set() async def __aexit__(self, exc_type, exc, tb): return False class _NoTriggerChannel: def __init__(self, channel_id: int = 123) -> None: self.id = channel_id def typing(self): async def _waiter(): await release.wait() # Hold the loop so task remains active until explicitly stopped. class _Ctx(_TypingCtx): async def __aenter__(self): await super().__aenter__() await _waiter() return _Ctx() typing_channel = _NoTriggerChannel(channel_id=123) await channel._start_typing(typing_channel) # type: ignore[arg-type] await asyncio.wait_for(entered.wait(), timeout=1.0) assert "123" in channel._typing_tasks await channel._stop_typing("123") release.set() await asyncio.sleep(0) assert channel._typing_tasks == {}