nanobot/tests/channels/test_discord_channel.py
Paresh Mathur 3e25a853aa
feat(discord): Use discord.py for stable discord channel (#2486)
Co-authored-by: Pares Mathur <paresh.2047@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-27 09:51:45 +08:00

677 lines
22 KiB
Python

from __future__ import annotations
import asyncio
from pathlib import Path
from types import SimpleNamespace
import discord
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig
from nanobot.command.builtin import build_help_text
# Minimal Discord client test double used to control startup/readiness behavior.
class _FakeDiscordClient:
instances: list["_FakeDiscordClient"] = []
start_error: Exception | None = None
def __init__(self, owner, *, intents) -> None:
self.owner = owner
self.intents = intents
self.closed = False
self.ready = True
self.channels: dict[int, object] = {}
self.user = SimpleNamespace(id=999)
self.__class__.instances.append(self)
async def start(self, token: str) -> None:
self.token = token
if self.__class__.start_error is not None:
raise self.__class__.start_error
async def close(self) -> None:
self.closed = True
def is_closed(self) -> bool:
return self.closed
def is_ready(self) -> bool:
return self.ready
def get_channel(self, channel_id: int):
return self.channels.get(channel_id)
async def send_outbound(self, msg: OutboundMessage) -> None:
channel = self.get_channel(int(msg.chat_id))
if channel is None:
return
await channel.send(content=msg.content)
class _FakeAttachment:
# Attachment double that can simulate successful or failing save() calls.
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
self.id = attachment_id
self.filename = filename
self.size = size
self._fail = fail
async def save(self, path: str | Path) -> None:
if self._fail:
raise RuntimeError("save failed")
Path(path).write_bytes(b"attachment")
class _FakePartialMessage:
# Lightweight stand-in for Discord partial message references used in replies.
def __init__(self, message_id: int) -> None:
self.id = message_id
class _FakeChannel:
# Channel double that records outbound payloads and typing activity.
def __init__(self, channel_id: int = 123) -> None:
self.id = channel_id
self.sent_payloads: list[dict] = []
self.trigger_typing_calls = 0
self.typing_enter_hook = None
async def send(self, **kwargs) -> None:
payload = dict(kwargs)
if "file" in payload:
payload["file_name"] = payload["file"].filename
del payload["file"]
self.sent_payloads.append(payload)
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
return _FakePartialMessage(message_id)
def typing(self):
channel = self
class _TypingContext:
async def __aenter__(self):
channel.trigger_typing_calls += 1
if channel.typing_enter_hook is not None:
await channel.typing_enter_hook()
async def __aexit__(self, exc_type, exc, tb):
return False
return _TypingContext()
class _FakeInteractionResponse:
def __init__(self) -> None:
self.messages: list[dict] = []
self._done = False
async def send_message(self, content: str, *, ephemeral: bool = False) -> None:
self.messages.append({"content": content, "ephemeral": ephemeral})
self._done = True
def is_done(self) -> bool:
return self._done
def _make_interaction(
*,
user_id: int = 123,
channel_id: int | None = 456,
guild_id: int | None = None,
interaction_id: int = 999,
):
return SimpleNamespace(
user=SimpleNamespace(id=user_id),
channel_id=channel_id,
guild_id=guild_id,
id=interaction_id,
command=SimpleNamespace(qualified_name="new"),
response=_FakeInteractionResponse(),
)
def _make_message(
*,
author_id: int = 123,
author_bot: bool = False,
channel_id: int = 456,
message_id: int = 789,
content: str = "hello",
guild_id: int | None = None,
mentions: list[object] | None = None,
attachments: list[object] | None = None,
reply_to: int | None = None,
):
# Factory for incoming Discord message objects with optional guild/reply/attachments.
guild = SimpleNamespace(id=guild_id) if guild_id is not None else None
reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None
return SimpleNamespace(
author=SimpleNamespace(id=author_id, bot=author_bot),
channel=_FakeChannel(channel_id),
content=content,
guild=guild,
mentions=mentions or [],
attachments=attachments or [],
reference=reference,
id=message_id,
)
@pytest.mark.asyncio
async def test_start_returns_when_token_missing() -> None:
# If no token is configured, startup should no-op and leave channel stopped.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None:
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False)
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_handles_client_construction_failure(monkeypatch) -> None:
# Construction errors from the Discord client should be swallowed and keep state clean.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
def _boom(owner, *, intents):
raise RuntimeError("bad client")
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_handles_client_start_failure(monkeypatch) -> None:
# If client.start fails, the partially created client should be closed and detached.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
_FakeDiscordClient.instances.clear()
_FakeDiscordClient.start_error = RuntimeError("connect failed")
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
await channel.start()
assert channel.is_running is False
assert channel._client is None
assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents
assert _FakeDiscordClient.instances[0].closed is True
_FakeDiscordClient.start_error = None
@pytest.mark.asyncio
async def test_stop_is_safe_after_partial_start(monkeypatch) -> None:
# stop() should close/discard the client even when startup was only partially completed.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
await channel.stop()
assert channel.is_running is False
assert client.closed is True
assert channel._client is None
@pytest.mark.asyncio
async def test_on_message_ignores_bot_messages() -> None:
# Incoming bot-authored messages must be ignored to prevent feedback loops.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign]
await channel._on_message(_make_message(author_bot=True))
assert handled == []
# If inbound handling raises, typing should be stopped for that channel.
async def fail_handle(**kwargs) -> None:
raise RuntimeError("boom")
channel._handle_message = fail_handle # type: ignore[method-assign]
with pytest.raises(RuntimeError, match="boom"):
await channel._on_message(_make_message(author_id=123, channel_id=456))
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_on_message_accepts_allowlisted_dm() -> None:
# Allowed direct messages should be forwarded with normalized metadata.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789))
assert len(handled) == 1
assert handled[0]["chat_id"] == "456"
assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None}
@pytest.mark.asyncio
async def test_on_message_ignores_unmentioned_guild_message() -> None:
# With mention-only group policy, guild messages without a bot mention are dropped.
channel = DiscordChannel(
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._bot_user_id = "999"
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(_make_message(guild_id=1, content="hello everyone"))
assert handled == []
@pytest.mark.asyncio
async def test_on_message_accepts_mentioned_guild_message() -> None:
# Mentioned guild messages should be accepted and preserve reply threading metadata.
channel = DiscordChannel(
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._bot_user_id = "999"
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(
_make_message(
guild_id=1,
content="<@999> hello",
mentions=[SimpleNamespace(id=999)],
reply_to=321,
)
)
assert len(handled) == 1
assert handled[0]["metadata"]["reply_to"] == "321"
@pytest.mark.asyncio
async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None:
# Attachment downloads should be saved and referenced in forwarded content/media.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
await channel._on_message(
_make_message(
attachments=[_FakeAttachment(12, "photo.png")],
content="see file",
)
)
assert len(handled) == 1
assert handled[0]["media"] == [str(tmp_path / "12_photo.png")]
assert "[attachment:" in handled[0]["content"]
@pytest.mark.asyncio
async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None:
# Failed attachment downloads should emit a readable placeholder and no media path.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
await channel._on_message(
_make_message(
attachments=[_FakeAttachment(12, "photo.png", fail=True)],
content="",
)
)
assert len(handled) == 1
assert handled[0]["media"] == []
assert handled[0]["content"] == "[attachment: photo.png - download failed]"
@pytest.mark.asyncio
async def test_send_warns_when_client_not_ready() -> None:
# Sending without a running/ready client should be a safe no-op.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_send_skips_when_channel_not_cached() -> None:
# Outbound sends should be skipped when the destination channel is not resolvable.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
fetch_calls: list[int] = []
async def fetch_channel(channel_id: int):
fetch_calls.append(channel_id)
raise RuntimeError("not found")
client.fetch_channel = fetch_channel # type: ignore[method-assign]
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert client.get_channel(123) is None
assert fetch_calls == [123]
@pytest.mark.asyncio
async def test_send_fetches_channel_when_not_cached() -> None:
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
async def fetch_channel(channel_id: int):
return target if channel_id == 123 else None
client.fetch_channel = fetch_channel # type: ignore[method-assign]
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert target.sent_payloads == [{"content": "hello"}]
@pytest.mark.asyncio
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321)
new_cmd = client.tree.get_command("new")
assert new_cmd is not None
await new_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": "Processing /new...", "ephemeral": True}
]
assert len(handled) == 1
assert handled[0]["content"] == "/new"
assert handled[0]["sender_id"] == "123"
assert handled[0]["chat_id"] == "456"
assert handled[0]["metadata"]["interaction_id"] == "321"
assert handled[0]["metadata"]["is_slash_command"] is True
@pytest.mark.asyncio
async def test_slash_new_is_blocked_for_disallowed_user() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction(user_id=123, channel_id=456)
new_cmd = client.tree.get_command("new")
assert new_cmd is not None
await new_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": "You are not allowed to use this bot.", "ephemeral": True}
]
assert handled == []
@pytest.mark.parametrize("slash_name", ["stop", "restart", "status"])
@pytest.mark.asyncio
async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction()
interaction.command.qualified_name = slash_name
cmd = client.tree.get_command(slash_name)
assert cmd is not None
await cmd.callback(interaction)
assert interaction.response.messages == [
{"content": f"Processing /{slash_name}...", "ephemeral": True}
]
assert len(handled) == 1
assert handled[0]["content"] == f"/{slash_name}"
assert handled[0]["metadata"]["is_slash_command"] is True
@pytest.mark.asyncio
async def test_slash_help_returns_ephemeral_help_text() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction()
interaction.command.qualified_name = "help"
help_cmd = client.tree.get_command("help")
assert help_cmd is not None
await help_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": build_help_text(), "ephemeral": True}
]
assert handled == []
@pytest.mark.asyncio
async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None:
# Outbound payloads should upload files, attach reply references, and chunk long text.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
file_path = tmp_path / "demo.txt"
file_path.write_text("hi")
await client.send_outbound(
OutboundMessage(
channel="discord",
chat_id="123",
content="a" * 2100,
reply_to="55",
media=[str(file_path)],
)
)
assert len(target.sent_payloads) == 3
assert target.sent_payloads[0]["file_name"] == "demo.txt"
assert target.sent_payloads[0]["reference"].id == 55
assert target.sent_payloads[1]["content"] == "a" * 2000
assert target.sent_payloads[2]["content"] == "a" * 100
@pytest.mark.asyncio
async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None:
# If all attachment sends fail and no text exists, emit a failure placeholder message.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
missing_file = tmp_path / "missing.txt"
await client.send_outbound(
OutboundMessage(
channel="discord",
chat_id="123",
content="",
media=[str(missing_file)],
)
)
assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}]
@pytest.mark.asyncio
async def test_send_stops_typing_after_send() -> None:
# Active typing indicators should be cancelled/cleared after a successful send.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
start = asyncio.Event()
release = asyncio.Event()
async def slow_typing() -> None:
start.set()
await release.wait()
typing_channel = _FakeChannel(channel_id=123)
typing_channel.typing_enter_hook = slow_typing
await channel._start_typing(typing_channel)
await start.wait()
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}
# Progress messages should keep typing active until a final (non-progress) send.
start = asyncio.Event()
release = asyncio.Event()
async def slow_typing_progress() -> None:
start.set()
await release.wait()
typing_channel = _FakeChannel(channel_id=123)
typing_channel.typing_enter_hook = slow_typing_progress
await channel._start_typing(typing_channel)
await start.wait()
await channel.send(
OutboundMessage(
channel="discord",
chat_id="123",
content="progress",
metadata={"_progress": True},
)
)
assert "123" in channel._typing_tasks
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final"))
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
channel._running = True
entered = asyncio.Event()
release = asyncio.Event()
class _TypingCtx:
async def __aenter__(self):
entered.set()
async def __aexit__(self, exc_type, exc, tb):
return False
class _NoTriggerChannel:
def __init__(self, channel_id: int = 123) -> None:
self.id = channel_id
def typing(self):
async def _waiter():
await release.wait()
# Hold the loop so task remains active until explicitly stopped.
class _Ctx(_TypingCtx):
async def __aenter__(self):
await super().__aenter__()
await _waiter()
return _Ctx()
typing_channel = _NoTriggerChannel(channel_id=123)
await channel._start_typing(typing_channel) # type: ignore[arg-type]
await entered.wait()
assert "123" in channel._typing_tasks
await channel._stop_typing("123")
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}