feat(agent): add ask_user tool

Made-with: Cursor
This commit is contained in:
Xubin Ren 2026-04-25 12:34:29 +00:00 committed by Xubin Ren
parent 830211b5d4
commit cfc76ffbbf
4 changed files with 320 additions and 16 deletions

View File

@ -20,14 +20,15 @@ from nanobot.agent.memory import Consolidator, Dream
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.ask import AskUserTool
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.notebook import NotebookEditTool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.search import GlobTool, GrepTool
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.self import MyTool
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.spawn import SpawnTool
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage, OutboundMessage
@ -287,6 +288,7 @@ class AgentLoop:
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
)
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
self.tools.register(AskUserTool())
self.tools.register(
ReadFileTool(
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
@ -407,6 +409,56 @@ class AgentLoop:
return UNIFIED_SESSION_KEY
return msg.session_key
@staticmethod
def _tool_call_name(tool_call: dict[str, Any]) -> str:
function = tool_call.get("function")
if isinstance(function, dict) and isinstance(function.get("name"), str):
return function["name"]
name = tool_call.get("name")
return name if isinstance(name, str) else ""
@staticmethod
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
function = tool_call.get("function")
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return {}
return parsed if isinstance(parsed, dict) else {}
return {}
def _pending_ask_user_id(self, history: list[dict[str, Any]]) -> str | None:
pending: dict[str, str] = {}
for message in history:
if message.get("role") == "assistant":
for tool_call in message.get("tool_calls") or []:
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
pending[tool_call["id"]] = self._tool_call_name(tool_call)
elif message.get("role") == "tool":
tool_call_id = message.get("tool_call_id")
if isinstance(tool_call_id, str):
pending.pop(tool_call_id, None)
for tool_call_id, name in reversed(pending.items()):
if name == "ask_user":
return tool_call_id
return None
def _ask_user_options_from_messages(self, messages: list[dict[str, Any]]) -> list[str]:
for message in reversed(messages):
if message.get("role") != "assistant":
continue
for tool_call in reversed(message.get("tool_calls") or []):
if not isinstance(tool_call, dict) or self._tool_call_name(tool_call) != "ask_user":
continue
options = self._tool_call_arguments(tool_call).get("options")
if isinstance(options, list):
return [str(option) for option in options if isinstance(option, str)]
return []
async def _run_agent_loop(
self,
initial_messages: list[dict],
@ -799,7 +851,7 @@ class AgentLoop:
session_summary=pending,
current_role=current_role,
)
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
messages, session=session, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"),
pending_queue=pending_queue,
@ -808,10 +860,12 @@ class AgentLoop:
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
options = self._ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
return OutboundMessage(
channel=channel,
chat_id=chat_id,
content=final_content or "Background task completed.",
buttons=[options] if options else [],
)
# Extract document text from media at the processing boundary so all
@ -850,6 +904,19 @@ class AgentLoop:
history = session.get_history(max_messages=0)
pending_ask_id = self._pending_ask_user_id(history)
if pending_ask_id:
initial_messages = [
{"role": "system", "content": self.context.build_system_prompt(channel=msg.channel)},
*history,
{
"role": "tool",
"tool_call_id": pending_ask_id,
"name": "ask_user",
"content": msg.content,
},
]
else:
initial_messages = self.context.build_messages(
history=history,
current_message=msg.content,
@ -898,7 +965,7 @@ class AgentLoop:
user_persisted_early = False
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
has_text = isinstance(msg.content, str) and msg.content.strip()
if has_text or media_paths:
if not pending_ask_id and (has_text or media_paths):
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
text = msg.content if isinstance(msg.content, str) else ""
session.add_message("user", text, **extra)
@ -944,6 +1011,11 @@ class AgentLoop:
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
meta = dict(msg.metadata or {})
buttons: list[list[str]] = []
if stop_reason == "ask_user":
options = self._ask_user_options_from_messages(all_msgs)
if options:
buttons = [options]
if on_stream is not None and stop_reason != "error":
meta["_streamed"] = True
return OutboundMessage(
@ -951,6 +1023,7 @@ class AgentLoop:
chat_id=msg.chat_id,
content=final_content,
metadata=meta,
buttons=buttons,
)
def _sanitize_persisted_blocks(

View File

@ -3,16 +3,16 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
import inspect
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from loguru import logger
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.utils.prompt_templates import render_template
from nanobot.agent.tools.ask import AskUserInterrupt
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.utils.helpers import (
@ -23,6 +23,7 @@ from nanobot.utils.helpers import (
maybe_persist_tool_result,
truncate_text,
)
from nanobot.utils.prompt_templates import render_template
from nanobot.utils.runtime import (
EMPTY_FINAL_RESPONSE_MESSAGE,
build_finalization_retry_message,
@ -312,6 +313,8 @@ class AgentRunner:
context.tool_events = list(new_events)
completed_tool_results: list[dict[str, Any]] = []
for tool_call, result in zip(response.tool_calls, results):
if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user":
continue
tool_message = {
"role": "tool",
"tool_call_id": tool_call.id,
@ -326,6 +329,15 @@ class AgentRunner:
messages.append(tool_message)
completed_tool_results.append(tool_message)
if fatal_error is not None:
if isinstance(fatal_error, AskUserInterrupt):
final_content = fatal_error.question
stop_reason = "ask_user"
context.final_content = final_content
context.stop_reason = stop_reason
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)
await hook.after_iteration(context)
break
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
final_content = error
stop_reason = "tool_error"
@ -656,13 +668,21 @@ class AgentRunner:
tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
for batch in batches:
if spec.concurrent_tools and len(batch) > 1:
tool_results.extend(await asyncio.gather(*(
batch_results = await asyncio.gather(*(
self._run_tool(spec, tool_call, external_lookup_counts)
for tool_call in batch
)))
))
tool_results.extend(batch_results)
else:
batch_results = []
for tool_call in batch:
tool_results.append(await self._run_tool(spec, tool_call, external_lookup_counts))
result = await self._run_tool(spec, tool_call, external_lookup_counts)
tool_results.append(result)
batch_results.append(result)
if isinstance(result[2], AskUserInterrupt):
break
if any(isinstance(error, AskUserInterrupt) for _, _, error in batch_results):
break
results: list[Any] = []
events: list[dict[str, str]] = []
@ -724,6 +744,9 @@ class AgentRunner:
"status": "error",
"detail": str(exc),
}
if isinstance(exc, AskUserInterrupt):
event["status"] = "waiting"
return "", event, exc
if spec.fail_on_tool_error:
return f"Error: {type(exc).__name__}: {exc}", event, exc
return f"Error: {type(exc).__name__}: {exc}", event, None

View File

@ -0,0 +1,50 @@
"""Tool for pausing a turn until the user answers."""
from typing import Any
from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
class AskUserInterrupt(BaseException):
"""Internal signal: the runner should stop and wait for user input."""
def __init__(self, question: str, options: list[str] | None = None) -> None:
self.question = question
self.options = [str(option) for option in (options or []) if str(option)]
super().__init__(question)
@tool_parameters(
tool_parameters_schema(
question=StringSchema(
"The question to ask before continuing. Use this only when the task needs the user's answer."
),
options=ArraySchema(
StringSchema("A possible answer label"),
description="Optional choices. The user may still reply with free text.",
),
required=["question"],
)
)
class AskUserTool(Tool):
"""Ask the user a blocking question."""
@property
def name(self) -> str:
return "ask_user"
@property
def description(self) -> str:
return (
"Pause and ask the user a question when their answer is required to continue. "
"Use options for likely answers; the user's reply, typed or selected, is returned as the tool result. "
"For non-blocking notifications or buttons, use the message tool instead."
)
@property
def exclusive(self) -> bool:
return True
async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any:
raise AskUserInterrupt(question=question, options=options)

View File

@ -0,0 +1,158 @@
import asyncio
from unittest.mock import MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.agent.runner import AgentRunner, AgentRunSpec
from nanobot.agent.tools.ask import AskUserInterrupt, AskUserTool
from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.schema import tool_parameters_schema
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequest
def _make_provider(chat_with_retry):
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.generation = GenerationSettings()
provider.chat_with_retry = chat_with_retry
return provider
def test_ask_user_tool_schema_and_interrupt():
tool = AskUserTool()
schema = tool.to_schema()["function"]
assert schema["name"] == "ask_user"
assert "question" in schema["parameters"]["required"]
assert schema["parameters"]["properties"]["options"]["type"] == "array"
with pytest.raises(AskUserInterrupt) as exc:
asyncio.run(tool.execute("Continue?", options=["Yes", "No"]))
assert exc.value.question == "Continue?"
assert exc.value.options == ["Yes", "No"]
@pytest.mark.asyncio
async def test_runner_pauses_on_ask_user_without_executing_later_tools():
@tool_parameters(tool_parameters_schema(required=[]))
class LaterTool(Tool):
called = False
@property
def name(self) -> str:
return "later"
@property
def description(self) -> str:
return "Should not run after ask_user pauses the turn."
async def execute(self, **kwargs):
self.called = True
return "later result"
async def chat_with_retry(**kwargs):
return LLMResponse(
content="",
finish_reason="tool_calls",
tool_calls=[
ToolCallRequest(
id="call_ask",
name="ask_user",
arguments={"question": "Install this package?", "options": ["Yes", "No"]},
),
ToolCallRequest(id="call_later", name="later", arguments={}),
],
)
later = LaterTool()
tools = ToolRegistry()
tools.register(AskUserTool())
tools.register(later)
result = await AgentRunner(_make_provider(chat_with_retry)).run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "continue"}],
tools=tools,
model="test-model",
max_iterations=3,
max_tool_result_chars=16_000,
concurrent_tools=True,
))
assert result.stop_reason == "ask_user"
assert result.final_content == "Install this package?"
assert "ask_user" in result.tools_used
assert later.called is False
assert result.messages[-1]["role"] == "assistant"
assert result.messages[-1]["tool_calls"][0]["function"]["name"] == "ask_user"
assert not any(message.get("name") == "ask_user" for message in result.messages)
@pytest.mark.asyncio
async def test_ask_user_sends_buttons_and_resumes_with_next_message(tmp_path):
seen_messages: list[list[dict]] = []
async def chat_with_retry(**kwargs):
seen_messages.append(kwargs["messages"])
if len(seen_messages) == 1:
return LLMResponse(
content="",
finish_reason="tool_calls",
tool_calls=[
ToolCallRequest(
id="call_ask",
name="ask_user",
arguments={
"question": "Install the optional package?",
"options": ["Install", "Skip"],
},
)
],
)
return LLMResponse(content="Skipped install.", usage={})
loop = AgentLoop(
bus=MessageBus(),
provider=_make_provider(chat_with_retry),
workspace=tmp_path,
model="test-model",
)
first = await loop._process_message(
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up")
)
assert first is not None
assert first.content == "Install the optional package?"
assert first.buttons == [["Install", "Skip"]]
session = loop.sessions.get_or_create("cli:direct")
assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)
assert not any(message.get("role") == "tool" and message.get("name") == "ask_user" for message in session.messages)
second = await loop._process_message(
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="Skip")
)
assert second is not None
assert second.content == "Skipped install."
assert any(
message.get("role") == "tool"
and message.get("name") == "ask_user"
and message.get("content") == "Skip"
for message in seen_messages[-1]
)
assert not any(
message.get("role") == "user" and message.get("content") == "Skip"
for message in session.messages
)
assert any(
message.get("role") == "tool"
and message.get("name") == "ask_user"
and message.get("content") == "Skip"
for message in session.messages
)