diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index 82eafcc00..ef7d41d77 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -1,25 +1,37 @@ -"""Discord channel implementation using Discord Gateway websocket.""" +"""Discord channel implementation using discord.py.""" + +from __future__ import annotations import asyncio -import json +import importlib.util from pathlib import Path -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal -import httpx -from pydantic import Field -import websockets from loguru import logger +from pydantic import Field from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel +from nanobot.command.builtin import build_help_text from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base -from nanobot.utils.helpers import split_message +from nanobot.utils.helpers import safe_filename, split_message + +DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None +if TYPE_CHECKING: + import discord + from discord import app_commands + from discord.abc import Messageable + +if DISCORD_AVAILABLE: + import discord + from discord import app_commands + from discord.abc import Messageable -DISCORD_API_BASE = "https://discord.com/api/v10" MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB MAX_MESSAGE_LEN = 2000 # Discord message character limit +TYPING_INTERVAL_S = 8 class DiscordConfig(Base): @@ -28,13 +40,202 @@ class DiscordConfig(Base): enabled: bool = False token: str = "" allow_from: list[str] = Field(default_factory=list) - gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json" intents: int = 37377 group_policy: Literal["mention", "open"] = "mention" +if DISCORD_AVAILABLE: + + class DiscordBotClient(discord.Client): + """discord.py client that forwards events to the channel.""" + + def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None: + super().__init__(intents=intents) + self._channel = channel + self.tree = app_commands.CommandTree(self) + self._register_app_commands() + + async def on_ready(self) -> None: + self._channel._bot_user_id = str(self.user.id) if self.user else None + logger.info("Discord bot connected as user {}", self._channel._bot_user_id) + try: + synced = await self.tree.sync() + logger.info("Discord app commands synced: {}", len(synced)) + except Exception as e: + logger.warning("Discord app command sync failed: {}", e) + + async def on_message(self, message: discord.Message) -> None: + await self._channel._handle_discord_message(message) + + async def _reply_ephemeral(self, interaction: discord.Interaction, text: str) -> bool: + """Send an ephemeral interaction response and report success.""" + try: + await interaction.response.send_message(text, ephemeral=True) + return True + except Exception as e: + logger.warning("Discord interaction response failed: {}", e) + return False + + async def _forward_slash_command( + self, + interaction: discord.Interaction, + command_text: str, + ) -> None: + sender_id = str(interaction.user.id) + channel_id = interaction.channel_id + + if channel_id is None: + logger.warning("Discord slash command missing channel_id: {}", command_text) + return + + if not self._channel.is_allowed(sender_id): + await self._reply_ephemeral(interaction, "You are not allowed to use this bot.") + return + + await self._reply_ephemeral(interaction, f"Processing {command_text}...") + + await self._channel._handle_message( + sender_id=sender_id, + chat_id=str(channel_id), + content=command_text, + metadata={ + "interaction_id": str(interaction.id), + "guild_id": str(interaction.guild_id) if interaction.guild_id else None, + "is_slash_command": True, + }, + ) + + def _register_app_commands(self) -> None: + commands = ( + ("new", "Start a new conversation", "/new"), + ("stop", "Stop the current task", "/stop"), + ("restart", "Restart the bot", "/restart"), + ("status", "Show bot status", "/status"), + ) + + for name, description, command_text in commands: + @self.tree.command(name=name, description=description) + async def command_handler( + interaction: discord.Interaction, + _command_text: str = command_text, + ) -> None: + await self._forward_slash_command(interaction, _command_text) + + @self.tree.command(name="help", description="Show available commands") + async def help_command(interaction: discord.Interaction) -> None: + sender_id = str(interaction.user.id) + if not self._channel.is_allowed(sender_id): + await self._reply_ephemeral(interaction, "You are not allowed to use this bot.") + return + await self._reply_ephemeral(interaction, build_help_text()) + + @self.tree.error + async def on_app_command_error( + interaction: discord.Interaction, + error: app_commands.AppCommandError, + ) -> None: + command_name = interaction.command.qualified_name if interaction.command else "?" + logger.warning( + "Discord app command failed user={} channel={} cmd={} error={}", + interaction.user.id, + interaction.channel_id, + command_name, + error, + ) + + async def send_outbound(self, msg: OutboundMessage) -> None: + """Send a nanobot outbound message using Discord transport rules.""" + channel_id = int(msg.chat_id) + + channel = self.get_channel(channel_id) + if channel is None: + try: + channel = await self.fetch_channel(channel_id) + except Exception as e: + logger.warning("Discord channel {} unavailable: {}", msg.chat_id, e) + return + + reference, mention_settings = self._build_reply_context(channel, msg.reply_to) + sent_media = False + failed_media: list[str] = [] + + for index, media_path in enumerate(msg.media or []): + if await self._send_file( + channel, + media_path, + reference=reference if index == 0 else None, + mention_settings=mention_settings, + ): + sent_media = True + else: + failed_media.append(Path(media_path).name) + + for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)): + kwargs: dict[str, Any] = {"content": chunk} + if index == 0 and reference is not None and not sent_media: + kwargs["reference"] = reference + kwargs["allowed_mentions"] = mention_settings + await channel.send(**kwargs) + + async def _send_file( + self, + channel: Messageable, + file_path: str, + *, + reference: discord.PartialMessage | None, + mention_settings: discord.AllowedMentions, + ) -> bool: + """Send a file attachment via discord.py.""" + path = Path(file_path) + if not path.is_file(): + logger.warning("Discord file not found, skipping: {}", file_path) + return False + + if path.stat().st_size > MAX_ATTACHMENT_BYTES: + logger.warning("Discord file too large (>20MB), skipping: {}", path.name) + return False + + try: + kwargs: dict[str, Any] = {"file": discord.File(path)} + if reference is not None: + kwargs["reference"] = reference + kwargs["allowed_mentions"] = mention_settings + await channel.send(**kwargs) + logger.info("Discord file sent: {}", path.name) + return True + except Exception as e: + logger.error("Error sending Discord file {}: {}", path.name, e) + return False + + @staticmethod + def _build_chunks(content: str, failed_media: list[str], sent_media: bool) -> list[str]: + """Build outbound text chunks, including attachment-failure fallback text.""" + chunks = split_message(content, MAX_MESSAGE_LEN) + if chunks or not failed_media or sent_media: + return chunks + fallback = "\n".join(f"[attachment: {name} - send failed]" for name in failed_media) + return split_message(fallback, MAX_MESSAGE_LEN) + + @staticmethod + def _build_reply_context( + channel: Messageable, + reply_to: str | None, + ) -> tuple[discord.PartialMessage | None, discord.AllowedMentions]: + """Build reply context for outbound messages.""" + mention_settings = discord.AllowedMentions(replied_user=False) + if not reply_to: + return None, mention_settings + try: + message_id = int(reply_to) + except ValueError: + logger.warning("Invalid Discord reply target: {}", reply_to) + return None, mention_settings + + return channel.get_partial_message(message_id), mention_settings + + class DiscordChannel(BaseChannel): - """Discord channel using Gateway websocket.""" + """Discord channel using discord.py.""" name = "discord" display_name = "Discord" @@ -43,353 +244,229 @@ class DiscordChannel(BaseChannel): def default_config(cls) -> dict[str, Any]: return DiscordConfig().model_dump(by_alias=True) + @staticmethod + def _channel_key(channel_or_id: Any) -> str: + """Normalize channel-like objects and ids to a stable string key.""" + channel_id = getattr(channel_or_id, "id", channel_or_id) + return str(channel_id) + def __init__(self, config: Any, bus: MessageBus): if isinstance(config, dict): config = DiscordConfig.model_validate(config) super().__init__(config, bus) self.config: DiscordConfig = config - self._ws: websockets.WebSocketClientProtocol | None = None - self._seq: int | None = None - self._heartbeat_task: asyncio.Task | None = None - self._typing_tasks: dict[str, asyncio.Task] = {} - self._http: httpx.AsyncClient | None = None + self._client: DiscordBotClient | None = None + self._typing_tasks: dict[str, asyncio.Task[None]] = {} self._bot_user_id: str | None = None async def start(self) -> None: - """Start the Discord gateway connection.""" + """Start the Discord client.""" + if not DISCORD_AVAILABLE: + logger.error("discord.py not installed. Run: pip install nanobot-ai[discord]") + return + if not self.config.token: logger.error("Discord bot token not configured") return - self._running = True - self._http = httpx.AsyncClient(timeout=30.0) + try: + intents = discord.Intents.none() + intents.value = self.config.intents + self._client = DiscordBotClient(self, intents=intents) + except Exception as e: + logger.error("Failed to initialize Discord client: {}", e) + self._client = None + self._running = False + return - while self._running: - try: - logger.info("Connecting to Discord gateway...") - async with websockets.connect(self.config.gateway_url) as ws: - self._ws = ws - await self._gateway_loop() - except asyncio.CancelledError: - break - except Exception as e: - logger.warning("Discord gateway error: {}", e) - if self._running: - logger.info("Reconnecting to Discord gateway in 5 seconds...") - await asyncio.sleep(5) + self._running = True + logger.info("Starting Discord client via discord.py...") + + try: + await self._client.start(self.config.token) + except asyncio.CancelledError: + raise + except Exception as e: + logger.error("Discord client startup failed: {}", e) + finally: + self._running = False + await self._reset_runtime_state(close_client=True) async def stop(self) -> None: """Stop the Discord channel.""" self._running = False - if self._heartbeat_task: - self._heartbeat_task.cancel() - self._heartbeat_task = None - for task in self._typing_tasks.values(): - task.cancel() - self._typing_tasks.clear() - if self._ws: - await self._ws.close() - self._ws = None - if self._http: - await self._http.aclose() - self._http = None + await self._reset_runtime_state(close_client=True) async def send(self, msg: OutboundMessage) -> None: - """Send a message through Discord REST API, including file attachments.""" - if not self._http: - logger.warning("Discord HTTP client not initialized") + """Send a message through Discord using discord.py.""" + client = self._client + if client is None or not client.is_ready(): + logger.warning("Discord client not ready; dropping outbound message") return - url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages" - headers = {"Authorization": f"Bot {self.config.token}"} + is_progress = bool((msg.metadata or {}).get("_progress")) + try: + await client.send_outbound(msg) + except Exception as e: + logger.error("Error sending Discord message: {}", e) + finally: + if not is_progress: + await self._stop_typing(msg.chat_id) + + async def _handle_discord_message(self, message: discord.Message) -> None: + """Handle incoming Discord messages from discord.py.""" + if message.author.bot: + return + + sender_id = str(message.author.id) + channel_id = self._channel_key(message.channel) + content = message.content or "" + + if not self._should_accept_inbound(message, sender_id, content): + return + + media_paths, attachment_markers = await self._download_attachments(message.attachments) + full_content = self._compose_inbound_content(content, attachment_markers) + metadata = self._build_inbound_metadata(message) + + await self._start_typing(message.channel) try: - sent_media = False - failed_media: list[str] = [] + await self._handle_message( + sender_id=sender_id, + chat_id=channel_id, + content=full_content, + media=media_paths, + metadata=metadata, + ) + except Exception: + await self._stop_typing(channel_id) + raise - # Send file attachments first - for media_path in msg.media or []: - if await self._send_file(url, headers, media_path, reply_to=msg.reply_to): - sent_media = True - else: - failed_media.append(Path(media_path).name) + async def _on_message(self, message: discord.Message) -> None: + """Backward-compatible alias for legacy tests/callers.""" + await self._handle_discord_message(message) - # Send text content - chunks = split_message(msg.content or "", MAX_MESSAGE_LEN) - if not chunks and failed_media and not sent_media: - chunks = split_message( - "\n".join(f"[attachment: {name} - send failed]" for name in failed_media), - MAX_MESSAGE_LEN, - ) - if not chunks: - return - - for i, chunk in enumerate(chunks): - payload: dict[str, Any] = {"content": chunk} - - # Let the first successful attachment carry the reply if present. - if i == 0 and msg.reply_to and not sent_media: - payload["message_reference"] = {"message_id": msg.reply_to} - payload["allowed_mentions"] = {"replied_user": False} - - if not await self._send_payload(url, headers, payload): - break # Abort remaining chunks on failure - finally: - await self._stop_typing(msg.chat_id) - - async def _send_payload( - self, url: str, headers: dict[str, str], payload: dict[str, Any] - ) -> bool: - """Send a single Discord API payload with retry on rate-limit. Returns True on success.""" - for attempt in range(3): - try: - response = await self._http.post(url, headers=headers, json=payload) - if response.status_code == 429: - data = response.json() - retry_after = float(data.get("retry_after", 1.0)) - logger.warning("Discord rate limited, retrying in {}s", retry_after) - await asyncio.sleep(retry_after) - continue - response.raise_for_status() - return True - except Exception as e: - if attempt == 2: - logger.error("Error sending Discord message: {}", e) - else: - await asyncio.sleep(1) - return False - - async def _send_file( + def _should_accept_inbound( self, - url: str, - headers: dict[str, str], - file_path: str, - reply_to: str | None = None, + message: discord.Message, + sender_id: str, + content: str, ) -> bool: - """Send a file attachment via Discord REST API using multipart/form-data.""" - path = Path(file_path) - if not path.is_file(): - logger.warning("Discord file not found, skipping: {}", file_path) - return False - - if path.stat().st_size > MAX_ATTACHMENT_BYTES: - logger.warning("Discord file too large (>20MB), skipping: {}", path.name) - return False - - payload_json: dict[str, Any] = {} - if reply_to: - payload_json["message_reference"] = {"message_id": reply_to} - payload_json["allowed_mentions"] = {"replied_user": False} - - for attempt in range(3): - try: - with open(path, "rb") as f: - files = {"files[0]": (path.name, f, "application/octet-stream")} - data: dict[str, Any] = {} - if payload_json: - data["payload_json"] = json.dumps(payload_json) - response = await self._http.post( - url, headers=headers, files=files, data=data - ) - if response.status_code == 429: - resp_data = response.json() - retry_after = float(resp_data.get("retry_after", 1.0)) - logger.warning("Discord rate limited, retrying in {}s", retry_after) - await asyncio.sleep(retry_after) - continue - response.raise_for_status() - logger.info("Discord file sent: {}", path.name) - return True - except Exception as e: - if attempt == 2: - logger.error("Error sending Discord file {}: {}", path.name, e) - else: - await asyncio.sleep(1) - return False - - async def _gateway_loop(self) -> None: - """Main gateway loop: identify, heartbeat, dispatch events.""" - if not self._ws: - return - - async for raw in self._ws: - try: - data = json.loads(raw) - except json.JSONDecodeError: - logger.warning("Invalid JSON from Discord gateway: {}", raw[:100]) - continue - - op = data.get("op") - event_type = data.get("t") - seq = data.get("s") - payload = data.get("d") - - if seq is not None: - self._seq = seq - - if op == 10: - # HELLO: start heartbeat and identify - interval_ms = payload.get("heartbeat_interval", 45000) - await self._start_heartbeat(interval_ms / 1000) - await self._identify() - elif op == 0 and event_type == "READY": - logger.info("Discord gateway READY") - # Capture bot user ID for mention detection - user_data = payload.get("user") or {} - self._bot_user_id = user_data.get("id") - logger.info("Discord bot connected as user {}", self._bot_user_id) - elif op == 0 and event_type == "MESSAGE_CREATE": - await self._handle_message_create(payload) - elif op == 7: - # RECONNECT: exit loop to reconnect - logger.info("Discord gateway requested reconnect") - break - elif op == 9: - # INVALID_SESSION: reconnect - logger.warning("Discord gateway invalid session") - break - - async def _identify(self) -> None: - """Send IDENTIFY payload.""" - if not self._ws: - return - - identify = { - "op": 2, - "d": { - "token": self.config.token, - "intents": self.config.intents, - "properties": { - "os": "nanobot", - "browser": "nanobot", - "device": "nanobot", - }, - }, - } - await self._ws.send(json.dumps(identify)) - - async def _start_heartbeat(self, interval_s: float) -> None: - """Start or restart the heartbeat loop.""" - if self._heartbeat_task: - self._heartbeat_task.cancel() - - async def heartbeat_loop() -> None: - while self._running and self._ws: - payload = {"op": 1, "d": self._seq} - try: - await self._ws.send(json.dumps(payload)) - except Exception as e: - logger.warning("Discord heartbeat failed: {}", e) - break - await asyncio.sleep(interval_s) - - self._heartbeat_task = asyncio.create_task(heartbeat_loop()) - - async def _handle_message_create(self, payload: dict[str, Any]) -> None: - """Handle incoming Discord messages.""" - author = payload.get("author") or {} - if author.get("bot"): - return - - sender_id = str(author.get("id", "")) - channel_id = str(payload.get("channel_id", "")) - content = payload.get("content") or "" - guild_id = payload.get("guild_id") - - if not sender_id or not channel_id: - return - + """Check if inbound Discord message should be processed.""" if not self.is_allowed(sender_id): - return + return False + if message.guild is not None and not self._should_respond_in_group(message, content): + return False + return True - # Check group channel policy (DMs always respond if is_allowed passes) - if guild_id is not None: - if not self._should_respond_in_group(payload, content): - return - - content_parts = [content] if content else [] + async def _download_attachments( + self, + attachments: list[discord.Attachment], + ) -> tuple[list[str], list[str]]: + """Download supported attachments and return paths + display markers.""" media_paths: list[str] = [] + markers: list[str] = [] media_dir = get_media_dir("discord") - for attachment in payload.get("attachments") or []: - url = attachment.get("url") - filename = attachment.get("filename") or "attachment" - size = attachment.get("size") or 0 - if not url or not self._http: - continue - if size and size > MAX_ATTACHMENT_BYTES: - content_parts.append(f"[attachment: {filename} - too large]") + for attachment in attachments: + filename = attachment.filename or "attachment" + if attachment.size and attachment.size > MAX_ATTACHMENT_BYTES: + markers.append(f"[attachment: {filename} - too large]") continue try: media_dir.mkdir(parents=True, exist_ok=True) - file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}" - resp = await self._http.get(url) - resp.raise_for_status() - file_path.write_bytes(resp.content) + safe_name = safe_filename(filename) + file_path = media_dir / f"{attachment.id}_{safe_name}" + await attachment.save(file_path) media_paths.append(str(file_path)) - content_parts.append(f"[attachment: {file_path}]") + markers.append(f"[attachment: {file_path.name}]") except Exception as e: logger.warning("Failed to download Discord attachment: {}", e) - content_parts.append(f"[attachment: {filename} - download failed]") + markers.append(f"[attachment: {filename} - download failed]") - reply_to = (payload.get("referenced_message") or {}).get("id") + return media_paths, markers - await self._start_typing(channel_id) + @staticmethod + def _compose_inbound_content(content: str, attachment_markers: list[str]) -> str: + """Combine message text with attachment markers.""" + content_parts = [content] if content else [] + content_parts.extend(attachment_markers) + return "\n".join(part for part in content_parts if part) or "[empty message]" - await self._handle_message( - sender_id=sender_id, - chat_id=channel_id, - content="\n".join(p for p in content_parts if p) or "[empty message]", - media=media_paths, - metadata={ - "message_id": str(payload.get("id", "")), - "guild_id": guild_id, - "reply_to": reply_to, - }, - ) + @staticmethod + def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]: + """Build metadata for inbound Discord messages.""" + reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None + return { + "message_id": str(message.id), + "guild_id": str(message.guild.id) if message.guild else None, + "reply_to": reply_to, + } - def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool: - """Check if bot should respond in a group channel based on policy.""" + def _should_respond_in_group(self, message: discord.Message, content: str) -> bool: + """Check if the bot should respond in a guild channel based on policy.""" if self.config.group_policy == "open": return True if self.config.group_policy == "mention": - # Check if bot was mentioned in the message - if self._bot_user_id: - # Check mentions array - mentions = payload.get("mentions") or [] - for mention in mentions: - if str(mention.get("id")) == self._bot_user_id: - return True - # Also check content for mention format <@USER_ID> - if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content: - return True - logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id")) + bot_user_id = self._bot_user_id + if bot_user_id is None: + logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id) + return False + + if any(str(user.id) == bot_user_id for user in message.mentions): + return True + if f"<@{bot_user_id}>" in content or f"<@!{bot_user_id}>" in content: + return True + + logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id) return False return True - async def _start_typing(self, channel_id: str) -> None: + async def _start_typing(self, channel: Messageable) -> None: """Start periodic typing indicator for a channel.""" + channel_id = self._channel_key(channel) await self._stop_typing(channel_id) async def typing_loop() -> None: - url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing" - headers = {"Authorization": f"Bot {self.config.token}"} while self._running: try: - await self._http.post(url, headers=headers) + async with channel.typing(): + await asyncio.sleep(TYPING_INTERVAL_S) except asyncio.CancelledError: return except Exception as e: logger.debug("Discord typing indicator failed for {}: {}", channel_id, e) return - await asyncio.sleep(8) self._typing_tasks[channel_id] = asyncio.create_task(typing_loop()) async def _stop_typing(self, channel_id: str) -> None: """Stop typing indicator for a channel.""" - task = self._typing_tasks.pop(channel_id, None) - if task: - task.cancel() + task = self._typing_tasks.pop(self._channel_key(channel_id), None) + if task is None: + return + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def _cancel_all_typing(self) -> None: + """Stop all typing tasks.""" + channel_ids = list(self._typing_tasks) + for channel_id in channel_ids: + await self._stop_typing(channel_id) + + async def _reset_runtime_state(self, close_client: bool) -> None: + """Reset client and typing state.""" + await self._cancel_all_typing() + if close_client and self._client is not None and not self._client.is_closed(): + try: + await self._client.close() + except Exception as e: + logger.warning("Discord client close failed: {}", e) + self._client = None + self._bot_user_id = None diff --git a/nanobot/command/builtin.py b/nanobot/command/builtin.py index 0a9af3cb9..643397057 100644 --- a/nanobot/command/builtin.py +++ b/nanobot/command/builtin.py @@ -84,6 +84,16 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage: async def cmd_help(ctx: CommandContext) -> OutboundMessage: """Return available slash commands.""" + return OutboundMessage( + channel=ctx.msg.channel, + chat_id=ctx.msg.chat_id, + content=build_help_text(), + metadata={"render_as": "text"}, + ) + + +def build_help_text() -> str: + """Build canonical help text shared across channels.""" lines = [ "🐈 nanobot commands:", "/new — Start a new conversation", @@ -92,12 +102,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage: "/status — Show bot status", "/help — Show available commands", ] - return OutboundMessage( - channel=ctx.msg.channel, - chat_id=ctx.msg.chat_id, - content="\n".join(lines), - metadata={"render_as": "text"}, - ) + return "\n".join(lines) def register_builtin_commands(router: CommandRouter) -> None: diff --git a/pyproject.toml b/pyproject.toml index d2952b039..4fbfefadf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,9 @@ matrix = [ "mistune>=3.0.0,<4.0.0", "nh3>=0.2.17,<1.0.0", ] +discord = [ + "discord.py>=2.5.2,<3.0.0", +] langsmith = [ "langsmith>=0.1.0", ] diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py new file mode 100644 index 000000000..3f1f996fc --- /dev/null +++ b/tests/channels/test_discord_channel.py @@ -0,0 +1,676 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from types import SimpleNamespace + +import discord +import pytest + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig +from nanobot.command.builtin import build_help_text + + +# Minimal Discord client test double used to control startup/readiness behavior. +class _FakeDiscordClient: + instances: list["_FakeDiscordClient"] = [] + start_error: Exception | None = None + + def __init__(self, owner, *, intents) -> None: + self.owner = owner + self.intents = intents + self.closed = False + self.ready = True + self.channels: dict[int, object] = {} + self.user = SimpleNamespace(id=999) + self.__class__.instances.append(self) + + async def start(self, token: str) -> None: + self.token = token + if self.__class__.start_error is not None: + raise self.__class__.start_error + + async def close(self) -> None: + self.closed = True + + def is_closed(self) -> bool: + return self.closed + + def is_ready(self) -> bool: + return self.ready + + def get_channel(self, channel_id: int): + return self.channels.get(channel_id) + + async def send_outbound(self, msg: OutboundMessage) -> None: + channel = self.get_channel(int(msg.chat_id)) + if channel is None: + return + await channel.send(content=msg.content) + + +class _FakeAttachment: + # Attachment double that can simulate successful or failing save() calls. + def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None: + self.id = attachment_id + self.filename = filename + self.size = size + self._fail = fail + + async def save(self, path: str | Path) -> None: + if self._fail: + raise RuntimeError("save failed") + Path(path).write_bytes(b"attachment") + + +class _FakePartialMessage: + # Lightweight stand-in for Discord partial message references used in replies. + def __init__(self, message_id: int) -> None: + self.id = message_id + + +class _FakeChannel: + # Channel double that records outbound payloads and typing activity. + def __init__(self, channel_id: int = 123) -> None: + self.id = channel_id + self.sent_payloads: list[dict] = [] + self.trigger_typing_calls = 0 + self.typing_enter_hook = None + + async def send(self, **kwargs) -> None: + payload = dict(kwargs) + if "file" in payload: + payload["file_name"] = payload["file"].filename + del payload["file"] + self.sent_payloads.append(payload) + + def get_partial_message(self, message_id: int) -> _FakePartialMessage: + return _FakePartialMessage(message_id) + + def typing(self): + channel = self + + class _TypingContext: + async def __aenter__(self): + channel.trigger_typing_calls += 1 + if channel.typing_enter_hook is not None: + await channel.typing_enter_hook() + + async def __aexit__(self, exc_type, exc, tb): + return False + + return _TypingContext() + + +class _FakeInteractionResponse: + def __init__(self) -> None: + self.messages: list[dict] = [] + self._done = False + + async def send_message(self, content: str, *, ephemeral: bool = False) -> None: + self.messages.append({"content": content, "ephemeral": ephemeral}) + self._done = True + + def is_done(self) -> bool: + return self._done + + +def _make_interaction( + *, + user_id: int = 123, + channel_id: int | None = 456, + guild_id: int | None = None, + interaction_id: int = 999, +): + return SimpleNamespace( + user=SimpleNamespace(id=user_id), + channel_id=channel_id, + guild_id=guild_id, + id=interaction_id, + command=SimpleNamespace(qualified_name="new"), + response=_FakeInteractionResponse(), + ) + + +def _make_message( + *, + author_id: int = 123, + author_bot: bool = False, + channel_id: int = 456, + message_id: int = 789, + content: str = "hello", + guild_id: int | None = None, + mentions: list[object] | None = None, + attachments: list[object] | None = None, + reply_to: int | None = None, +): + # Factory for incoming Discord message objects with optional guild/reply/attachments. + guild = SimpleNamespace(id=guild_id) if guild_id is not None else None + reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None + return SimpleNamespace( + author=SimpleNamespace(id=author_id, bot=author_bot), + channel=_FakeChannel(channel_id), + content=content, + guild=guild, + mentions=mentions or [], + attachments=attachments or [], + reference=reference, + id=message_id, + ) + + +@pytest.mark.asyncio +async def test_start_returns_when_token_missing() -> None: + # If no token is configured, startup should no-op and leave channel stopped. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + + await channel.start() + + assert channel.is_running is False + assert channel._client is None + + +@pytest.mark.asyncio +async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None: + channel = DiscordChannel( + DiscordConfig(enabled=True, token="token", allow_from=["*"]), + MessageBus(), + ) + monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False) + + await channel.start() + + assert channel.is_running is False + assert channel._client is None + + +@pytest.mark.asyncio +async def test_start_handles_client_construction_failure(monkeypatch) -> None: + # Construction errors from the Discord client should be swallowed and keep state clean. + channel = DiscordChannel( + DiscordConfig(enabled=True, token="token", allow_from=["*"]), + MessageBus(), + ) + + def _boom(owner, *, intents): + raise RuntimeError("bad client") + + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom) + + await channel.start() + + assert channel.is_running is False + assert channel._client is None + + +@pytest.mark.asyncio +async def test_start_handles_client_start_failure(monkeypatch) -> None: + # If client.start fails, the partially created client should be closed and detached. + channel = DiscordChannel( + DiscordConfig(enabled=True, token="token", allow_from=["*"]), + MessageBus(), + ) + + _FakeDiscordClient.instances.clear() + _FakeDiscordClient.start_error = RuntimeError("connect failed") + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient) + + await channel.start() + + assert channel.is_running is False + assert channel._client is None + assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents + assert _FakeDiscordClient.instances[0].closed is True + + _FakeDiscordClient.start_error = None + + +@pytest.mark.asyncio +async def test_stop_is_safe_after_partial_start(monkeypatch) -> None: + # stop() should close/discard the client even when startup was only partially completed. + channel = DiscordChannel( + DiscordConfig(enabled=True, token="token", allow_from=["*"]), + MessageBus(), + ) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + await channel.stop() + + assert channel.is_running is False + assert client.closed is True + assert channel._client is None + + +@pytest.mark.asyncio +async def test_on_message_ignores_bot_messages() -> None: + # Incoming bot-authored messages must be ignored to prevent feedback loops. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign] + + await channel._on_message(_make_message(author_bot=True)) + + assert handled == [] + + # If inbound handling raises, typing should be stopped for that channel. + async def fail_handle(**kwargs) -> None: + raise RuntimeError("boom") + + channel._handle_message = fail_handle # type: ignore[method-assign] + + with pytest.raises(RuntimeError, match="boom"): + await channel._on_message(_make_message(author_id=123, channel_id=456)) + + assert channel._typing_tasks == {} + + +@pytest.mark.asyncio +async def test_on_message_accepts_allowlisted_dm() -> None: + # Allowed direct messages should be forwarded with normalized metadata. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + + await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789)) + + assert len(handled) == 1 + assert handled[0]["chat_id"] == "456" + assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None} + + +@pytest.mark.asyncio +async def test_on_message_ignores_unmentioned_guild_message() -> None: + # With mention-only group policy, guild messages without a bot mention are dropped. + channel = DiscordChannel( + DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"), + MessageBus(), + ) + channel._bot_user_id = "999" + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + + await channel._on_message(_make_message(guild_id=1, content="hello everyone")) + + assert handled == [] + + +@pytest.mark.asyncio +async def test_on_message_accepts_mentioned_guild_message() -> None: + # Mentioned guild messages should be accepted and preserve reply threading metadata. + channel = DiscordChannel( + DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"), + MessageBus(), + ) + channel._bot_user_id = "999" + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + + await channel._on_message( + _make_message( + guild_id=1, + content="<@999> hello", + mentions=[SimpleNamespace(id=999)], + reply_to=321, + ) + ) + + assert len(handled) == 1 + assert handled[0]["metadata"]["reply_to"] == "321" + + +@pytest.mark.asyncio +async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None: + # Attachment downloads should be saved and referenced in forwarded content/media. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path) + + await channel._on_message( + _make_message( + attachments=[_FakeAttachment(12, "photo.png")], + content="see file", + ) + ) + + assert len(handled) == 1 + assert handled[0]["media"] == [str(tmp_path / "12_photo.png")] + assert "[attachment:" in handled[0]["content"] + + +@pytest.mark.asyncio +async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None: + # Failed attachment downloads should emit a readable placeholder and no media path. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path) + + await channel._on_message( + _make_message( + attachments=[_FakeAttachment(12, "photo.png", fail=True)], + content="", + ) + ) + + assert len(handled) == 1 + assert handled[0]["media"] == [] + assert handled[0]["content"] == "[attachment: photo.png - download failed]" + + +@pytest.mark.asyncio +async def test_send_warns_when_client_not_ready() -> None: + # Sending without a running/ready client should be a safe no-op. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + assert channel._typing_tasks == {} + + +@pytest.mark.asyncio +async def test_send_skips_when_channel_not_cached() -> None: + # Outbound sends should be skipped when the destination channel is not resolvable. + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = DiscordBotClient(owner, intents=discord.Intents.none()) + fetch_calls: list[int] = [] + + async def fetch_channel(channel_id: int): + fetch_calls.append(channel_id) + raise RuntimeError("not found") + + client.fetch_channel = fetch_channel # type: ignore[method-assign] + + await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + assert client.get_channel(123) is None + assert fetch_calls == [123] + + +@pytest.mark.asyncio +async def test_send_fetches_channel_when_not_cached() -> None: + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = DiscordBotClient(owner, intents=discord.Intents.none()) + target = _FakeChannel(channel_id=123) + + async def fetch_channel(channel_id: int): + return target if channel_id == 123 else None + + client.fetch_channel = fetch_channel # type: ignore[method-assign] + + await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + assert target.sent_payloads == [{"content": "hello"}] + + +@pytest.mark.asyncio +async def test_slash_new_forwards_when_user_is_allowlisted() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + client = DiscordBotClient(channel, intents=discord.Intents.none()) + interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321) + + new_cmd = client.tree.get_command("new") + assert new_cmd is not None + await new_cmd.callback(interaction) + + assert interaction.response.messages == [ + {"content": "Processing /new...", "ephemeral": True} + ] + assert len(handled) == 1 + assert handled[0]["content"] == "/new" + assert handled[0]["sender_id"] == "123" + assert handled[0]["chat_id"] == "456" + assert handled[0]["metadata"]["interaction_id"] == "321" + assert handled[0]["metadata"]["is_slash_command"] is True + + +@pytest.mark.asyncio +async def test_slash_new_is_blocked_for_disallowed_user() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + client = DiscordBotClient(channel, intents=discord.Intents.none()) + interaction = _make_interaction(user_id=123, channel_id=456) + + new_cmd = client.tree.get_command("new") + assert new_cmd is not None + await new_cmd.callback(interaction) + + assert interaction.response.messages == [ + {"content": "You are not allowed to use this bot.", "ephemeral": True} + ] + assert handled == [] + + +@pytest.mark.parametrize("slash_name", ["stop", "restart", "status"]) +@pytest.mark.asyncio +async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + client = DiscordBotClient(channel, intents=discord.Intents.none()) + interaction = _make_interaction() + interaction.command.qualified_name = slash_name + + cmd = client.tree.get_command(slash_name) + assert cmd is not None + await cmd.callback(interaction) + + assert interaction.response.messages == [ + {"content": f"Processing /{slash_name}...", "ephemeral": True} + ] + assert len(handled) == 1 + assert handled[0]["content"] == f"/{slash_name}" + assert handled[0]["metadata"]["is_slash_command"] is True + + +@pytest.mark.asyncio +async def test_slash_help_returns_ephemeral_help_text() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + client = DiscordBotClient(channel, intents=discord.Intents.none()) + interaction = _make_interaction() + interaction.command.qualified_name = "help" + + help_cmd = client.tree.get_command("help") + assert help_cmd is not None + await help_cmd.callback(interaction) + + assert interaction.response.messages == [ + {"content": build_help_text(), "ephemeral": True} + ] + assert handled == [] + + +@pytest.mark.asyncio +async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None: + # Outbound payloads should upload files, attach reply references, and chunk long text. + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = DiscordBotClient(owner, intents=discord.Intents.none()) + target = _FakeChannel(channel_id=123) + client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign] + + file_path = tmp_path / "demo.txt" + file_path.write_text("hi") + + await client.send_outbound( + OutboundMessage( + channel="discord", + chat_id="123", + content="a" * 2100, + reply_to="55", + media=[str(file_path)], + ) + ) + + assert len(target.sent_payloads) == 3 + assert target.sent_payloads[0]["file_name"] == "demo.txt" + assert target.sent_payloads[0]["reference"].id == 55 + assert target.sent_payloads[1]["content"] == "a" * 2000 + assert target.sent_payloads[2]["content"] == "a" * 100 + + +@pytest.mark.asyncio +async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None: + # If all attachment sends fail and no text exists, emit a failure placeholder message. + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = DiscordBotClient(owner, intents=discord.Intents.none()) + target = _FakeChannel(channel_id=123) + client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign] + + missing_file = tmp_path / "missing.txt" + + await client.send_outbound( + OutboundMessage( + channel="discord", + chat_id="123", + content="", + media=[str(missing_file)], + ) + ) + + assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}] + + +@pytest.mark.asyncio +async def test_send_stops_typing_after_send() -> None: + # Active typing indicators should be cancelled/cleared after a successful send. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + start = asyncio.Event() + release = asyncio.Event() + + async def slow_typing() -> None: + start.set() + await release.wait() + + typing_channel = _FakeChannel(channel_id=123) + typing_channel.typing_enter_hook = slow_typing + + await channel._start_typing(typing_channel) + await start.wait() + + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) + release.set() + await asyncio.sleep(0) + + assert channel._typing_tasks == {} + + # Progress messages should keep typing active until a final (non-progress) send. + start = asyncio.Event() + release = asyncio.Event() + + async def slow_typing_progress() -> None: + start.set() + await release.wait() + + typing_channel = _FakeChannel(channel_id=123) + typing_channel.typing_enter_hook = slow_typing_progress + + await channel._start_typing(typing_channel) + await start.wait() + + await channel.send( + OutboundMessage( + channel="discord", + chat_id="123", + content="progress", + metadata={"_progress": True}, + ) + ) + + assert "123" in channel._typing_tasks + + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final")) + release.set() + await asyncio.sleep(0) + + assert channel._typing_tasks == {} + + +@pytest.mark.asyncio +async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + channel._running = True + + entered = asyncio.Event() + release = asyncio.Event() + + class _TypingCtx: + async def __aenter__(self): + entered.set() + + async def __aexit__(self, exc_type, exc, tb): + return False + + class _NoTriggerChannel: + def __init__(self, channel_id: int = 123) -> None: + self.id = channel_id + + def typing(self): + async def _waiter(): + await release.wait() + # Hold the loop so task remains active until explicitly stopped. + class _Ctx(_TypingCtx): + async def __aenter__(self): + await super().__aenter__() + await _waiter() + return _Ctx() + + typing_channel = _NoTriggerChannel(channel_id=123) + await channel._start_typing(typing_channel) # type: ignore[arg-type] + await entered.wait() + + assert "123" in channel._typing_tasks + + await channel._stop_typing("123") + release.set() + await asyncio.sleep(0) + + assert channel._typing_tasks == {}