nanobot/tests/channels/test_discord_channel.py

677 lines
22 KiB
Python

from __future__ import annotations
import asyncio
from pathlib import Path
from types import SimpleNamespace
import pytest
discord = pytest.importorskip("discord")
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig
from nanobot.command.builtin import build_help_text
# Minimal Discord client test double used to control startup/readiness behavior.
class _FakeDiscordClient:
instances: list["_FakeDiscordClient"] = []
start_error: Exception | None = None
def __init__(self, owner, *, intents) -> None:
self.owner = owner
self.intents = intents
self.closed = False
self.ready = True
self.channels: dict[int, object] = {}
self.user = SimpleNamespace(id=999)
self.__class__.instances.append(self)
async def start(self, token: str) -> None:
self.token = token
if self.__class__.start_error is not None:
raise self.__class__.start_error
async def close(self) -> None:
self.closed = True
def is_closed(self) -> bool:
return self.closed
def is_ready(self) -> bool:
return self.ready
def get_channel(self, channel_id: int):
return self.channels.get(channel_id)
async def send_outbound(self, msg: OutboundMessage) -> None:
channel = self.get_channel(int(msg.chat_id))
if channel is None:
return
await channel.send(content=msg.content)
class _FakeAttachment:
# Attachment double that can simulate successful or failing save() calls.
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
self.id = attachment_id
self.filename = filename
self.size = size
self._fail = fail
async def save(self, path: str | Path) -> None:
if self._fail:
raise RuntimeError("save failed")
Path(path).write_bytes(b"attachment")
class _FakePartialMessage:
# Lightweight stand-in for Discord partial message references used in replies.
def __init__(self, message_id: int) -> None:
self.id = message_id
class _FakeChannel:
# Channel double that records outbound payloads and typing activity.
def __init__(self, channel_id: int = 123) -> None:
self.id = channel_id
self.sent_payloads: list[dict] = []
self.trigger_typing_calls = 0
self.typing_enter_hook = None
async def send(self, **kwargs) -> None:
payload = dict(kwargs)
if "file" in payload:
payload["file_name"] = payload["file"].filename
del payload["file"]
self.sent_payloads.append(payload)
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
return _FakePartialMessage(message_id)
def typing(self):
channel = self
class _TypingContext:
async def __aenter__(self):
channel.trigger_typing_calls += 1
if channel.typing_enter_hook is not None:
await channel.typing_enter_hook()
async def __aexit__(self, exc_type, exc, tb):
return False
return _TypingContext()
class _FakeInteractionResponse:
def __init__(self) -> None:
self.messages: list[dict] = []
self._done = False
async def send_message(self, content: str, *, ephemeral: bool = False) -> None:
self.messages.append({"content": content, "ephemeral": ephemeral})
self._done = True
def is_done(self) -> bool:
return self._done
def _make_interaction(
*,
user_id: int = 123,
channel_id: int | None = 456,
guild_id: int | None = None,
interaction_id: int = 999,
):
return SimpleNamespace(
user=SimpleNamespace(id=user_id),
channel_id=channel_id,
guild_id=guild_id,
id=interaction_id,
command=SimpleNamespace(qualified_name="new"),
response=_FakeInteractionResponse(),
)
def _make_message(
*,
author_id: int = 123,
author_bot: bool = False,
channel_id: int = 456,
message_id: int = 789,
content: str = "hello",
guild_id: int | None = None,
mentions: list[object] | None = None,
attachments: list[object] | None = None,
reply_to: int | None = None,
):
# Factory for incoming Discord message objects with optional guild/reply/attachments.
guild = SimpleNamespace(id=guild_id) if guild_id is not None else None
reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None
return SimpleNamespace(
author=SimpleNamespace(id=author_id, bot=author_bot),
channel=_FakeChannel(channel_id),
content=content,
guild=guild,
mentions=mentions or [],
attachments=attachments or [],
reference=reference,
id=message_id,
)
@pytest.mark.asyncio
async def test_start_returns_when_token_missing() -> None:
# If no token is configured, startup should no-op and leave channel stopped.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None:
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False)
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_handles_client_construction_failure(monkeypatch) -> None:
# Construction errors from the Discord client should be swallowed and keep state clean.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
def _boom(owner, *, intents):
raise RuntimeError("bad client")
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_handles_client_start_failure(monkeypatch) -> None:
# If client.start fails, the partially created client should be closed and detached.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
_FakeDiscordClient.instances.clear()
_FakeDiscordClient.start_error = RuntimeError("connect failed")
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
await channel.start()
assert channel.is_running is False
assert channel._client is None
assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents
assert _FakeDiscordClient.instances[0].closed is True
_FakeDiscordClient.start_error = None
@pytest.mark.asyncio
async def test_stop_is_safe_after_partial_start(monkeypatch) -> None:
# stop() should close/discard the client even when startup was only partially completed.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
await channel.stop()
assert channel.is_running is False
assert client.closed is True
assert channel._client is None
@pytest.mark.asyncio
async def test_on_message_ignores_bot_messages() -> None:
# Incoming bot-authored messages must be ignored to prevent feedback loops.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign]
await channel._on_message(_make_message(author_bot=True))
assert handled == []
# If inbound handling raises, typing should be stopped for that channel.
async def fail_handle(**kwargs) -> None:
raise RuntimeError("boom")
channel._handle_message = fail_handle # type: ignore[method-assign]
with pytest.raises(RuntimeError, match="boom"):
await channel._on_message(_make_message(author_id=123, channel_id=456))
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_on_message_accepts_allowlisted_dm() -> None:
# Allowed direct messages should be forwarded with normalized metadata.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789))
assert len(handled) == 1
assert handled[0]["chat_id"] == "456"
assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None}
@pytest.mark.asyncio
async def test_on_message_ignores_unmentioned_guild_message() -> None:
# With mention-only group policy, guild messages without a bot mention are dropped.
channel = DiscordChannel(
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._bot_user_id = "999"
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(_make_message(guild_id=1, content="hello everyone"))
assert handled == []
@pytest.mark.asyncio
async def test_on_message_accepts_mentioned_guild_message() -> None:
# Mentioned guild messages should be accepted and preserve reply threading metadata.
channel = DiscordChannel(
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._bot_user_id = "999"
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(
_make_message(
guild_id=1,
content="<@999> hello",
mentions=[SimpleNamespace(id=999)],
reply_to=321,
)
)
assert len(handled) == 1
assert handled[0]["metadata"]["reply_to"] == "321"
@pytest.mark.asyncio
async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None:
# Attachment downloads should be saved and referenced in forwarded content/media.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
await channel._on_message(
_make_message(
attachments=[_FakeAttachment(12, "photo.png")],
content="see file",
)
)
assert len(handled) == 1
assert handled[0]["media"] == [str(tmp_path / "12_photo.png")]
assert "[attachment:" in handled[0]["content"]
@pytest.mark.asyncio
async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None:
# Failed attachment downloads should emit a readable placeholder and no media path.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
await channel._on_message(
_make_message(
attachments=[_FakeAttachment(12, "photo.png", fail=True)],
content="",
)
)
assert len(handled) == 1
assert handled[0]["media"] == []
assert handled[0]["content"] == "[attachment: photo.png - download failed]"
@pytest.mark.asyncio
async def test_send_warns_when_client_not_ready() -> None:
# Sending without a running/ready client should be a safe no-op.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_send_skips_when_channel_not_cached() -> None:
# Outbound sends should be skipped when the destination channel is not resolvable.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
fetch_calls: list[int] = []
async def fetch_channel(channel_id: int):
fetch_calls.append(channel_id)
raise RuntimeError("not found")
client.fetch_channel = fetch_channel # type: ignore[method-assign]
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert client.get_channel(123) is None
assert fetch_calls == [123]
@pytest.mark.asyncio
async def test_send_fetches_channel_when_not_cached() -> None:
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
async def fetch_channel(channel_id: int):
return target if channel_id == 123 else None
client.fetch_channel = fetch_channel # type: ignore[method-assign]
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert target.sent_payloads == [{"content": "hello"}]
@pytest.mark.asyncio
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321)
new_cmd = client.tree.get_command("new")
assert new_cmd is not None
await new_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": "Processing /new...", "ephemeral": True}
]
assert len(handled) == 1
assert handled[0]["content"] == "/new"
assert handled[0]["sender_id"] == "123"
assert handled[0]["chat_id"] == "456"
assert handled[0]["metadata"]["interaction_id"] == "321"
assert handled[0]["metadata"]["is_slash_command"] is True
@pytest.mark.asyncio
async def test_slash_new_is_blocked_for_disallowed_user() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction(user_id=123, channel_id=456)
new_cmd = client.tree.get_command("new")
assert new_cmd is not None
await new_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": "You are not allowed to use this bot.", "ephemeral": True}
]
assert handled == []
@pytest.mark.parametrize("slash_name", ["stop", "restart", "status"])
@pytest.mark.asyncio
async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction()
interaction.command.qualified_name = slash_name
cmd = client.tree.get_command(slash_name)
assert cmd is not None
await cmd.callback(interaction)
assert interaction.response.messages == [
{"content": f"Processing /{slash_name}...", "ephemeral": True}
]
assert len(handled) == 1
assert handled[0]["content"] == f"/{slash_name}"
assert handled[0]["metadata"]["is_slash_command"] is True
@pytest.mark.asyncio
async def test_slash_help_returns_ephemeral_help_text() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction()
interaction.command.qualified_name = "help"
help_cmd = client.tree.get_command("help")
assert help_cmd is not None
await help_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": build_help_text(), "ephemeral": True}
]
assert handled == []
@pytest.mark.asyncio
async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None:
# Outbound payloads should upload files, attach reply references, and chunk long text.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
file_path = tmp_path / "demo.txt"
file_path.write_text("hi")
await client.send_outbound(
OutboundMessage(
channel="discord",
chat_id="123",
content="a" * 2100,
reply_to="55",
media=[str(file_path)],
)
)
assert len(target.sent_payloads) == 3
assert target.sent_payloads[0]["file_name"] == "demo.txt"
assert target.sent_payloads[0]["reference"].id == 55
assert target.sent_payloads[1]["content"] == "a" * 2000
assert target.sent_payloads[2]["content"] == "a" * 100
@pytest.mark.asyncio
async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None:
# If all attachment sends fail and no text exists, emit a failure placeholder message.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
missing_file = tmp_path / "missing.txt"
await client.send_outbound(
OutboundMessage(
channel="discord",
chat_id="123",
content="",
media=[str(missing_file)],
)
)
assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}]
@pytest.mark.asyncio
async def test_send_stops_typing_after_send() -> None:
# Active typing indicators should be cancelled/cleared after a successful send.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
start = asyncio.Event()
release = asyncio.Event()
async def slow_typing() -> None:
start.set()
await release.wait()
typing_channel = _FakeChannel(channel_id=123)
typing_channel.typing_enter_hook = slow_typing
await channel._start_typing(typing_channel)
await asyncio.wait_for(start.wait(), timeout=1.0)
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}
# Progress messages should keep typing active until a final (non-progress) send.
start = asyncio.Event()
release = asyncio.Event()
async def slow_typing_progress() -> None:
start.set()
await release.wait()
typing_channel = _FakeChannel(channel_id=123)
typing_channel.typing_enter_hook = slow_typing_progress
await channel._start_typing(typing_channel)
await asyncio.wait_for(start.wait(), timeout=1.0)
await channel.send(
OutboundMessage(
channel="discord",
chat_id="123",
content="progress",
metadata={"_progress": True},
)
)
assert "123" in channel._typing_tasks
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final"))
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
channel._running = True
entered = asyncio.Event()
release = asyncio.Event()
class _TypingCtx:
async def __aenter__(self):
entered.set()
async def __aexit__(self, exc_type, exc, tb):
return False
class _NoTriggerChannel:
def __init__(self, channel_id: int = 123) -> None:
self.id = channel_id
def typing(self):
async def _waiter():
await release.wait()
# Hold the loop so task remains active until explicitly stopped.
class _Ctx(_TypingCtx):
async def __aenter__(self):
await super().__aenter__()
await _waiter()
return _Ctx()
typing_channel = _NoTriggerChannel(channel_id=123)
await channel._start_typing(typing_channel) # type: ignore[arg-type]
await asyncio.wait_for(entered.wait(), timeout=1.0)
assert "123" in channel._typing_tasks
await channel._stop_typing("123")
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}