mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-18 23:52:23 +00:00
fix(agent): tighten ask_user CLI handling
Made-with: Cursor
This commit is contained in:
parent
3b1ea99ee1
commit
403ce23d22
@ -20,7 +20,13 @@ from nanobot.agent.memory import Consolidator, Dream
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.ask import AskUserTool
|
||||
from nanobot.agent.tools.ask import (
|
||||
AskUserTool,
|
||||
ask_user_options_from_messages,
|
||||
ask_user_outbound,
|
||||
ask_user_tool_result_messages,
|
||||
pending_ask_user_id,
|
||||
)
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
@ -54,7 +60,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
UNIFIED_SESSION_KEY = "unified:default"
|
||||
BUTTON_CHANNELS = frozenset({"telegram"})
|
||||
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
@ -410,69 +415,6 @@ class AgentLoop:
|
||||
return UNIFIED_SESSION_KEY
|
||||
return msg.session_key
|
||||
|
||||
@staticmethod
|
||||
def _tool_call_name(tool_call: dict[str, Any]) -> str:
|
||||
function = tool_call.get("function")
|
||||
if isinstance(function, dict) and isinstance(function.get("name"), str):
|
||||
return function["name"]
|
||||
name = tool_call.get("name")
|
||||
return name if isinstance(name, str) else ""
|
||||
|
||||
@staticmethod
|
||||
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
|
||||
function = tool_call.get("function")
|
||||
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
return {}
|
||||
|
||||
def _pending_ask_user_id(self, history: list[dict[str, Any]]) -> str | None:
|
||||
pending: dict[str, str] = {}
|
||||
for message in history:
|
||||
if message.get("role") == "assistant":
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
|
||||
pending[tool_call["id"]] = self._tool_call_name(tool_call)
|
||||
elif message.get("role") == "tool":
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
pending.pop(tool_call_id, None)
|
||||
for tool_call_id, name in reversed(pending.items()):
|
||||
if name == "ask_user":
|
||||
return tool_call_id
|
||||
return None
|
||||
|
||||
def _ask_user_options_from_messages(self, messages: list[dict[str, Any]]) -> list[str]:
|
||||
for message in reversed(messages):
|
||||
if message.get("role") != "assistant":
|
||||
continue
|
||||
for tool_call in reversed(message.get("tool_calls") or []):
|
||||
if not isinstance(tool_call, dict) or self._tool_call_name(tool_call) != "ask_user":
|
||||
continue
|
||||
options = self._tool_call_arguments(tool_call).get("options")
|
||||
if isinstance(options, list):
|
||||
return [str(option) for option in options if isinstance(option, str)]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _ask_user_outbound(
|
||||
content: str | None,
|
||||
options: list[str],
|
||||
channel: str,
|
||||
) -> tuple[str | None, list[list[str]]]:
|
||||
if not options:
|
||||
return content, []
|
||||
if channel in BUTTON_CHANNELS:
|
||||
return content, [options]
|
||||
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
|
||||
return f"{content}\n\n{option_text}" if content else option_text, []
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
@ -874,8 +816,8 @@ class AgentLoop:
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
options = self._ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
|
||||
content, buttons = self._ask_user_outbound(
|
||||
options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
|
||||
content, buttons = ask_user_outbound(
|
||||
final_content or "Background task completed.",
|
||||
options,
|
||||
channel,
|
||||
@ -923,18 +865,14 @@ class AgentLoop:
|
||||
|
||||
history = session.get_history(max_messages=0)
|
||||
|
||||
pending_ask_id = self._pending_ask_user_id(history)
|
||||
pending_ask_id = pending_ask_user_id(history)
|
||||
if pending_ask_id:
|
||||
initial_messages = [
|
||||
{"role": "system", "content": self.context.build_system_prompt(channel=msg.channel)},
|
||||
*history,
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": pending_ask_id,
|
||||
"name": "ask_user",
|
||||
"content": msg.content,
|
||||
},
|
||||
]
|
||||
initial_messages = ask_user_tool_result_messages(
|
||||
self.context.build_system_prompt(channel=msg.channel),
|
||||
history,
|
||||
pending_ask_id,
|
||||
msg.content,
|
||||
)
|
||||
else:
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
@ -1030,12 +968,12 @@ class AgentLoop:
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
meta = dict(msg.metadata or {})
|
||||
final_content, buttons = self._ask_user_outbound(
|
||||
final_content, buttons = ask_user_outbound(
|
||||
final_content,
|
||||
self._ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [],
|
||||
ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [],
|
||||
msg.channel,
|
||||
)
|
||||
if on_stream is not None and stop_reason != "error":
|
||||
if on_stream is not None and stop_reason not in {"ask_user", "error"}:
|
||||
meta["_streamed"] = True
|
||||
return OutboundMessage(
|
||||
channel=msg.channel,
|
||||
|
||||
@ -278,17 +278,22 @@ class AgentRunner:
|
||||
self._accumulate_usage(usage, raw_usage)
|
||||
|
||||
if response.should_execute_tools:
|
||||
tool_calls = list(response.tool_calls)
|
||||
ask_index = next((i for i, tc in enumerate(tool_calls) if tc.name == "ask_user"), None)
|
||||
if ask_index is not None:
|
||||
tool_calls = tool_calls[: ask_index + 1]
|
||||
context.tool_calls = list(tool_calls)
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=True)
|
||||
|
||||
assistant_message = build_assistant_message(
|
||||
response.content or "",
|
||||
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
tool_calls=[tc.to_openai_tool_call() for tc in tool_calls],
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
messages.append(assistant_message)
|
||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
||||
tools_used.extend(tc.name for tc in tool_calls)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
@ -297,7 +302,7 @@ class AgentRunner:
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
"pending_tool_calls": [tc.to_openai_tool_call() for tc in tool_calls],
|
||||
},
|
||||
)
|
||||
|
||||
@ -305,14 +310,14 @@ class AgentRunner:
|
||||
|
||||
results, new_events, fatal_error = await self._execute_tools(
|
||||
spec,
|
||||
response.tool_calls,
|
||||
tool_calls,
|
||||
external_lookup_counts,
|
||||
)
|
||||
tool_events.extend(new_events)
|
||||
context.tool_results = list(results)
|
||||
context.tool_events = list(new_events)
|
||||
completed_tool_results: list[dict[str, Any]] = []
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
for tool_call, result in zip(tool_calls, results):
|
||||
if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user":
|
||||
continue
|
||||
tool_message = {
|
||||
@ -700,7 +705,7 @@ class AgentRunner:
|
||||
tool_call: ToolCallRequest,
|
||||
external_lookup_counts: dict[str, int],
|
||||
) -> tuple[Any, dict[str, str], BaseException | None]:
|
||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||
hint = "\n\n[Analyze the error above and try a different approach.]"
|
||||
lookup_error = repeated_external_lookup_error(
|
||||
tool_call.name,
|
||||
tool_call.arguments,
|
||||
@ -713,8 +718,8 @@ class AgentRunner:
|
||||
"detail": "repeated external lookup blocked",
|
||||
}
|
||||
if spec.fail_on_tool_error:
|
||||
return lookup_error + _HINT, event, RuntimeError(lookup_error)
|
||||
return lookup_error + _HINT, event, None
|
||||
return lookup_error + hint, event, RuntimeError(lookup_error)
|
||||
return lookup_error + hint, event, None
|
||||
prepare_call = getattr(spec.tools, "prepare_call", None)
|
||||
tool, params, prep_error = None, tool_call.arguments, None
|
||||
if callable(prepare_call):
|
||||
@ -730,7 +735,7 @@ class AgentRunner:
|
||||
"status": "error",
|
||||
"detail": prep_error.split(": ", 1)[-1][:120],
|
||||
}
|
||||
return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
||||
return prep_error + hint, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
||||
try:
|
||||
if tool is not None:
|
||||
result = await tool.execute(**params)
|
||||
@ -758,8 +763,8 @@ class AgentRunner:
|
||||
"detail": result.replace("\n", " ").strip()[:120],
|
||||
}
|
||||
if spec.fail_on_tool_error:
|
||||
return result + _HINT, event, RuntimeError(result)
|
||||
return result + _HINT, event, None
|
||||
return result + hint, event, RuntimeError(result)
|
||||
return result + hint, event, None
|
||||
|
||||
detail = "" if result is None else str(result)
|
||||
detail = detail.replace("\n", " ").strip()
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
"""Tool for pausing a turn until the user answers."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
||||
|
||||
BUTTON_CHANNELS = frozenset({"telegram"})
|
||||
|
||||
|
||||
class AskUserInterrupt(BaseException):
|
||||
"""Internal signal: the runner should stop and wait for user input."""
|
||||
@ -48,3 +51,86 @@ class AskUserTool(Tool):
|
||||
|
||||
async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any:
|
||||
raise AskUserInterrupt(question=question, options=options)
|
||||
|
||||
|
||||
def _tool_call_name(tool_call: dict[str, Any]) -> str:
|
||||
function = tool_call.get("function")
|
||||
if isinstance(function, dict) and isinstance(function.get("name"), str):
|
||||
return function["name"]
|
||||
name = tool_call.get("name")
|
||||
return name if isinstance(name, str) else ""
|
||||
|
||||
|
||||
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
|
||||
function = tool_call.get("function")
|
||||
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
return {}
|
||||
|
||||
|
||||
def pending_ask_user_id(history: list[dict[str, Any]]) -> str | None:
|
||||
pending: dict[str, str] = {}
|
||||
for message in history:
|
||||
if message.get("role") == "assistant":
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
|
||||
pending[tool_call["id"]] = _tool_call_name(tool_call)
|
||||
elif message.get("role") == "tool":
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
pending.pop(tool_call_id, None)
|
||||
for tool_call_id, name in reversed(pending.items()):
|
||||
if name == "ask_user":
|
||||
return tool_call_id
|
||||
return None
|
||||
|
||||
|
||||
def ask_user_tool_result_messages(
|
||||
system_prompt: str,
|
||||
history: list[dict[str, Any]],
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
*history,
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": "ask_user",
|
||||
"content": content,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def ask_user_options_from_messages(messages: list[dict[str, Any]]) -> list[str]:
|
||||
for message in reversed(messages):
|
||||
if message.get("role") != "assistant":
|
||||
continue
|
||||
for tool_call in reversed(message.get("tool_calls") or []):
|
||||
if not isinstance(tool_call, dict) or _tool_call_name(tool_call) != "ask_user":
|
||||
continue
|
||||
options = _tool_call_arguments(tool_call).get("options")
|
||||
if isinstance(options, list):
|
||||
return [str(option) for option in options if isinstance(option, str)]
|
||||
return []
|
||||
|
||||
|
||||
def ask_user_outbound(
|
||||
content: str | None,
|
||||
options: list[str],
|
||||
channel: str,
|
||||
) -> tuple[str | None, list[list[str]]]:
|
||||
if not options:
|
||||
return content, []
|
||||
if channel in BUTTON_CHANNELS:
|
||||
return content, [options]
|
||||
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
|
||||
return f"{content}\n\n{option_text}" if content else option_text, []
|
||||
|
||||
@ -212,12 +212,16 @@ async def _print_interactive_response(
|
||||
|
||||
def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
|
||||
"""Print a CLI progress line, pausing the spinner if needed."""
|
||||
if not text.strip():
|
||||
return
|
||||
with thinking.pause() if thinking else nullcontext():
|
||||
console.print(f" [dim]↳ {text}[/dim]")
|
||||
|
||||
|
||||
async def _print_interactive_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
|
||||
"""Print an interactive progress line, pausing the spinner if needed."""
|
||||
if not text.strip():
|
||||
return
|
||||
with thinking.pause() if thinking else nullcontext():
|
||||
await _print_interactive_line(text)
|
||||
|
||||
|
||||
@ -15,10 +15,15 @@ from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequ
|
||||
|
||||
|
||||
def _make_provider(chat_with_retry):
|
||||
async def chat_stream_with_retry(**kwargs):
|
||||
kwargs.pop("on_content_delta", None)
|
||||
return await chat_with_retry(**kwargs)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = GenerationSettings()
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
return provider
|
||||
|
||||
|
||||
@ -88,7 +93,8 @@ async def test_runner_pauses_on_ask_user_without_executing_later_tools():
|
||||
assert "ask_user" in result.tools_used
|
||||
assert later.called is False
|
||||
assert result.messages[-1]["role"] == "assistant"
|
||||
assert result.messages[-1]["tool_calls"][0]["function"]["name"] == "ask_user"
|
||||
tool_calls = result.messages[-1]["tool_calls"]
|
||||
assert [tool_call["function"]["name"] for tool_call in tool_calls] == ["ask_user"]
|
||||
assert not any(message.get("name") == "ask_user" for message in result.messages)
|
||||
|
||||
|
||||
@ -122,13 +128,22 @@ async def test_ask_user_text_fallback_resumes_with_next_message(tmp_path):
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
pass
|
||||
|
||||
async def on_stream_end(**kwargs) -> None:
|
||||
pass
|
||||
|
||||
first = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up")
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up"),
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
assert first is not None
|
||||
assert first.content == "Install the optional package?\n\n1. Install\n2. Skip"
|
||||
assert first.buttons == []
|
||||
assert "_streamed" not in first.metadata
|
||||
|
||||
session = loop.sessions.get_or_create("cli:direct")
|
||||
assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user