diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index ca80475a7..637bb5126 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -20,14 +20,15 @@ from nanobot.agent.memory import Consolidator, Dream from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec from nanobot.agent.skills import BUILTIN_SKILLS_DIR from nanobot.agent.subagent import SubagentManager +from nanobot.agent.tools.ask import AskUserTool from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.notebook import NotebookEditTool from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.search import GlobTool, GrepTool -from nanobot.agent.tools.shell import ExecTool from nanobot.agent.tools.self import MyTool +from nanobot.agent.tools.shell import ExecTool from nanobot.agent.tools.spawn import SpawnTool from nanobot.agent.tools.web import WebFetchTool, WebSearchTool from nanobot.bus.events import InboundMessage, OutboundMessage @@ -287,6 +288,7 @@ class AgentLoop: self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None ) extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None + self.tools.register(AskUserTool()) self.tools.register( ReadFileTool( workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read @@ -407,6 +409,56 @@ class AgentLoop: return UNIFIED_SESSION_KEY return msg.session_key + @staticmethod + def _tool_call_name(tool_call: dict[str, Any]) -> str: + function = tool_call.get("function") + if isinstance(function, dict) and isinstance(function.get("name"), str): + return function["name"] + name = tool_call.get("name") + return name if isinstance(name, str) else "" + + @staticmethod + def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]: + function = tool_call.get("function") + raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments") + if isinstance(raw, dict): + return raw + if isinstance(raw, str): + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + return {} + return parsed if isinstance(parsed, dict) else {} + return {} + + def _pending_ask_user_id(self, history: list[dict[str, Any]]) -> str | None: + pending: dict[str, str] = {} + for message in history: + if message.get("role") == "assistant": + for tool_call in message.get("tool_calls") or []: + if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str): + pending[tool_call["id"]] = self._tool_call_name(tool_call) + elif message.get("role") == "tool": + tool_call_id = message.get("tool_call_id") + if isinstance(tool_call_id, str): + pending.pop(tool_call_id, None) + for tool_call_id, name in reversed(pending.items()): + if name == "ask_user": + return tool_call_id + return None + + def _ask_user_options_from_messages(self, messages: list[dict[str, Any]]) -> list[str]: + for message in reversed(messages): + if message.get("role") != "assistant": + continue + for tool_call in reversed(message.get("tool_calls") or []): + if not isinstance(tool_call, dict) or self._tool_call_name(tool_call) != "ask_user": + continue + options = self._tool_call_arguments(tool_call).get("options") + if isinstance(options, list): + return [str(option) for option in options if isinstance(option, str)] + return [] + async def _run_agent_loop( self, initial_messages: list[dict], @@ -799,7 +851,7 @@ class AgentLoop: session_summary=pending, current_role=current_role, ) - final_content, _, all_msgs, _, _ = await self._run_agent_loop( + final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop( messages, session=session, channel=channel, chat_id=chat_id, message_id=msg.metadata.get("message_id"), pending_queue=pending_queue, @@ -808,10 +860,12 @@ class AgentLoop: self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) + options = self._ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [] return OutboundMessage( channel=channel, chat_id=chat_id, content=final_content or "Background task completed.", + buttons=[options] if options else [], ) # Extract document text from media at the processing boundary so all @@ -850,14 +904,27 @@ class AgentLoop: history = session.get_history(max_messages=0) - initial_messages = self.context.build_messages( - history=history, - current_message=msg.content, - session_summary=pending, - media=msg.media if msg.media else None, - channel=msg.channel, - chat_id=msg.chat_id, - ) + pending_ask_id = self._pending_ask_user_id(history) + if pending_ask_id: + initial_messages = [ + {"role": "system", "content": self.context.build_system_prompt(channel=msg.channel)}, + *history, + { + "role": "tool", + "tool_call_id": pending_ask_id, + "name": "ask_user", + "content": msg.content, + }, + ] + else: + initial_messages = self.context.build_messages( + history=history, + current_message=msg.content, + session_summary=pending, + media=msg.media if msg.media else None, + channel=msg.channel, + chat_id=msg.chat_id, + ) async def _bus_progress( content: str, @@ -898,7 +965,7 @@ class AgentLoop: user_persisted_early = False media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p] has_text = isinstance(msg.content, str) and msg.content.strip() - if has_text or media_paths: + if not pending_ask_id and (has_text or media_paths): extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {} text = msg.content if isinstance(msg.content, str) else "" session.add_message("user", text, **extra) @@ -944,6 +1011,11 @@ class AgentLoop: logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) meta = dict(msg.metadata or {}) + buttons: list[list[str]] = [] + if stop_reason == "ask_user": + options = self._ask_user_options_from_messages(all_msgs) + if options: + buttons = [options] if on_stream is not None and stop_reason != "error": meta["_streamed"] = True return OutboundMessage( @@ -951,6 +1023,7 @@ class AgentLoop: chat_id=msg.chat_id, content=final_content, metadata=meta, + buttons=buttons, ) def _sanitize_persisted_blocks( diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 3704f3030..688d38714 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -3,16 +3,16 @@ from __future__ import annotations import asyncio -from dataclasses import dataclass, field import inspect import os +from dataclasses import dataclass, field from pathlib import Path from typing import Any from loguru import logger from nanobot.agent.hook import AgentHook, AgentHookContext -from nanobot.utils.prompt_templates import render_template +from nanobot.agent.tools.ask import AskUserInterrupt from nanobot.agent.tools.registry import ToolRegistry from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.utils.helpers import ( @@ -23,6 +23,7 @@ from nanobot.utils.helpers import ( maybe_persist_tool_result, truncate_text, ) +from nanobot.utils.prompt_templates import render_template from nanobot.utils.runtime import ( EMPTY_FINAL_RESPONSE_MESSAGE, build_finalization_retry_message, @@ -312,6 +313,8 @@ class AgentRunner: context.tool_events = list(new_events) completed_tool_results: list[dict[str, Any]] = [] for tool_call, result in zip(response.tool_calls, results): + if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user": + continue tool_message = { "role": "tool", "tool_call_id": tool_call.id, @@ -326,6 +329,15 @@ class AgentRunner: messages.append(tool_message) completed_tool_results.append(tool_message) if fatal_error is not None: + if isinstance(fatal_error, AskUserInterrupt): + final_content = fatal_error.question + stop_reason = "ask_user" + context.final_content = final_content + context.stop_reason = stop_reason + if hook.wants_streaming(): + await hook.on_stream_end(context, resuming=False) + await hook.after_iteration(context) + break error = f"Error: {type(fatal_error).__name__}: {fatal_error}" final_content = error stop_reason = "tool_error" @@ -656,13 +668,21 @@ class AgentRunner: tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = [] for batch in batches: if spec.concurrent_tools and len(batch) > 1: - tool_results.extend(await asyncio.gather(*( + batch_results = await asyncio.gather(*( self._run_tool(spec, tool_call, external_lookup_counts) for tool_call in batch - ))) + )) + tool_results.extend(batch_results) else: + batch_results = [] for tool_call in batch: - tool_results.append(await self._run_tool(spec, tool_call, external_lookup_counts)) + result = await self._run_tool(spec, tool_call, external_lookup_counts) + tool_results.append(result) + batch_results.append(result) + if isinstance(result[2], AskUserInterrupt): + break + if any(isinstance(error, AskUserInterrupt) for _, _, error in batch_results): + break results: list[Any] = [] events: list[dict[str, str]] = [] @@ -724,6 +744,9 @@ class AgentRunner: "status": "error", "detail": str(exc), } + if isinstance(exc, AskUserInterrupt): + event["status"] = "waiting" + return "", event, exc if spec.fail_on_tool_error: return f"Error: {type(exc).__name__}: {exc}", event, exc return f"Error: {type(exc).__name__}: {exc}", event, None diff --git a/nanobot/agent/tools/ask.py b/nanobot/agent/tools/ask.py new file mode 100644 index 000000000..0ce371ea8 --- /dev/null +++ b/nanobot/agent/tools/ask.py @@ -0,0 +1,50 @@ +"""Tool for pausing a turn until the user answers.""" + +from typing import Any + +from nanobot.agent.tools.base import Tool, tool_parameters +from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema + + +class AskUserInterrupt(BaseException): + """Internal signal: the runner should stop and wait for user input.""" + + def __init__(self, question: str, options: list[str] | None = None) -> None: + self.question = question + self.options = [str(option) for option in (options or []) if str(option)] + super().__init__(question) + + +@tool_parameters( + tool_parameters_schema( + question=StringSchema( + "The question to ask before continuing. Use this only when the task needs the user's answer." + ), + options=ArraySchema( + StringSchema("A possible answer label"), + description="Optional choices. The user may still reply with free text.", + ), + required=["question"], + ) +) +class AskUserTool(Tool): + """Ask the user a blocking question.""" + + @property + def name(self) -> str: + return "ask_user" + + @property + def description(self) -> str: + return ( + "Pause and ask the user a question when their answer is required to continue. " + "Use options for likely answers; the user's reply, typed or selected, is returned as the tool result. " + "For non-blocking notifications or buttons, use the message tool instead." + ) + + @property + def exclusive(self) -> bool: + return True + + async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any: + raise AskUserInterrupt(question=question, options=options) diff --git a/tests/agent/test_ask_user.py b/tests/agent/test_ask_user.py new file mode 100644 index 000000000..fd8993ceb --- /dev/null +++ b/tests/agent/test_ask_user.py @@ -0,0 +1,158 @@ +import asyncio +from unittest.mock import MagicMock + +import pytest + +from nanobot.agent.loop import AgentLoop +from nanobot.agent.runner import AgentRunner, AgentRunSpec +from nanobot.agent.tools.ask import AskUserInterrupt, AskUserTool +from nanobot.agent.tools.base import Tool, tool_parameters +from nanobot.agent.tools.registry import ToolRegistry +from nanobot.agent.tools.schema import tool_parameters_schema +from nanobot.bus.events import InboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequest + + +def _make_provider(chat_with_retry): + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.generation = GenerationSettings() + provider.chat_with_retry = chat_with_retry + return provider + + +def test_ask_user_tool_schema_and_interrupt(): + tool = AskUserTool() + schema = tool.to_schema()["function"] + + assert schema["name"] == "ask_user" + assert "question" in schema["parameters"]["required"] + assert schema["parameters"]["properties"]["options"]["type"] == "array" + + with pytest.raises(AskUserInterrupt) as exc: + asyncio.run(tool.execute("Continue?", options=["Yes", "No"])) + + assert exc.value.question == "Continue?" + assert exc.value.options == ["Yes", "No"] + + +@pytest.mark.asyncio +async def test_runner_pauses_on_ask_user_without_executing_later_tools(): + @tool_parameters(tool_parameters_schema(required=[])) + class LaterTool(Tool): + called = False + + @property + def name(self) -> str: + return "later" + + @property + def description(self) -> str: + return "Should not run after ask_user pauses the turn." + + async def execute(self, **kwargs): + self.called = True + return "later result" + + async def chat_with_retry(**kwargs): + return LLMResponse( + content="", + finish_reason="tool_calls", + tool_calls=[ + ToolCallRequest( + id="call_ask", + name="ask_user", + arguments={"question": "Install this package?", "options": ["Yes", "No"]}, + ), + ToolCallRequest(id="call_later", name="later", arguments={}), + ], + ) + + later = LaterTool() + tools = ToolRegistry() + tools.register(AskUserTool()) + tools.register(later) + + result = await AgentRunner(_make_provider(chat_with_retry)).run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "continue"}], + tools=tools, + model="test-model", + max_iterations=3, + max_tool_result_chars=16_000, + concurrent_tools=True, + )) + + assert result.stop_reason == "ask_user" + assert result.final_content == "Install this package?" + assert "ask_user" in result.tools_used + assert later.called is False + assert result.messages[-1]["role"] == "assistant" + assert result.messages[-1]["tool_calls"][0]["function"]["name"] == "ask_user" + assert not any(message.get("name") == "ask_user" for message in result.messages) + + +@pytest.mark.asyncio +async def test_ask_user_sends_buttons_and_resumes_with_next_message(tmp_path): + seen_messages: list[list[dict]] = [] + + async def chat_with_retry(**kwargs): + seen_messages.append(kwargs["messages"]) + if len(seen_messages) == 1: + return LLMResponse( + content="", + finish_reason="tool_calls", + tool_calls=[ + ToolCallRequest( + id="call_ask", + name="ask_user", + arguments={ + "question": "Install the optional package?", + "options": ["Install", "Skip"], + }, + ) + ], + ) + return LLMResponse(content="Skipped install.", usage={}) + + loop = AgentLoop( + bus=MessageBus(), + provider=_make_provider(chat_with_retry), + workspace=tmp_path, + model="test-model", + ) + + first = await loop._process_message( + InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up") + ) + + assert first is not None + assert first.content == "Install the optional package?" + assert first.buttons == [["Install", "Skip"]] + + session = loop.sessions.get_or_create("cli:direct") + assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages) + assert not any(message.get("role") == "tool" and message.get("name") == "ask_user" for message in session.messages) + + second = await loop._process_message( + InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="Skip") + ) + + assert second is not None + assert second.content == "Skipped install." + assert any( + message.get("role") == "tool" + and message.get("name") == "ask_user" + and message.get("content") == "Skip" + for message in seen_messages[-1] + ) + assert not any( + message.get("role") == "user" and message.get("content") == "Skip" + for message in session.messages + ) + assert any( + message.get("role") == "tool" + and message.get("name") == "ask_user" + and message.get("content") == "Skip" + for message in session.messages + )