From 403ce23d22c59fe31a89cc3b4ed8db31d8cf6f17 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sat, 25 Apr 2026 14:06:09 +0000 Subject: [PATCH] fix(agent): tighten ask_user CLI handling Made-with: Cursor --- nanobot/agent/loop.py | 100 +++++++---------------------------- nanobot/agent/runner.py | 27 ++++++---- nanobot/agent/tools/ask.py | 86 ++++++++++++++++++++++++++++++ nanobot/cli/commands.py | 4 ++ tests/agent/test_ask_user.py | 19 ++++++- 5 files changed, 142 insertions(+), 94 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index d87ad1a80..5a4480041 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -20,7 +20,13 @@ from nanobot.agent.memory import Consolidator, Dream from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec from nanobot.agent.skills import BUILTIN_SKILLS_DIR from nanobot.agent.subagent import SubagentManager -from nanobot.agent.tools.ask import AskUserTool +from nanobot.agent.tools.ask import ( + AskUserTool, + ask_user_options_from_messages, + ask_user_outbound, + ask_user_tool_result_messages, + pending_ask_user_id, +) from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.message import MessageTool @@ -54,7 +60,6 @@ if TYPE_CHECKING: UNIFIED_SESSION_KEY = "unified:default" -BUTTON_CHANNELS = frozenset({"telegram"}) class _LoopHook(AgentHook): @@ -410,69 +415,6 @@ class AgentLoop: return UNIFIED_SESSION_KEY return msg.session_key - @staticmethod - def _tool_call_name(tool_call: dict[str, Any]) -> str: - function = tool_call.get("function") - if isinstance(function, dict) and isinstance(function.get("name"), str): - return function["name"] - name = tool_call.get("name") - return name if isinstance(name, str) else "" - - @staticmethod - def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]: - function = tool_call.get("function") - raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments") - if isinstance(raw, dict): - return raw - if isinstance(raw, str): - try: - parsed = json.loads(raw) - except json.JSONDecodeError: - return {} - return parsed if isinstance(parsed, dict) else {} - return {} - - def _pending_ask_user_id(self, history: list[dict[str, Any]]) -> str | None: - pending: dict[str, str] = {} - for message in history: - if message.get("role") == "assistant": - for tool_call in message.get("tool_calls") or []: - if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str): - pending[tool_call["id"]] = self._tool_call_name(tool_call) - elif message.get("role") == "tool": - tool_call_id = message.get("tool_call_id") - if isinstance(tool_call_id, str): - pending.pop(tool_call_id, None) - for tool_call_id, name in reversed(pending.items()): - if name == "ask_user": - return tool_call_id - return None - - def _ask_user_options_from_messages(self, messages: list[dict[str, Any]]) -> list[str]: - for message in reversed(messages): - if message.get("role") != "assistant": - continue - for tool_call in reversed(message.get("tool_calls") or []): - if not isinstance(tool_call, dict) or self._tool_call_name(tool_call) != "ask_user": - continue - options = self._tool_call_arguments(tool_call).get("options") - if isinstance(options, list): - return [str(option) for option in options if isinstance(option, str)] - return [] - - @staticmethod - def _ask_user_outbound( - content: str | None, - options: list[str], - channel: str, - ) -> tuple[str | None, list[list[str]]]: - if not options: - return content, [] - if channel in BUTTON_CHANNELS: - return content, [options] - option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1)) - return f"{content}\n\n{option_text}" if content else option_text, [] - async def _run_agent_loop( self, initial_messages: list[dict], @@ -874,8 +816,8 @@ class AgentLoop: self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) - options = self._ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [] - content, buttons = self._ask_user_outbound( + options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [] + content, buttons = ask_user_outbound( final_content or "Background task completed.", options, channel, @@ -923,18 +865,14 @@ class AgentLoop: history = session.get_history(max_messages=0) - pending_ask_id = self._pending_ask_user_id(history) + pending_ask_id = pending_ask_user_id(history) if pending_ask_id: - initial_messages = [ - {"role": "system", "content": self.context.build_system_prompt(channel=msg.channel)}, - *history, - { - "role": "tool", - "tool_call_id": pending_ask_id, - "name": "ask_user", - "content": msg.content, - }, - ] + initial_messages = ask_user_tool_result_messages( + self.context.build_system_prompt(channel=msg.channel), + history, + pending_ask_id, + msg.content, + ) else: initial_messages = self.context.build_messages( history=history, @@ -1030,12 +968,12 @@ class AgentLoop: logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) meta = dict(msg.metadata or {}) - final_content, buttons = self._ask_user_outbound( + final_content, buttons = ask_user_outbound( final_content, - self._ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [], + ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [], msg.channel, ) - if on_stream is not None and stop_reason != "error": + if on_stream is not None and stop_reason not in {"ask_user", "error"}: meta["_streamed"] = True return OutboundMessage( channel=msg.channel, diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 688d38714..be71f6498 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -278,17 +278,22 @@ class AgentRunner: self._accumulate_usage(usage, raw_usage) if response.should_execute_tools: + tool_calls = list(response.tool_calls) + ask_index = next((i for i, tc in enumerate(tool_calls) if tc.name == "ask_user"), None) + if ask_index is not None: + tool_calls = tool_calls[: ask_index + 1] + context.tool_calls = list(tool_calls) if hook.wants_streaming(): await hook.on_stream_end(context, resuming=True) assistant_message = build_assistant_message( response.content or "", - tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls], + tool_calls=[tc.to_openai_tool_call() for tc in tool_calls], reasoning_content=response.reasoning_content, thinking_blocks=response.thinking_blocks, ) messages.append(assistant_message) - tools_used.extend(tc.name for tc in response.tool_calls) + tools_used.extend(tc.name for tc in tool_calls) await self._emit_checkpoint( spec, { @@ -297,7 +302,7 @@ class AgentRunner: "model": spec.model, "assistant_message": assistant_message, "completed_tool_results": [], - "pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls], + "pending_tool_calls": [tc.to_openai_tool_call() for tc in tool_calls], }, ) @@ -305,14 +310,14 @@ class AgentRunner: results, new_events, fatal_error = await self._execute_tools( spec, - response.tool_calls, + tool_calls, external_lookup_counts, ) tool_events.extend(new_events) context.tool_results = list(results) context.tool_events = list(new_events) completed_tool_results: list[dict[str, Any]] = [] - for tool_call, result in zip(response.tool_calls, results): + for tool_call, result in zip(tool_calls, results): if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user": continue tool_message = { @@ -700,7 +705,7 @@ class AgentRunner: tool_call: ToolCallRequest, external_lookup_counts: dict[str, int], ) -> tuple[Any, dict[str, str], BaseException | None]: - _HINT = "\n\n[Analyze the error above and try a different approach.]" + hint = "\n\n[Analyze the error above and try a different approach.]" lookup_error = repeated_external_lookup_error( tool_call.name, tool_call.arguments, @@ -713,8 +718,8 @@ class AgentRunner: "detail": "repeated external lookup blocked", } if spec.fail_on_tool_error: - return lookup_error + _HINT, event, RuntimeError(lookup_error) - return lookup_error + _HINT, event, None + return lookup_error + hint, event, RuntimeError(lookup_error) + return lookup_error + hint, event, None prepare_call = getattr(spec.tools, "prepare_call", None) tool, params, prep_error = None, tool_call.arguments, None if callable(prepare_call): @@ -730,7 +735,7 @@ class AgentRunner: "status": "error", "detail": prep_error.split(": ", 1)[-1][:120], } - return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None + return prep_error + hint, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None try: if tool is not None: result = await tool.execute(**params) @@ -758,8 +763,8 @@ class AgentRunner: "detail": result.replace("\n", " ").strip()[:120], } if spec.fail_on_tool_error: - return result + _HINT, event, RuntimeError(result) - return result + _HINT, event, None + return result + hint, event, RuntimeError(result) + return result + hint, event, None detail = "" if result is None else str(result) detail = detail.replace("\n", " ").strip() diff --git a/nanobot/agent/tools/ask.py b/nanobot/agent/tools/ask.py index 0ce371ea8..c2aa8e0e8 100644 --- a/nanobot/agent/tools/ask.py +++ b/nanobot/agent/tools/ask.py @@ -1,10 +1,13 @@ """Tool for pausing a turn until the user answers.""" +import json from typing import Any from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema +BUTTON_CHANNELS = frozenset({"telegram"}) + class AskUserInterrupt(BaseException): """Internal signal: the runner should stop and wait for user input.""" @@ -48,3 +51,86 @@ class AskUserTool(Tool): async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any: raise AskUserInterrupt(question=question, options=options) + + +def _tool_call_name(tool_call: dict[str, Any]) -> str: + function = tool_call.get("function") + if isinstance(function, dict) and isinstance(function.get("name"), str): + return function["name"] + name = tool_call.get("name") + return name if isinstance(name, str) else "" + + +def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]: + function = tool_call.get("function") + raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments") + if isinstance(raw, dict): + return raw + if isinstance(raw, str): + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + return {} + return parsed if isinstance(parsed, dict) else {} + return {} + + +def pending_ask_user_id(history: list[dict[str, Any]]) -> str | None: + pending: dict[str, str] = {} + for message in history: + if message.get("role") == "assistant": + for tool_call in message.get("tool_calls") or []: + if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str): + pending[tool_call["id"]] = _tool_call_name(tool_call) + elif message.get("role") == "tool": + tool_call_id = message.get("tool_call_id") + if isinstance(tool_call_id, str): + pending.pop(tool_call_id, None) + for tool_call_id, name in reversed(pending.items()): + if name == "ask_user": + return tool_call_id + return None + + +def ask_user_tool_result_messages( + system_prompt: str, + history: list[dict[str, Any]], + tool_call_id: str, + content: str, +) -> list[dict[str, Any]]: + return [ + {"role": "system", "content": system_prompt}, + *history, + { + "role": "tool", + "tool_call_id": tool_call_id, + "name": "ask_user", + "content": content, + }, + ] + + +def ask_user_options_from_messages(messages: list[dict[str, Any]]) -> list[str]: + for message in reversed(messages): + if message.get("role") != "assistant": + continue + for tool_call in reversed(message.get("tool_calls") or []): + if not isinstance(tool_call, dict) or _tool_call_name(tool_call) != "ask_user": + continue + options = _tool_call_arguments(tool_call).get("options") + if isinstance(options, list): + return [str(option) for option in options if isinstance(option, str)] + return [] + + +def ask_user_outbound( + content: str | None, + options: list[str], + channel: str, +) -> tuple[str | None, list[list[str]]]: + if not options: + return content, [] + if channel in BUTTON_CHANNELS: + return content, [options] + option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1)) + return f"{content}\n\n{option_text}" if content else option_text, [] diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index d5b17518d..c4cd2b1b4 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -212,12 +212,16 @@ async def _print_interactive_response( def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None) -> None: """Print a CLI progress line, pausing the spinner if needed.""" + if not text.strip(): + return with thinking.pause() if thinking else nullcontext(): console.print(f" [dim]↳ {text}[/dim]") async def _print_interactive_progress_line(text: str, thinking: ThinkingSpinner | None) -> None: """Print an interactive progress line, pausing the spinner if needed.""" + if not text.strip(): + return with thinking.pause() if thinking else nullcontext(): await _print_interactive_line(text) diff --git a/tests/agent/test_ask_user.py b/tests/agent/test_ask_user.py index bdf49663a..4d5b5be93 100644 --- a/tests/agent/test_ask_user.py +++ b/tests/agent/test_ask_user.py @@ -15,10 +15,15 @@ from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequ def _make_provider(chat_with_retry): + async def chat_stream_with_retry(**kwargs): + kwargs.pop("on_content_delta", None) + return await chat_with_retry(**kwargs) + provider = MagicMock() provider.get_default_model.return_value = "test-model" provider.generation = GenerationSettings() provider.chat_with_retry = chat_with_retry + provider.chat_stream_with_retry = chat_stream_with_retry return provider @@ -88,7 +93,8 @@ async def test_runner_pauses_on_ask_user_without_executing_later_tools(): assert "ask_user" in result.tools_used assert later.called is False assert result.messages[-1]["role"] == "assistant" - assert result.messages[-1]["tool_calls"][0]["function"]["name"] == "ask_user" + tool_calls = result.messages[-1]["tool_calls"] + assert [tool_call["function"]["name"] for tool_call in tool_calls] == ["ask_user"] assert not any(message.get("name") == "ask_user" for message in result.messages) @@ -122,13 +128,22 @@ async def test_ask_user_text_fallback_resumes_with_next_message(tmp_path): model="test-model", ) + async def on_stream(delta: str) -> None: + pass + + async def on_stream_end(**kwargs) -> None: + pass + first = await loop._process_message( - InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up") + InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up"), + on_stream=on_stream, + on_stream_end=on_stream_end, ) assert first is not None assert first.content == "Install the optional package?\n\n1. Install\n2. Skip" assert first.buttons == [] + assert "_streamed" not in first.metadata session = loop.sessions.get_or_create("cli:direct") assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)