mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-01 07:15:52 +00:00
fix(commands): intercept non-priority commands during active turn
Non-priority slash commands (e.g. /new, /help, /dream-log) arriving while a session has an active LLM turn were silently queued into the pending injection buffer and later injected as raw user messages into the LLM conversation. This caused the model to respond to "/new" as plain text instead of executing the command. Root cause: the run() loop only checked priority commands (/stop, /restart, /status) before routing messages to the pending queue. All other command tiers (exact, prefix) bypassed command dispatch entirely. Changes: - Add CommandRouter.is_dispatchable_command() to match exact/prefix tiers, mirroring the existing is_priority() pattern. - In run(), intercept dispatchable commands before pending queue insertion and dispatch them directly via _dispatch_command_inline(). - Extract _cancel_active_tasks() from cmd_stop for reuse; cmd_new now cancels active tasks before clearing the session to prevent shared mutable state corruption from concurrent asyncio coroutines. - Update /new semantics: stops active task first, then clears session. - Update documentation in help text, docs, and Discord command list.
This commit is contained in:
parent
f8a023218d
commit
d4e34f8c67
@ -4,7 +4,7 @@ These commands work inside chat channels and interactive agent sessions:
|
|||||||
|
|
||||||
| Command | Description |
|
| Command | Description |
|
||||||
|---------|-------------|
|
|---------|-------------|
|
||||||
| `/new` | Start a new conversation |
|
| `/new` | Stop current task and start a new conversation |
|
||||||
| `/stop` | Stop the current task |
|
| `/stop` | Stop the current task |
|
||||||
| `/restart` | Restart the bot |
|
| `/restart` | Restart the bot |
|
||||||
| `/status` | Show bot status |
|
| `/status` | Show bot status |
|
||||||
|
|||||||
@ -345,6 +345,36 @@ class AgentLoop:
|
|||||||
|
|
||||||
return format_tool_hints(tool_calls)
|
return format_tool_hints(tool_calls)
|
||||||
|
|
||||||
|
async def _dispatch_command_inline(
|
||||||
|
self,
|
||||||
|
msg: InboundMessage,
|
||||||
|
key: str,
|
||||||
|
raw: str,
|
||||||
|
dispatch_fn: Callable[[CommandContext], Awaitable[OutboundMessage | None]],
|
||||||
|
) -> None:
|
||||||
|
"""Dispatch a command directly from the run() loop and publish the result."""
|
||||||
|
ctx = CommandContext(msg=msg, session=None, key=key, raw=raw, loop=self)
|
||||||
|
result = await dispatch_fn(ctx)
|
||||||
|
if result:
|
||||||
|
await self.bus.publish_outbound(result)
|
||||||
|
else:
|
||||||
|
logger.warning("Command '{}' matched but dispatch returned None", raw)
|
||||||
|
|
||||||
|
async def _cancel_active_tasks(self, key: str) -> int:
|
||||||
|
"""Cancel and await all active tasks and subagents for *key*.
|
||||||
|
|
||||||
|
Returns the total number of cancelled tasks + subagents.
|
||||||
|
"""
|
||||||
|
tasks = self._active_tasks.pop(key, [])
|
||||||
|
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
||||||
|
for t in tasks:
|
||||||
|
try:
|
||||||
|
await t
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
sub_cancelled = await self.subagents.cancel_by_session(key)
|
||||||
|
return cancelled + sub_cancelled
|
||||||
|
|
||||||
def _effective_session_key(self, msg: InboundMessage) -> str:
|
def _effective_session_key(self, msg: InboundMessage) -> str:
|
||||||
"""Return the session key used for task routing and mid-turn injections."""
|
"""Return the session key used for task routing and mid-turn injections."""
|
||||||
if self._unified_session and not msg.session_key_override:
|
if self._unified_session and not msg.session_key_override:
|
||||||
@ -478,16 +508,24 @@ class AgentLoop:
|
|||||||
|
|
||||||
raw = msg.content.strip()
|
raw = msg.content.strip()
|
||||||
if self.commands.is_priority(raw):
|
if self.commands.is_priority(raw):
|
||||||
ctx = CommandContext(msg=msg, session=None, key=msg.session_key, raw=raw, loop=self)
|
await self._dispatch_command_inline(
|
||||||
result = await self.commands.dispatch_priority(ctx)
|
msg, msg.session_key, raw,
|
||||||
if result:
|
self.commands.dispatch_priority,
|
||||||
await self.bus.publish_outbound(result)
|
)
|
||||||
continue
|
continue
|
||||||
effective_key = self._effective_session_key(msg)
|
effective_key = self._effective_session_key(msg)
|
||||||
# If this session already has an active pending queue (i.e. a task
|
# If this session already has an active pending queue (i.e. a task
|
||||||
# is processing this session), route the message there for mid-turn
|
# is processing this session), route the message there for mid-turn
|
||||||
# injection instead of creating a competing task.
|
# injection instead of creating a competing task.
|
||||||
if effective_key in self._pending_queues:
|
if effective_key in self._pending_queues:
|
||||||
|
# Non-priority commands must not be queued for injection;
|
||||||
|
# dispatch them directly (same pattern as priority commands).
|
||||||
|
if self.commands.is_dispatchable_command(raw):
|
||||||
|
await self._dispatch_command_inline(
|
||||||
|
msg, effective_key, raw,
|
||||||
|
self.commands.dispatch,
|
||||||
|
)
|
||||||
|
continue
|
||||||
pending_msg = msg
|
pending_msg = msg
|
||||||
if effective_key != msg.session_key:
|
if effective_key != msg.session_key:
|
||||||
pending_msg = dataclasses.replace(
|
pending_msg = dataclasses.replace(
|
||||||
|
|||||||
@ -135,7 +135,7 @@ if DISCORD_AVAILABLE:
|
|||||||
|
|
||||||
def _register_app_commands(self) -> None:
|
def _register_app_commands(self) -> None:
|
||||||
commands = (
|
commands = (
|
||||||
("new", "Start a new conversation", "/new"),
|
("new", "Stop current task and start a new conversation", "/new"),
|
||||||
("stop", "Stop the current task", "/stop"),
|
("stop", "Stop the current task", "/stop"),
|
||||||
("restart", "Restart the bot", "/restart"),
|
("restart", "Restart the bot", "/restart"),
|
||||||
("status", "Show bot status", "/status"),
|
("status", "Show bot status", "/status"),
|
||||||
|
|||||||
@ -17,15 +17,7 @@ async def cmd_stop(ctx: CommandContext) -> OutboundMessage:
|
|||||||
"""Cancel all active tasks and subagents for the session."""
|
"""Cancel all active tasks and subagents for the session."""
|
||||||
loop = ctx.loop
|
loop = ctx.loop
|
||||||
msg = ctx.msg
|
msg = ctx.msg
|
||||||
tasks = loop._active_tasks.pop(msg.session_key, [])
|
total = await loop._cancel_active_tasks(msg.session_key)
|
||||||
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
|
||||||
for t in tasks:
|
|
||||||
try:
|
|
||||||
await t
|
|
||||||
except (asyncio.CancelledError, Exception):
|
|
||||||
pass
|
|
||||||
sub_cancelled = await loop.subagents.cancel_by_session(msg.session_key)
|
|
||||||
total = cancelled + sub_cancelled
|
|
||||||
content = f"Stopped {total} task(s)." if total else "No active task to stop."
|
content = f"Stopped {total} task(s)." if total else "No active task to stop."
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||||
@ -100,8 +92,9 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage:
|
|||||||
|
|
||||||
|
|
||||||
async def cmd_new(ctx: CommandContext) -> OutboundMessage:
|
async def cmd_new(ctx: CommandContext) -> OutboundMessage:
|
||||||
"""Start a fresh session."""
|
"""Stop active task and start a fresh session."""
|
||||||
loop = ctx.loop
|
loop = ctx.loop
|
||||||
|
await loop._cancel_active_tasks(ctx.key)
|
||||||
session = ctx.session or loop.sessions.get_or_create(ctx.key)
|
session = ctx.session or loop.sessions.get_or_create(ctx.key)
|
||||||
snapshot = session.messages[session.last_consolidated:]
|
snapshot = session.messages[session.last_consolidated:]
|
||||||
session.clear()
|
session.clear()
|
||||||
@ -327,7 +320,7 @@ def build_help_text() -> str:
|
|||||||
"""Build canonical help text shared across channels."""
|
"""Build canonical help text shared across channels."""
|
||||||
lines = [
|
lines = [
|
||||||
"🐈 nanobot commands:",
|
"🐈 nanobot commands:",
|
||||||
"/new — Start a new conversation",
|
"/new — Stop current task and start a new conversation",
|
||||||
"/stop — Stop the current task",
|
"/stop — Stop the current task",
|
||||||
"/restart — Restart the bot",
|
"/restart — Restart the bot",
|
||||||
"/status — Show bot status",
|
"/status — Show bot status",
|
||||||
|
|||||||
@ -57,6 +57,20 @@ class CommandRouter:
|
|||||||
def is_priority(self, text: str) -> bool:
|
def is_priority(self, text: str) -> bool:
|
||||||
return text.strip().lower() in self._priority
|
return text.strip().lower() in self._priority
|
||||||
|
|
||||||
|
def is_dispatchable_command(self, text: str) -> bool:
|
||||||
|
"""Check whether *text* matches any non-priority command tier (exact or prefix).
|
||||||
|
|
||||||
|
Does NOT check priority or interceptor tiers.
|
||||||
|
If this returns True, ``dispatch()`` is guaranteed to match a handler.
|
||||||
|
"""
|
||||||
|
cmd = text.strip().lower()
|
||||||
|
if cmd in self._exact:
|
||||||
|
return True
|
||||||
|
for pfx, _ in self._prefix:
|
||||||
|
if cmd.startswith(pfx):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def dispatch_priority(self, ctx: CommandContext) -> OutboundMessage | None:
|
async def dispatch_priority(self, ctx: CommandContext) -> OutboundMessage | None:
|
||||||
"""Dispatch a priority command. Called from run() without the lock."""
|
"""Dispatch a priority command. Called from run() without the lock."""
|
||||||
handler = self._priority.get(ctx.raw.lower())
|
handler = self._priority.get(ctx.raw.lower())
|
||||||
|
|||||||
143
tests/command/test_router_dispatchable.py
Normal file
143
tests/command/test_router_dispatchable.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
"""Tests for CommandRouter.is_dispatchable_command and mid-turn command interception."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.command.builtin import register_builtin_commands
|
||||||
|
from nanobot.command.router import CommandContext, CommandRouter
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsDispatchableCommand:
|
||||||
|
"""Unit tests for the is_dispatchable_command() predicate."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def router(self) -> CommandRouter:
|
||||||
|
r = CommandRouter()
|
||||||
|
register_builtin_commands(r)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def test_exact_commands_match(self, router: CommandRouter) -> None:
|
||||||
|
assert router.is_dispatchable_command("/new")
|
||||||
|
assert router.is_dispatchable_command("/help")
|
||||||
|
assert router.is_dispatchable_command("/dream")
|
||||||
|
assert router.is_dispatchable_command("/dream-log")
|
||||||
|
assert router.is_dispatchable_command("/dream-restore")
|
||||||
|
|
||||||
|
def test_prefix_commands_match(self, router: CommandRouter) -> None:
|
||||||
|
assert router.is_dispatchable_command("/dream-log abc123")
|
||||||
|
assert router.is_dispatchable_command("/dream-restore def456")
|
||||||
|
|
||||||
|
def test_priority_commands_not_matched(self, router: CommandRouter) -> None:
|
||||||
|
# Priority commands are NOT in the dispatchable tiers — they are
|
||||||
|
# handled by is_priority() separately.
|
||||||
|
assert not router.is_dispatchable_command("/stop")
|
||||||
|
assert not router.is_dispatchable_command("/restart")
|
||||||
|
|
||||||
|
def test_regular_text_not_matched(self, router: CommandRouter) -> None:
|
||||||
|
assert not router.is_dispatchable_command("hello")
|
||||||
|
assert not router.is_dispatchable_command("what is 2+2?")
|
||||||
|
assert not router.is_dispatchable_command("")
|
||||||
|
|
||||||
|
def test_case_insensitive(self, router: CommandRouter) -> None:
|
||||||
|
assert router.is_dispatchable_command("/NEW")
|
||||||
|
assert router.is_dispatchable_command("/Help")
|
||||||
|
|
||||||
|
def test_strips_whitespace(self, router: CommandRouter) -> None:
|
||||||
|
assert router.is_dispatchable_command(" /new ")
|
||||||
|
|
||||||
|
def test_unknown_slash_command_not_matched(self, router: CommandRouter) -> None:
|
||||||
|
assert not router.is_dispatchable_command("/unknown")
|
||||||
|
assert not router.is_dispatchable_command("/foo bar")
|
||||||
|
|
||||||
|
|
||||||
|
class TestMidTurnCommandDispatchedDirectly:
|
||||||
|
"""Verify that commands matching is_dispatchable_command() are dispatched
|
||||||
|
correctly when session=None (the mid-turn path)."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def router(self) -> CommandRouter:
|
||||||
|
r = CommandRouter()
|
||||||
|
register_builtin_commands(r)
|
||||||
|
return r
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def fake_loop(self) -> MagicMock:
|
||||||
|
loop = MagicMock()
|
||||||
|
loop.sessions = MagicMock()
|
||||||
|
loop.sessions.get_or_create = MagicMock(return_value=MagicMock(
|
||||||
|
messages=[], last_consolidated=0, clear=MagicMock(),
|
||||||
|
))
|
||||||
|
loop.sessions.save = MagicMock()
|
||||||
|
loop.sessions.invalidate = MagicMock()
|
||||||
|
loop._schedule_background = MagicMock()
|
||||||
|
loop._cancel_active_tasks = AsyncMock(return_value=0)
|
||||||
|
return loop
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def fake_msg(self) -> MagicMock:
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.channel = "test"
|
||||||
|
msg.chat_id = "chat1"
|
||||||
|
msg.content = "/new"
|
||||||
|
msg.metadata = {}
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_dispatched_with_session_none(
|
||||||
|
self, router: CommandRouter, fake_loop: MagicMock, fake_msg: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""cmd_new works when session=None (mid-turn dispatch path)."""
|
||||||
|
ctx = CommandContext(
|
||||||
|
msg=fake_msg, session=None,
|
||||||
|
key="test:chat1", raw="/new", loop=fake_loop,
|
||||||
|
)
|
||||||
|
result = await router.dispatch(ctx)
|
||||||
|
assert result is not None
|
||||||
|
assert "New session" in result.content
|
||||||
|
fake_loop.sessions.get_or_create.assert_called_once_with("test:chat1")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_help_dispatched_with_session_none(
|
||||||
|
self, router: CommandRouter, fake_loop: MagicMock, fake_msg: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
ctx = CommandContext(
|
||||||
|
msg=fake_msg, session=None,
|
||||||
|
key="test:chat1", raw="/help", loop=fake_loop,
|
||||||
|
)
|
||||||
|
result = await router.dispatch(ctx)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prefix_command_args_populated(self, router: CommandRouter) -> None:
|
||||||
|
"""Prefix commands have args populated correctly in mid-turn path."""
|
||||||
|
# Use a custom prefix handler to avoid needing full mock setup.
|
||||||
|
custom = CommandRouter()
|
||||||
|
captured_args = []
|
||||||
|
|
||||||
|
async def fake_handler(ctx: CommandContext) -> None:
|
||||||
|
captured_args.append(ctx.args)
|
||||||
|
return None
|
||||||
|
|
||||||
|
custom.prefix("/test ", fake_handler)
|
||||||
|
|
||||||
|
ctx = CommandContext(
|
||||||
|
msg=MagicMock(channel="test", chat_id="c1", metadata={}),
|
||||||
|
session=None, key="test:c1", raw="/test hello world", loop=MagicMock(),
|
||||||
|
)
|
||||||
|
await custom.dispatch(ctx)
|
||||||
|
assert captured_args == ["hello world"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_command_returns_none(
|
||||||
|
self, router: CommandRouter, fake_loop: MagicMock, fake_msg: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Regular text returns None from dispatch (not a command)."""
|
||||||
|
ctx = CommandContext(
|
||||||
|
msg=fake_msg, session=None,
|
||||||
|
key="test:chat1", raw="hello world", loop=fake_loop,
|
||||||
|
)
|
||||||
|
result = await router.dispatch(ctx)
|
||||||
|
assert result is None
|
||||||
Loading…
x
Reference in New Issue
Block a user