feat(discord): Use discord.py for stable discord channel (#2486)

Co-authored-by: Pares Mathur <paresh.2047@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Paresh Mathur 2026-03-27 02:51:45 +01:00 committed by chengyongru
parent 5118ea824a
commit 80403d352b
4 changed files with 1061 additions and 300 deletions

View File

@ -1,25 +1,37 @@
"""Discord channel implementation using Discord Gateway websocket."""
"""Discord channel implementation using discord.py."""
from __future__ import annotations
import asyncio
import json
import importlib.util
from pathlib import Path
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal
import httpx
from pydantic import Field
import websockets
from loguru import logger
from pydantic import Field
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.command.builtin import build_help_text
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base
from nanobot.utils.helpers import split_message
from nanobot.utils.helpers import safe_filename, split_message
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
if TYPE_CHECKING:
import discord
from discord import app_commands
from discord.abc import Messageable
if DISCORD_AVAILABLE:
import discord
from discord import app_commands
from discord.abc import Messageable
DISCORD_API_BASE = "https://discord.com/api/v10"
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
MAX_MESSAGE_LEN = 2000 # Discord message character limit
TYPING_INTERVAL_S = 8
class DiscordConfig(Base):
@ -28,13 +40,202 @@ class DiscordConfig(Base):
enabled: bool = False
token: str = ""
allow_from: list[str] = Field(default_factory=list)
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
intents: int = 37377
group_policy: Literal["mention", "open"] = "mention"
if DISCORD_AVAILABLE:
class DiscordBotClient(discord.Client):
"""discord.py client that forwards events to the channel."""
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
super().__init__(intents=intents)
self._channel = channel
self.tree = app_commands.CommandTree(self)
self._register_app_commands()
async def on_ready(self) -> None:
self._channel._bot_user_id = str(self.user.id) if self.user else None
logger.info("Discord bot connected as user {}", self._channel._bot_user_id)
try:
synced = await self.tree.sync()
logger.info("Discord app commands synced: {}", len(synced))
except Exception as e:
logger.warning("Discord app command sync failed: {}", e)
async def on_message(self, message: discord.Message) -> None:
await self._channel._handle_discord_message(message)
async def _reply_ephemeral(self, interaction: discord.Interaction, text: str) -> bool:
"""Send an ephemeral interaction response and report success."""
try:
await interaction.response.send_message(text, ephemeral=True)
return True
except Exception as e:
logger.warning("Discord interaction response failed: {}", e)
return False
async def _forward_slash_command(
self,
interaction: discord.Interaction,
command_text: str,
) -> None:
sender_id = str(interaction.user.id)
channel_id = interaction.channel_id
if channel_id is None:
logger.warning("Discord slash command missing channel_id: {}", command_text)
return
if not self._channel.is_allowed(sender_id):
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
return
await self._reply_ephemeral(interaction, f"Processing {command_text}...")
await self._channel._handle_message(
sender_id=sender_id,
chat_id=str(channel_id),
content=command_text,
metadata={
"interaction_id": str(interaction.id),
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
"is_slash_command": True,
},
)
def _register_app_commands(self) -> None:
commands = (
("new", "Start a new conversation", "/new"),
("stop", "Stop the current task", "/stop"),
("restart", "Restart the bot", "/restart"),
("status", "Show bot status", "/status"),
)
for name, description, command_text in commands:
@self.tree.command(name=name, description=description)
async def command_handler(
interaction: discord.Interaction,
_command_text: str = command_text,
) -> None:
await self._forward_slash_command(interaction, _command_text)
@self.tree.command(name="help", description="Show available commands")
async def help_command(interaction: discord.Interaction) -> None:
sender_id = str(interaction.user.id)
if not self._channel.is_allowed(sender_id):
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
return
await self._reply_ephemeral(interaction, build_help_text())
@self.tree.error
async def on_app_command_error(
interaction: discord.Interaction,
error: app_commands.AppCommandError,
) -> None:
command_name = interaction.command.qualified_name if interaction.command else "?"
logger.warning(
"Discord app command failed user={} channel={} cmd={} error={}",
interaction.user.id,
interaction.channel_id,
command_name,
error,
)
async def send_outbound(self, msg: OutboundMessage) -> None:
"""Send a nanobot outbound message using Discord transport rules."""
channel_id = int(msg.chat_id)
channel = self.get_channel(channel_id)
if channel is None:
try:
channel = await self.fetch_channel(channel_id)
except Exception as e:
logger.warning("Discord channel {} unavailable: {}", msg.chat_id, e)
return
reference, mention_settings = self._build_reply_context(channel, msg.reply_to)
sent_media = False
failed_media: list[str] = []
for index, media_path in enumerate(msg.media or []):
if await self._send_file(
channel,
media_path,
reference=reference if index == 0 else None,
mention_settings=mention_settings,
):
sent_media = True
else:
failed_media.append(Path(media_path).name)
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
kwargs: dict[str, Any] = {"content": chunk}
if index == 0 and reference is not None and not sent_media:
kwargs["reference"] = reference
kwargs["allowed_mentions"] = mention_settings
await channel.send(**kwargs)
async def _send_file(
self,
channel: Messageable,
file_path: str,
*,
reference: discord.PartialMessage | None,
mention_settings: discord.AllowedMentions,
) -> bool:
"""Send a file attachment via discord.py."""
path = Path(file_path)
if not path.is_file():
logger.warning("Discord file not found, skipping: {}", file_path)
return False
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
return False
try:
kwargs: dict[str, Any] = {"file": discord.File(path)}
if reference is not None:
kwargs["reference"] = reference
kwargs["allowed_mentions"] = mention_settings
await channel.send(**kwargs)
logger.info("Discord file sent: {}", path.name)
return True
except Exception as e:
logger.error("Error sending Discord file {}: {}", path.name, e)
return False
@staticmethod
def _build_chunks(content: str, failed_media: list[str], sent_media: bool) -> list[str]:
"""Build outbound text chunks, including attachment-failure fallback text."""
chunks = split_message(content, MAX_MESSAGE_LEN)
if chunks or not failed_media or sent_media:
return chunks
fallback = "\n".join(f"[attachment: {name} - send failed]" for name in failed_media)
return split_message(fallback, MAX_MESSAGE_LEN)
@staticmethod
def _build_reply_context(
channel: Messageable,
reply_to: str | None,
) -> tuple[discord.PartialMessage | None, discord.AllowedMentions]:
"""Build reply context for outbound messages."""
mention_settings = discord.AllowedMentions(replied_user=False)
if not reply_to:
return None, mention_settings
try:
message_id = int(reply_to)
except ValueError:
logger.warning("Invalid Discord reply target: {}", reply_to)
return None, mention_settings
return channel.get_partial_message(message_id), mention_settings
class DiscordChannel(BaseChannel):
"""Discord channel using Gateway websocket."""
"""Discord channel using discord.py."""
name = "discord"
display_name = "Discord"
@ -43,353 +244,229 @@ class DiscordChannel(BaseChannel):
def default_config(cls) -> dict[str, Any]:
return DiscordConfig().model_dump(by_alias=True)
@staticmethod
def _channel_key(channel_or_id: Any) -> str:
"""Normalize channel-like objects and ids to a stable string key."""
channel_id = getattr(channel_or_id, "id", channel_or_id)
return str(channel_id)
def __init__(self, config: Any, bus: MessageBus):
if isinstance(config, dict):
config = DiscordConfig.model_validate(config)
super().__init__(config, bus)
self.config: DiscordConfig = config
self._ws: websockets.WebSocketClientProtocol | None = None
self._seq: int | None = None
self._heartbeat_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {}
self._http: httpx.AsyncClient | None = None
self._client: DiscordBotClient | None = None
self._typing_tasks: dict[str, asyncio.Task[None]] = {}
self._bot_user_id: str | None = None
async def start(self) -> None:
"""Start the Discord gateway connection."""
"""Start the Discord client."""
if not DISCORD_AVAILABLE:
logger.error("discord.py not installed. Run: pip install nanobot-ai[discord]")
return
if not self.config.token:
logger.error("Discord bot token not configured")
return
self._running = True
self._http = httpx.AsyncClient(timeout=30.0)
try:
intents = discord.Intents.none()
intents.value = self.config.intents
self._client = DiscordBotClient(self, intents=intents)
except Exception as e:
logger.error("Failed to initialize Discord client: {}", e)
self._client = None
self._running = False
return
while self._running:
try:
logger.info("Connecting to Discord gateway...")
async with websockets.connect(self.config.gateway_url) as ws:
self._ws = ws
await self._gateway_loop()
except asyncio.CancelledError:
break
except Exception as e:
logger.warning("Discord gateway error: {}", e)
if self._running:
logger.info("Reconnecting to Discord gateway in 5 seconds...")
await asyncio.sleep(5)
self._running = True
logger.info("Starting Discord client via discord.py...")
try:
await self._client.start(self.config.token)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error("Discord client startup failed: {}", e)
finally:
self._running = False
await self._reset_runtime_state(close_client=True)
async def stop(self) -> None:
"""Stop the Discord channel."""
self._running = False
if self._heartbeat_task:
self._heartbeat_task.cancel()
self._heartbeat_task = None
for task in self._typing_tasks.values():
task.cancel()
self._typing_tasks.clear()
if self._ws:
await self._ws.close()
self._ws = None
if self._http:
await self._http.aclose()
self._http = None
await self._reset_runtime_state(close_client=True)
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through Discord REST API, including file attachments."""
if not self._http:
logger.warning("Discord HTTP client not initialized")
"""Send a message through Discord using discord.py."""
client = self._client
if client is None or not client.is_ready():
logger.warning("Discord client not ready; dropping outbound message")
return
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
headers = {"Authorization": f"Bot {self.config.token}"}
is_progress = bool((msg.metadata or {}).get("_progress"))
try:
await client.send_outbound(msg)
except Exception as e:
logger.error("Error sending Discord message: {}", e)
finally:
if not is_progress:
await self._stop_typing(msg.chat_id)
async def _handle_discord_message(self, message: discord.Message) -> None:
"""Handle incoming Discord messages from discord.py."""
if message.author.bot:
return
sender_id = str(message.author.id)
channel_id = self._channel_key(message.channel)
content = message.content or ""
if not self._should_accept_inbound(message, sender_id, content):
return
media_paths, attachment_markers = await self._download_attachments(message.attachments)
full_content = self._compose_inbound_content(content, attachment_markers)
metadata = self._build_inbound_metadata(message)
await self._start_typing(message.channel)
try:
sent_media = False
failed_media: list[str] = []
await self._handle_message(
sender_id=sender_id,
chat_id=channel_id,
content=full_content,
media=media_paths,
metadata=metadata,
)
except Exception:
await self._stop_typing(channel_id)
raise
# Send file attachments first
for media_path in msg.media or []:
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
sent_media = True
else:
failed_media.append(Path(media_path).name)
async def _on_message(self, message: discord.Message) -> None:
"""Backward-compatible alias for legacy tests/callers."""
await self._handle_discord_message(message)
# Send text content
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
if not chunks and failed_media and not sent_media:
chunks = split_message(
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
MAX_MESSAGE_LEN,
)
if not chunks:
return
for i, chunk in enumerate(chunks):
payload: dict[str, Any] = {"content": chunk}
# Let the first successful attachment carry the reply if present.
if i == 0 and msg.reply_to and not sent_media:
payload["message_reference"] = {"message_id": msg.reply_to}
payload["allowed_mentions"] = {"replied_user": False}
if not await self._send_payload(url, headers, payload):
break # Abort remaining chunks on failure
finally:
await self._stop_typing(msg.chat_id)
async def _send_payload(
self, url: str, headers: dict[str, str], payload: dict[str, Any]
) -> bool:
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
for attempt in range(3):
try:
response = await self._http.post(url, headers=headers, json=payload)
if response.status_code == 429:
data = response.json()
retry_after = float(data.get("retry_after", 1.0))
logger.warning("Discord rate limited, retrying in {}s", retry_after)
await asyncio.sleep(retry_after)
continue
response.raise_for_status()
return True
except Exception as e:
if attempt == 2:
logger.error("Error sending Discord message: {}", e)
else:
await asyncio.sleep(1)
return False
async def _send_file(
def _should_accept_inbound(
self,
url: str,
headers: dict[str, str],
file_path: str,
reply_to: str | None = None,
message: discord.Message,
sender_id: str,
content: str,
) -> bool:
"""Send a file attachment via Discord REST API using multipart/form-data."""
path = Path(file_path)
if not path.is_file():
logger.warning("Discord file not found, skipping: {}", file_path)
return False
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
return False
payload_json: dict[str, Any] = {}
if reply_to:
payload_json["message_reference"] = {"message_id": reply_to}
payload_json["allowed_mentions"] = {"replied_user": False}
for attempt in range(3):
try:
with open(path, "rb") as f:
files = {"files[0]": (path.name, f, "application/octet-stream")}
data: dict[str, Any] = {}
if payload_json:
data["payload_json"] = json.dumps(payload_json)
response = await self._http.post(
url, headers=headers, files=files, data=data
)
if response.status_code == 429:
resp_data = response.json()
retry_after = float(resp_data.get("retry_after", 1.0))
logger.warning("Discord rate limited, retrying in {}s", retry_after)
await asyncio.sleep(retry_after)
continue
response.raise_for_status()
logger.info("Discord file sent: {}", path.name)
return True
except Exception as e:
if attempt == 2:
logger.error("Error sending Discord file {}: {}", path.name, e)
else:
await asyncio.sleep(1)
return False
async def _gateway_loop(self) -> None:
"""Main gateway loop: identify, heartbeat, dispatch events."""
if not self._ws:
return
async for raw in self._ws:
try:
data = json.loads(raw)
except json.JSONDecodeError:
logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
continue
op = data.get("op")
event_type = data.get("t")
seq = data.get("s")
payload = data.get("d")
if seq is not None:
self._seq = seq
if op == 10:
# HELLO: start heartbeat and identify
interval_ms = payload.get("heartbeat_interval", 45000)
await self._start_heartbeat(interval_ms / 1000)
await self._identify()
elif op == 0 and event_type == "READY":
logger.info("Discord gateway READY")
# Capture bot user ID for mention detection
user_data = payload.get("user") or {}
self._bot_user_id = user_data.get("id")
logger.info("Discord bot connected as user {}", self._bot_user_id)
elif op == 0 and event_type == "MESSAGE_CREATE":
await self._handle_message_create(payload)
elif op == 7:
# RECONNECT: exit loop to reconnect
logger.info("Discord gateway requested reconnect")
break
elif op == 9:
# INVALID_SESSION: reconnect
logger.warning("Discord gateway invalid session")
break
async def _identify(self) -> None:
"""Send IDENTIFY payload."""
if not self._ws:
return
identify = {
"op": 2,
"d": {
"token": self.config.token,
"intents": self.config.intents,
"properties": {
"os": "nanobot",
"browser": "nanobot",
"device": "nanobot",
},
},
}
await self._ws.send(json.dumps(identify))
async def _start_heartbeat(self, interval_s: float) -> None:
"""Start or restart the heartbeat loop."""
if self._heartbeat_task:
self._heartbeat_task.cancel()
async def heartbeat_loop() -> None:
while self._running and self._ws:
payload = {"op": 1, "d": self._seq}
try:
await self._ws.send(json.dumps(payload))
except Exception as e:
logger.warning("Discord heartbeat failed: {}", e)
break
await asyncio.sleep(interval_s)
self._heartbeat_task = asyncio.create_task(heartbeat_loop())
async def _handle_message_create(self, payload: dict[str, Any]) -> None:
"""Handle incoming Discord messages."""
author = payload.get("author") or {}
if author.get("bot"):
return
sender_id = str(author.get("id", ""))
channel_id = str(payload.get("channel_id", ""))
content = payload.get("content") or ""
guild_id = payload.get("guild_id")
if not sender_id or not channel_id:
return
"""Check if inbound Discord message should be processed."""
if not self.is_allowed(sender_id):
return
return False
if message.guild is not None and not self._should_respond_in_group(message, content):
return False
return True
# Check group channel policy (DMs always respond if is_allowed passes)
if guild_id is not None:
if not self._should_respond_in_group(payload, content):
return
content_parts = [content] if content else []
async def _download_attachments(
self,
attachments: list[discord.Attachment],
) -> tuple[list[str], list[str]]:
"""Download supported attachments and return paths + display markers."""
media_paths: list[str] = []
markers: list[str] = []
media_dir = get_media_dir("discord")
for attachment in payload.get("attachments") or []:
url = attachment.get("url")
filename = attachment.get("filename") or "attachment"
size = attachment.get("size") or 0
if not url or not self._http:
continue
if size and size > MAX_ATTACHMENT_BYTES:
content_parts.append(f"[attachment: {filename} - too large]")
for attachment in attachments:
filename = attachment.filename or "attachment"
if attachment.size and attachment.size > MAX_ATTACHMENT_BYTES:
markers.append(f"[attachment: {filename} - too large]")
continue
try:
media_dir.mkdir(parents=True, exist_ok=True)
file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}"
resp = await self._http.get(url)
resp.raise_for_status()
file_path.write_bytes(resp.content)
safe_name = safe_filename(filename)
file_path = media_dir / f"{attachment.id}_{safe_name}"
await attachment.save(file_path)
media_paths.append(str(file_path))
content_parts.append(f"[attachment: {file_path}]")
markers.append(f"[attachment: {file_path.name}]")
except Exception as e:
logger.warning("Failed to download Discord attachment: {}", e)
content_parts.append(f"[attachment: {filename} - download failed]")
markers.append(f"[attachment: {filename} - download failed]")
reply_to = (payload.get("referenced_message") or {}).get("id")
return media_paths, markers
await self._start_typing(channel_id)
@staticmethod
def _compose_inbound_content(content: str, attachment_markers: list[str]) -> str:
"""Combine message text with attachment markers."""
content_parts = [content] if content else []
content_parts.extend(attachment_markers)
return "\n".join(part for part in content_parts if part) or "[empty message]"
await self._handle_message(
sender_id=sender_id,
chat_id=channel_id,
content="\n".join(p for p in content_parts if p) or "[empty message]",
media=media_paths,
metadata={
"message_id": str(payload.get("id", "")),
"guild_id": guild_id,
"reply_to": reply_to,
},
)
@staticmethod
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
"""Build metadata for inbound Discord messages."""
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
return {
"message_id": str(message.id),
"guild_id": str(message.guild.id) if message.guild else None,
"reply_to": reply_to,
}
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
"""Check if bot should respond in a group channel based on policy."""
def _should_respond_in_group(self, message: discord.Message, content: str) -> bool:
"""Check if the bot should respond in a guild channel based on policy."""
if self.config.group_policy == "open":
return True
if self.config.group_policy == "mention":
# Check if bot was mentioned in the message
if self._bot_user_id:
# Check mentions array
mentions = payload.get("mentions") or []
for mention in mentions:
if str(mention.get("id")) == self._bot_user_id:
return True
# Also check content for mention format <@USER_ID>
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
return True
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
bot_user_id = self._bot_user_id
if bot_user_id is None:
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
return False
if any(str(user.id) == bot_user_id for user in message.mentions):
return True
if f"<@{bot_user_id}>" in content or f"<@!{bot_user_id}>" in content:
return True
logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id)
return False
return True
async def _start_typing(self, channel_id: str) -> None:
async def _start_typing(self, channel: Messageable) -> None:
"""Start periodic typing indicator for a channel."""
channel_id = self._channel_key(channel)
await self._stop_typing(channel_id)
async def typing_loop() -> None:
url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing"
headers = {"Authorization": f"Bot {self.config.token}"}
while self._running:
try:
await self._http.post(url, headers=headers)
async with channel.typing():
await asyncio.sleep(TYPING_INTERVAL_S)
except asyncio.CancelledError:
return
except Exception as e:
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
return
await asyncio.sleep(8)
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
async def _stop_typing(self, channel_id: str) -> None:
"""Stop typing indicator for a channel."""
task = self._typing_tasks.pop(channel_id, None)
if task:
task.cancel()
task = self._typing_tasks.pop(self._channel_key(channel_id), None)
if task is None:
return
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def _cancel_all_typing(self) -> None:
"""Stop all typing tasks."""
channel_ids = list(self._typing_tasks)
for channel_id in channel_ids:
await self._stop_typing(channel_id)
async def _reset_runtime_state(self, close_client: bool) -> None:
"""Reset client and typing state."""
await self._cancel_all_typing()
if close_client and self._client is not None and not self._client.is_closed():
try:
await self._client.close()
except Exception as e:
logger.warning("Discord client close failed: {}", e)
self._client = None
self._bot_user_id = None

View File

@ -84,6 +84,16 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
async def cmd_help(ctx: CommandContext) -> OutboundMessage:
"""Return available slash commands."""
return OutboundMessage(
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content=build_help_text(),
metadata={"render_as": "text"},
)
def build_help_text() -> str:
"""Build canonical help text shared across channels."""
lines = [
"🐈 nanobot commands:",
"/new — Start a new conversation",
@ -92,12 +102,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage:
"/status — Show bot status",
"/help — Show available commands",
]
return OutboundMessage(
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content="\n".join(lines),
metadata={"render_as": "text"},
)
return "\n".join(lines)
def register_builtin_commands(router: CommandRouter) -> None:

View File

@ -64,6 +64,9 @@ matrix = [
"mistune>=3.0.0,<4.0.0",
"nh3>=0.2.17,<1.0.0",
]
discord = [
"discord.py>=2.5.2,<3.0.0",
]
langsmith = [
"langsmith>=0.1.0",
]

View File

@ -0,0 +1,676 @@
from __future__ import annotations
import asyncio
from pathlib import Path
from types import SimpleNamespace
import discord
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig
from nanobot.command.builtin import build_help_text
# Minimal Discord client test double used to control startup/readiness behavior.
class _FakeDiscordClient:
instances: list["_FakeDiscordClient"] = []
start_error: Exception | None = None
def __init__(self, owner, *, intents) -> None:
self.owner = owner
self.intents = intents
self.closed = False
self.ready = True
self.channels: dict[int, object] = {}
self.user = SimpleNamespace(id=999)
self.__class__.instances.append(self)
async def start(self, token: str) -> None:
self.token = token
if self.__class__.start_error is not None:
raise self.__class__.start_error
async def close(self) -> None:
self.closed = True
def is_closed(self) -> bool:
return self.closed
def is_ready(self) -> bool:
return self.ready
def get_channel(self, channel_id: int):
return self.channels.get(channel_id)
async def send_outbound(self, msg: OutboundMessage) -> None:
channel = self.get_channel(int(msg.chat_id))
if channel is None:
return
await channel.send(content=msg.content)
class _FakeAttachment:
# Attachment double that can simulate successful or failing save() calls.
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
self.id = attachment_id
self.filename = filename
self.size = size
self._fail = fail
async def save(self, path: str | Path) -> None:
if self._fail:
raise RuntimeError("save failed")
Path(path).write_bytes(b"attachment")
class _FakePartialMessage:
# Lightweight stand-in for Discord partial message references used in replies.
def __init__(self, message_id: int) -> None:
self.id = message_id
class _FakeChannel:
# Channel double that records outbound payloads and typing activity.
def __init__(self, channel_id: int = 123) -> None:
self.id = channel_id
self.sent_payloads: list[dict] = []
self.trigger_typing_calls = 0
self.typing_enter_hook = None
async def send(self, **kwargs) -> None:
payload = dict(kwargs)
if "file" in payload:
payload["file_name"] = payload["file"].filename
del payload["file"]
self.sent_payloads.append(payload)
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
return _FakePartialMessage(message_id)
def typing(self):
channel = self
class _TypingContext:
async def __aenter__(self):
channel.trigger_typing_calls += 1
if channel.typing_enter_hook is not None:
await channel.typing_enter_hook()
async def __aexit__(self, exc_type, exc, tb):
return False
return _TypingContext()
class _FakeInteractionResponse:
def __init__(self) -> None:
self.messages: list[dict] = []
self._done = False
async def send_message(self, content: str, *, ephemeral: bool = False) -> None:
self.messages.append({"content": content, "ephemeral": ephemeral})
self._done = True
def is_done(self) -> bool:
return self._done
def _make_interaction(
*,
user_id: int = 123,
channel_id: int | None = 456,
guild_id: int | None = None,
interaction_id: int = 999,
):
return SimpleNamespace(
user=SimpleNamespace(id=user_id),
channel_id=channel_id,
guild_id=guild_id,
id=interaction_id,
command=SimpleNamespace(qualified_name="new"),
response=_FakeInteractionResponse(),
)
def _make_message(
*,
author_id: int = 123,
author_bot: bool = False,
channel_id: int = 456,
message_id: int = 789,
content: str = "hello",
guild_id: int | None = None,
mentions: list[object] | None = None,
attachments: list[object] | None = None,
reply_to: int | None = None,
):
# Factory for incoming Discord message objects with optional guild/reply/attachments.
guild = SimpleNamespace(id=guild_id) if guild_id is not None else None
reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None
return SimpleNamespace(
author=SimpleNamespace(id=author_id, bot=author_bot),
channel=_FakeChannel(channel_id),
content=content,
guild=guild,
mentions=mentions or [],
attachments=attachments or [],
reference=reference,
id=message_id,
)
@pytest.mark.asyncio
async def test_start_returns_when_token_missing() -> None:
# If no token is configured, startup should no-op and leave channel stopped.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None:
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False)
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_handles_client_construction_failure(monkeypatch) -> None:
# Construction errors from the Discord client should be swallowed and keep state clean.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
def _boom(owner, *, intents):
raise RuntimeError("bad client")
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_handles_client_start_failure(monkeypatch) -> None:
# If client.start fails, the partially created client should be closed and detached.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
_FakeDiscordClient.instances.clear()
_FakeDiscordClient.start_error = RuntimeError("connect failed")
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
await channel.start()
assert channel.is_running is False
assert channel._client is None
assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents
assert _FakeDiscordClient.instances[0].closed is True
_FakeDiscordClient.start_error = None
@pytest.mark.asyncio
async def test_stop_is_safe_after_partial_start(monkeypatch) -> None:
# stop() should close/discard the client even when startup was only partially completed.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
await channel.stop()
assert channel.is_running is False
assert client.closed is True
assert channel._client is None
@pytest.mark.asyncio
async def test_on_message_ignores_bot_messages() -> None:
# Incoming bot-authored messages must be ignored to prevent feedback loops.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign]
await channel._on_message(_make_message(author_bot=True))
assert handled == []
# If inbound handling raises, typing should be stopped for that channel.
async def fail_handle(**kwargs) -> None:
raise RuntimeError("boom")
channel._handle_message = fail_handle # type: ignore[method-assign]
with pytest.raises(RuntimeError, match="boom"):
await channel._on_message(_make_message(author_id=123, channel_id=456))
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_on_message_accepts_allowlisted_dm() -> None:
# Allowed direct messages should be forwarded with normalized metadata.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789))
assert len(handled) == 1
assert handled[0]["chat_id"] == "456"
assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None}
@pytest.mark.asyncio
async def test_on_message_ignores_unmentioned_guild_message() -> None:
# With mention-only group policy, guild messages without a bot mention are dropped.
channel = DiscordChannel(
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._bot_user_id = "999"
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(_make_message(guild_id=1, content="hello everyone"))
assert handled == []
@pytest.mark.asyncio
async def test_on_message_accepts_mentioned_guild_message() -> None:
# Mentioned guild messages should be accepted and preserve reply threading metadata.
channel = DiscordChannel(
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._bot_user_id = "999"
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(
_make_message(
guild_id=1,
content="<@999> hello",
mentions=[SimpleNamespace(id=999)],
reply_to=321,
)
)
assert len(handled) == 1
assert handled[0]["metadata"]["reply_to"] == "321"
@pytest.mark.asyncio
async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None:
# Attachment downloads should be saved and referenced in forwarded content/media.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
await channel._on_message(
_make_message(
attachments=[_FakeAttachment(12, "photo.png")],
content="see file",
)
)
assert len(handled) == 1
assert handled[0]["media"] == [str(tmp_path / "12_photo.png")]
assert "[attachment:" in handled[0]["content"]
@pytest.mark.asyncio
async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None:
# Failed attachment downloads should emit a readable placeholder and no media path.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
await channel._on_message(
_make_message(
attachments=[_FakeAttachment(12, "photo.png", fail=True)],
content="",
)
)
assert len(handled) == 1
assert handled[0]["media"] == []
assert handled[0]["content"] == "[attachment: photo.png - download failed]"
@pytest.mark.asyncio
async def test_send_warns_when_client_not_ready() -> None:
# Sending without a running/ready client should be a safe no-op.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_send_skips_when_channel_not_cached() -> None:
# Outbound sends should be skipped when the destination channel is not resolvable.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
fetch_calls: list[int] = []
async def fetch_channel(channel_id: int):
fetch_calls.append(channel_id)
raise RuntimeError("not found")
client.fetch_channel = fetch_channel # type: ignore[method-assign]
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert client.get_channel(123) is None
assert fetch_calls == [123]
@pytest.mark.asyncio
async def test_send_fetches_channel_when_not_cached() -> None:
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
async def fetch_channel(channel_id: int):
return target if channel_id == 123 else None
client.fetch_channel = fetch_channel # type: ignore[method-assign]
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert target.sent_payloads == [{"content": "hello"}]
@pytest.mark.asyncio
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321)
new_cmd = client.tree.get_command("new")
assert new_cmd is not None
await new_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": "Processing /new...", "ephemeral": True}
]
assert len(handled) == 1
assert handled[0]["content"] == "/new"
assert handled[0]["sender_id"] == "123"
assert handled[0]["chat_id"] == "456"
assert handled[0]["metadata"]["interaction_id"] == "321"
assert handled[0]["metadata"]["is_slash_command"] is True
@pytest.mark.asyncio
async def test_slash_new_is_blocked_for_disallowed_user() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction(user_id=123, channel_id=456)
new_cmd = client.tree.get_command("new")
assert new_cmd is not None
await new_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": "You are not allowed to use this bot.", "ephemeral": True}
]
assert handled == []
@pytest.mark.parametrize("slash_name", ["stop", "restart", "status"])
@pytest.mark.asyncio
async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction()
interaction.command.qualified_name = slash_name
cmd = client.tree.get_command(slash_name)
assert cmd is not None
await cmd.callback(interaction)
assert interaction.response.messages == [
{"content": f"Processing /{slash_name}...", "ephemeral": True}
]
assert len(handled) == 1
assert handled[0]["content"] == f"/{slash_name}"
assert handled[0]["metadata"]["is_slash_command"] is True
@pytest.mark.asyncio
async def test_slash_help_returns_ephemeral_help_text() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction()
interaction.command.qualified_name = "help"
help_cmd = client.tree.get_command("help")
assert help_cmd is not None
await help_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": build_help_text(), "ephemeral": True}
]
assert handled == []
@pytest.mark.asyncio
async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None:
# Outbound payloads should upload files, attach reply references, and chunk long text.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
file_path = tmp_path / "demo.txt"
file_path.write_text("hi")
await client.send_outbound(
OutboundMessage(
channel="discord",
chat_id="123",
content="a" * 2100,
reply_to="55",
media=[str(file_path)],
)
)
assert len(target.sent_payloads) == 3
assert target.sent_payloads[0]["file_name"] == "demo.txt"
assert target.sent_payloads[0]["reference"].id == 55
assert target.sent_payloads[1]["content"] == "a" * 2000
assert target.sent_payloads[2]["content"] == "a" * 100
@pytest.mark.asyncio
async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None:
# If all attachment sends fail and no text exists, emit a failure placeholder message.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
missing_file = tmp_path / "missing.txt"
await client.send_outbound(
OutboundMessage(
channel="discord",
chat_id="123",
content="",
media=[str(missing_file)],
)
)
assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}]
@pytest.mark.asyncio
async def test_send_stops_typing_after_send() -> None:
# Active typing indicators should be cancelled/cleared after a successful send.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
start = asyncio.Event()
release = asyncio.Event()
async def slow_typing() -> None:
start.set()
await release.wait()
typing_channel = _FakeChannel(channel_id=123)
typing_channel.typing_enter_hook = slow_typing
await channel._start_typing(typing_channel)
await start.wait()
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}
# Progress messages should keep typing active until a final (non-progress) send.
start = asyncio.Event()
release = asyncio.Event()
async def slow_typing_progress() -> None:
start.set()
await release.wait()
typing_channel = _FakeChannel(channel_id=123)
typing_channel.typing_enter_hook = slow_typing_progress
await channel._start_typing(typing_channel)
await start.wait()
await channel.send(
OutboundMessage(
channel="discord",
chat_id="123",
content="progress",
metadata={"_progress": True},
)
)
assert "123" in channel._typing_tasks
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final"))
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
channel._running = True
entered = asyncio.Event()
release = asyncio.Event()
class _TypingCtx:
async def __aenter__(self):
entered.set()
async def __aexit__(self, exc_type, exc, tb):
return False
class _NoTriggerChannel:
def __init__(self, channel_id: int = 123) -> None:
self.id = channel_id
def typing(self):
async def _waiter():
await release.wait()
# Hold the loop so task remains active until explicitly stopped.
class _Ctx(_TypingCtx):
async def __aenter__(self):
await super().__aenter__()
await _waiter()
return _Ctx()
typing_channel = _NoTriggerChannel(channel_id=123)
await channel._start_typing(typing_channel) # type: ignore[arg-type]
await entered.wait()
assert "123" in channel._typing_tasks
await channel._stop_typing("123")
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}