from __future__ import annotations import asyncio from pathlib import Path from types import SimpleNamespace import discord import pytest from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig from nanobot.command.builtin import build_help_text # Minimal Discord client test double used to control startup/readiness behavior. class _FakeDiscordClient: instances: list["_FakeDiscordClient"] = [] start_error: Exception | None = None def __init__(self, owner, *, intents) -> None: self.owner = owner self.intents = intents self.closed = False self.ready = True self.channels: dict[int, object] = {} self.user = SimpleNamespace(id=999) self.__class__.instances.append(self) async def start(self, token: str) -> None: self.token = token if self.__class__.start_error is not None: raise self.__class__.start_error async def close(self) -> None: self.closed = True def is_closed(self) -> bool: return self.closed def is_ready(self) -> bool: return self.ready def get_channel(self, channel_id: int): return self.channels.get(channel_id) async def send_outbound(self, msg: OutboundMessage) -> None: channel = self.get_channel(int(msg.chat_id)) if channel is None: return await channel.send(content=msg.content) class _FakeAttachment: # Attachment double that can simulate successful or failing save() calls. def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None: self.id = attachment_id self.filename = filename self.size = size self._fail = fail async def save(self, path: str | Path) -> None: if self._fail: raise RuntimeError("save failed") Path(path).write_bytes(b"attachment") class _FakePartialMessage: # Lightweight stand-in for Discord partial message references used in replies. def __init__(self, message_id: int) -> None: self.id = message_id class _FakeChannel: # Channel double that records outbound payloads and typing activity. def __init__(self, channel_id: int = 123) -> None: self.id = channel_id self.sent_payloads: list[dict] = [] self.trigger_typing_calls = 0 self.typing_enter_hook = None async def send(self, **kwargs) -> None: payload = dict(kwargs) if "file" in payload: payload["file_name"] = payload["file"].filename del payload["file"] self.sent_payloads.append(payload) def get_partial_message(self, message_id: int) -> _FakePartialMessage: return _FakePartialMessage(message_id) def typing(self): channel = self class _TypingContext: async def __aenter__(self): channel.trigger_typing_calls += 1 if channel.typing_enter_hook is not None: await channel.typing_enter_hook() async def __aexit__(self, exc_type, exc, tb): return False return _TypingContext() class _FakeInteractionResponse: def __init__(self) -> None: self.messages: list[dict] = [] self._done = False async def send_message(self, content: str, *, ephemeral: bool = False) -> None: self.messages.append({"content": content, "ephemeral": ephemeral}) self._done = True def is_done(self) -> bool: return self._done def _make_interaction( *, user_id: int = 123, channel_id: int | None = 456, guild_id: int | None = None, interaction_id: int = 999, ): return SimpleNamespace( user=SimpleNamespace(id=user_id), channel_id=channel_id, guild_id=guild_id, id=interaction_id, command=SimpleNamespace(qualified_name="new"), response=_FakeInteractionResponse(), ) def _make_message( *, author_id: int = 123, author_bot: bool = False, channel_id: int = 456, message_id: int = 789, content: str = "hello", guild_id: int | None = None, mentions: list[object] | None = None, attachments: list[object] | None = None, reply_to: int | None = None, ): # Factory for incoming Discord message objects with optional guild/reply/attachments. guild = SimpleNamespace(id=guild_id) if guild_id is not None else None reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None return SimpleNamespace( author=SimpleNamespace(id=author_id, bot=author_bot), channel=_FakeChannel(channel_id), content=content, guild=guild, mentions=mentions or [], attachments=attachments or [], reference=reference, id=message_id, ) @pytest.mark.asyncio async def test_start_returns_when_token_missing() -> None: # If no token is configured, startup should no-op and leave channel stopped. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) await channel.start() assert channel.is_running is False assert channel._client is None @pytest.mark.asyncio async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None: channel = DiscordChannel( DiscordConfig(enabled=True, token="token", allow_from=["*"]), MessageBus(), ) monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False) await channel.start() assert channel.is_running is False assert channel._client is None @pytest.mark.asyncio async def test_start_handles_client_construction_failure(monkeypatch) -> None: # Construction errors from the Discord client should be swallowed and keep state clean. channel = DiscordChannel( DiscordConfig(enabled=True, token="token", allow_from=["*"]), MessageBus(), ) def _boom(owner, *, intents): raise RuntimeError("bad client") monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom) await channel.start() assert channel.is_running is False assert channel._client is None @pytest.mark.asyncio async def test_start_handles_client_start_failure(monkeypatch) -> None: # If client.start fails, the partially created client should be closed and detached. channel = DiscordChannel( DiscordConfig(enabled=True, token="token", allow_from=["*"]), MessageBus(), ) _FakeDiscordClient.instances.clear() _FakeDiscordClient.start_error = RuntimeError("connect failed") monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient) await channel.start() assert channel.is_running is False assert channel._client is None assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents assert _FakeDiscordClient.instances[0].closed is True _FakeDiscordClient.start_error = None @pytest.mark.asyncio async def test_stop_is_safe_after_partial_start(monkeypatch) -> None: # stop() should close/discard the client even when startup was only partially completed. channel = DiscordChannel( DiscordConfig(enabled=True, token="token", allow_from=["*"]), MessageBus(), ) client = _FakeDiscordClient(channel, intents=None) channel._client = client channel._running = True await channel.stop() assert channel.is_running is False assert client.closed is True assert channel._client is None @pytest.mark.asyncio async def test_on_message_ignores_bot_messages() -> None: # Incoming bot-authored messages must be ignored to prevent feedback loops. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) handled: list[dict] = [] channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign] await channel._on_message(_make_message(author_bot=True)) assert handled == [] # If inbound handling raises, typing should be stopped for that channel. async def fail_handle(**kwargs) -> None: raise RuntimeError("boom") channel._handle_message = fail_handle # type: ignore[method-assign] with pytest.raises(RuntimeError, match="boom"): await channel._on_message(_make_message(author_id=123, channel_id=456)) assert channel._typing_tasks == {} @pytest.mark.asyncio async def test_on_message_accepts_allowlisted_dm() -> None: # Allowed direct messages should be forwarded with normalized metadata. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789)) assert len(handled) == 1 assert handled[0]["chat_id"] == "456" assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None} @pytest.mark.asyncio async def test_on_message_ignores_unmentioned_guild_message() -> None: # With mention-only group policy, guild messages without a bot mention are dropped. channel = DiscordChannel( DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"), MessageBus(), ) channel._bot_user_id = "999" handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] await channel._on_message(_make_message(guild_id=1, content="hello everyone")) assert handled == [] @pytest.mark.asyncio async def test_on_message_accepts_mentioned_guild_message() -> None: # Mentioned guild messages should be accepted and preserve reply threading metadata. channel = DiscordChannel( DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"), MessageBus(), ) channel._bot_user_id = "999" handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] await channel._on_message( _make_message( guild_id=1, content="<@999> hello", mentions=[SimpleNamespace(id=999)], reply_to=321, ) ) assert len(handled) == 1 assert handled[0]["metadata"]["reply_to"] == "321" @pytest.mark.asyncio async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None: # Attachment downloads should be saved and referenced in forwarded content/media. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path) await channel._on_message( _make_message( attachments=[_FakeAttachment(12, "photo.png")], content="see file", ) ) assert len(handled) == 1 assert handled[0]["media"] == [str(tmp_path / "12_photo.png")] assert "[attachment:" in handled[0]["content"] @pytest.mark.asyncio async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None: # Failed attachment downloads should emit a readable placeholder and no media path. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path) await channel._on_message( _make_message( attachments=[_FakeAttachment(12, "photo.png", fail=True)], content="", ) ) assert len(handled) == 1 assert handled[0]["media"] == [] assert handled[0]["content"] == "[attachment: photo.png - download failed]" @pytest.mark.asyncio async def test_send_warns_when_client_not_ready() -> None: # Sending without a running/ready client should be a safe no-op. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) assert channel._typing_tasks == {} @pytest.mark.asyncio async def test_send_skips_when_channel_not_cached() -> None: # Outbound sends should be skipped when the destination channel is not resolvable. owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) client = DiscordBotClient(owner, intents=discord.Intents.none()) fetch_calls: list[int] = [] async def fetch_channel(channel_id: int): fetch_calls.append(channel_id) raise RuntimeError("not found") client.fetch_channel = fetch_channel # type: ignore[method-assign] await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello")) assert client.get_channel(123) is None assert fetch_calls == [123] @pytest.mark.asyncio async def test_send_fetches_channel_when_not_cached() -> None: owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) client = DiscordBotClient(owner, intents=discord.Intents.none()) target = _FakeChannel(channel_id=123) async def fetch_channel(channel_id: int): return target if channel_id == 123 else None client.fetch_channel = fetch_channel # type: ignore[method-assign] await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello")) assert target.sent_payloads == [{"content": "hello"}] @pytest.mark.asyncio async def test_slash_new_forwards_when_user_is_allowlisted() -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] client = DiscordBotClient(channel, intents=discord.Intents.none()) interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321) new_cmd = client.tree.get_command("new") assert new_cmd is not None await new_cmd.callback(interaction) assert interaction.response.messages == [ {"content": "Processing /new...", "ephemeral": True} ] assert len(handled) == 1 assert handled[0]["content"] == "/new" assert handled[0]["sender_id"] == "123" assert handled[0]["chat_id"] == "456" assert handled[0]["metadata"]["interaction_id"] == "321" assert handled[0]["metadata"]["is_slash_command"] is True @pytest.mark.asyncio async def test_slash_new_is_blocked_for_disallowed_user() -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] client = DiscordBotClient(channel, intents=discord.Intents.none()) interaction = _make_interaction(user_id=123, channel_id=456) new_cmd = client.tree.get_command("new") assert new_cmd is not None await new_cmd.callback(interaction) assert interaction.response.messages == [ {"content": "You are not allowed to use this bot.", "ephemeral": True} ] assert handled == [] @pytest.mark.parametrize("slash_name", ["stop", "restart", "status"]) @pytest.mark.asyncio async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] client = DiscordBotClient(channel, intents=discord.Intents.none()) interaction = _make_interaction() interaction.command.qualified_name = slash_name cmd = client.tree.get_command(slash_name) assert cmd is not None await cmd.callback(interaction) assert interaction.response.messages == [ {"content": f"Processing /{slash_name}...", "ephemeral": True} ] assert len(handled) == 1 assert handled[0]["content"] == f"/{slash_name}" assert handled[0]["metadata"]["is_slash_command"] is True @pytest.mark.asyncio async def test_slash_help_returns_ephemeral_help_text() -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) handled: list[dict] = [] async def capture_handle(**kwargs) -> None: handled.append(kwargs) channel._handle_message = capture_handle # type: ignore[method-assign] client = DiscordBotClient(channel, intents=discord.Intents.none()) interaction = _make_interaction() interaction.command.qualified_name = "help" help_cmd = client.tree.get_command("help") assert help_cmd is not None await help_cmd.callback(interaction) assert interaction.response.messages == [ {"content": build_help_text(), "ephemeral": True} ] assert handled == [] @pytest.mark.asyncio async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None: # Outbound payloads should upload files, attach reply references, and chunk long text. owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) client = DiscordBotClient(owner, intents=discord.Intents.none()) target = _FakeChannel(channel_id=123) client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign] file_path = tmp_path / "demo.txt" file_path.write_text("hi") await client.send_outbound( OutboundMessage( channel="discord", chat_id="123", content="a" * 2100, reply_to="55", media=[str(file_path)], ) ) assert len(target.sent_payloads) == 3 assert target.sent_payloads[0]["file_name"] == "demo.txt" assert target.sent_payloads[0]["reference"].id == 55 assert target.sent_payloads[1]["content"] == "a" * 2000 assert target.sent_payloads[2]["content"] == "a" * 100 @pytest.mark.asyncio async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None: # If all attachment sends fail and no text exists, emit a failure placeholder message. owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) client = DiscordBotClient(owner, intents=discord.Intents.none()) target = _FakeChannel(channel_id=123) client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign] missing_file = tmp_path / "missing.txt" await client.send_outbound( OutboundMessage( channel="discord", chat_id="123", content="", media=[str(missing_file)], ) ) assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}] @pytest.mark.asyncio async def test_send_stops_typing_after_send() -> None: # Active typing indicators should be cancelled/cleared after a successful send. channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) client = _FakeDiscordClient(channel, intents=None) channel._client = client channel._running = True start = asyncio.Event() release = asyncio.Event() async def slow_typing() -> None: start.set() await release.wait() typing_channel = _FakeChannel(channel_id=123) typing_channel.typing_enter_hook = slow_typing await channel._start_typing(typing_channel) await start.wait() await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) release.set() await asyncio.sleep(0) assert channel._typing_tasks == {} # Progress messages should keep typing active until a final (non-progress) send. start = asyncio.Event() release = asyncio.Event() async def slow_typing_progress() -> None: start.set() await release.wait() typing_channel = _FakeChannel(channel_id=123) typing_channel.typing_enter_hook = slow_typing_progress await channel._start_typing(typing_channel) await start.wait() await channel.send( OutboundMessage( channel="discord", chat_id="123", content="progress", metadata={"_progress": True}, ) ) assert "123" in channel._typing_tasks await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final")) release.set() await asyncio.sleep(0) assert channel._typing_tasks == {} @pytest.mark.asyncio async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) channel._running = True entered = asyncio.Event() release = asyncio.Event() class _TypingCtx: async def __aenter__(self): entered.set() async def __aexit__(self, exc_type, exc, tb): return False class _NoTriggerChannel: def __init__(self, channel_id: int = 123) -> None: self.id = channel_id def typing(self): async def _waiter(): await release.wait() # Hold the loop so task remains active until explicitly stopped. class _Ctx(_TypingCtx): async def __aenter__(self): await super().__aenter__() await _waiter() return _Ctx() typing_channel = _NoTriggerChannel(channel_id=123) await channel._start_typing(typing_channel) # type: ignore[arg-type] await entered.wait() assert "123" in channel._typing_tasks await channel._stop_typing("123") release.set() await asyncio.sleep(0) assert channel._typing_tasks == {}