mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-26 03:22:38 +00:00
Merge PR #3397: fix(discord): full thread support with session isolation and allowlist enforcement
fix(discord): full thread support with session isolation and allowlist enforcement
This commit is contained in:
commit
b8932bc041
@ -147,7 +147,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
|
|||||||
> - `"open"` — Respond to all messages
|
> - `"open"` — Respond to all messages
|
||||||
> DMs always respond when the sender is in `allowFrom`.
|
> DMs always respond when the sender is in `allowFrom`.
|
||||||
> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session.
|
> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session.
|
||||||
> `allowChannels` restricts the bot to specific Discord channel IDs. Empty (default) means respond in every channel the bot can see. Example: `["1234567890", "0987654321"]`. The filter applies after `allowFrom`, so both must pass.
|
> `allowChannels` restricts the bot to specific Discord channel IDs. Empty (default) means respond in every channel the bot can see. Example: `["1234567890", "0987654321"]`. The filter applies after `allowFrom`, so both must pass. Discord threads under an allowed parent channel are also allowed; for Forum channels, allowing the parent Forum channel allows all threads/posts in that forum.
|
||||||
> `streaming` defaults to `true`. Disable it only if you explicitly want non-streaming replies.
|
> `streaming` defaults to `true`. Disable it only if you explicitly want non-streaming replies.
|
||||||
|
|
||||||
**5. Invite the bot**
|
**5. Invite the bot**
|
||||||
|
|||||||
@ -434,6 +434,11 @@ class AgentLoop:
|
|||||||
|
|
||||||
return strip_think(text) or None
|
return strip_think(text) or None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _runtime_chat_id(msg: InboundMessage) -> str:
|
||||||
|
"""Return the chat id shown in runtime metadata for the model."""
|
||||||
|
return str(msg.metadata.get("context_chat_id") or msg.chat_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _tool_hint(tool_calls: list) -> str:
|
def _tool_hint(tool_calls: list) -> str:
|
||||||
"""Format tool calls as concise hints with smart abbreviation."""
|
"""Format tool calls as concise hints with smart abbreviation."""
|
||||||
@ -555,7 +560,7 @@ class AgentLoop:
|
|||||||
user_content = self.context._build_user_content(content, media)
|
user_content = self.context._build_user_content(content, media)
|
||||||
runtime_ctx = self.context._build_runtime_context(
|
runtime_ctx = self.context._build_runtime_context(
|
||||||
pending_msg.channel,
|
pending_msg.channel,
|
||||||
pending_msg.chat_id,
|
self._runtime_chat_id(pending_msg),
|
||||||
self.context.timezone,
|
self.context.timezone,
|
||||||
)
|
)
|
||||||
if isinstance(user_content, str):
|
if isinstance(user_content, str):
|
||||||
@ -986,7 +991,7 @@ class AgentLoop:
|
|||||||
session_summary=pending,
|
session_summary=pending,
|
||||||
media=msg.media if msg.media else None,
|
media=msg.media if msg.media else None,
|
||||||
channel=msg.channel,
|
channel=msg.channel,
|
||||||
chat_id=msg.chat_id,
|
chat_id=self._runtime_chat_id(msg),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _bus_progress(
|
async def _bus_progress(
|
||||||
|
|||||||
@ -95,6 +95,15 @@ if DISCORD_AVAILABLE:
|
|||||||
async def on_message(self, message: discord.Message) -> None:
|
async def on_message(self, message: discord.Message) -> None:
|
||||||
await self._channel._handle_discord_message(message)
|
await self._channel._handle_discord_message(message)
|
||||||
|
|
||||||
|
async def on_thread_delete(self, thread: discord.Thread) -> None:
|
||||||
|
self._channel._forget_channel(thread)
|
||||||
|
|
||||||
|
async def on_thread_update(self, before: discord.Thread, after: discord.Thread) -> None:
|
||||||
|
if getattr(after, "archived", False):
|
||||||
|
self._channel._forget_channel(after)
|
||||||
|
else:
|
||||||
|
self._channel._remember_channel(after)
|
||||||
|
|
||||||
async def _reply_ephemeral(self, interaction: discord.Interaction, text: str) -> bool:
|
async def _reply_ephemeral(self, interaction: discord.Interaction, text: str) -> bool:
|
||||||
"""Send an ephemeral interaction response and report success."""
|
"""Send an ephemeral interaction response and report success."""
|
||||||
try:
|
try:
|
||||||
@ -104,6 +113,37 @@ if DISCORD_AVAILABLE:
|
|||||||
logger.warning("Discord interaction response failed: {}", e)
|
logger.warning("Discord interaction response failed: {}", e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def _resolve_interaction_channel(
|
||||||
|
self,
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
) -> Any | None:
|
||||||
|
channel_id = interaction.channel_id
|
||||||
|
if channel_id is None:
|
||||||
|
return None
|
||||||
|
channel = getattr(interaction, "channel", None) or self.get_channel(channel_id)
|
||||||
|
if channel is None:
|
||||||
|
try:
|
||||||
|
channel = await self.fetch_channel(channel_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Discord interaction channel {} unavailable: {}", channel_id, e)
|
||||||
|
return None
|
||||||
|
self._channel._remember_channel(channel)
|
||||||
|
return channel
|
||||||
|
|
||||||
|
async def _interaction_channel_allowed(
|
||||||
|
self,
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
channel: Any | None,
|
||||||
|
) -> bool:
|
||||||
|
allow_channels = self._channel.config.allow_channels
|
||||||
|
if not allow_channels:
|
||||||
|
return True
|
||||||
|
if channel is None:
|
||||||
|
channel_id = interaction.channel_id
|
||||||
|
return channel_id is not None and str(channel_id) in allow_channels
|
||||||
|
channel_ids = self._channel._channel_allow_keys(channel)
|
||||||
|
return not channel_ids.isdisjoint(allow_channels)
|
||||||
|
|
||||||
async def _forward_slash_command(
|
async def _forward_slash_command(
|
||||||
self,
|
self,
|
||||||
interaction: discord.Interaction,
|
interaction: discord.Interaction,
|
||||||
@ -120,17 +160,33 @@ if DISCORD_AVAILABLE:
|
|||||||
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
|
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
channel = await self._resolve_interaction_channel(interaction)
|
||||||
|
if not await self._interaction_channel_allowed(interaction, channel):
|
||||||
|
await self._reply_ephemeral(interaction, "This channel is not allowed for this bot.")
|
||||||
|
return
|
||||||
|
|
||||||
await self._reply_ephemeral(interaction, f"Processing {command_text}...")
|
await self._reply_ephemeral(interaction, f"Processing {command_text}...")
|
||||||
|
|
||||||
|
metadata: dict[str, Any] = {
|
||||||
|
"interaction_id": str(interaction.id),
|
||||||
|
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
|
||||||
|
"is_slash_command": True,
|
||||||
|
}
|
||||||
|
session_key = None
|
||||||
|
if channel is not None:
|
||||||
|
parent_channel_id = self._channel._channel_parent_key(channel)
|
||||||
|
if parent_channel_id is not None:
|
||||||
|
metadata["parent_channel_id"] = parent_channel_id
|
||||||
|
metadata["context_chat_id"] = parent_channel_id
|
||||||
|
metadata["thread_id"] = str(channel_id)
|
||||||
|
session_key = f"{self._channel.name}:{parent_channel_id}:thread:{channel_id}"
|
||||||
|
|
||||||
await self._channel._handle_message(
|
await self._channel._handle_message(
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
chat_id=str(channel_id),
|
chat_id=str(channel_id),
|
||||||
content=command_text,
|
content=command_text,
|
||||||
metadata={
|
metadata=metadata,
|
||||||
"interaction_id": str(interaction.id),
|
session_key=session_key,
|
||||||
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
|
|
||||||
"is_slash_command": True,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _register_app_commands(self) -> None:
|
def _register_app_commands(self) -> None:
|
||||||
@ -156,6 +212,10 @@ if DISCORD_AVAILABLE:
|
|||||||
if not self._channel.is_allowed(sender_id):
|
if not self._channel.is_allowed(sender_id):
|
||||||
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
|
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
|
||||||
return
|
return
|
||||||
|
channel = await self._resolve_interaction_channel(interaction)
|
||||||
|
if not await self._interaction_channel_allowed(interaction, channel):
|
||||||
|
await self._reply_ephemeral(interaction, "This channel is not allowed for this bot.")
|
||||||
|
return
|
||||||
await self._reply_ephemeral(interaction, build_help_text())
|
await self._reply_ephemeral(interaction, build_help_text())
|
||||||
|
|
||||||
@self.tree.error
|
@self.tree.error
|
||||||
@ -176,7 +236,7 @@ if DISCORD_AVAILABLE:
|
|||||||
"""Send a nanobot outbound message using Discord transport rules."""
|
"""Send a nanobot outbound message using Discord transport rules."""
|
||||||
channel_id = int(msg.chat_id)
|
channel_id = int(msg.chat_id)
|
||||||
|
|
||||||
channel = self.get_channel(channel_id)
|
channel = self._channel._known_channels.get(msg.chat_id) or self.get_channel(channel_id)
|
||||||
if channel is None:
|
if channel is None:
|
||||||
try:
|
try:
|
||||||
channel = await self.fetch_channel(channel_id)
|
channel = await self.fetch_channel(channel_id)
|
||||||
@ -282,6 +342,25 @@ class DiscordChannel(BaseChannel):
|
|||||||
channel_id = getattr(channel_or_id, "id", channel_or_id)
|
channel_id = getattr(channel_or_id, "id", channel_or_id)
|
||||||
return str(channel_id)
|
return str(channel_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _channel_allow_keys(cls, channel: Any) -> set[str]:
|
||||||
|
"""Return channel IDs that can satisfy allow_channels for this channel."""
|
||||||
|
keys = {cls._channel_key(channel)}
|
||||||
|
if parent_key := cls._channel_parent_key(channel):
|
||||||
|
keys.add(parent_key)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _channel_parent_key(cls, channel: Any) -> str | None:
|
||||||
|
"""Return the parent channel key for a Discord thread-like channel."""
|
||||||
|
parent_id = getattr(channel, "parent_id", None)
|
||||||
|
if parent_id is not None:
|
||||||
|
return cls._channel_key(parent_id)
|
||||||
|
parent = getattr(channel, "parent", None)
|
||||||
|
if parent is not None:
|
||||||
|
return cls._channel_key(parent)
|
||||||
|
return None
|
||||||
|
|
||||||
def __init__(self, config: Any, bus: MessageBus):
|
def __init__(self, config: Any, bus: MessageBus):
|
||||||
if isinstance(config, dict):
|
if isinstance(config, dict):
|
||||||
config = DiscordConfig.model_validate(config)
|
config = DiscordConfig.model_validate(config)
|
||||||
@ -293,6 +372,13 @@ class DiscordChannel(BaseChannel):
|
|||||||
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
|
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
|
||||||
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
|
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
self._stream_bufs: dict[str, _StreamBuf] = {}
|
self._stream_bufs: dict[str, _StreamBuf] = {}
|
||||||
|
self._known_channels: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def _remember_channel(self, channel: Any) -> None:
|
||||||
|
self._known_channels[self._channel_key(channel)] = channel
|
||||||
|
|
||||||
|
def _forget_channel(self, channel_or_id: Any) -> None:
|
||||||
|
self._known_channels.pop(self._channel_key(channel_or_id), None)
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Discord client."""
|
"""Start the Discord client."""
|
||||||
@ -443,9 +529,12 @@ class DiscordChannel(BaseChannel):
|
|||||||
"""
|
"""
|
||||||
if self._bot_user_id is not None and str(message.author.id) == self._bot_user_id:
|
if self._bot_user_id is not None and str(message.author.id) == self._bot_user_id:
|
||||||
return
|
return
|
||||||
|
if self._is_system_message(message):
|
||||||
|
return
|
||||||
|
|
||||||
sender_id = str(message.author.id)
|
sender_id = str(message.author.id)
|
||||||
channel_id = self._channel_key(message.channel)
|
channel_id = self._channel_key(message.channel)
|
||||||
|
self._remember_channel(message.channel)
|
||||||
content = message.content or ""
|
content = message.content or ""
|
||||||
|
|
||||||
if not self._should_accept_inbound(message, sender_id, content):
|
if not self._should_accept_inbound(message, sender_id, content):
|
||||||
@ -454,6 +543,13 @@ class DiscordChannel(BaseChannel):
|
|||||||
media_paths, attachment_markers = await self._download_attachments(message.attachments)
|
media_paths, attachment_markers = await self._download_attachments(message.attachments)
|
||||||
full_content = self._compose_inbound_content(content, attachment_markers)
|
full_content = self._compose_inbound_content(content, attachment_markers)
|
||||||
metadata = self._build_inbound_metadata(message)
|
metadata = self._build_inbound_metadata(message)
|
||||||
|
parent_channel_id = self._channel_parent_key(message.channel)
|
||||||
|
session_key = None
|
||||||
|
if parent_channel_id is not None:
|
||||||
|
metadata["parent_channel_id"] = parent_channel_id
|
||||||
|
metadata["context_chat_id"] = parent_channel_id
|
||||||
|
metadata["thread_id"] = channel_id
|
||||||
|
session_key = f"{self.name}:{parent_channel_id}:thread:{channel_id}"
|
||||||
|
|
||||||
await self._start_typing(message.channel)
|
await self._start_typing(message.channel)
|
||||||
|
|
||||||
@ -481,6 +577,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
content=full_content,
|
content=full_content,
|
||||||
media=media_paths,
|
media=media_paths,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
session_key=session_key,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
await self._clear_reactions(channel_id)
|
await self._clear_reactions(channel_id)
|
||||||
@ -496,6 +593,9 @@ class DiscordChannel(BaseChannel):
|
|||||||
client = self._client
|
client = self._client
|
||||||
if client is None or not client.is_ready():
|
if client is None or not client.is_ready():
|
||||||
return None
|
return None
|
||||||
|
channel = self._known_channels.get(chat_id)
|
||||||
|
if channel is not None:
|
||||||
|
return channel
|
||||||
channel_id = int(chat_id)
|
channel_id = int(chat_id)
|
||||||
channel = client.get_channel(channel_id)
|
channel = client.get_channel(channel_id)
|
||||||
if channel is not None:
|
if channel is not None:
|
||||||
@ -544,8 +644,8 @@ class DiscordChannel(BaseChannel):
|
|||||||
# Channel-based filtering: only respond in allowed channels
|
# Channel-based filtering: only respond in allowed channels
|
||||||
allow_channels = self.config.allow_channels
|
allow_channels = self.config.allow_channels
|
||||||
if allow_channels:
|
if allow_channels:
|
||||||
channel_id = self._channel_key(message.channel)
|
channel_ids = self._channel_allow_keys(message.channel)
|
||||||
if channel_id not in allow_channels:
|
if channel_ids.isdisjoint(allow_channels):
|
||||||
return False
|
return False
|
||||||
if message.guild is not None and not self._should_respond_in_group(message, content):
|
if message.guild is not None and not self._should_respond_in_group(message, content):
|
||||||
return False
|
return False
|
||||||
@ -585,6 +685,12 @@ class DiscordChannel(BaseChannel):
|
|||||||
content_parts.extend(attachment_markers)
|
content_parts.extend(attachment_markers)
|
||||||
return "\n".join(part for part in content_parts if part) or "[empty message]"
|
return "\n".join(part for part in content_parts if part) or "[empty message]"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_system_message(message: discord.Message) -> bool:
|
||||||
|
"""Return True for Discord system messages that carry no user prompt."""
|
||||||
|
message_type = getattr(message, "type", discord.MessageType.default)
|
||||||
|
return message_type not in {discord.MessageType.default, discord.MessageType.reply}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
||||||
"""Build metadata for inbound Discord messages."""
|
"""Build metadata for inbound Discord messages."""
|
||||||
@ -606,6 +712,8 @@ class DiscordChannel(BaseChannel):
|
|||||||
|
|
||||||
if self.config.group_policy == "mention":
|
if self.config.group_policy == "mention":
|
||||||
bot_user_id = self._bot_user_id
|
bot_user_id = self._bot_user_id
|
||||||
|
if bot_user_id is None and self._client and self._client.user:
|
||||||
|
bot_user_id = str(self._client.user.id)
|
||||||
if bot_user_id is None:
|
if bot_user_id is None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Discord message in {} ignored (bot identity unavailable)", message.channel.id
|
"Discord message in {} ignored (bot identity unavailable)", message.channel.id
|
||||||
@ -614,14 +722,30 @@ class DiscordChannel(BaseChannel):
|
|||||||
|
|
||||||
if any(str(user.id) == bot_user_id for user in message.mentions):
|
if any(str(user.id) == bot_user_id for user in message.mentions):
|
||||||
return True
|
return True
|
||||||
|
if bot_user_id in {str(user_id) for user_id in getattr(message, "raw_mentions", [])}:
|
||||||
|
return True
|
||||||
if f"<@{bot_user_id}>" in content or f"<@!{bot_user_id}>" in content:
|
if f"<@{bot_user_id}>" in content or f"<@!{bot_user_id}>" in content:
|
||||||
return True
|
return True
|
||||||
|
if self._references_bot_message(message, bot_user_id):
|
||||||
|
return True
|
||||||
|
|
||||||
logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id)
|
logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _references_bot_message(message: discord.Message, bot_user_id: str) -> bool:
|
||||||
|
"""Return True when a Discord reply targets a message authored by this bot."""
|
||||||
|
reference = getattr(message, "reference", None)
|
||||||
|
if reference is None:
|
||||||
|
return False
|
||||||
|
referenced_message = getattr(reference, "resolved", None) or getattr(
|
||||||
|
reference, "cached_message", None
|
||||||
|
)
|
||||||
|
author = getattr(referenced_message, "author", None)
|
||||||
|
return str(getattr(author, "id", "")) == bot_user_id
|
||||||
|
|
||||||
async def _start_typing(self, channel: Messageable) -> None:
|
async def _start_typing(self, channel: Messageable) -> None:
|
||||||
"""Start periodic typing indicator for a channel."""
|
"""Start periodic typing indicator for a channel."""
|
||||||
channel_id = self._channel_key(channel)
|
channel_id = self._channel_key(channel)
|
||||||
@ -678,6 +802,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
"""Reset client and typing state."""
|
"""Reset client and typing state."""
|
||||||
await self._cancel_all_typing()
|
await self._cancel_all_typing()
|
||||||
self._stream_bufs.clear()
|
self._stream_bufs.clear()
|
||||||
|
self._known_channels.clear()
|
||||||
if close_client and self._client is not None and not self._client.is_closed():
|
if close_client and self._client is not None and not self._client.is_closed():
|
||||||
try:
|
try:
|
||||||
await self._client.close()
|
await self._client.close()
|
||||||
|
|||||||
@ -348,6 +348,61 @@ async def test_process_message_does_not_duplicate_early_persisted_user_message(t
|
|||||||
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_message_uses_context_chat_id_for_runtime_prompt(tmp_path: Path) -> None:
|
||||||
|
loop = _make_full_loop(tmp_path)
|
||||||
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||||
|
loop.context.build_messages = MagicMock( # type: ignore[method-assign]
|
||||||
|
return_value=[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "runtime + hello"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
loop._run_agent_loop = AsyncMock(return_value=( # type: ignore[method-assign]
|
||||||
|
"done",
|
||||||
|
[],
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "runtime + hello"},
|
||||||
|
{"role": "assistant", "content": "done"},
|
||||||
|
],
|
||||||
|
"stop",
|
||||||
|
False,
|
||||||
|
))
|
||||||
|
|
||||||
|
result = await loop._process_message(
|
||||||
|
InboundMessage(
|
||||||
|
channel="discord",
|
||||||
|
sender_id="u1",
|
||||||
|
chat_id="thread-777",
|
||||||
|
content="hello",
|
||||||
|
metadata={"context_chat_id": "parent-456"},
|
||||||
|
session_key_override="discord:parent-456:thread:thread-777",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.chat_id == "thread-777"
|
||||||
|
assert loop.context.build_messages.call_args.kwargs["chat_id"] == "parent-456"
|
||||||
|
assert loop._run_agent_loop.call_args.kwargs["chat_id"] == "thread-777"
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_tool_context_uses_effective_key_for_spawn_tool(tmp_path: Path) -> None:
|
||||||
|
loop = _make_full_loop(tmp_path)
|
||||||
|
spawn_tool = loop.tools.get("spawn")
|
||||||
|
assert spawn_tool is not None
|
||||||
|
|
||||||
|
loop._set_tool_context(
|
||||||
|
"discord",
|
||||||
|
"thread-777",
|
||||||
|
session_key="discord:parent-456:thread:thread-777",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert spawn_tool._origin_channel.get() == "discord" # type: ignore[attr-defined]
|
||||||
|
assert spawn_tool._origin_chat_id.get() == "thread-777" # type: ignore[attr-defined]
|
||||||
|
assert spawn_tool._session_key.get() == "discord:parent-456:thread:thread-777" # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(tmp_path: Path) -> None:
|
async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(tmp_path: Path) -> None:
|
||||||
loop = _make_full_loop(tmp_path)
|
loop = _make_full_loop(tmp_path)
|
||||||
|
|||||||
@ -96,8 +96,15 @@ class _FakeSentMessage:
|
|||||||
|
|
||||||
class _FakeChannel:
|
class _FakeChannel:
|
||||||
# Channel double that records outbound payloads and typing activity.
|
# Channel double that records outbound payloads and typing activity.
|
||||||
def __init__(self, channel_id: int = 123) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
channel_id: int = 123,
|
||||||
|
parent_id: int | None = None,
|
||||||
|
parent: object | None = None,
|
||||||
|
) -> None:
|
||||||
self.id = channel_id
|
self.id = channel_id
|
||||||
|
self.parent_id = parent_id
|
||||||
|
self.parent = parent
|
||||||
self.sent_payloads: list[dict] = []
|
self.sent_payloads: list[dict] = []
|
||||||
self.sent_messages: list[_FakeSentMessage] = []
|
self.sent_messages: list[_FakeSentMessage] = []
|
||||||
self.trigger_typing_calls = 0
|
self.trigger_typing_calls = 0
|
||||||
@ -148,12 +155,14 @@ def _make_interaction(
|
|||||||
*,
|
*,
|
||||||
user_id: int = 123,
|
user_id: int = 123,
|
||||||
channel_id: int | None = 456,
|
channel_id: int | None = 456,
|
||||||
|
channel=None,
|
||||||
guild_id: int | None = None,
|
guild_id: int | None = None,
|
||||||
interaction_id: int = 999,
|
interaction_id: int = 999,
|
||||||
):
|
):
|
||||||
return SimpleNamespace(
|
return SimpleNamespace(
|
||||||
user=SimpleNamespace(id=user_id),
|
user=SimpleNamespace(id=user_id),
|
||||||
channel_id=channel_id,
|
channel_id=channel_id,
|
||||||
|
channel=channel,
|
||||||
guild_id=guild_id,
|
guild_id=guild_id,
|
||||||
id=interaction_id,
|
id=interaction_id,
|
||||||
command=SimpleNamespace(qualified_name="new"),
|
command=SimpleNamespace(qualified_name="new"),
|
||||||
@ -166,25 +175,39 @@ def _make_message(
|
|||||||
author_id: int = 123,
|
author_id: int = 123,
|
||||||
author_bot: bool = False,
|
author_bot: bool = False,
|
||||||
channel_id: int = 456,
|
channel_id: int = 456,
|
||||||
|
parent_channel_id: int | None = None,
|
||||||
message_id: int = 789,
|
message_id: int = 789,
|
||||||
content: str = "hello",
|
content: str = "hello",
|
||||||
guild_id: int | None = None,
|
guild_id: int | None = None,
|
||||||
mentions: list[object] | None = None,
|
mentions: list[object] | None = None,
|
||||||
attachments: list[object] | None = None,
|
attachments: list[object] | None = None,
|
||||||
reply_to: int | None = None,
|
reply_to: int | None = None,
|
||||||
|
reply_author_id: int | None = None,
|
||||||
|
message_type=None,
|
||||||
):
|
):
|
||||||
# Factory for incoming Discord message objects with optional guild/reply/attachments.
|
# Factory for incoming Discord message objects with optional guild/reply/attachments.
|
||||||
guild = SimpleNamespace(id=guild_id) if guild_id is not None else None
|
guild = SimpleNamespace(id=guild_id) if guild_id is not None else None
|
||||||
reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None
|
referenced_message = (
|
||||||
|
SimpleNamespace(author=SimpleNamespace(id=reply_author_id))
|
||||||
|
if reply_author_id is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
reference = (
|
||||||
|
SimpleNamespace(message_id=reply_to, resolved=referenced_message)
|
||||||
|
if reply_to is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
return SimpleNamespace(
|
return SimpleNamespace(
|
||||||
author=SimpleNamespace(id=author_id, bot=author_bot),
|
author=SimpleNamespace(id=author_id, bot=author_bot),
|
||||||
channel=_FakeChannel(channel_id),
|
channel=_FakeChannel(channel_id, parent_channel_id),
|
||||||
content=content,
|
content=content,
|
||||||
guild=guild,
|
guild=guild,
|
||||||
mentions=mentions or [],
|
mentions=mentions or [],
|
||||||
|
raw_mentions=[],
|
||||||
attachments=attachments or [],
|
attachments=attachments or [],
|
||||||
reference=reference,
|
reference=reference,
|
||||||
id=message_id,
|
id=message_id,
|
||||||
|
type=message_type or discord.MessageType.default,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -357,6 +380,147 @@ async def test_on_message_accepts_when_channel_in_allow_channels() -> None:
|
|||||||
assert handled[0]["chat_id"] == "456"
|
assert handled[0]["chat_id"] == "456"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_accepts_thread_when_parent_channel_in_allow_channels() -> None:
|
||||||
|
# Discord threads have independent channel IDs, but inherit allowlist access
|
||||||
|
# from their parent channel.
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(
|
||||||
|
enabled=True,
|
||||||
|
allow_from=["*"],
|
||||||
|
allow_channels=["456"],
|
||||||
|
group_policy="mention",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._bot_user_id = "999"
|
||||||
|
handled: list[dict] = []
|
||||||
|
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
|
||||||
|
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await channel._on_message(
|
||||||
|
_make_message(
|
||||||
|
channel_id=777,
|
||||||
|
parent_channel_id=456,
|
||||||
|
guild_id=1,
|
||||||
|
mentions=[SimpleNamespace(id=999)],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(handled) == 1
|
||||||
|
assert handled[0]["chat_id"] == "777"
|
||||||
|
assert handled[0]["metadata"]["context_chat_id"] == "456"
|
||||||
|
assert handled[0]["metadata"]["thread_id"] == "777"
|
||||||
|
assert handled[0]["session_key"] == "discord:456:thread:777"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_accepts_thread_reply_to_bot_under_allowed_parent() -> None:
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(
|
||||||
|
enabled=True,
|
||||||
|
allow_from=["*"],
|
||||||
|
allow_channels=["456"],
|
||||||
|
group_policy="mention",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._bot_user_id = "999"
|
||||||
|
handled: list[dict] = []
|
||||||
|
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
|
||||||
|
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await channel._on_message(
|
||||||
|
_make_message(
|
||||||
|
channel_id=777,
|
||||||
|
parent_channel_id=456,
|
||||||
|
guild_id=1,
|
||||||
|
content="follow up",
|
||||||
|
reply_to=111,
|
||||||
|
reply_author_id=999,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(handled) == 1
|
||||||
|
assert handled[0]["chat_id"] == "777"
|
||||||
|
assert handled[0]["metadata"]["reply_to"] == "111"
|
||||||
|
assert handled[0]["metadata"]["context_chat_id"] == "456"
|
||||||
|
assert handled[0]["session_key"] == "discord:456:thread:777"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_ignores_thread_lifecycle_messages() -> None:
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(
|
||||||
|
enabled=True,
|
||||||
|
allow_from=["*"],
|
||||||
|
allow_channels=["456"],
|
||||||
|
group_policy="open",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
handled: list[dict] = []
|
||||||
|
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
|
||||||
|
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await channel._on_message(
|
||||||
|
_make_message(
|
||||||
|
channel_id=777,
|
||||||
|
parent_channel_id=456,
|
||||||
|
guild_id=1,
|
||||||
|
content="",
|
||||||
|
message_type=discord.MessageType.thread_created,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await channel._on_message(
|
||||||
|
_make_message(
|
||||||
|
channel_id=777,
|
||||||
|
parent_channel_id=456,
|
||||||
|
guild_id=1,
|
||||||
|
content="",
|
||||||
|
message_type=discord.MessageType.thread_starter_message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await channel._on_message(
|
||||||
|
_make_message(
|
||||||
|
channel_id=777,
|
||||||
|
parent_channel_id=456,
|
||||||
|
guild_id=1,
|
||||||
|
content="",
|
||||||
|
message_type=discord.MessageType.pins_add,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert handled == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_drops_thread_when_neither_thread_nor_parent_allowed() -> None:
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(enabled=True, allow_from=["*"], allow_channels=["999"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
handled: list[dict] = []
|
||||||
|
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
|
||||||
|
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await channel._on_message(_make_message(channel_id=777, parent_channel_id=456))
|
||||||
|
|
||||||
|
assert handled == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_message_drops_when_channel_not_in_allow_channels() -> None:
|
async def test_on_message_drops_when_channel_not_in_allow_channels() -> None:
|
||||||
# When allow_channels is set and incoming channel is not listed, drop silently.
|
# When allow_channels is set and incoming channel is not listed, drop silently.
|
||||||
@ -517,6 +681,24 @@ async def test_send_fetches_channel_when_not_cached() -> None:
|
|||||||
assert target.sent_payloads == [{"content": "hello"}]
|
assert target.sent_payloads == [{"content": "hello"}]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_uses_seen_thread_channel_when_client_cannot_resolve_it() -> None:
|
||||||
|
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||||
|
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||||
|
target = _FakeChannel(channel_id=777, parent_id=456)
|
||||||
|
owner._known_channels["777"] = target
|
||||||
|
client.get_channel = lambda channel_id: None # type: ignore[method-assign]
|
||||||
|
|
||||||
|
async def fetch_channel(channel_id: int):
|
||||||
|
raise RuntimeError("not found")
|
||||||
|
|
||||||
|
client.fetch_channel = fetch_channel # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await client.send_outbound(OutboundMessage(channel="discord", chat_id="777", content="hello"))
|
||||||
|
|
||||||
|
assert target.sent_payloads == [{"content": "hello"}]
|
||||||
|
|
||||||
|
|
||||||
def test_supports_streaming_enabled_by_default() -> None:
|
def test_supports_streaming_enabled_by_default() -> None:
|
||||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||||
|
|
||||||
@ -596,6 +778,71 @@ async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
|||||||
assert handled[0]["metadata"]["is_slash_command"] is True
|
assert handled[0]["metadata"]["is_slash_command"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_slash_new_accepts_thread_when_parent_channel_in_allow_channels() -> None:
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(enabled=True, allow_from=["*"], allow_channels=["456"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
handled: list[dict] = []
|
||||||
|
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
|
||||||
|
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||||
|
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||||
|
thread = _FakeChannel(channel_id=777, parent_id=456)
|
||||||
|
interaction = _make_interaction(
|
||||||
|
user_id=123,
|
||||||
|
channel_id=777,
|
||||||
|
channel=thread,
|
||||||
|
guild_id=1,
|
||||||
|
interaction_id=321,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cmd = client.tree.get_command("new")
|
||||||
|
assert new_cmd is not None
|
||||||
|
await new_cmd.callback(interaction)
|
||||||
|
|
||||||
|
assert interaction.response.messages == [{"content": "Processing /new...", "ephemeral": True}]
|
||||||
|
assert len(handled) == 1
|
||||||
|
assert handled[0]["chat_id"] == "777"
|
||||||
|
assert handled[0]["metadata"]["context_chat_id"] == "456"
|
||||||
|
assert handled[0]["metadata"]["thread_id"] == "777"
|
||||||
|
assert handled[0]["session_key"] == "discord:456:thread:777"
|
||||||
|
assert channel._known_channels["777"] is thread
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_slash_new_blocks_channel_not_in_allow_channels() -> None:
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(enabled=True, allow_from=["*"], allow_channels=["999"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
handled: list[dict] = []
|
||||||
|
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
|
||||||
|
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||||
|
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||||
|
interaction = _make_interaction(
|
||||||
|
user_id=123,
|
||||||
|
channel_id=777,
|
||||||
|
channel=_FakeChannel(channel_id=777, parent_id=456),
|
||||||
|
guild_id=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cmd = client.tree.get_command("new")
|
||||||
|
assert new_cmd is not None
|
||||||
|
await new_cmd.callback(interaction)
|
||||||
|
|
||||||
|
assert interaction.response.messages == [
|
||||||
|
{"content": "This channel is not allowed for this bot.", "ephemeral": True}
|
||||||
|
]
|
||||||
|
assert handled == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_slash_new_is_blocked_for_disallowed_user() -> None:
|
async def test_slash_new_is_blocked_for_disallowed_user() -> None:
|
||||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus())
|
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus())
|
||||||
@ -665,6 +912,45 @@ async def test_slash_help_returns_ephemeral_help_text() -> None:
|
|||||||
assert handled == []
|
assert handled == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_slash_help_respects_allow_channels() -> None:
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(enabled=True, allow_from=["*"], allow_channels=["999"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||||
|
interaction = _make_interaction(
|
||||||
|
channel_id=777,
|
||||||
|
channel=_FakeChannel(channel_id=777, parent_id=456),
|
||||||
|
guild_id=1,
|
||||||
|
)
|
||||||
|
interaction.command.qualified_name = "help"
|
||||||
|
|
||||||
|
help_cmd = client.tree.get_command("help")
|
||||||
|
assert help_cmd is not None
|
||||||
|
await help_cmd.callback(interaction)
|
||||||
|
|
||||||
|
assert interaction.response.messages == [
|
||||||
|
{"content": "This channel is not allowed for this bot.", "ephemeral": True}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_thread_delete_and_archive_remove_known_channel() -> None:
|
||||||
|
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||||
|
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||||
|
thread = _FakeChannel(channel_id=777, parent_id=456)
|
||||||
|
|
||||||
|
channel._remember_channel(thread)
|
||||||
|
await client.on_thread_delete(thread)
|
||||||
|
assert "777" not in channel._known_channels
|
||||||
|
|
||||||
|
channel._remember_channel(thread)
|
||||||
|
archived_thread = SimpleNamespace(id=777, parent_id=456, archived=True)
|
||||||
|
await client.on_thread_update(thread, archived_thread)
|
||||||
|
assert "777" not in channel._known_channels
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None:
|
async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None:
|
||||||
# Outbound payloads should upload files, attach reply references, and chunk long text.
|
# Outbound payloads should upload files, attach reply references, and chunk long text.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user