mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
feat(discord): Use discord.py for stable discord channel (#2486)
Co-authored-by: Pares Mathur <paresh.2047@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
b94d4c0509
commit
0506e6c1c1
@ -1,25 +1,37 @@
|
||||
"""Discord channel implementation using Discord Gateway websocket."""
|
||||
"""Discord channel implementation using discord.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import Field
|
||||
import websockets
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.command.builtin import build_help_text
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
from nanobot.utils.helpers import split_message
|
||||
from nanobot.utils.helpers import safe_filename, split_message
|
||||
|
||||
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
|
||||
if TYPE_CHECKING:
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.abc import Messageable
|
||||
|
||||
if DISCORD_AVAILABLE:
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.abc import Messageable
|
||||
|
||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||
TYPING_INTERVAL_S = 8
|
||||
|
||||
|
||||
class DiscordConfig(Base):
|
||||
@ -28,13 +40,202 @@ class DiscordConfig(Base):
|
||||
enabled: bool = False
|
||||
token: str = ""
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
|
||||
intents: int = 37377
|
||||
group_policy: Literal["mention", "open"] = "mention"
|
||||
|
||||
|
||||
if DISCORD_AVAILABLE:
|
||||
|
||||
class DiscordBotClient(discord.Client):
|
||||
"""discord.py client that forwards events to the channel."""
|
||||
|
||||
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
|
||||
super().__init__(intents=intents)
|
||||
self._channel = channel
|
||||
self.tree = app_commands.CommandTree(self)
|
||||
self._register_app_commands()
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
self._channel._bot_user_id = str(self.user.id) if self.user else None
|
||||
logger.info("Discord bot connected as user {}", self._channel._bot_user_id)
|
||||
try:
|
||||
synced = await self.tree.sync()
|
||||
logger.info("Discord app commands synced: {}", len(synced))
|
||||
except Exception as e:
|
||||
logger.warning("Discord app command sync failed: {}", e)
|
||||
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
await self._channel._handle_discord_message(message)
|
||||
|
||||
async def _reply_ephemeral(self, interaction: discord.Interaction, text: str) -> bool:
|
||||
"""Send an ephemeral interaction response and report success."""
|
||||
try:
|
||||
await interaction.response.send_message(text, ephemeral=True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Discord interaction response failed: {}", e)
|
||||
return False
|
||||
|
||||
async def _forward_slash_command(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
command_text: str,
|
||||
) -> None:
|
||||
sender_id = str(interaction.user.id)
|
||||
channel_id = interaction.channel_id
|
||||
|
||||
if channel_id is None:
|
||||
logger.warning("Discord slash command missing channel_id: {}", command_text)
|
||||
return
|
||||
|
||||
if not self._channel.is_allowed(sender_id):
|
||||
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
|
||||
return
|
||||
|
||||
await self._reply_ephemeral(interaction, f"Processing {command_text}...")
|
||||
|
||||
await self._channel._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=str(channel_id),
|
||||
content=command_text,
|
||||
metadata={
|
||||
"interaction_id": str(interaction.id),
|
||||
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
|
||||
"is_slash_command": True,
|
||||
},
|
||||
)
|
||||
|
||||
def _register_app_commands(self) -> None:
|
||||
commands = (
|
||||
("new", "Start a new conversation", "/new"),
|
||||
("stop", "Stop the current task", "/stop"),
|
||||
("restart", "Restart the bot", "/restart"),
|
||||
("status", "Show bot status", "/status"),
|
||||
)
|
||||
|
||||
for name, description, command_text in commands:
|
||||
@self.tree.command(name=name, description=description)
|
||||
async def command_handler(
|
||||
interaction: discord.Interaction,
|
||||
_command_text: str = command_text,
|
||||
) -> None:
|
||||
await self._forward_slash_command(interaction, _command_text)
|
||||
|
||||
@self.tree.command(name="help", description="Show available commands")
|
||||
async def help_command(interaction: discord.Interaction) -> None:
|
||||
sender_id = str(interaction.user.id)
|
||||
if not self._channel.is_allowed(sender_id):
|
||||
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
|
||||
return
|
||||
await self._reply_ephemeral(interaction, build_help_text())
|
||||
|
||||
@self.tree.error
|
||||
async def on_app_command_error(
|
||||
interaction: discord.Interaction,
|
||||
error: app_commands.AppCommandError,
|
||||
) -> None:
|
||||
command_name = interaction.command.qualified_name if interaction.command else "?"
|
||||
logger.warning(
|
||||
"Discord app command failed user={} channel={} cmd={} error={}",
|
||||
interaction.user.id,
|
||||
interaction.channel_id,
|
||||
command_name,
|
||||
error,
|
||||
)
|
||||
|
||||
async def send_outbound(self, msg: OutboundMessage) -> None:
|
||||
"""Send a nanobot outbound message using Discord transport rules."""
|
||||
channel_id = int(msg.chat_id)
|
||||
|
||||
channel = self.get_channel(channel_id)
|
||||
if channel is None:
|
||||
try:
|
||||
channel = await self.fetch_channel(channel_id)
|
||||
except Exception as e:
|
||||
logger.warning("Discord channel {} unavailable: {}", msg.chat_id, e)
|
||||
return
|
||||
|
||||
reference, mention_settings = self._build_reply_context(channel, msg.reply_to)
|
||||
sent_media = False
|
||||
failed_media: list[str] = []
|
||||
|
||||
for index, media_path in enumerate(msg.media or []):
|
||||
if await self._send_file(
|
||||
channel,
|
||||
media_path,
|
||||
reference=reference if index == 0 else None,
|
||||
mention_settings=mention_settings,
|
||||
):
|
||||
sent_media = True
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
|
||||
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
|
||||
kwargs: dict[str, Any] = {"content": chunk}
|
||||
if index == 0 and reference is not None and not sent_media:
|
||||
kwargs["reference"] = reference
|
||||
kwargs["allowed_mentions"] = mention_settings
|
||||
await channel.send(**kwargs)
|
||||
|
||||
async def _send_file(
|
||||
self,
|
||||
channel: Messageable,
|
||||
file_path: str,
|
||||
*,
|
||||
reference: discord.PartialMessage | None,
|
||||
mention_settings: discord.AllowedMentions,
|
||||
) -> bool:
|
||||
"""Send a file attachment via discord.py."""
|
||||
path = Path(file_path)
|
||||
if not path.is_file():
|
||||
logger.warning("Discord file not found, skipping: {}", file_path)
|
||||
return False
|
||||
|
||||
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
||||
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
|
||||
return False
|
||||
|
||||
try:
|
||||
kwargs: dict[str, Any] = {"file": discord.File(path)}
|
||||
if reference is not None:
|
||||
kwargs["reference"] = reference
|
||||
kwargs["allowed_mentions"] = mention_settings
|
||||
await channel.send(**kwargs)
|
||||
logger.info("Discord file sent: {}", path.name)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error sending Discord file {}: {}", path.name, e)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _build_chunks(content: str, failed_media: list[str], sent_media: bool) -> list[str]:
|
||||
"""Build outbound text chunks, including attachment-failure fallback text."""
|
||||
chunks = split_message(content, MAX_MESSAGE_LEN)
|
||||
if chunks or not failed_media or sent_media:
|
||||
return chunks
|
||||
fallback = "\n".join(f"[attachment: {name} - send failed]" for name in failed_media)
|
||||
return split_message(fallback, MAX_MESSAGE_LEN)
|
||||
|
||||
@staticmethod
|
||||
def _build_reply_context(
|
||||
channel: Messageable,
|
||||
reply_to: str | None,
|
||||
) -> tuple[discord.PartialMessage | None, discord.AllowedMentions]:
|
||||
"""Build reply context for outbound messages."""
|
||||
mention_settings = discord.AllowedMentions(replied_user=False)
|
||||
if not reply_to:
|
||||
return None, mention_settings
|
||||
try:
|
||||
message_id = int(reply_to)
|
||||
except ValueError:
|
||||
logger.warning("Invalid Discord reply target: {}", reply_to)
|
||||
return None, mention_settings
|
||||
|
||||
return channel.get_partial_message(message_id), mention_settings
|
||||
|
||||
|
||||
class DiscordChannel(BaseChannel):
|
||||
"""Discord channel using Gateway websocket."""
|
||||
"""Discord channel using discord.py."""
|
||||
|
||||
name = "discord"
|
||||
display_name = "Discord"
|
||||
@ -43,353 +244,229 @@ class DiscordChannel(BaseChannel):
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
return DiscordConfig().model_dump(by_alias=True)
|
||||
|
||||
@staticmethod
|
||||
def _channel_key(channel_or_id: Any) -> str:
|
||||
"""Normalize channel-like objects and ids to a stable string key."""
|
||||
channel_id = getattr(channel_or_id, "id", channel_or_id)
|
||||
return str(channel_id)
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
if isinstance(config, dict):
|
||||
config = DiscordConfig.model_validate(config)
|
||||
super().__init__(config, bus)
|
||||
self.config: DiscordConfig = config
|
||||
self._ws: websockets.WebSocketClientProtocol | None = None
|
||||
self._seq: int | None = None
|
||||
self._heartbeat_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._client: DiscordBotClient | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._bot_user_id: str | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Discord gateway connection."""
|
||||
"""Start the Discord client."""
|
||||
if not DISCORD_AVAILABLE:
|
||||
logger.error("discord.py not installed. Run: pip install nanobot-ai[discord]")
|
||||
return
|
||||
|
||||
if not self.config.token:
|
||||
logger.error("Discord bot token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient(timeout=30.0)
|
||||
try:
|
||||
intents = discord.Intents.none()
|
||||
intents.value = self.config.intents
|
||||
self._client = DiscordBotClient(self, intents=intents)
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Discord client: {}", e)
|
||||
self._client = None
|
||||
self._running = False
|
||||
return
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
logger.info("Connecting to Discord gateway...")
|
||||
async with websockets.connect(self.config.gateway_url) as ws:
|
||||
self._ws = ws
|
||||
await self._gateway_loop()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Discord gateway error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting to Discord gateway in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
self._running = True
|
||||
logger.info("Starting Discord client via discord.py...")
|
||||
|
||||
try:
|
||||
await self._client.start(self.config.token)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Discord client startup failed: {}", e)
|
||||
finally:
|
||||
self._running = False
|
||||
await self._reset_runtime_state(close_client=True)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Discord channel."""
|
||||
self._running = False
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
self._heartbeat_task = None
|
||||
for task in self._typing_tasks.values():
|
||||
task.cancel()
|
||||
self._typing_tasks.clear()
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
await self._reset_runtime_state(close_client=True)
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Discord REST API, including file attachments."""
|
||||
if not self._http:
|
||||
logger.warning("Discord HTTP client not initialized")
|
||||
"""Send a message through Discord using discord.py."""
|
||||
client = self._client
|
||||
if client is None or not client.is_ready():
|
||||
logger.warning("Discord client not ready; dropping outbound message")
|
||||
return
|
||||
|
||||
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
is_progress = bool((msg.metadata or {}).get("_progress"))
|
||||
try:
|
||||
await client.send_outbound(msg)
|
||||
except Exception as e:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
finally:
|
||||
if not is_progress:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
|
||||
async def _handle_discord_message(self, message: discord.Message) -> None:
|
||||
"""Handle incoming Discord messages from discord.py."""
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
sender_id = str(message.author.id)
|
||||
channel_id = self._channel_key(message.channel)
|
||||
content = message.content or ""
|
||||
|
||||
if not self._should_accept_inbound(message, sender_id, content):
|
||||
return
|
||||
|
||||
media_paths, attachment_markers = await self._download_attachments(message.attachments)
|
||||
full_content = self._compose_inbound_content(content, attachment_markers)
|
||||
metadata = self._build_inbound_metadata(message)
|
||||
|
||||
await self._start_typing(message.channel)
|
||||
|
||||
try:
|
||||
sent_media = False
|
||||
failed_media: list[str] = []
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=channel_id,
|
||||
content=full_content,
|
||||
media=media_paths,
|
||||
metadata=metadata,
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing(channel_id)
|
||||
raise
|
||||
|
||||
# Send file attachments first
|
||||
for media_path in msg.media or []:
|
||||
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
|
||||
sent_media = True
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
async def _on_message(self, message: discord.Message) -> None:
|
||||
"""Backward-compatible alias for legacy tests/callers."""
|
||||
await self._handle_discord_message(message)
|
||||
|
||||
# Send text content
|
||||
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
|
||||
if not chunks and failed_media and not sent_media:
|
||||
chunks = split_message(
|
||||
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
|
||||
MAX_MESSAGE_LEN,
|
||||
)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
payload: dict[str, Any] = {"content": chunk}
|
||||
|
||||
# Let the first successful attachment carry the reply if present.
|
||||
if i == 0 and msg.reply_to and not sent_media:
|
||||
payload["message_reference"] = {"message_id": msg.reply_to}
|
||||
payload["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
if not await self._send_payload(url, headers, payload):
|
||||
break # Abort remaining chunks on failure
|
||||
finally:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
|
||||
async def _send_payload(
|
||||
self, url: str, headers: dict[str, str], payload: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await self._http.post(url, headers=headers, json=payload)
|
||||
if response.status_code == 429:
|
||||
data = response.json()
|
||||
retry_after = float(data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _send_file(
|
||||
def _should_accept_inbound(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
file_path: str,
|
||||
reply_to: str | None = None,
|
||||
message: discord.Message,
|
||||
sender_id: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Send a file attachment via Discord REST API using multipart/form-data."""
|
||||
path = Path(file_path)
|
||||
if not path.is_file():
|
||||
logger.warning("Discord file not found, skipping: {}", file_path)
|
||||
return False
|
||||
|
||||
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
||||
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
|
||||
return False
|
||||
|
||||
payload_json: dict[str, Any] = {}
|
||||
if reply_to:
|
||||
payload_json["message_reference"] = {"message_id": reply_to}
|
||||
payload_json["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
files = {"files[0]": (path.name, f, "application/octet-stream")}
|
||||
data: dict[str, Any] = {}
|
||||
if payload_json:
|
||||
data["payload_json"] = json.dumps(payload_json)
|
||||
response = await self._http.post(
|
||||
url, headers=headers, files=files, data=data
|
||||
)
|
||||
if response.status_code == 429:
|
||||
resp_data = response.json()
|
||||
retry_after = float(resp_data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
logger.info("Discord file sent: {}", path.name)
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord file {}: {}", path.name, e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _gateway_loop(self) -> None:
|
||||
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
async for raw in self._ws:
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
|
||||
continue
|
||||
|
||||
op = data.get("op")
|
||||
event_type = data.get("t")
|
||||
seq = data.get("s")
|
||||
payload = data.get("d")
|
||||
|
||||
if seq is not None:
|
||||
self._seq = seq
|
||||
|
||||
if op == 10:
|
||||
# HELLO: start heartbeat and identify
|
||||
interval_ms = payload.get("heartbeat_interval", 45000)
|
||||
await self._start_heartbeat(interval_ms / 1000)
|
||||
await self._identify()
|
||||
elif op == 0 and event_type == "READY":
|
||||
logger.info("Discord gateway READY")
|
||||
# Capture bot user ID for mention detection
|
||||
user_data = payload.get("user") or {}
|
||||
self._bot_user_id = user_data.get("id")
|
||||
logger.info("Discord bot connected as user {}", self._bot_user_id)
|
||||
elif op == 0 and event_type == "MESSAGE_CREATE":
|
||||
await self._handle_message_create(payload)
|
||||
elif op == 7:
|
||||
# RECONNECT: exit loop to reconnect
|
||||
logger.info("Discord gateway requested reconnect")
|
||||
break
|
||||
elif op == 9:
|
||||
# INVALID_SESSION: reconnect
|
||||
logger.warning("Discord gateway invalid session")
|
||||
break
|
||||
|
||||
async def _identify(self) -> None:
|
||||
"""Send IDENTIFY payload."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
identify = {
|
||||
"op": 2,
|
||||
"d": {
|
||||
"token": self.config.token,
|
||||
"intents": self.config.intents,
|
||||
"properties": {
|
||||
"os": "nanobot",
|
||||
"browser": "nanobot",
|
||||
"device": "nanobot",
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send(json.dumps(identify))
|
||||
|
||||
async def _start_heartbeat(self, interval_s: float) -> None:
|
||||
"""Start or restart the heartbeat loop."""
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
|
||||
async def heartbeat_loop() -> None:
|
||||
while self._running and self._ws:
|
||||
payload = {"op": 1, "d": self._seq}
|
||||
try:
|
||||
await self._ws.send(json.dumps(payload))
|
||||
except Exception as e:
|
||||
logger.warning("Discord heartbeat failed: {}", e)
|
||||
break
|
||||
await asyncio.sleep(interval_s)
|
||||
|
||||
self._heartbeat_task = asyncio.create_task(heartbeat_loop())
|
||||
|
||||
async def _handle_message_create(self, payload: dict[str, Any]) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
author = payload.get("author") or {}
|
||||
if author.get("bot"):
|
||||
return
|
||||
|
||||
sender_id = str(author.get("id", ""))
|
||||
channel_id = str(payload.get("channel_id", ""))
|
||||
content = payload.get("content") or ""
|
||||
guild_id = payload.get("guild_id")
|
||||
|
||||
if not sender_id or not channel_id:
|
||||
return
|
||||
|
||||
"""Check if inbound Discord message should be processed."""
|
||||
if not self.is_allowed(sender_id):
|
||||
return
|
||||
return False
|
||||
if message.guild is not None and not self._should_respond_in_group(message, content):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Check group channel policy (DMs always respond if is_allowed passes)
|
||||
if guild_id is not None:
|
||||
if not self._should_respond_in_group(payload, content):
|
||||
return
|
||||
|
||||
content_parts = [content] if content else []
|
||||
async def _download_attachments(
|
||||
self,
|
||||
attachments: list[discord.Attachment],
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Download supported attachments and return paths + display markers."""
|
||||
media_paths: list[str] = []
|
||||
markers: list[str] = []
|
||||
media_dir = get_media_dir("discord")
|
||||
|
||||
for attachment in payload.get("attachments") or []:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename") or "attachment"
|
||||
size = attachment.get("size") or 0
|
||||
if not url or not self._http:
|
||||
continue
|
||||
if size and size > MAX_ATTACHMENT_BYTES:
|
||||
content_parts.append(f"[attachment: {filename} - too large]")
|
||||
for attachment in attachments:
|
||||
filename = attachment.filename or "attachment"
|
||||
if attachment.size and attachment.size > MAX_ATTACHMENT_BYTES:
|
||||
markers.append(f"[attachment: {filename} - too large]")
|
||||
continue
|
||||
try:
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}"
|
||||
resp = await self._http.get(url)
|
||||
resp.raise_for_status()
|
||||
file_path.write_bytes(resp.content)
|
||||
safe_name = safe_filename(filename)
|
||||
file_path = media_dir / f"{attachment.id}_{safe_name}"
|
||||
await attachment.save(file_path)
|
||||
media_paths.append(str(file_path))
|
||||
content_parts.append(f"[attachment: {file_path}]")
|
||||
markers.append(f"[attachment: {file_path.name}]")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download Discord attachment: {}", e)
|
||||
content_parts.append(f"[attachment: {filename} - download failed]")
|
||||
markers.append(f"[attachment: {filename} - download failed]")
|
||||
|
||||
reply_to = (payload.get("referenced_message") or {}).get("id")
|
||||
return media_paths, markers
|
||||
|
||||
await self._start_typing(channel_id)
|
||||
@staticmethod
|
||||
def _compose_inbound_content(content: str, attachment_markers: list[str]) -> str:
|
||||
"""Combine message text with attachment markers."""
|
||||
content_parts = [content] if content else []
|
||||
content_parts.extend(attachment_markers)
|
||||
return "\n".join(part for part in content_parts if part) or "[empty message]"
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=channel_id,
|
||||
content="\n".join(p for p in content_parts if p) or "[empty message]",
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": str(payload.get("id", "")),
|
||||
"guild_id": guild_id,
|
||||
"reply_to": reply_to,
|
||||
},
|
||||
)
|
||||
@staticmethod
|
||||
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
||||
"""Build metadata for inbound Discord messages."""
|
||||
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
|
||||
return {
|
||||
"message_id": str(message.id),
|
||||
"guild_id": str(message.guild.id) if message.guild else None,
|
||||
"reply_to": reply_to,
|
||||
}
|
||||
|
||||
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
|
||||
"""Check if bot should respond in a group channel based on policy."""
|
||||
def _should_respond_in_group(self, message: discord.Message, content: str) -> bool:
|
||||
"""Check if the bot should respond in a guild channel based on policy."""
|
||||
if self.config.group_policy == "open":
|
||||
return True
|
||||
|
||||
if self.config.group_policy == "mention":
|
||||
# Check if bot was mentioned in the message
|
||||
if self._bot_user_id:
|
||||
# Check mentions array
|
||||
mentions = payload.get("mentions") or []
|
||||
for mention in mentions:
|
||||
if str(mention.get("id")) == self._bot_user_id:
|
||||
return True
|
||||
# Also check content for mention format <@USER_ID>
|
||||
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
|
||||
return True
|
||||
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
|
||||
bot_user_id = self._bot_user_id
|
||||
if bot_user_id is None:
|
||||
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
|
||||
return False
|
||||
|
||||
if any(str(user.id) == bot_user_id for user in message.mentions):
|
||||
return True
|
||||
if f"<@{bot_user_id}>" in content or f"<@!{bot_user_id}>" in content:
|
||||
return True
|
||||
|
||||
logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _start_typing(self, channel_id: str) -> None:
|
||||
async def _start_typing(self, channel: Messageable) -> None:
|
||||
"""Start periodic typing indicator for a channel."""
|
||||
channel_id = self._channel_key(channel)
|
||||
await self._stop_typing(channel_id)
|
||||
|
||||
async def typing_loop() -> None:
|
||||
url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
while self._running:
|
||||
try:
|
||||
await self._http.post(url, headers=headers)
|
||||
async with channel.typing():
|
||||
await asyncio.sleep(TYPING_INTERVAL_S)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
|
||||
return
|
||||
await asyncio.sleep(8)
|
||||
|
||||
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
|
||||
|
||||
async def _stop_typing(self, channel_id: str) -> None:
|
||||
"""Stop typing indicator for a channel."""
|
||||
task = self._typing_tasks.pop(channel_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
task = self._typing_tasks.pop(self._channel_key(channel_id), None)
|
||||
if task is None:
|
||||
return
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _cancel_all_typing(self) -> None:
|
||||
"""Stop all typing tasks."""
|
||||
channel_ids = list(self._typing_tasks)
|
||||
for channel_id in channel_ids:
|
||||
await self._stop_typing(channel_id)
|
||||
|
||||
async def _reset_runtime_state(self, close_client: bool) -> None:
|
||||
"""Reset client and typing state."""
|
||||
await self._cancel_all_typing()
|
||||
if close_client and self._client is not None and not self._client.is_closed():
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception as e:
|
||||
logger.warning("Discord client close failed: {}", e)
|
||||
self._client = None
|
||||
self._bot_user_id = None
|
||||
|
||||
@ -84,6 +84,16 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
|
||||
|
||||
async def cmd_help(ctx: CommandContext) -> OutboundMessage:
|
||||
"""Return available slash commands."""
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content=build_help_text(),
|
||||
metadata={"render_as": "text"},
|
||||
)
|
||||
|
||||
|
||||
def build_help_text() -> str:
|
||||
"""Build canonical help text shared across channels."""
|
||||
lines = [
|
||||
"🐈 nanobot commands:",
|
||||
"/new — Start a new conversation",
|
||||
@ -92,12 +102,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage:
|
||||
"/status — Show bot status",
|
||||
"/help — Show available commands",
|
||||
]
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content="\n".join(lines),
|
||||
metadata={"render_as": "text"},
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def register_builtin_commands(router: CommandRouter) -> None:
|
||||
|
||||
@ -67,6 +67,9 @@ matrix = [
|
||||
"mistune>=3.0.0,<4.0.0",
|
||||
"nh3>=0.2.17,<1.0.0",
|
||||
]
|
||||
discord = [
|
||||
"discord.py>=2.5.2,<3.0.0",
|
||||
]
|
||||
langsmith = [
|
||||
"langsmith>=0.1.0",
|
||||
]
|
||||
|
||||
676
tests/channels/test_discord_channel.py
Normal file
676
tests/channels/test_discord_channel.py
Normal file
@ -0,0 +1,676 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import discord
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig
|
||||
from nanobot.command.builtin import build_help_text
|
||||
|
||||
|
||||
# Minimal Discord client test double used to control startup/readiness behavior.
|
||||
class _FakeDiscordClient:
|
||||
instances: list["_FakeDiscordClient"] = []
|
||||
start_error: Exception | None = None
|
||||
|
||||
def __init__(self, owner, *, intents) -> None:
|
||||
self.owner = owner
|
||||
self.intents = intents
|
||||
self.closed = False
|
||||
self.ready = True
|
||||
self.channels: dict[int, object] = {}
|
||||
self.user = SimpleNamespace(id=999)
|
||||
self.__class__.instances.append(self)
|
||||
|
||||
async def start(self, token: str) -> None:
|
||||
self.token = token
|
||||
if self.__class__.start_error is not None:
|
||||
raise self.__class__.start_error
|
||||
|
||||
async def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self.closed
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return self.ready
|
||||
|
||||
def get_channel(self, channel_id: int):
|
||||
return self.channels.get(channel_id)
|
||||
|
||||
async def send_outbound(self, msg: OutboundMessage) -> None:
|
||||
channel = self.get_channel(int(msg.chat_id))
|
||||
if channel is None:
|
||||
return
|
||||
await channel.send(content=msg.content)
|
||||
|
||||
|
||||
class _FakeAttachment:
|
||||
# Attachment double that can simulate successful or failing save() calls.
|
||||
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
|
||||
self.id = attachment_id
|
||||
self.filename = filename
|
||||
self.size = size
|
||||
self._fail = fail
|
||||
|
||||
async def save(self, path: str | Path) -> None:
|
||||
if self._fail:
|
||||
raise RuntimeError("save failed")
|
||||
Path(path).write_bytes(b"attachment")
|
||||
|
||||
|
||||
class _FakePartialMessage:
|
||||
# Lightweight stand-in for Discord partial message references used in replies.
|
||||
def __init__(self, message_id: int) -> None:
|
||||
self.id = message_id
|
||||
|
||||
|
||||
class _FakeChannel:
|
||||
# Channel double that records outbound payloads and typing activity.
|
||||
def __init__(self, channel_id: int = 123) -> None:
|
||||
self.id = channel_id
|
||||
self.sent_payloads: list[dict] = []
|
||||
self.trigger_typing_calls = 0
|
||||
self.typing_enter_hook = None
|
||||
|
||||
async def send(self, **kwargs) -> None:
|
||||
payload = dict(kwargs)
|
||||
if "file" in payload:
|
||||
payload["file_name"] = payload["file"].filename
|
||||
del payload["file"]
|
||||
self.sent_payloads.append(payload)
|
||||
|
||||
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
|
||||
return _FakePartialMessage(message_id)
|
||||
|
||||
def typing(self):
|
||||
channel = self
|
||||
|
||||
class _TypingContext:
|
||||
async def __aenter__(self):
|
||||
channel.trigger_typing_calls += 1
|
||||
if channel.typing_enter_hook is not None:
|
||||
await channel.typing_enter_hook()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
return _TypingContext()
|
||||
|
||||
|
||||
class _FakeInteractionResponse:
|
||||
def __init__(self) -> None:
|
||||
self.messages: list[dict] = []
|
||||
self._done = False
|
||||
|
||||
async def send_message(self, content: str, *, ephemeral: bool = False) -> None:
|
||||
self.messages.append({"content": content, "ephemeral": ephemeral})
|
||||
self._done = True
|
||||
|
||||
def is_done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
|
||||
def _make_interaction(
|
||||
*,
|
||||
user_id: int = 123,
|
||||
channel_id: int | None = 456,
|
||||
guild_id: int | None = None,
|
||||
interaction_id: int = 999,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
user=SimpleNamespace(id=user_id),
|
||||
channel_id=channel_id,
|
||||
guild_id=guild_id,
|
||||
id=interaction_id,
|
||||
command=SimpleNamespace(qualified_name="new"),
|
||||
response=_FakeInteractionResponse(),
|
||||
)
|
||||
|
||||
|
||||
def _make_message(
|
||||
*,
|
||||
author_id: int = 123,
|
||||
author_bot: bool = False,
|
||||
channel_id: int = 456,
|
||||
message_id: int = 789,
|
||||
content: str = "hello",
|
||||
guild_id: int | None = None,
|
||||
mentions: list[object] | None = None,
|
||||
attachments: list[object] | None = None,
|
||||
reply_to: int | None = None,
|
||||
):
|
||||
# Factory for incoming Discord message objects with optional guild/reply/attachments.
|
||||
guild = SimpleNamespace(id=guild_id) if guild_id is not None else None
|
||||
reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None
|
||||
return SimpleNamespace(
|
||||
author=SimpleNamespace(id=author_id, bot=author_bot),
|
||||
channel=_FakeChannel(channel_id),
|
||||
content=content,
|
||||
guild=guild,
|
||||
mentions=mentions or [],
|
||||
attachments=attachments or [],
|
||||
reference=reference,
|
||||
id=message_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_returns_when_token_missing() -> None:
|
||||
# If no token is configured, startup should no-op and leave channel stopped.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert channel._client is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None:
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert channel._client is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_handles_client_construction_failure(monkeypatch) -> None:
|
||||
# Construction errors from the Discord client should be swallowed and keep state clean.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
def _boom(owner, *, intents):
|
||||
raise RuntimeError("bad client")
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert channel._client is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_handles_client_start_failure(monkeypatch) -> None:
|
||||
# If client.start fails, the partially created client should be closed and detached.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
_FakeDiscordClient.instances.clear()
|
||||
_FakeDiscordClient.start_error = RuntimeError("connect failed")
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert channel._client is None
|
||||
assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents
|
||||
assert _FakeDiscordClient.instances[0].closed is True
|
||||
|
||||
_FakeDiscordClient.start_error = None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_is_safe_after_partial_start(monkeypatch) -> None:
|
||||
# stop() should close/discard the client even when startup was only partially completed.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
await channel.stop()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert client.closed is True
|
||||
assert channel._client is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_ignores_bot_messages() -> None:
|
||||
# Incoming bot-authored messages must be ignored to prevent feedback loops.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(_make_message(author_bot=True))
|
||||
|
||||
assert handled == []
|
||||
|
||||
# If inbound handling raises, typing should be stopped for that channel.
|
||||
async def fail_handle(**kwargs) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
channel._handle_message = fail_handle # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await channel._on_message(_make_message(author_id=123, channel_id=456))
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_accepts_allowlisted_dm() -> None:
|
||||
# Allowed direct messages should be forwarded with normalized metadata.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789))
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["chat_id"] == "456"
|
||||
assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_ignores_unmentioned_guild_message() -> None:
|
||||
# With mention-only group policy, guild messages without a bot mention are dropped.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._bot_user_id = "999"
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(_make_message(guild_id=1, content="hello everyone"))
|
||||
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_accepts_mentioned_guild_message() -> None:
|
||||
# Mentioned guild messages should be accepted and preserve reply threading metadata.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._bot_user_id = "999"
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
guild_id=1,
|
||||
content="<@999> hello",
|
||||
mentions=[SimpleNamespace(id=999)],
|
||||
reply_to=321,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["metadata"]["reply_to"] == "321"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None:
|
||||
# Attachment downloads should be saved and referenced in forwarded content/media.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
|
||||
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
attachments=[_FakeAttachment(12, "photo.png")],
|
||||
content="see file",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["media"] == [str(tmp_path / "12_photo.png")]
|
||||
assert "[attachment:" in handled[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None:
|
||||
# Failed attachment downloads should emit a readable placeholder and no media path.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
|
||||
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
attachments=[_FakeAttachment(12, "photo.png", fail=True)],
|
||||
content="",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["media"] == []
|
||||
assert handled[0]["content"] == "[attachment: photo.png - download failed]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_warns_when_client_not_ready() -> None:
|
||||
# Sending without a running/ready client should be a safe no-op.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_when_channel_not_cached() -> None:
|
||||
# Outbound sends should be skipped when the destination channel is not resolvable.
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
fetch_calls: list[int] = []
|
||||
|
||||
async def fetch_channel(channel_id: int):
|
||||
fetch_calls.append(channel_id)
|
||||
raise RuntimeError("not found")
|
||||
|
||||
client.fetch_channel = fetch_channel # type: ignore[method-assign]
|
||||
|
||||
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
assert client.get_channel(123) is None
|
||||
assert fetch_calls == [123]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_fetches_channel_when_not_cached() -> None:
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
target = _FakeChannel(channel_id=123)
|
||||
|
||||
async def fetch_channel(channel_id: int):
|
||||
return target if channel_id == 123 else None
|
||||
|
||||
client.fetch_channel = fetch_channel # type: ignore[method-assign]
|
||||
|
||||
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
assert target.sent_payloads == [{"content": "hello"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321)
|
||||
|
||||
new_cmd = client.tree.get_command("new")
|
||||
assert new_cmd is not None
|
||||
await new_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": "Processing /new...", "ephemeral": True}
|
||||
]
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "/new"
|
||||
assert handled[0]["sender_id"] == "123"
|
||||
assert handled[0]["chat_id"] == "456"
|
||||
assert handled[0]["metadata"]["interaction_id"] == "321"
|
||||
assert handled[0]["metadata"]["is_slash_command"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_new_is_blocked_for_disallowed_user() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction(user_id=123, channel_id=456)
|
||||
|
||||
new_cmd = client.tree.get_command("new")
|
||||
assert new_cmd is not None
|
||||
await new_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": "You are not allowed to use this bot.", "ephemeral": True}
|
||||
]
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("slash_name", ["stop", "restart", "status"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction()
|
||||
interaction.command.qualified_name = slash_name
|
||||
|
||||
cmd = client.tree.get_command(slash_name)
|
||||
assert cmd is not None
|
||||
await cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": f"Processing /{slash_name}...", "ephemeral": True}
|
||||
]
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == f"/{slash_name}"
|
||||
assert handled[0]["metadata"]["is_slash_command"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_help_returns_ephemeral_help_text() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction()
|
||||
interaction.command.qualified_name = "help"
|
||||
|
||||
help_cmd = client.tree.get_command("help")
|
||||
assert help_cmd is not None
|
||||
await help_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": build_help_text(), "ephemeral": True}
|
||||
]
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None:
|
||||
# Outbound payloads should upload files, attach reply references, and chunk long text.
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
target = _FakeChannel(channel_id=123)
|
||||
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
|
||||
|
||||
file_path = tmp_path / "demo.txt"
|
||||
file_path.write_text("hi")
|
||||
|
||||
await client.send_outbound(
|
||||
OutboundMessage(
|
||||
channel="discord",
|
||||
chat_id="123",
|
||||
content="a" * 2100,
|
||||
reply_to="55",
|
||||
media=[str(file_path)],
|
||||
)
|
||||
)
|
||||
|
||||
assert len(target.sent_payloads) == 3
|
||||
assert target.sent_payloads[0]["file_name"] == "demo.txt"
|
||||
assert target.sent_payloads[0]["reference"].id == 55
|
||||
assert target.sent_payloads[1]["content"] == "a" * 2000
|
||||
assert target.sent_payloads[2]["content"] == "a" * 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None:
|
||||
# If all attachment sends fail and no text exists, emit a failure placeholder message.
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
target = _FakeChannel(channel_id=123)
|
||||
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
|
||||
|
||||
missing_file = tmp_path / "missing.txt"
|
||||
|
||||
await client.send_outbound(
|
||||
OutboundMessage(
|
||||
channel="discord",
|
||||
chat_id="123",
|
||||
content="",
|
||||
media=[str(missing_file)],
|
||||
)
|
||||
)
|
||||
|
||||
assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_stops_typing_after_send() -> None:
|
||||
# Active typing indicators should be cancelled/cleared after a successful send.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
start = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def slow_typing() -> None:
|
||||
start.set()
|
||||
await release.wait()
|
||||
|
||||
typing_channel = _FakeChannel(channel_id=123)
|
||||
typing_channel.typing_enter_hook = slow_typing
|
||||
|
||||
await channel._start_typing(typing_channel)
|
||||
await start.wait()
|
||||
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
# Progress messages should keep typing active until a final (non-progress) send.
|
||||
start = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def slow_typing_progress() -> None:
|
||||
start.set()
|
||||
await release.wait()
|
||||
|
||||
typing_channel = _FakeChannel(channel_id=123)
|
||||
typing_channel.typing_enter_hook = slow_typing_progress
|
||||
|
||||
await channel._start_typing(typing_channel)
|
||||
await start.wait()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="discord",
|
||||
chat_id="123",
|
||||
content="progress",
|
||||
metadata={"_progress": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert "123" in channel._typing_tasks
|
||||
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final"))
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
channel._running = True
|
||||
|
||||
entered = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
class _TypingCtx:
|
||||
async def __aenter__(self):
|
||||
entered.set()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class _NoTriggerChannel:
|
||||
def __init__(self, channel_id: int = 123) -> None:
|
||||
self.id = channel_id
|
||||
|
||||
def typing(self):
|
||||
async def _waiter():
|
||||
await release.wait()
|
||||
# Hold the loop so task remains active until explicitly stopped.
|
||||
class _Ctx(_TypingCtx):
|
||||
async def __aenter__(self):
|
||||
await super().__aenter__()
|
||||
await _waiter()
|
||||
return _Ctx()
|
||||
|
||||
typing_channel = _NoTriggerChannel(channel_id=123)
|
||||
await channel._start_typing(typing_channel) # type: ignore[arg-type]
|
||||
await entered.wait()
|
||||
|
||||
assert "123" in channel._typing_tasks
|
||||
|
||||
await channel._stop_typing("123")
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
Loading…
x
Reference in New Issue
Block a user