mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-26 19:42:41 +00:00
feat(agent): add ask_user tool
Made-with: Cursor
This commit is contained in:
parent
830211b5d4
commit
cfc76ffbbf
@ -20,14 +20,15 @@ from nanobot.agent.memory import Consolidator, Dream
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.ask import AskUserTool
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.search import GlobTool, GrepTool
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.self import MyTool
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
@ -287,6 +288,7 @@ class AgentLoop:
|
||||
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||
)
|
||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||
self.tools.register(AskUserTool())
|
||||
self.tools.register(
|
||||
ReadFileTool(
|
||||
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
|
||||
@ -407,6 +409,56 @@ class AgentLoop:
|
||||
return UNIFIED_SESSION_KEY
|
||||
return msg.session_key
|
||||
|
||||
@staticmethod
|
||||
def _tool_call_name(tool_call: dict[str, Any]) -> str:
|
||||
function = tool_call.get("function")
|
||||
if isinstance(function, dict) and isinstance(function.get("name"), str):
|
||||
return function["name"]
|
||||
name = tool_call.get("name")
|
||||
return name if isinstance(name, str) else ""
|
||||
|
||||
@staticmethod
|
||||
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
|
||||
function = tool_call.get("function")
|
||||
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
return {}
|
||||
|
||||
def _pending_ask_user_id(self, history: list[dict[str, Any]]) -> str | None:
|
||||
pending: dict[str, str] = {}
|
||||
for message in history:
|
||||
if message.get("role") == "assistant":
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
|
||||
pending[tool_call["id"]] = self._tool_call_name(tool_call)
|
||||
elif message.get("role") == "tool":
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
pending.pop(tool_call_id, None)
|
||||
for tool_call_id, name in reversed(pending.items()):
|
||||
if name == "ask_user":
|
||||
return tool_call_id
|
||||
return None
|
||||
|
||||
def _ask_user_options_from_messages(self, messages: list[dict[str, Any]]) -> list[str]:
|
||||
for message in reversed(messages):
|
||||
if message.get("role") != "assistant":
|
||||
continue
|
||||
for tool_call in reversed(message.get("tool_calls") or []):
|
||||
if not isinstance(tool_call, dict) or self._tool_call_name(tool_call) != "ask_user":
|
||||
continue
|
||||
options = self._tool_call_arguments(tool_call).get("options")
|
||||
if isinstance(options, list):
|
||||
return [str(option) for option in options if isinstance(option, str)]
|
||||
return []
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
@ -799,7 +851,7 @@ class AgentLoop:
|
||||
session_summary=pending,
|
||||
current_role=current_role,
|
||||
)
|
||||
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
|
||||
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
pending_queue=pending_queue,
|
||||
@ -808,10 +860,12 @@ class AgentLoop:
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
options = self._ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
|
||||
return OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=final_content or "Background task completed.",
|
||||
buttons=[options] if options else [],
|
||||
)
|
||||
|
||||
# Extract document text from media at the processing boundary so all
|
||||
@ -850,14 +904,27 @@ class AgentLoop:
|
||||
|
||||
history = session.get_history(max_messages=0)
|
||||
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
session_summary=pending,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
pending_ask_id = self._pending_ask_user_id(history)
|
||||
if pending_ask_id:
|
||||
initial_messages = [
|
||||
{"role": "system", "content": self.context.build_system_prompt(channel=msg.channel)},
|
||||
*history,
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": pending_ask_id,
|
||||
"name": "ask_user",
|
||||
"content": msg.content,
|
||||
},
|
||||
]
|
||||
else:
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
session_summary=pending,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
async def _bus_progress(
|
||||
content: str,
|
||||
@ -898,7 +965,7 @@ class AgentLoop:
|
||||
user_persisted_early = False
|
||||
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
|
||||
has_text = isinstance(msg.content, str) and msg.content.strip()
|
||||
if has_text or media_paths:
|
||||
if not pending_ask_id and (has_text or media_paths):
|
||||
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
|
||||
text = msg.content if isinstance(msg.content, str) else ""
|
||||
session.add_message("user", text, **extra)
|
||||
@ -944,6 +1011,11 @@ class AgentLoop:
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
meta = dict(msg.metadata or {})
|
||||
buttons: list[list[str]] = []
|
||||
if stop_reason == "ask_user":
|
||||
options = self._ask_user_options_from_messages(all_msgs)
|
||||
if options:
|
||||
buttons = [options]
|
||||
if on_stream is not None and stop_reason != "error":
|
||||
meta["_streamed"] = True
|
||||
return OutboundMessage(
|
||||
@ -951,6 +1023,7 @@ class AgentLoop:
|
||||
chat_id=msg.chat_id,
|
||||
content=final_content,
|
||||
metadata=meta,
|
||||
buttons=buttons,
|
||||
)
|
||||
|
||||
def _sanitize_persisted_blocks(
|
||||
|
||||
@ -3,16 +3,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
import inspect
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.utils.prompt_templates import render_template
|
||||
from nanobot.agent.tools.ask import AskUserInterrupt
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.utils.helpers import (
|
||||
@ -23,6 +23,7 @@ from nanobot.utils.helpers import (
|
||||
maybe_persist_tool_result,
|
||||
truncate_text,
|
||||
)
|
||||
from nanobot.utils.prompt_templates import render_template
|
||||
from nanobot.utils.runtime import (
|
||||
EMPTY_FINAL_RESPONSE_MESSAGE,
|
||||
build_finalization_retry_message,
|
||||
@ -312,6 +313,8 @@ class AgentRunner:
|
||||
context.tool_events = list(new_events)
|
||||
completed_tool_results: list[dict[str, Any]] = []
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user":
|
||||
continue
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
@ -326,6 +329,15 @@ class AgentRunner:
|
||||
messages.append(tool_message)
|
||||
completed_tool_results.append(tool_message)
|
||||
if fatal_error is not None:
|
||||
if isinstance(fatal_error, AskUserInterrupt):
|
||||
final_content = fatal_error.question
|
||||
stop_reason = "ask_user"
|
||||
context.final_content = final_content
|
||||
context.stop_reason = stop_reason
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||
final_content = error
|
||||
stop_reason = "tool_error"
|
||||
@ -656,13 +668,21 @@ class AgentRunner:
|
||||
tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
|
||||
for batch in batches:
|
||||
if spec.concurrent_tools and len(batch) > 1:
|
||||
tool_results.extend(await asyncio.gather(*(
|
||||
batch_results = await asyncio.gather(*(
|
||||
self._run_tool(spec, tool_call, external_lookup_counts)
|
||||
for tool_call in batch
|
||||
)))
|
||||
))
|
||||
tool_results.extend(batch_results)
|
||||
else:
|
||||
batch_results = []
|
||||
for tool_call in batch:
|
||||
tool_results.append(await self._run_tool(spec, tool_call, external_lookup_counts))
|
||||
result = await self._run_tool(spec, tool_call, external_lookup_counts)
|
||||
tool_results.append(result)
|
||||
batch_results.append(result)
|
||||
if isinstance(result[2], AskUserInterrupt):
|
||||
break
|
||||
if any(isinstance(error, AskUserInterrupt) for _, _, error in batch_results):
|
||||
break
|
||||
|
||||
results: list[Any] = []
|
||||
events: list[dict[str, str]] = []
|
||||
@ -724,6 +744,9 @@ class AgentRunner:
|
||||
"status": "error",
|
||||
"detail": str(exc),
|
||||
}
|
||||
if isinstance(exc, AskUserInterrupt):
|
||||
event["status"] = "waiting"
|
||||
return "", event, exc
|
||||
if spec.fail_on_tool_error:
|
||||
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
||||
return f"Error: {type(exc).__name__}: {exc}", event, None
|
||||
|
||||
50
nanobot/agent/tools/ask.py
Normal file
50
nanobot/agent/tools/ask.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""Tool for pausing a turn until the user answers."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
||||
|
||||
|
||||
class AskUserInterrupt(BaseException):
|
||||
"""Internal signal: the runner should stop and wait for user input."""
|
||||
|
||||
def __init__(self, question: str, options: list[str] | None = None) -> None:
|
||||
self.question = question
|
||||
self.options = [str(option) for option in (options or []) if str(option)]
|
||||
super().__init__(question)
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
question=StringSchema(
|
||||
"The question to ask before continuing. Use this only when the task needs the user's answer."
|
||||
),
|
||||
options=ArraySchema(
|
||||
StringSchema("A possible answer label"),
|
||||
description="Optional choices. The user may still reply with free text.",
|
||||
),
|
||||
required=["question"],
|
||||
)
|
||||
)
|
||||
class AskUserTool(Tool):
|
||||
"""Ask the user a blocking question."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "ask_user"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Pause and ask the user a question when their answer is required to continue. "
|
||||
"Use options for likely answers; the user's reply, typed or selected, is returned as the tool result. "
|
||||
"For non-blocking notifications or buttons, use the message tool instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any:
|
||||
raise AskUserInterrupt(question=question, options=options)
|
||||
158
tests/agent/test_ask_user.py
Normal file
158
tests/agent/test_ask_user.py
Normal file
@ -0,0 +1,158 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.tools.ask import AskUserInterrupt, AskUserTool
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.schema import tool_parameters_schema
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_provider(chat_with_retry):
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = GenerationSettings()
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
return provider
|
||||
|
||||
|
||||
def test_ask_user_tool_schema_and_interrupt():
|
||||
tool = AskUserTool()
|
||||
schema = tool.to_schema()["function"]
|
||||
|
||||
assert schema["name"] == "ask_user"
|
||||
assert "question" in schema["parameters"]["required"]
|
||||
assert schema["parameters"]["properties"]["options"]["type"] == "array"
|
||||
|
||||
with pytest.raises(AskUserInterrupt) as exc:
|
||||
asyncio.run(tool.execute("Continue?", options=["Yes", "No"]))
|
||||
|
||||
assert exc.value.question == "Continue?"
|
||||
assert exc.value.options == ["Yes", "No"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_pauses_on_ask_user_without_executing_later_tools():
|
||||
@tool_parameters(tool_parameters_schema(required=[]))
|
||||
class LaterTool(Tool):
|
||||
called = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "later"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Should not run after ask_user pauses the turn."
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
self.called = True
|
||||
return "later result"
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={"question": "Install this package?", "options": ["Yes", "No"]},
|
||||
),
|
||||
ToolCallRequest(id="call_later", name="later", arguments={}),
|
||||
],
|
||||
)
|
||||
|
||||
later = LaterTool()
|
||||
tools = ToolRegistry()
|
||||
tools.register(AskUserTool())
|
||||
tools.register(later)
|
||||
|
||||
result = await AgentRunner(_make_provider(chat_with_retry)).run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "continue"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=16_000,
|
||||
concurrent_tools=True,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "ask_user"
|
||||
assert result.final_content == "Install this package?"
|
||||
assert "ask_user" in result.tools_used
|
||||
assert later.called is False
|
||||
assert result.messages[-1]["role"] == "assistant"
|
||||
assert result.messages[-1]["tool_calls"][0]["function"]["name"] == "ask_user"
|
||||
assert not any(message.get("name") == "ask_user" for message in result.messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_user_sends_buttons_and_resumes_with_next_message(tmp_path):
|
||||
seen_messages: list[list[dict]] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
seen_messages.append(kwargs["messages"])
|
||||
if len(seen_messages) == 1:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={
|
||||
"question": "Install the optional package?",
|
||||
"options": ["Install", "Skip"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
return LLMResponse(content="Skipped install.", usage={})
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=_make_provider(chat_with_retry),
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
first = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up")
|
||||
)
|
||||
|
||||
assert first is not None
|
||||
assert first.content == "Install the optional package?"
|
||||
assert first.buttons == [["Install", "Skip"]]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:direct")
|
||||
assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)
|
||||
assert not any(message.get("role") == "tool" and message.get("name") == "ask_user" for message in session.messages)
|
||||
|
||||
second = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="Skip")
|
||||
)
|
||||
|
||||
assert second is not None
|
||||
assert second.content == "Skipped install."
|
||||
assert any(
|
||||
message.get("role") == "tool"
|
||||
and message.get("name") == "ask_user"
|
||||
and message.get("content") == "Skip"
|
||||
for message in seen_messages[-1]
|
||||
)
|
||||
assert not any(
|
||||
message.get("role") == "user" and message.get("content") == "Skip"
|
||||
for message in session.messages
|
||||
)
|
||||
assert any(
|
||||
message.get("role") == "tool"
|
||||
and message.get("name") == "ask_user"
|
||||
and message.get("content") == "Skip"
|
||||
for message in session.messages
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user