mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-26 19:42:41 +00:00
feat(agent): add ask_user tool
Made-with: Cursor
This commit is contained in:
parent
830211b5d4
commit
cfc76ffbbf
@ -20,14 +20,15 @@ from nanobot.agent.memory import Consolidator, Dream
|
|||||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
||||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.agent.tools.ask import AskUserTool
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||||
from nanobot.agent.tools.message import MessageTool
|
from nanobot.agent.tools.message import MessageTool
|
||||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.agent.tools.search import GlobTool, GrepTool
|
from nanobot.agent.tools.search import GlobTool, GrepTool
|
||||||
from nanobot.agent.tools.shell import ExecTool
|
|
||||||
from nanobot.agent.tools.self import MyTool
|
from nanobot.agent.tools.self import MyTool
|
||||||
|
from nanobot.agent.tools.shell import ExecTool
|
||||||
from nanobot.agent.tools.spawn import SpawnTool
|
from nanobot.agent.tools.spawn import SpawnTool
|
||||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
@ -287,6 +288,7 @@ class AgentLoop:
|
|||||||
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||||
)
|
)
|
||||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||||
|
self.tools.register(AskUserTool())
|
||||||
self.tools.register(
|
self.tools.register(
|
||||||
ReadFileTool(
|
ReadFileTool(
|
||||||
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
|
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
|
||||||
@ -407,6 +409,56 @@ class AgentLoop:
|
|||||||
return UNIFIED_SESSION_KEY
|
return UNIFIED_SESSION_KEY
|
||||||
return msg.session_key
|
return msg.session_key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tool_call_name(tool_call: dict[str, Any]) -> str:
|
||||||
|
function = tool_call.get("function")
|
||||||
|
if isinstance(function, dict) and isinstance(function.get("name"), str):
|
||||||
|
return function["name"]
|
||||||
|
name = tool_call.get("name")
|
||||||
|
return name if isinstance(name, str) else ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
function = tool_call.get("function")
|
||||||
|
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
|
||||||
|
if isinstance(raw, dict):
|
||||||
|
return raw
|
||||||
|
if isinstance(raw, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {}
|
||||||
|
return parsed if isinstance(parsed, dict) else {}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _pending_ask_user_id(self, history: list[dict[str, Any]]) -> str | None:
|
||||||
|
pending: dict[str, str] = {}
|
||||||
|
for message in history:
|
||||||
|
if message.get("role") == "assistant":
|
||||||
|
for tool_call in message.get("tool_calls") or []:
|
||||||
|
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
|
||||||
|
pending[tool_call["id"]] = self._tool_call_name(tool_call)
|
||||||
|
elif message.get("role") == "tool":
|
||||||
|
tool_call_id = message.get("tool_call_id")
|
||||||
|
if isinstance(tool_call_id, str):
|
||||||
|
pending.pop(tool_call_id, None)
|
||||||
|
for tool_call_id, name in reversed(pending.items()):
|
||||||
|
if name == "ask_user":
|
||||||
|
return tool_call_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _ask_user_options_from_messages(self, messages: list[dict[str, Any]]) -> list[str]:
|
||||||
|
for message in reversed(messages):
|
||||||
|
if message.get("role") != "assistant":
|
||||||
|
continue
|
||||||
|
for tool_call in reversed(message.get("tool_calls") or []):
|
||||||
|
if not isinstance(tool_call, dict) or self._tool_call_name(tool_call) != "ask_user":
|
||||||
|
continue
|
||||||
|
options = self._tool_call_arguments(tool_call).get("options")
|
||||||
|
if isinstance(options, list):
|
||||||
|
return [str(option) for option in options if isinstance(option, str)]
|
||||||
|
return []
|
||||||
|
|
||||||
async def _run_agent_loop(
|
async def _run_agent_loop(
|
||||||
self,
|
self,
|
||||||
initial_messages: list[dict],
|
initial_messages: list[dict],
|
||||||
@ -799,7 +851,7 @@ class AgentLoop:
|
|||||||
session_summary=pending,
|
session_summary=pending,
|
||||||
current_role=current_role,
|
current_role=current_role,
|
||||||
)
|
)
|
||||||
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
|
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
|
||||||
messages, session=session, channel=channel, chat_id=chat_id,
|
messages, session=session, channel=channel, chat_id=chat_id,
|
||||||
message_id=msg.metadata.get("message_id"),
|
message_id=msg.metadata.get("message_id"),
|
||||||
pending_queue=pending_queue,
|
pending_queue=pending_queue,
|
||||||
@ -808,10 +860,12 @@ class AgentLoop:
|
|||||||
self._clear_runtime_checkpoint(session)
|
self._clear_runtime_checkpoint(session)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||||
|
options = self._ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=channel,
|
channel=channel,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
content=final_content or "Background task completed.",
|
content=final_content or "Background task completed.",
|
||||||
|
buttons=[options] if options else [],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract document text from media at the processing boundary so all
|
# Extract document text from media at the processing boundary so all
|
||||||
@ -850,6 +904,19 @@ class AgentLoop:
|
|||||||
|
|
||||||
history = session.get_history(max_messages=0)
|
history = session.get_history(max_messages=0)
|
||||||
|
|
||||||
|
pending_ask_id = self._pending_ask_user_id(history)
|
||||||
|
if pending_ask_id:
|
||||||
|
initial_messages = [
|
||||||
|
{"role": "system", "content": self.context.build_system_prompt(channel=msg.channel)},
|
||||||
|
*history,
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": pending_ask_id,
|
||||||
|
"name": "ask_user",
|
||||||
|
"content": msg.content,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
else:
|
||||||
initial_messages = self.context.build_messages(
|
initial_messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
current_message=msg.content,
|
current_message=msg.content,
|
||||||
@ -898,7 +965,7 @@ class AgentLoop:
|
|||||||
user_persisted_early = False
|
user_persisted_early = False
|
||||||
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
|
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
|
||||||
has_text = isinstance(msg.content, str) and msg.content.strip()
|
has_text = isinstance(msg.content, str) and msg.content.strip()
|
||||||
if has_text or media_paths:
|
if not pending_ask_id and (has_text or media_paths):
|
||||||
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
|
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
|
||||||
text = msg.content if isinstance(msg.content, str) else ""
|
text = msg.content if isinstance(msg.content, str) else ""
|
||||||
session.add_message("user", text, **extra)
|
session.add_message("user", text, **extra)
|
||||||
@ -944,6 +1011,11 @@ class AgentLoop:
|
|||||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
|
|
||||||
meta = dict(msg.metadata or {})
|
meta = dict(msg.metadata or {})
|
||||||
|
buttons: list[list[str]] = []
|
||||||
|
if stop_reason == "ask_user":
|
||||||
|
options = self._ask_user_options_from_messages(all_msgs)
|
||||||
|
if options:
|
||||||
|
buttons = [options]
|
||||||
if on_stream is not None and stop_reason != "error":
|
if on_stream is not None and stop_reason != "error":
|
||||||
meta["_streamed"] = True
|
meta["_streamed"] = True
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
@ -951,6 +1023,7 @@ class AgentLoop:
|
|||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
content=final_content,
|
content=final_content,
|
||||||
metadata=meta,
|
metadata=meta,
|
||||||
|
buttons=buttons,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _sanitize_persisted_blocks(
|
def _sanitize_persisted_blocks(
|
||||||
|
|||||||
@ -3,16 +3,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass, field
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||||
from nanobot.utils.prompt_templates import render_template
|
from nanobot.agent.tools.ask import AskUserInterrupt
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
from nanobot.utils.helpers import (
|
from nanobot.utils.helpers import (
|
||||||
@ -23,6 +23,7 @@ from nanobot.utils.helpers import (
|
|||||||
maybe_persist_tool_result,
|
maybe_persist_tool_result,
|
||||||
truncate_text,
|
truncate_text,
|
||||||
)
|
)
|
||||||
|
from nanobot.utils.prompt_templates import render_template
|
||||||
from nanobot.utils.runtime import (
|
from nanobot.utils.runtime import (
|
||||||
EMPTY_FINAL_RESPONSE_MESSAGE,
|
EMPTY_FINAL_RESPONSE_MESSAGE,
|
||||||
build_finalization_retry_message,
|
build_finalization_retry_message,
|
||||||
@ -312,6 +313,8 @@ class AgentRunner:
|
|||||||
context.tool_events = list(new_events)
|
context.tool_events = list(new_events)
|
||||||
completed_tool_results: list[dict[str, Any]] = []
|
completed_tool_results: list[dict[str, Any]] = []
|
||||||
for tool_call, result in zip(response.tool_calls, results):
|
for tool_call, result in zip(response.tool_calls, results):
|
||||||
|
if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user":
|
||||||
|
continue
|
||||||
tool_message = {
|
tool_message = {
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": tool_call.id,
|
"tool_call_id": tool_call.id,
|
||||||
@ -326,6 +329,15 @@ class AgentRunner:
|
|||||||
messages.append(tool_message)
|
messages.append(tool_message)
|
||||||
completed_tool_results.append(tool_message)
|
completed_tool_results.append(tool_message)
|
||||||
if fatal_error is not None:
|
if fatal_error is not None:
|
||||||
|
if isinstance(fatal_error, AskUserInterrupt):
|
||||||
|
final_content = fatal_error.question
|
||||||
|
stop_reason = "ask_user"
|
||||||
|
context.final_content = final_content
|
||||||
|
context.stop_reason = stop_reason
|
||||||
|
if hook.wants_streaming():
|
||||||
|
await hook.on_stream_end(context, resuming=False)
|
||||||
|
await hook.after_iteration(context)
|
||||||
|
break
|
||||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||||
final_content = error
|
final_content = error
|
||||||
stop_reason = "tool_error"
|
stop_reason = "tool_error"
|
||||||
@ -656,13 +668,21 @@ class AgentRunner:
|
|||||||
tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
|
tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
if spec.concurrent_tools and len(batch) > 1:
|
if spec.concurrent_tools and len(batch) > 1:
|
||||||
tool_results.extend(await asyncio.gather(*(
|
batch_results = await asyncio.gather(*(
|
||||||
self._run_tool(spec, tool_call, external_lookup_counts)
|
self._run_tool(spec, tool_call, external_lookup_counts)
|
||||||
for tool_call in batch
|
for tool_call in batch
|
||||||
)))
|
))
|
||||||
|
tool_results.extend(batch_results)
|
||||||
else:
|
else:
|
||||||
|
batch_results = []
|
||||||
for tool_call in batch:
|
for tool_call in batch:
|
||||||
tool_results.append(await self._run_tool(spec, tool_call, external_lookup_counts))
|
result = await self._run_tool(spec, tool_call, external_lookup_counts)
|
||||||
|
tool_results.append(result)
|
||||||
|
batch_results.append(result)
|
||||||
|
if isinstance(result[2], AskUserInterrupt):
|
||||||
|
break
|
||||||
|
if any(isinstance(error, AskUserInterrupt) for _, _, error in batch_results):
|
||||||
|
break
|
||||||
|
|
||||||
results: list[Any] = []
|
results: list[Any] = []
|
||||||
events: list[dict[str, str]] = []
|
events: list[dict[str, str]] = []
|
||||||
@ -724,6 +744,9 @@ class AgentRunner:
|
|||||||
"status": "error",
|
"status": "error",
|
||||||
"detail": str(exc),
|
"detail": str(exc),
|
||||||
}
|
}
|
||||||
|
if isinstance(exc, AskUserInterrupt):
|
||||||
|
event["status"] = "waiting"
|
||||||
|
return "", event, exc
|
||||||
if spec.fail_on_tool_error:
|
if spec.fail_on_tool_error:
|
||||||
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
||||||
return f"Error: {type(exc).__name__}: {exc}", event, None
|
return f"Error: {type(exc).__name__}: {exc}", event, None
|
||||||
|
|||||||
50
nanobot/agent/tools/ask.py
Normal file
50
nanobot/agent/tools/ask.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
"""Tool for pausing a turn until the user answers."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||||
|
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
||||||
|
|
||||||
|
|
||||||
|
class AskUserInterrupt(BaseException):
|
||||||
|
"""Internal signal: the runner should stop and wait for user input."""
|
||||||
|
|
||||||
|
def __init__(self, question: str, options: list[str] | None = None) -> None:
|
||||||
|
self.question = question
|
||||||
|
self.options = [str(option) for option in (options or []) if str(option)]
|
||||||
|
super().__init__(question)
|
||||||
|
|
||||||
|
|
||||||
|
@tool_parameters(
|
||||||
|
tool_parameters_schema(
|
||||||
|
question=StringSchema(
|
||||||
|
"The question to ask before continuing. Use this only when the task needs the user's answer."
|
||||||
|
),
|
||||||
|
options=ArraySchema(
|
||||||
|
StringSchema("A possible answer label"),
|
||||||
|
description="Optional choices. The user may still reply with free text.",
|
||||||
|
),
|
||||||
|
required=["question"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
class AskUserTool(Tool):
|
||||||
|
"""Ask the user a blocking question."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "ask_user"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Pause and ask the user a question when their answer is required to continue. "
|
||||||
|
"Use options for likely answers; the user's reply, typed or selected, is returned as the tool result. "
|
||||||
|
"For non-blocking notifications or buttons, use the message tool instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exclusive(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any:
|
||||||
|
raise AskUserInterrupt(question=question, options=options)
|
||||||
158
tests/agent/test_ask_user.py
Normal file
158
tests/agent/test_ask_user.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
from nanobot.agent.tools.ask import AskUserInterrupt, AskUserTool
|
||||||
|
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
from nanobot.agent.tools.schema import tool_parameters_schema
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_provider(chat_with_retry):
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.generation = GenerationSettings()
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_user_tool_schema_and_interrupt():
|
||||||
|
tool = AskUserTool()
|
||||||
|
schema = tool.to_schema()["function"]
|
||||||
|
|
||||||
|
assert schema["name"] == "ask_user"
|
||||||
|
assert "question" in schema["parameters"]["required"]
|
||||||
|
assert schema["parameters"]["properties"]["options"]["type"] == "array"
|
||||||
|
|
||||||
|
with pytest.raises(AskUserInterrupt) as exc:
|
||||||
|
asyncio.run(tool.execute("Continue?", options=["Yes", "No"]))
|
||||||
|
|
||||||
|
assert exc.value.question == "Continue?"
|
||||||
|
assert exc.value.options == ["Yes", "No"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_pauses_on_ask_user_without_executing_later_tools():
|
||||||
|
@tool_parameters(tool_parameters_schema(required=[]))
|
||||||
|
class LaterTool(Tool):
|
||||||
|
called = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "later"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Should not run after ask_user pauses the turn."
|
||||||
|
|
||||||
|
async def execute(self, **kwargs):
|
||||||
|
self.called = True
|
||||||
|
return "later result"
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_ask",
|
||||||
|
name="ask_user",
|
||||||
|
arguments={"question": "Install this package?", "options": ["Yes", "No"]},
|
||||||
|
),
|
||||||
|
ToolCallRequest(id="call_later", name="later", arguments={}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
later = LaterTool()
|
||||||
|
tools = ToolRegistry()
|
||||||
|
tools.register(AskUserTool())
|
||||||
|
tools.register(later)
|
||||||
|
|
||||||
|
result = await AgentRunner(_make_provider(chat_with_retry)).run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "continue"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=16_000,
|
||||||
|
concurrent_tools=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.stop_reason == "ask_user"
|
||||||
|
assert result.final_content == "Install this package?"
|
||||||
|
assert "ask_user" in result.tools_used
|
||||||
|
assert later.called is False
|
||||||
|
assert result.messages[-1]["role"] == "assistant"
|
||||||
|
assert result.messages[-1]["tool_calls"][0]["function"]["name"] == "ask_user"
|
||||||
|
assert not any(message.get("name") == "ask_user" for message in result.messages)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ask_user_sends_buttons_and_resumes_with_next_message(tmp_path):
|
||||||
|
seen_messages: list[list[dict]] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
seen_messages.append(kwargs["messages"])
|
||||||
|
if len(seen_messages) == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_ask",
|
||||||
|
name="ask_user",
|
||||||
|
arguments={
|
||||||
|
"question": "Install the optional package?",
|
||||||
|
"options": ["Install", "Skip"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return LLMResponse(content="Skipped install.", usage={})
|
||||||
|
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=_make_provider(chat_with_retry),
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
first = await loop._process_message(
|
||||||
|
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert first is not None
|
||||||
|
assert first.content == "Install the optional package?"
|
||||||
|
assert first.buttons == [["Install", "Skip"]]
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:direct")
|
||||||
|
assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)
|
||||||
|
assert not any(message.get("role") == "tool" and message.get("name") == "ask_user" for message in session.messages)
|
||||||
|
|
||||||
|
second = await loop._process_message(
|
||||||
|
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="Skip")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert second is not None
|
||||||
|
assert second.content == "Skipped install."
|
||||||
|
assert any(
|
||||||
|
message.get("role") == "tool"
|
||||||
|
and message.get("name") == "ask_user"
|
||||||
|
and message.get("content") == "Skip"
|
||||||
|
for message in seen_messages[-1]
|
||||||
|
)
|
||||||
|
assert not any(
|
||||||
|
message.get("role") == "user" and message.get("content") == "Skip"
|
||||||
|
for message in session.messages
|
||||||
|
)
|
||||||
|
assert any(
|
||||||
|
message.get("role") == "tool"
|
||||||
|
and message.get("name") == "ask_user"
|
||||||
|
and message.get("content") == "Skip"
|
||||||
|
for message in session.messages
|
||||||
|
)
|
||||||
Loading…
x
Reference in New Issue
Block a user