diff --git a/docs/chat-commands.md b/docs/chat-commands.md index 72707e764..816292e74 100644 --- a/docs/chat-commands.md +++ b/docs/chat-commands.md @@ -4,7 +4,7 @@ These commands work inside chat channels and interactive agent sessions: | Command | Description | |---------|-------------| -| `/new` | Start a new conversation | +| `/new` | Stop current task and start a new conversation | | `/stop` | Stop the current task | | `/restart` | Restart the bot | | `/status` | Show bot status | diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 116868bb0..25af137c8 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -345,6 +345,36 @@ class AgentLoop: return format_tool_hints(tool_calls) + async def _dispatch_command_inline( + self, + msg: InboundMessage, + key: str, + raw: str, + dispatch_fn: Callable[[CommandContext], Awaitable[OutboundMessage | None]], + ) -> None: + """Dispatch a command directly from the run() loop and publish the result.""" + ctx = CommandContext(msg=msg, session=None, key=key, raw=raw, loop=self) + result = await dispatch_fn(ctx) + if result: + await self.bus.publish_outbound(result) + else: + logger.warning("Command '{}' matched but dispatch returned None", raw) + + async def _cancel_active_tasks(self, key: str) -> int: + """Cancel and await all active tasks and subagents for *key*. + + Returns the total number of cancelled tasks + subagents. + """ + tasks = self._active_tasks.pop(key, []) + cancelled = sum(1 for t in tasks if not t.done() and t.cancel()) + for t in tasks: + try: + await t + except (asyncio.CancelledError, Exception): + pass + sub_cancelled = await self.subagents.cancel_by_session(key) + return cancelled + sub_cancelled + def _effective_session_key(self, msg: InboundMessage) -> str: """Return the session key used for task routing and mid-turn injections.""" if self._unified_session and not msg.session_key_override: @@ -478,16 +508,24 @@ class AgentLoop: raw = msg.content.strip() if self.commands.is_priority(raw): - ctx = CommandContext(msg=msg, session=None, key=msg.session_key, raw=raw, loop=self) - result = await self.commands.dispatch_priority(ctx) - if result: - await self.bus.publish_outbound(result) + await self._dispatch_command_inline( + msg, msg.session_key, raw, + self.commands.dispatch_priority, + ) continue effective_key = self._effective_session_key(msg) # If this session already has an active pending queue (i.e. a task # is processing this session), route the message there for mid-turn # injection instead of creating a competing task. if effective_key in self._pending_queues: + # Non-priority commands must not be queued for injection; + # dispatch them directly (same pattern as priority commands). + if self.commands.is_dispatchable_command(raw): + await self._dispatch_command_inline( + msg, effective_key, raw, + self.commands.dispatch, + ) + continue pending_msg = msg if effective_key != msg.session_key: pending_msg = dataclasses.replace( diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index 9710c5efc..97fa30bd0 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -135,7 +135,7 @@ if DISCORD_AVAILABLE: def _register_app_commands(self) -> None: commands = ( - ("new", "Start a new conversation", "/new"), + ("new", "Stop current task and start a new conversation", "/new"), ("stop", "Stop the current task", "/stop"), ("restart", "Restart the bot", "/restart"), ("status", "Show bot status", "/status"), diff --git a/nanobot/command/builtin.py b/nanobot/command/builtin.py index 94ee0646a..87d4bf640 100644 --- a/nanobot/command/builtin.py +++ b/nanobot/command/builtin.py @@ -17,15 +17,7 @@ async def cmd_stop(ctx: CommandContext) -> OutboundMessage: """Cancel all active tasks and subagents for the session.""" loop = ctx.loop msg = ctx.msg - tasks = loop._active_tasks.pop(msg.session_key, []) - cancelled = sum(1 for t in tasks if not t.done() and t.cancel()) - for t in tasks: - try: - await t - except (asyncio.CancelledError, Exception): - pass - sub_cancelled = await loop.subagents.cancel_by_session(msg.session_key) - total = cancelled + sub_cancelled + total = await loop._cancel_active_tasks(msg.session_key) content = f"Stopped {total} task(s)." if total else "No active task to stop." return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content=content, @@ -100,8 +92,9 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage: async def cmd_new(ctx: CommandContext) -> OutboundMessage: - """Start a fresh session.""" + """Stop active task and start a fresh session.""" loop = ctx.loop + await loop._cancel_active_tasks(ctx.key) session = ctx.session or loop.sessions.get_or_create(ctx.key) snapshot = session.messages[session.last_consolidated:] session.clear() @@ -327,7 +320,7 @@ def build_help_text() -> str: """Build canonical help text shared across channels.""" lines = [ "🐈 nanobot commands:", - "/new — Start a new conversation", + "/new — Stop current task and start a new conversation", "/stop — Stop the current task", "/restart — Restart the bot", "/status — Show bot status", diff --git a/nanobot/command/router.py b/nanobot/command/router.py index 35a475453..98f938b17 100644 --- a/nanobot/command/router.py +++ b/nanobot/command/router.py @@ -57,6 +57,20 @@ class CommandRouter: def is_priority(self, text: str) -> bool: return text.strip().lower() in self._priority + def is_dispatchable_command(self, text: str) -> bool: + """Check whether *text* matches any non-priority command tier (exact or prefix). + + Does NOT check priority or interceptor tiers. + If this returns True, ``dispatch()`` is guaranteed to match a handler. + """ + cmd = text.strip().lower() + if cmd in self._exact: + return True + for pfx, _ in self._prefix: + if cmd.startswith(pfx): + return True + return False + async def dispatch_priority(self, ctx: CommandContext) -> OutboundMessage | None: """Dispatch a priority command. Called from run() without the lock.""" handler = self._priority.get(ctx.raw.lower()) diff --git a/tests/command/test_router_dispatchable.py b/tests/command/test_router_dispatchable.py new file mode 100644 index 000000000..3be684072 --- /dev/null +++ b/tests/command/test_router_dispatchable.py @@ -0,0 +1,143 @@ +"""Tests for CommandRouter.is_dispatchable_command and mid-turn command interception.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.command.builtin import register_builtin_commands +from nanobot.command.router import CommandContext, CommandRouter + + +class TestIsDispatchableCommand: + """Unit tests for the is_dispatchable_command() predicate.""" + + @pytest.fixture() + def router(self) -> CommandRouter: + r = CommandRouter() + register_builtin_commands(r) + return r + + def test_exact_commands_match(self, router: CommandRouter) -> None: + assert router.is_dispatchable_command("/new") + assert router.is_dispatchable_command("/help") + assert router.is_dispatchable_command("/dream") + assert router.is_dispatchable_command("/dream-log") + assert router.is_dispatchable_command("/dream-restore") + + def test_prefix_commands_match(self, router: CommandRouter) -> None: + assert router.is_dispatchable_command("/dream-log abc123") + assert router.is_dispatchable_command("/dream-restore def456") + + def test_priority_commands_not_matched(self, router: CommandRouter) -> None: + # Priority commands are NOT in the dispatchable tiers — they are + # handled by is_priority() separately. + assert not router.is_dispatchable_command("/stop") + assert not router.is_dispatchable_command("/restart") + + def test_regular_text_not_matched(self, router: CommandRouter) -> None: + assert not router.is_dispatchable_command("hello") + assert not router.is_dispatchable_command("what is 2+2?") + assert not router.is_dispatchable_command("") + + def test_case_insensitive(self, router: CommandRouter) -> None: + assert router.is_dispatchable_command("/NEW") + assert router.is_dispatchable_command("/Help") + + def test_strips_whitespace(self, router: CommandRouter) -> None: + assert router.is_dispatchable_command(" /new ") + + def test_unknown_slash_command_not_matched(self, router: CommandRouter) -> None: + assert not router.is_dispatchable_command("/unknown") + assert not router.is_dispatchable_command("/foo bar") + + +class TestMidTurnCommandDispatchedDirectly: + """Verify that commands matching is_dispatchable_command() are dispatched + correctly when session=None (the mid-turn path).""" + + @pytest.fixture() + def router(self) -> CommandRouter: + r = CommandRouter() + register_builtin_commands(r) + return r + + @pytest.fixture() + def fake_loop(self) -> MagicMock: + loop = MagicMock() + loop.sessions = MagicMock() + loop.sessions.get_or_create = MagicMock(return_value=MagicMock( + messages=[], last_consolidated=0, clear=MagicMock(), + )) + loop.sessions.save = MagicMock() + loop.sessions.invalidate = MagicMock() + loop._schedule_background = MagicMock() + loop._cancel_active_tasks = AsyncMock(return_value=0) + return loop + + @pytest.fixture() + def fake_msg(self) -> MagicMock: + msg = MagicMock() + msg.channel = "test" + msg.chat_id = "chat1" + msg.content = "/new" + msg.metadata = {} + return msg + + @pytest.mark.asyncio + async def test_new_dispatched_with_session_none( + self, router: CommandRouter, fake_loop: MagicMock, fake_msg: MagicMock, + ) -> None: + """cmd_new works when session=None (mid-turn dispatch path).""" + ctx = CommandContext( + msg=fake_msg, session=None, + key="test:chat1", raw="/new", loop=fake_loop, + ) + result = await router.dispatch(ctx) + assert result is not None + assert "New session" in result.content + fake_loop.sessions.get_or_create.assert_called_once_with("test:chat1") + + @pytest.mark.asyncio + async def test_help_dispatched_with_session_none( + self, router: CommandRouter, fake_loop: MagicMock, fake_msg: MagicMock, + ) -> None: + ctx = CommandContext( + msg=fake_msg, session=None, + key="test:chat1", raw="/help", loop=fake_loop, + ) + result = await router.dispatch(ctx) + assert result is not None + + @pytest.mark.asyncio + async def test_prefix_command_args_populated(self, router: CommandRouter) -> None: + """Prefix commands have args populated correctly in mid-turn path.""" + # Use a custom prefix handler to avoid needing full mock setup. + custom = CommandRouter() + captured_args = [] + + async def fake_handler(ctx: CommandContext) -> None: + captured_args.append(ctx.args) + return None + + custom.prefix("/test ", fake_handler) + + ctx = CommandContext( + msg=MagicMock(channel="test", chat_id="c1", metadata={}), + session=None, key="test:c1", raw="/test hello world", loop=MagicMock(), + ) + await custom.dispatch(ctx) + assert captured_args == ["hello world"] + + @pytest.mark.asyncio + async def test_non_command_returns_none( + self, router: CommandRouter, fake_loop: MagicMock, fake_msg: MagicMock, + ) -> None: + """Regular text returns None from dispatch (not a command).""" + ctx = CommandContext( + msg=fake_msg, session=None, + key="test:chat1", raw="hello world", loop=fake_loop, + ) + result = await router.dispatch(ctx) + assert result is None