mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-25 19:12:43 +00:00
Merge PR #3397: fix(discord): full thread support with session isolation and allowlist enforcement
fix(discord): full thread support with session isolation and allowlist enforcement
This commit is contained in:
commit
b8932bc041
@ -147,7 +147,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
|
||||
> - `"open"` — Respond to all messages
|
||||
> DMs always respond when the sender is in `allowFrom`.
|
||||
> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session.
|
||||
> `allowChannels` restricts the bot to specific Discord channel IDs. Empty (default) means respond in every channel the bot can see. Example: `["1234567890", "0987654321"]`. The filter applies after `allowFrom`, so both must pass.
|
||||
> `allowChannels` restricts the bot to specific Discord channel IDs. Empty (default) means respond in every channel the bot can see. Example: `["1234567890", "0987654321"]`. The filter applies after `allowFrom`, so both must pass. Discord threads under an allowed parent channel are also allowed; for Forum channels, allowing the parent Forum channel allows all threads/posts in that forum.
|
||||
> `streaming` defaults to `true`. Disable it only if you explicitly want non-streaming replies.
|
||||
|
||||
**5. Invite the bot**
|
||||
|
||||
@ -434,6 +434,11 @@ class AgentLoop:
|
||||
|
||||
return strip_think(text) or None
|
||||
|
||||
@staticmethod
|
||||
def _runtime_chat_id(msg: InboundMessage) -> str:
|
||||
"""Return the chat id shown in runtime metadata for the model."""
|
||||
return str(msg.metadata.get("context_chat_id") or msg.chat_id)
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hints with smart abbreviation."""
|
||||
@ -555,7 +560,7 @@ class AgentLoop:
|
||||
user_content = self.context._build_user_content(content, media)
|
||||
runtime_ctx = self.context._build_runtime_context(
|
||||
pending_msg.channel,
|
||||
pending_msg.chat_id,
|
||||
self._runtime_chat_id(pending_msg),
|
||||
self.context.timezone,
|
||||
)
|
||||
if isinstance(user_content, str):
|
||||
@ -986,7 +991,7 @@ class AgentLoop:
|
||||
session_summary=pending,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
chat_id=self._runtime_chat_id(msg),
|
||||
)
|
||||
|
||||
async def _bus_progress(
|
||||
|
||||
@ -95,6 +95,15 @@ if DISCORD_AVAILABLE:
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
await self._channel._handle_discord_message(message)
|
||||
|
||||
async def on_thread_delete(self, thread: discord.Thread) -> None:
|
||||
self._channel._forget_channel(thread)
|
||||
|
||||
async def on_thread_update(self, before: discord.Thread, after: discord.Thread) -> None:
|
||||
if getattr(after, "archived", False):
|
||||
self._channel._forget_channel(after)
|
||||
else:
|
||||
self._channel._remember_channel(after)
|
||||
|
||||
async def _reply_ephemeral(self, interaction: discord.Interaction, text: str) -> bool:
|
||||
"""Send an ephemeral interaction response and report success."""
|
||||
try:
|
||||
@ -104,6 +113,37 @@ if DISCORD_AVAILABLE:
|
||||
logger.warning("Discord interaction response failed: {}", e)
|
||||
return False
|
||||
|
||||
async def _resolve_interaction_channel(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
) -> Any | None:
|
||||
channel_id = interaction.channel_id
|
||||
if channel_id is None:
|
||||
return None
|
||||
channel = getattr(interaction, "channel", None) or self.get_channel(channel_id)
|
||||
if channel is None:
|
||||
try:
|
||||
channel = await self.fetch_channel(channel_id)
|
||||
except Exception as e:
|
||||
logger.warning("Discord interaction channel {} unavailable: {}", channel_id, e)
|
||||
return None
|
||||
self._channel._remember_channel(channel)
|
||||
return channel
|
||||
|
||||
async def _interaction_channel_allowed(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
channel: Any | None,
|
||||
) -> bool:
|
||||
allow_channels = self._channel.config.allow_channels
|
||||
if not allow_channels:
|
||||
return True
|
||||
if channel is None:
|
||||
channel_id = interaction.channel_id
|
||||
return channel_id is not None and str(channel_id) in allow_channels
|
||||
channel_ids = self._channel._channel_allow_keys(channel)
|
||||
return not channel_ids.isdisjoint(allow_channels)
|
||||
|
||||
async def _forward_slash_command(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
@ -120,17 +160,33 @@ if DISCORD_AVAILABLE:
|
||||
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
|
||||
return
|
||||
|
||||
channel = await self._resolve_interaction_channel(interaction)
|
||||
if not await self._interaction_channel_allowed(interaction, channel):
|
||||
await self._reply_ephemeral(interaction, "This channel is not allowed for this bot.")
|
||||
return
|
||||
|
||||
await self._reply_ephemeral(interaction, f"Processing {command_text}...")
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"interaction_id": str(interaction.id),
|
||||
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
|
||||
"is_slash_command": True,
|
||||
}
|
||||
session_key = None
|
||||
if channel is not None:
|
||||
parent_channel_id = self._channel._channel_parent_key(channel)
|
||||
if parent_channel_id is not None:
|
||||
metadata["parent_channel_id"] = parent_channel_id
|
||||
metadata["context_chat_id"] = parent_channel_id
|
||||
metadata["thread_id"] = str(channel_id)
|
||||
session_key = f"{self._channel.name}:{parent_channel_id}:thread:{channel_id}"
|
||||
|
||||
await self._channel._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=str(channel_id),
|
||||
content=command_text,
|
||||
metadata={
|
||||
"interaction_id": str(interaction.id),
|
||||
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
|
||||
"is_slash_command": True,
|
||||
},
|
||||
metadata=metadata,
|
||||
session_key=session_key,
|
||||
)
|
||||
|
||||
def _register_app_commands(self) -> None:
|
||||
@ -156,6 +212,10 @@ if DISCORD_AVAILABLE:
|
||||
if not self._channel.is_allowed(sender_id):
|
||||
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
|
||||
return
|
||||
channel = await self._resolve_interaction_channel(interaction)
|
||||
if not await self._interaction_channel_allowed(interaction, channel):
|
||||
await self._reply_ephemeral(interaction, "This channel is not allowed for this bot.")
|
||||
return
|
||||
await self._reply_ephemeral(interaction, build_help_text())
|
||||
|
||||
@self.tree.error
|
||||
@ -176,7 +236,7 @@ if DISCORD_AVAILABLE:
|
||||
"""Send a nanobot outbound message using Discord transport rules."""
|
||||
channel_id = int(msg.chat_id)
|
||||
|
||||
channel = self.get_channel(channel_id)
|
||||
channel = self._channel._known_channels.get(msg.chat_id) or self.get_channel(channel_id)
|
||||
if channel is None:
|
||||
try:
|
||||
channel = await self.fetch_channel(channel_id)
|
||||
@ -282,6 +342,25 @@ class DiscordChannel(BaseChannel):
|
||||
channel_id = getattr(channel_or_id, "id", channel_or_id)
|
||||
return str(channel_id)
|
||||
|
||||
@classmethod
|
||||
def _channel_allow_keys(cls, channel: Any) -> set[str]:
|
||||
"""Return channel IDs that can satisfy allow_channels for this channel."""
|
||||
keys = {cls._channel_key(channel)}
|
||||
if parent_key := cls._channel_parent_key(channel):
|
||||
keys.add(parent_key)
|
||||
return keys
|
||||
|
||||
@classmethod
|
||||
def _channel_parent_key(cls, channel: Any) -> str | None:
|
||||
"""Return the parent channel key for a Discord thread-like channel."""
|
||||
parent_id = getattr(channel, "parent_id", None)
|
||||
if parent_id is not None:
|
||||
return cls._channel_key(parent_id)
|
||||
parent = getattr(channel, "parent", None)
|
||||
if parent is not None:
|
||||
return cls._channel_key(parent)
|
||||
return None
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
if isinstance(config, dict):
|
||||
config = DiscordConfig.model_validate(config)
|
||||
@ -293,6 +372,13 @@ class DiscordChannel(BaseChannel):
|
||||
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
|
||||
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._stream_bufs: dict[str, _StreamBuf] = {}
|
||||
self._known_channels: dict[str, Any] = {}
|
||||
|
||||
def _remember_channel(self, channel: Any) -> None:
|
||||
self._known_channels[self._channel_key(channel)] = channel
|
||||
|
||||
def _forget_channel(self, channel_or_id: Any) -> None:
|
||||
self._known_channels.pop(self._channel_key(channel_or_id), None)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Discord client."""
|
||||
@ -443,9 +529,12 @@ class DiscordChannel(BaseChannel):
|
||||
"""
|
||||
if self._bot_user_id is not None and str(message.author.id) == self._bot_user_id:
|
||||
return
|
||||
if self._is_system_message(message):
|
||||
return
|
||||
|
||||
sender_id = str(message.author.id)
|
||||
channel_id = self._channel_key(message.channel)
|
||||
self._remember_channel(message.channel)
|
||||
content = message.content or ""
|
||||
|
||||
if not self._should_accept_inbound(message, sender_id, content):
|
||||
@ -454,6 +543,13 @@ class DiscordChannel(BaseChannel):
|
||||
media_paths, attachment_markers = await self._download_attachments(message.attachments)
|
||||
full_content = self._compose_inbound_content(content, attachment_markers)
|
||||
metadata = self._build_inbound_metadata(message)
|
||||
parent_channel_id = self._channel_parent_key(message.channel)
|
||||
session_key = None
|
||||
if parent_channel_id is not None:
|
||||
metadata["parent_channel_id"] = parent_channel_id
|
||||
metadata["context_chat_id"] = parent_channel_id
|
||||
metadata["thread_id"] = channel_id
|
||||
session_key = f"{self.name}:{parent_channel_id}:thread:{channel_id}"
|
||||
|
||||
await self._start_typing(message.channel)
|
||||
|
||||
@ -481,6 +577,7 @@ class DiscordChannel(BaseChannel):
|
||||
content=full_content,
|
||||
media=media_paths,
|
||||
metadata=metadata,
|
||||
session_key=session_key,
|
||||
)
|
||||
except Exception:
|
||||
await self._clear_reactions(channel_id)
|
||||
@ -496,6 +593,9 @@ class DiscordChannel(BaseChannel):
|
||||
client = self._client
|
||||
if client is None or not client.is_ready():
|
||||
return None
|
||||
channel = self._known_channels.get(chat_id)
|
||||
if channel is not None:
|
||||
return channel
|
||||
channel_id = int(chat_id)
|
||||
channel = client.get_channel(channel_id)
|
||||
if channel is not None:
|
||||
@ -544,8 +644,8 @@ class DiscordChannel(BaseChannel):
|
||||
# Channel-based filtering: only respond in allowed channels
|
||||
allow_channels = self.config.allow_channels
|
||||
if allow_channels:
|
||||
channel_id = self._channel_key(message.channel)
|
||||
if channel_id not in allow_channels:
|
||||
channel_ids = self._channel_allow_keys(message.channel)
|
||||
if channel_ids.isdisjoint(allow_channels):
|
||||
return False
|
||||
if message.guild is not None and not self._should_respond_in_group(message, content):
|
||||
return False
|
||||
@ -585,6 +685,12 @@ class DiscordChannel(BaseChannel):
|
||||
content_parts.extend(attachment_markers)
|
||||
return "\n".join(part for part in content_parts if part) or "[empty message]"
|
||||
|
||||
@staticmethod
|
||||
def _is_system_message(message: discord.Message) -> bool:
|
||||
"""Return True for Discord system messages that carry no user prompt."""
|
||||
message_type = getattr(message, "type", discord.MessageType.default)
|
||||
return message_type not in {discord.MessageType.default, discord.MessageType.reply}
|
||||
|
||||
@staticmethod
|
||||
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
||||
"""Build metadata for inbound Discord messages."""
|
||||
@ -606,6 +712,8 @@ class DiscordChannel(BaseChannel):
|
||||
|
||||
if self.config.group_policy == "mention":
|
||||
bot_user_id = self._bot_user_id
|
||||
if bot_user_id is None and self._client and self._client.user:
|
||||
bot_user_id = str(self._client.user.id)
|
||||
if bot_user_id is None:
|
||||
logger.debug(
|
||||
"Discord message in {} ignored (bot identity unavailable)", message.channel.id
|
||||
@ -614,14 +722,30 @@ class DiscordChannel(BaseChannel):
|
||||
|
||||
if any(str(user.id) == bot_user_id for user in message.mentions):
|
||||
return True
|
||||
if bot_user_id in {str(user_id) for user_id in getattr(message, "raw_mentions", [])}:
|
||||
return True
|
||||
if f"<@{bot_user_id}>" in content or f"<@!{bot_user_id}>" in content:
|
||||
return True
|
||||
if self._references_bot_message(message, bot_user_id):
|
||||
return True
|
||||
|
||||
logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _references_bot_message(message: discord.Message, bot_user_id: str) -> bool:
|
||||
"""Return True when a Discord reply targets a message authored by this bot."""
|
||||
reference = getattr(message, "reference", None)
|
||||
if reference is None:
|
||||
return False
|
||||
referenced_message = getattr(reference, "resolved", None) or getattr(
|
||||
reference, "cached_message", None
|
||||
)
|
||||
author = getattr(referenced_message, "author", None)
|
||||
return str(getattr(author, "id", "")) == bot_user_id
|
||||
|
||||
async def _start_typing(self, channel: Messageable) -> None:
|
||||
"""Start periodic typing indicator for a channel."""
|
||||
channel_id = self._channel_key(channel)
|
||||
@ -678,6 +802,7 @@ class DiscordChannel(BaseChannel):
|
||||
"""Reset client and typing state."""
|
||||
await self._cancel_all_typing()
|
||||
self._stream_bufs.clear()
|
||||
self._known_channels.clear()
|
||||
if close_client and self._client is not None and not self._client.is_closed():
|
||||
try:
|
||||
await self._client.close()
|
||||
|
||||
@ -348,6 +348,61 @@ async def test_process_message_does_not_duplicate_early_persisted_user_message(t
|
||||
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_uses_context_chat_id_for_runtime_prompt(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
loop.context.build_messages = MagicMock( # type: ignore[method-assign]
|
||||
return_value=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "runtime + hello"},
|
||||
]
|
||||
)
|
||||
loop._run_agent_loop = AsyncMock(return_value=( # type: ignore[method-assign]
|
||||
"done",
|
||||
[],
|
||||
[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "runtime + hello"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
],
|
||||
"stop",
|
||||
False,
|
||||
))
|
||||
|
||||
result = await loop._process_message(
|
||||
InboundMessage(
|
||||
channel="discord",
|
||||
sender_id="u1",
|
||||
chat_id="thread-777",
|
||||
content="hello",
|
||||
metadata={"context_chat_id": "parent-456"},
|
||||
session_key_override="discord:parent-456:thread:thread-777",
|
||||
)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.chat_id == "thread-777"
|
||||
assert loop.context.build_messages.call_args.kwargs["chat_id"] == "parent-456"
|
||||
assert loop._run_agent_loop.call_args.kwargs["chat_id"] == "thread-777"
|
||||
|
||||
|
||||
def test_set_tool_context_uses_effective_key_for_spawn_tool(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
spawn_tool = loop.tools.get("spawn")
|
||||
assert spawn_tool is not None
|
||||
|
||||
loop._set_tool_context(
|
||||
"discord",
|
||||
"thread-777",
|
||||
session_key="discord:parent-456:thread:thread-777",
|
||||
)
|
||||
|
||||
assert spawn_tool._origin_channel.get() == "discord" # type: ignore[attr-defined]
|
||||
assert spawn_tool._origin_chat_id.get() == "thread-777" # type: ignore[attr-defined]
|
||||
assert spawn_tool._session_key.get() == "discord:parent-456:thread:thread-777" # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
|
||||
@ -96,8 +96,15 @@ class _FakeSentMessage:
|
||||
|
||||
class _FakeChannel:
|
||||
# Channel double that records outbound payloads and typing activity.
|
||||
def __init__(self, channel_id: int = 123) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
channel_id: int = 123,
|
||||
parent_id: int | None = None,
|
||||
parent: object | None = None,
|
||||
) -> None:
|
||||
self.id = channel_id
|
||||
self.parent_id = parent_id
|
||||
self.parent = parent
|
||||
self.sent_payloads: list[dict] = []
|
||||
self.sent_messages: list[_FakeSentMessage] = []
|
||||
self.trigger_typing_calls = 0
|
||||
@ -148,12 +155,14 @@ def _make_interaction(
|
||||
*,
|
||||
user_id: int = 123,
|
||||
channel_id: int | None = 456,
|
||||
channel=None,
|
||||
guild_id: int | None = None,
|
||||
interaction_id: int = 999,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
user=SimpleNamespace(id=user_id),
|
||||
channel_id=channel_id,
|
||||
channel=channel,
|
||||
guild_id=guild_id,
|
||||
id=interaction_id,
|
||||
command=SimpleNamespace(qualified_name="new"),
|
||||
@ -166,25 +175,39 @@ def _make_message(
|
||||
author_id: int = 123,
|
||||
author_bot: bool = False,
|
||||
channel_id: int = 456,
|
||||
parent_channel_id: int | None = None,
|
||||
message_id: int = 789,
|
||||
content: str = "hello",
|
||||
guild_id: int | None = None,
|
||||
mentions: list[object] | None = None,
|
||||
attachments: list[object] | None = None,
|
||||
reply_to: int | None = None,
|
||||
reply_author_id: int | None = None,
|
||||
message_type=None,
|
||||
):
|
||||
# Factory for incoming Discord message objects with optional guild/reply/attachments.
|
||||
guild = SimpleNamespace(id=guild_id) if guild_id is not None else None
|
||||
reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None
|
||||
referenced_message = (
|
||||
SimpleNamespace(author=SimpleNamespace(id=reply_author_id))
|
||||
if reply_author_id is not None
|
||||
else None
|
||||
)
|
||||
reference = (
|
||||
SimpleNamespace(message_id=reply_to, resolved=referenced_message)
|
||||
if reply_to is not None
|
||||
else None
|
||||
)
|
||||
return SimpleNamespace(
|
||||
author=SimpleNamespace(id=author_id, bot=author_bot),
|
||||
channel=_FakeChannel(channel_id),
|
||||
channel=_FakeChannel(channel_id, parent_channel_id),
|
||||
content=content,
|
||||
guild=guild,
|
||||
mentions=mentions or [],
|
||||
raw_mentions=[],
|
||||
attachments=attachments or [],
|
||||
reference=reference,
|
||||
id=message_id,
|
||||
type=message_type or discord.MessageType.default,
|
||||
)
|
||||
|
||||
|
||||
@ -357,6 +380,147 @@ async def test_on_message_accepts_when_channel_in_allow_channels() -> None:
|
||||
assert handled[0]["chat_id"] == "456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_accepts_thread_when_parent_channel_in_allow_channels() -> None:
|
||||
# Discord threads have independent channel IDs, but inherit allowlist access
|
||||
# from their parent channel.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
allow_from=["*"],
|
||||
allow_channels=["456"],
|
||||
group_policy="mention",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._bot_user_id = "999"
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
channel_id=777,
|
||||
parent_channel_id=456,
|
||||
guild_id=1,
|
||||
mentions=[SimpleNamespace(id=999)],
|
||||
)
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["chat_id"] == "777"
|
||||
assert handled[0]["metadata"]["context_chat_id"] == "456"
|
||||
assert handled[0]["metadata"]["thread_id"] == "777"
|
||||
assert handled[0]["session_key"] == "discord:456:thread:777"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_accepts_thread_reply_to_bot_under_allowed_parent() -> None:
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
allow_from=["*"],
|
||||
allow_channels=["456"],
|
||||
group_policy="mention",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._bot_user_id = "999"
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
channel_id=777,
|
||||
parent_channel_id=456,
|
||||
guild_id=1,
|
||||
content="follow up",
|
||||
reply_to=111,
|
||||
reply_author_id=999,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["chat_id"] == "777"
|
||||
assert handled[0]["metadata"]["reply_to"] == "111"
|
||||
assert handled[0]["metadata"]["context_chat_id"] == "456"
|
||||
assert handled[0]["session_key"] == "discord:456:thread:777"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_ignores_thread_lifecycle_messages() -> None:
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
allow_from=["*"],
|
||||
allow_channels=["456"],
|
||||
group_policy="open",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
channel_id=777,
|
||||
parent_channel_id=456,
|
||||
guild_id=1,
|
||||
content="",
|
||||
message_type=discord.MessageType.thread_created,
|
||||
)
|
||||
)
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
channel_id=777,
|
||||
parent_channel_id=456,
|
||||
guild_id=1,
|
||||
content="",
|
||||
message_type=discord.MessageType.thread_starter_message,
|
||||
)
|
||||
)
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
channel_id=777,
|
||||
parent_channel_id=456,
|
||||
guild_id=1,
|
||||
content="",
|
||||
message_type=discord.MessageType.pins_add,
|
||||
)
|
||||
)
|
||||
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_drops_thread_when_neither_thread_nor_parent_allowed() -> None:
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, allow_from=["*"], allow_channels=["999"]),
|
||||
MessageBus(),
|
||||
)
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(_make_message(channel_id=777, parent_channel_id=456))
|
||||
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_drops_when_channel_not_in_allow_channels() -> None:
|
||||
# When allow_channels is set and incoming channel is not listed, drop silently.
|
||||
@ -517,6 +681,24 @@ async def test_send_fetches_channel_when_not_cached() -> None:
|
||||
assert target.sent_payloads == [{"content": "hello"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_seen_thread_channel_when_client_cannot_resolve_it() -> None:
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
target = _FakeChannel(channel_id=777, parent_id=456)
|
||||
owner._known_channels["777"] = target
|
||||
client.get_channel = lambda channel_id: None # type: ignore[method-assign]
|
||||
|
||||
async def fetch_channel(channel_id: int):
|
||||
raise RuntimeError("not found")
|
||||
|
||||
client.fetch_channel = fetch_channel # type: ignore[method-assign]
|
||||
|
||||
await client.send_outbound(OutboundMessage(channel="discord", chat_id="777", content="hello"))
|
||||
|
||||
assert target.sent_payloads == [{"content": "hello"}]
|
||||
|
||||
|
||||
def test_supports_streaming_enabled_by_default() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
|
||||
@ -596,6 +778,71 @@ async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
||||
assert handled[0]["metadata"]["is_slash_command"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_new_accepts_thread_when_parent_channel_in_allow_channels() -> None:
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, allow_from=["*"], allow_channels=["456"]),
|
||||
MessageBus(),
|
||||
)
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
thread = _FakeChannel(channel_id=777, parent_id=456)
|
||||
interaction = _make_interaction(
|
||||
user_id=123,
|
||||
channel_id=777,
|
||||
channel=thread,
|
||||
guild_id=1,
|
||||
interaction_id=321,
|
||||
)
|
||||
|
||||
new_cmd = client.tree.get_command("new")
|
||||
assert new_cmd is not None
|
||||
await new_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [{"content": "Processing /new...", "ephemeral": True}]
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["chat_id"] == "777"
|
||||
assert handled[0]["metadata"]["context_chat_id"] == "456"
|
||||
assert handled[0]["metadata"]["thread_id"] == "777"
|
||||
assert handled[0]["session_key"] == "discord:456:thread:777"
|
||||
assert channel._known_channels["777"] is thread
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_new_blocks_channel_not_in_allow_channels() -> None:
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, allow_from=["*"], allow_channels=["999"]),
|
||||
MessageBus(),
|
||||
)
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction(
|
||||
user_id=123,
|
||||
channel_id=777,
|
||||
channel=_FakeChannel(channel_id=777, parent_id=456),
|
||||
guild_id=1,
|
||||
)
|
||||
|
||||
new_cmd = client.tree.get_command("new")
|
||||
assert new_cmd is not None
|
||||
await new_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": "This channel is not allowed for this bot.", "ephemeral": True}
|
||||
]
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_new_is_blocked_for_disallowed_user() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus())
|
||||
@ -665,6 +912,45 @@ async def test_slash_help_returns_ephemeral_help_text() -> None:
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_help_respects_allow_channels() -> None:
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, allow_from=["*"], allow_channels=["999"]),
|
||||
MessageBus(),
|
||||
)
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction(
|
||||
channel_id=777,
|
||||
channel=_FakeChannel(channel_id=777, parent_id=456),
|
||||
guild_id=1,
|
||||
)
|
||||
interaction.command.qualified_name = "help"
|
||||
|
||||
help_cmd = client.tree.get_command("help")
|
||||
assert help_cmd is not None
|
||||
await help_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": "This channel is not allowed for this bot.", "ephemeral": True}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_delete_and_archive_remove_known_channel() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
thread = _FakeChannel(channel_id=777, parent_id=456)
|
||||
|
||||
channel._remember_channel(thread)
|
||||
await client.on_thread_delete(thread)
|
||||
assert "777" not in channel._known_channels
|
||||
|
||||
channel._remember_channel(thread)
|
||||
archived_thread = SimpleNamespace(id=777, parent_id=456, archived=True)
|
||||
await client.on_thread_update(thread, archived_thread)
|
||||
assert "777" not in channel._known_channels
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None:
|
||||
# Outbound payloads should upload files, attach reply references, and chunk long text.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user