fix(agent): tighten ask_user CLI handling

Made-with: Cursor
This commit is contained in:
Xubin Ren 2026-04-25 14:06:09 +00:00 committed by Xubin Ren
parent 3b1ea99ee1
commit 403ce23d22
5 changed files with 142 additions and 94 deletions

View File

@ -20,7 +20,13 @@ from nanobot.agent.memory import Consolidator, Dream
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
from nanobot.agent.skills import BUILTIN_SKILLS_DIR from nanobot.agent.skills import BUILTIN_SKILLS_DIR
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.ask import AskUserTool from nanobot.agent.tools.ask import (
AskUserTool,
ask_user_options_from_messages,
ask_user_outbound,
ask_user_tool_result_messages,
pending_ask_user_id,
)
from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.message import MessageTool
@ -54,7 +60,6 @@ if TYPE_CHECKING:
UNIFIED_SESSION_KEY = "unified:default" UNIFIED_SESSION_KEY = "unified:default"
BUTTON_CHANNELS = frozenset({"telegram"})
class _LoopHook(AgentHook): class _LoopHook(AgentHook):
@ -410,69 +415,6 @@ class AgentLoop:
return UNIFIED_SESSION_KEY return UNIFIED_SESSION_KEY
return msg.session_key return msg.session_key
@staticmethod
def _tool_call_name(tool_call: dict[str, Any]) -> str:
function = tool_call.get("function")
if isinstance(function, dict) and isinstance(function.get("name"), str):
return function["name"]
name = tool_call.get("name")
return name if isinstance(name, str) else ""
@staticmethod
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
function = tool_call.get("function")
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return {}
return parsed if isinstance(parsed, dict) else {}
return {}
def _pending_ask_user_id(self, history: list[dict[str, Any]]) -> str | None:
pending: dict[str, str] = {}
for message in history:
if message.get("role") == "assistant":
for tool_call in message.get("tool_calls") or []:
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
pending[tool_call["id"]] = self._tool_call_name(tool_call)
elif message.get("role") == "tool":
tool_call_id = message.get("tool_call_id")
if isinstance(tool_call_id, str):
pending.pop(tool_call_id, None)
for tool_call_id, name in reversed(pending.items()):
if name == "ask_user":
return tool_call_id
return None
def _ask_user_options_from_messages(self, messages: list[dict[str, Any]]) -> list[str]:
for message in reversed(messages):
if message.get("role") != "assistant":
continue
for tool_call in reversed(message.get("tool_calls") or []):
if not isinstance(tool_call, dict) or self._tool_call_name(tool_call) != "ask_user":
continue
options = self._tool_call_arguments(tool_call).get("options")
if isinstance(options, list):
return [str(option) for option in options if isinstance(option, str)]
return []
@staticmethod
def _ask_user_outbound(
content: str | None,
options: list[str],
channel: str,
) -> tuple[str | None, list[list[str]]]:
if not options:
return content, []
if channel in BUTTON_CHANNELS:
return content, [options]
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
return f"{content}\n\n{option_text}" if content else option_text, []
async def _run_agent_loop( async def _run_agent_loop(
self, self,
initial_messages: list[dict], initial_messages: list[dict],
@ -874,8 +816,8 @@ class AgentLoop:
self._clear_runtime_checkpoint(session) self._clear_runtime_checkpoint(session)
self.sessions.save(session) self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
options = self._ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [] options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
content, buttons = self._ask_user_outbound( content, buttons = ask_user_outbound(
final_content or "Background task completed.", final_content or "Background task completed.",
options, options,
channel, channel,
@ -923,18 +865,14 @@ class AgentLoop:
history = session.get_history(max_messages=0) history = session.get_history(max_messages=0)
pending_ask_id = self._pending_ask_user_id(history) pending_ask_id = pending_ask_user_id(history)
if pending_ask_id: if pending_ask_id:
initial_messages = [ initial_messages = ask_user_tool_result_messages(
{"role": "system", "content": self.context.build_system_prompt(channel=msg.channel)}, self.context.build_system_prompt(channel=msg.channel),
*history, history,
{ pending_ask_id,
"role": "tool", msg.content,
"tool_call_id": pending_ask_id, )
"name": "ask_user",
"content": msg.content,
},
]
else: else:
initial_messages = self.context.build_messages( initial_messages = self.context.build_messages(
history=history, history=history,
@ -1030,12 +968,12 @@ class AgentLoop:
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
meta = dict(msg.metadata or {}) meta = dict(msg.metadata or {})
final_content, buttons = self._ask_user_outbound( final_content, buttons = ask_user_outbound(
final_content, final_content,
self._ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [], ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [],
msg.channel, msg.channel,
) )
if on_stream is not None and stop_reason != "error": if on_stream is not None and stop_reason not in {"ask_user", "error"}:
meta["_streamed"] = True meta["_streamed"] = True
return OutboundMessage( return OutboundMessage(
channel=msg.channel, channel=msg.channel,

View File

@ -278,17 +278,22 @@ class AgentRunner:
self._accumulate_usage(usage, raw_usage) self._accumulate_usage(usage, raw_usage)
if response.should_execute_tools: if response.should_execute_tools:
tool_calls = list(response.tool_calls)
ask_index = next((i for i, tc in enumerate(tool_calls) if tc.name == "ask_user"), None)
if ask_index is not None:
tool_calls = tool_calls[: ask_index + 1]
context.tool_calls = list(tool_calls)
if hook.wants_streaming(): if hook.wants_streaming():
await hook.on_stream_end(context, resuming=True) await hook.on_stream_end(context, resuming=True)
assistant_message = build_assistant_message( assistant_message = build_assistant_message(
response.content or "", response.content or "",
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls], tool_calls=[tc.to_openai_tool_call() for tc in tool_calls],
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks, thinking_blocks=response.thinking_blocks,
) )
messages.append(assistant_message) messages.append(assistant_message)
tools_used.extend(tc.name for tc in response.tool_calls) tools_used.extend(tc.name for tc in tool_calls)
await self._emit_checkpoint( await self._emit_checkpoint(
spec, spec,
{ {
@ -297,7 +302,7 @@ class AgentRunner:
"model": spec.model, "model": spec.model,
"assistant_message": assistant_message, "assistant_message": assistant_message,
"completed_tool_results": [], "completed_tool_results": [],
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls], "pending_tool_calls": [tc.to_openai_tool_call() for tc in tool_calls],
}, },
) )
@ -305,14 +310,14 @@ class AgentRunner:
results, new_events, fatal_error = await self._execute_tools( results, new_events, fatal_error = await self._execute_tools(
spec, spec,
response.tool_calls, tool_calls,
external_lookup_counts, external_lookup_counts,
) )
tool_events.extend(new_events) tool_events.extend(new_events)
context.tool_results = list(results) context.tool_results = list(results)
context.tool_events = list(new_events) context.tool_events = list(new_events)
completed_tool_results: list[dict[str, Any]] = [] completed_tool_results: list[dict[str, Any]] = []
for tool_call, result in zip(response.tool_calls, results): for tool_call, result in zip(tool_calls, results):
if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user": if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user":
continue continue
tool_message = { tool_message = {
@ -700,7 +705,7 @@ class AgentRunner:
tool_call: ToolCallRequest, tool_call: ToolCallRequest,
external_lookup_counts: dict[str, int], external_lookup_counts: dict[str, int],
) -> tuple[Any, dict[str, str], BaseException | None]: ) -> tuple[Any, dict[str, str], BaseException | None]:
_HINT = "\n\n[Analyze the error above and try a different approach.]" hint = "\n\n[Analyze the error above and try a different approach.]"
lookup_error = repeated_external_lookup_error( lookup_error = repeated_external_lookup_error(
tool_call.name, tool_call.name,
tool_call.arguments, tool_call.arguments,
@ -713,8 +718,8 @@ class AgentRunner:
"detail": "repeated external lookup blocked", "detail": "repeated external lookup blocked",
} }
if spec.fail_on_tool_error: if spec.fail_on_tool_error:
return lookup_error + _HINT, event, RuntimeError(lookup_error) return lookup_error + hint, event, RuntimeError(lookup_error)
return lookup_error + _HINT, event, None return lookup_error + hint, event, None
prepare_call = getattr(spec.tools, "prepare_call", None) prepare_call = getattr(spec.tools, "prepare_call", None)
tool, params, prep_error = None, tool_call.arguments, None tool, params, prep_error = None, tool_call.arguments, None
if callable(prepare_call): if callable(prepare_call):
@ -730,7 +735,7 @@ class AgentRunner:
"status": "error", "status": "error",
"detail": prep_error.split(": ", 1)[-1][:120], "detail": prep_error.split(": ", 1)[-1][:120],
} }
return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None return prep_error + hint, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
try: try:
if tool is not None: if tool is not None:
result = await tool.execute(**params) result = await tool.execute(**params)
@ -758,8 +763,8 @@ class AgentRunner:
"detail": result.replace("\n", " ").strip()[:120], "detail": result.replace("\n", " ").strip()[:120],
} }
if spec.fail_on_tool_error: if spec.fail_on_tool_error:
return result + _HINT, event, RuntimeError(result) return result + hint, event, RuntimeError(result)
return result + _HINT, event, None return result + hint, event, None
detail = "" if result is None else str(result) detail = "" if result is None else str(result)
detail = detail.replace("\n", " ").strip() detail = detail.replace("\n", " ").strip()

View File

@ -1,10 +1,13 @@
"""Tool for pausing a turn until the user answers.""" """Tool for pausing a turn until the user answers."""
import json
from typing import Any from typing import Any
from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
BUTTON_CHANNELS = frozenset({"telegram"})
class AskUserInterrupt(BaseException): class AskUserInterrupt(BaseException):
"""Internal signal: the runner should stop and wait for user input.""" """Internal signal: the runner should stop and wait for user input."""
@ -48,3 +51,86 @@ class AskUserTool(Tool):
async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any: async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any:
raise AskUserInterrupt(question=question, options=options) raise AskUserInterrupt(question=question, options=options)
def _tool_call_name(tool_call: dict[str, Any]) -> str:
function = tool_call.get("function")
if isinstance(function, dict) and isinstance(function.get("name"), str):
return function["name"]
name = tool_call.get("name")
return name if isinstance(name, str) else ""
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
function = tool_call.get("function")
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return {}
return parsed if isinstance(parsed, dict) else {}
return {}
def pending_ask_user_id(history: list[dict[str, Any]]) -> str | None:
pending: dict[str, str] = {}
for message in history:
if message.get("role") == "assistant":
for tool_call in message.get("tool_calls") or []:
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
pending[tool_call["id"]] = _tool_call_name(tool_call)
elif message.get("role") == "tool":
tool_call_id = message.get("tool_call_id")
if isinstance(tool_call_id, str):
pending.pop(tool_call_id, None)
for tool_call_id, name in reversed(pending.items()):
if name == "ask_user":
return tool_call_id
return None
def ask_user_tool_result_messages(
system_prompt: str,
history: list[dict[str, Any]],
tool_call_id: str,
content: str,
) -> list[dict[str, Any]]:
return [
{"role": "system", "content": system_prompt},
*history,
{
"role": "tool",
"tool_call_id": tool_call_id,
"name": "ask_user",
"content": content,
},
]
def ask_user_options_from_messages(messages: list[dict[str, Any]]) -> list[str]:
for message in reversed(messages):
if message.get("role") != "assistant":
continue
for tool_call in reversed(message.get("tool_calls") or []):
if not isinstance(tool_call, dict) or _tool_call_name(tool_call) != "ask_user":
continue
options = _tool_call_arguments(tool_call).get("options")
if isinstance(options, list):
return [str(option) for option in options if isinstance(option, str)]
return []
def ask_user_outbound(
content: str | None,
options: list[str],
channel: str,
) -> tuple[str | None, list[list[str]]]:
if not options:
return content, []
if channel in BUTTON_CHANNELS:
return content, [options]
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
return f"{content}\n\n{option_text}" if content else option_text, []

View File

@ -212,12 +212,16 @@ async def _print_interactive_response(
def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None) -> None: def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
"""Print a CLI progress line, pausing the spinner if needed.""" """Print a CLI progress line, pausing the spinner if needed."""
if not text.strip():
return
with thinking.pause() if thinking else nullcontext(): with thinking.pause() if thinking else nullcontext():
console.print(f" [dim]↳ {text}[/dim]") console.print(f" [dim]↳ {text}[/dim]")
async def _print_interactive_progress_line(text: str, thinking: ThinkingSpinner | None) -> None: async def _print_interactive_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
"""Print an interactive progress line, pausing the spinner if needed.""" """Print an interactive progress line, pausing the spinner if needed."""
if not text.strip():
return
with thinking.pause() if thinking else nullcontext(): with thinking.pause() if thinking else nullcontext():
await _print_interactive_line(text) await _print_interactive_line(text)

View File

@ -15,10 +15,15 @@ from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequ
def _make_provider(chat_with_retry): def _make_provider(chat_with_retry):
async def chat_stream_with_retry(**kwargs):
kwargs.pop("on_content_delta", None)
return await chat_with_retry(**kwargs)
provider = MagicMock() provider = MagicMock()
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
provider.generation = GenerationSettings() provider.generation = GenerationSettings()
provider.chat_with_retry = chat_with_retry provider.chat_with_retry = chat_with_retry
provider.chat_stream_with_retry = chat_stream_with_retry
return provider return provider
@ -88,7 +93,8 @@ async def test_runner_pauses_on_ask_user_without_executing_later_tools():
assert "ask_user" in result.tools_used assert "ask_user" in result.tools_used
assert later.called is False assert later.called is False
assert result.messages[-1]["role"] == "assistant" assert result.messages[-1]["role"] == "assistant"
assert result.messages[-1]["tool_calls"][0]["function"]["name"] == "ask_user" tool_calls = result.messages[-1]["tool_calls"]
assert [tool_call["function"]["name"] for tool_call in tool_calls] == ["ask_user"]
assert not any(message.get("name") == "ask_user" for message in result.messages) assert not any(message.get("name") == "ask_user" for message in result.messages)
@ -122,13 +128,22 @@ async def test_ask_user_text_fallback_resumes_with_next_message(tmp_path):
model="test-model", model="test-model",
) )
async def on_stream(delta: str) -> None:
pass
async def on_stream_end(**kwargs) -> None:
pass
first = await loop._process_message( first = await loop._process_message(
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up") InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up"),
on_stream=on_stream,
on_stream_end=on_stream_end,
) )
assert first is not None assert first is not None
assert first.content == "Install the optional package?\n\n1. Install\n2. Skip" assert first.content == "Install the optional package?\n\n1. Install\n2. Skip"
assert first.buttons == [] assert first.buttons == []
assert "_streamed" not in first.metadata
session = loop.sessions.get_or_create("cli:direct") session = loop.sessions.get_or_create("cli:direct")
assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages) assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)