mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 00:22:31 +00:00
refactor(agent): remove ask_user tool
The ask_user tool used AskUserInterrupt(BaseException) for mid-turn blocking, creating heavy coupling across runner, loop, and session management. The model now asks questions naturally in response text, the turn ends normally, and the user's next message starts a new turn with session history providing continuity. Removed: - nanobot/agent/tools/ask.py (tool, interrupt, helpers) - tests/agent/test_ask_user.py - webui/src/components/thread/AskUserPrompt.tsx - AskUserInterrupt handling in runner.py - Dual-path message building in loop.py - Pending ask detection via history scanning - button_prompt/buttons emission in WebSocket channel - ask_user references in Slack channel docstrings Preserved (MessageTool uses these independently): - OutboundMessage.buttons field - Channel button rendering (Telegram, Slack, WebSocket)
This commit is contained in:
parent
07f9ab580a
commit
9e15925cf4
@ -22,12 +22,6 @@ from nanobot.agent.memory import Consolidator, Dream
|
||||
from nanobot.agent import model_presets as preset_helpers
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.ask import (
|
||||
ask_user_options_from_messages,
|
||||
ask_user_outbound,
|
||||
ask_user_tool_result_messages,
|
||||
pending_ask_user_id,
|
||||
)
|
||||
from nanobot.agent.tools.file_state import FileStateStore, bind_file_states, reset_file_states
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
@ -693,7 +687,6 @@ class AgentLoop:
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
session: Session,
|
||||
pending_ask_id: str | None,
|
||||
) -> bool:
|
||||
"""Persist the triggering user message before the turn starts.
|
||||
|
||||
@ -701,7 +694,7 @@ class AgentLoop:
|
||||
"""
|
||||
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
|
||||
has_text = isinstance(msg.content, str) and msg.content.strip()
|
||||
if not pending_ask_id and (has_text or media_paths):
|
||||
if has_text or media_paths:
|
||||
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
|
||||
text = msg.content if isinstance(msg.content, str) else ""
|
||||
session.add_message("user", text, **extra)
|
||||
@ -715,21 +708,9 @@ class AgentLoop:
|
||||
msg: InboundMessage,
|
||||
session: Session,
|
||||
history: list[dict[str, Any]],
|
||||
pending_ask_id: str | None,
|
||||
pending_summary: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the initial message list for the LLM turn."""
|
||||
if pending_ask_id:
|
||||
system_prompt = self.context.build_system_prompt(
|
||||
channel=msg.channel,
|
||||
session_summary=pending_summary,
|
||||
)
|
||||
return ask_user_tool_result_messages(
|
||||
system_prompt,
|
||||
history,
|
||||
pending_ask_id,
|
||||
image_generation_prompt(msg.content, msg.metadata),
|
||||
)
|
||||
return self.context.build_messages(
|
||||
history=history,
|
||||
current_message=image_generation_prompt(msg.content, msg.metadata),
|
||||
@ -1237,12 +1218,7 @@ class AgentLoop:
|
||||
replay_max_messages=self._max_messages,
|
||||
)
|
||||
)
|
||||
options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
|
||||
content, buttons = ask_user_outbound(
|
||||
final_content or "Background task completed.",
|
||||
options,
|
||||
channel,
|
||||
)
|
||||
content = final_content or "Background task completed."
|
||||
outbound_metadata: dict[str, Any] = {}
|
||||
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
|
||||
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
|
||||
@ -1252,7 +1228,6 @@ class AgentLoop:
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
buttons=buttons,
|
||||
metadata=outbound_metadata,
|
||||
)
|
||||
|
||||
@ -1365,21 +1340,15 @@ class AgentLoop:
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
meta = dict(msg.metadata or {})
|
||||
content, buttons = ask_user_outbound(
|
||||
final_content,
|
||||
ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [],
|
||||
msg.channel,
|
||||
)
|
||||
if on_stream is not None and stop_reason not in {"ask_user", "error", "tool_error"}:
|
||||
if on_stream is not None and stop_reason not in {"error", "tool_error"}:
|
||||
meta["_streamed"] = True
|
||||
|
||||
return OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=content,
|
||||
content=final_content,
|
||||
media=generated_media,
|
||||
metadata=meta,
|
||||
buttons=buttons,
|
||||
)
|
||||
|
||||
async def _state_restore(self, ctx: TurnContext) -> TurnState:
|
||||
@ -1446,12 +1415,11 @@ class AgentLoop:
|
||||
}
|
||||
ctx.history = ctx.session.get_history(**_hist_kwargs)
|
||||
|
||||
pending_ask_id = pending_ask_user_id(ctx.history)
|
||||
ctx.initial_messages = self._build_initial_messages(
|
||||
ctx.msg, ctx.session, ctx.history, pending_ask_id, ctx.pending_summary
|
||||
ctx.msg, ctx.session, ctx.history, ctx.pending_summary
|
||||
)
|
||||
ctx.user_persisted_early = self._persist_user_message_early(
|
||||
ctx.msg, ctx.session, pending_ask_id
|
||||
ctx.msg, ctx.session
|
||||
)
|
||||
|
||||
if ctx.on_progress is None:
|
||||
|
||||
@ -13,7 +13,6 @@ from typing import Any
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.tools.ask import AskUserInterrupt
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.utils.helpers import (
|
||||
@ -283,22 +282,18 @@ class AgentRunner:
|
||||
self._accumulate_usage(usage, raw_usage)
|
||||
|
||||
if response.should_execute_tools:
|
||||
tool_calls = list(response.tool_calls)
|
||||
ask_index = next((i for i, tc in enumerate(tool_calls) if tc.name == "ask_user"), None)
|
||||
if ask_index is not None:
|
||||
tool_calls = tool_calls[: ask_index + 1]
|
||||
context.tool_calls = list(tool_calls)
|
||||
context.tool_calls = list(response.tool_calls)
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=True)
|
||||
|
||||
assistant_message = build_assistant_message(
|
||||
response.content or "",
|
||||
tool_calls=[tc.to_openai_tool_call() for tc in tool_calls],
|
||||
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
messages.append(assistant_message)
|
||||
tools_used.extend(tc.name for tc in tool_calls)
|
||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
@ -307,7 +302,7 @@ class AgentRunner:
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [tc.to_openai_tool_call() for tc in tool_calls],
|
||||
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
},
|
||||
)
|
||||
|
||||
@ -315,7 +310,7 @@ class AgentRunner:
|
||||
|
||||
results, new_events, fatal_error = await self._execute_tools(
|
||||
spec,
|
||||
tool_calls,
|
||||
response.tool_calls,
|
||||
external_lookup_counts,
|
||||
workspace_violation_counts,
|
||||
)
|
||||
@ -323,9 +318,7 @@ class AgentRunner:
|
||||
context.tool_results = list(results)
|
||||
context.tool_events = list(new_events)
|
||||
completed_tool_results: list[dict[str, Any]] = []
|
||||
for tool_call, result in zip(tool_calls, results):
|
||||
if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user":
|
||||
continue
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
@ -340,15 +333,6 @@ class AgentRunner:
|
||||
messages.append(tool_message)
|
||||
completed_tool_results.append(tool_message)
|
||||
if fatal_error is not None:
|
||||
if isinstance(fatal_error, AskUserInterrupt):
|
||||
final_content = fatal_error.question
|
||||
stop_reason = "ask_user"
|
||||
context.final_content = final_content
|
||||
context.stop_reason = stop_reason
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||
final_content = error
|
||||
stop_reason = "tool_error"
|
||||
@ -724,10 +708,6 @@ class AgentRunner:
|
||||
)
|
||||
tool_results.append(result)
|
||||
batch_results.append(result)
|
||||
if isinstance(result[2], AskUserInterrupt):
|
||||
break
|
||||
if any(isinstance(error, AskUserInterrupt) for _, _, error in batch_results):
|
||||
break
|
||||
|
||||
results: list[Any] = []
|
||||
events: list[dict[str, str]] = []
|
||||
@ -799,9 +779,6 @@ class AgentRunner:
|
||||
"status": "error",
|
||||
"detail": str(exc),
|
||||
}
|
||||
if isinstance(exc, AskUserInterrupt):
|
||||
event["status"] = "waiting"
|
||||
return "", event, exc
|
||||
payload = f"Error: {type(exc).__name__}: {exc}"
|
||||
handled = self._classify_violation(
|
||||
raw_text=str(exc),
|
||||
|
||||
@ -1,136 +0,0 @@
|
||||
"""Tool for pausing a turn until the user answers."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
||||
|
||||
STRUCTURED_BUTTON_CHANNELS = frozenset({"telegram", "websocket"})
|
||||
|
||||
|
||||
class AskUserInterrupt(BaseException):
|
||||
"""Internal signal: the runner should stop and wait for user input."""
|
||||
|
||||
def __init__(self, question: str, options: list[str] | None = None) -> None:
|
||||
self.question = question
|
||||
self.options = [str(option) for option in (options or []) if str(option)]
|
||||
super().__init__(question)
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
question=StringSchema(
|
||||
"The question to ask before continuing. Use this only when the task needs the user's answer."
|
||||
),
|
||||
options=ArraySchema(
|
||||
StringSchema("A possible answer label"),
|
||||
description="Optional choices. The user may still reply with free text.",
|
||||
),
|
||||
required=["question"],
|
||||
)
|
||||
)
|
||||
class AskUserTool(Tool):
|
||||
"""Ask the user a blocking question."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "ask_user"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Pause and ask the user a question when their answer is required to continue. "
|
||||
"Use options for likely answers; the user's reply, typed or selected, is returned as the tool result. "
|
||||
"For non-blocking notifications or buttons, use the message tool instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any:
|
||||
raise AskUserInterrupt(question=question, options=options)
|
||||
|
||||
|
||||
def _tool_call_name(tool_call: dict[str, Any]) -> str:
|
||||
function = tool_call.get("function")
|
||||
if isinstance(function, dict) and isinstance(function.get("name"), str):
|
||||
return function["name"]
|
||||
name = tool_call.get("name")
|
||||
return name if isinstance(name, str) else ""
|
||||
|
||||
|
||||
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
|
||||
function = tool_call.get("function")
|
||||
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
return {}
|
||||
|
||||
|
||||
def pending_ask_user_id(history: list[dict[str, Any]]) -> str | None:
|
||||
pending: dict[str, str] = {}
|
||||
for message in history:
|
||||
if message.get("role") == "assistant":
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
|
||||
pending[tool_call["id"]] = _tool_call_name(tool_call)
|
||||
elif message.get("role") == "tool":
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
pending.pop(tool_call_id, None)
|
||||
for tool_call_id, name in reversed(pending.items()):
|
||||
if name == "ask_user":
|
||||
return tool_call_id
|
||||
return None
|
||||
|
||||
|
||||
def ask_user_tool_result_messages(
|
||||
system_prompt: str,
|
||||
history: list[dict[str, Any]],
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
*history,
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": "ask_user",
|
||||
"content": content,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def ask_user_options_from_messages(messages: list[dict[str, Any]]) -> list[str]:
|
||||
for message in reversed(messages):
|
||||
if message.get("role") != "assistant":
|
||||
continue
|
||||
for tool_call in reversed(message.get("tool_calls") or []):
|
||||
if not isinstance(tool_call, dict) or _tool_call_name(tool_call) != "ask_user":
|
||||
continue
|
||||
options = _tool_call_arguments(tool_call).get("options")
|
||||
if isinstance(options, list):
|
||||
return [str(option) for option in options if isinstance(option, str)]
|
||||
return []
|
||||
|
||||
|
||||
def ask_user_outbound(
|
||||
content: str | None,
|
||||
options: list[str],
|
||||
channel: str,
|
||||
) -> tuple[str | None, list[list[str]]]:
|
||||
if not options:
|
||||
return content, []
|
||||
if channel in STRUCTURED_BUTTON_CHANNELS:
|
||||
return content, [options]
|
||||
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
|
||||
return f"{content}\n\n{option_text}" if content else option_text, []
|
||||
@ -471,7 +471,7 @@ class SlackChannel(BaseChannel):
|
||||
return preview.startswith(_HTML_DOWNLOAD_PREFIXES)
|
||||
|
||||
async def _on_block_action(self, client: SocketModeClient, req: SocketModeRequest) -> None:
|
||||
"""Handle button clicks from ask_user blocks."""
|
||||
"""Handle button clicks from inline action buttons."""
|
||||
await client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id))
|
||||
payload = req.payload or {}
|
||||
actions = payload.get("actions") or []
|
||||
@ -568,7 +568,7 @@ class SlackChannel(BaseChannel):
|
||||
|
||||
@staticmethod
|
||||
def _build_button_blocks(text: str, buttons: list[list[str]]) -> list[dict[str, Any]]:
|
||||
"""Build Slack Block Kit blocks with action buttons for ask_user choices."""
|
||||
"""Build Slack Block Kit blocks with action buttons."""
|
||||
blocks: list[dict[str, Any]] = [
|
||||
{"type": "section", "text": {"type": "mrkdwn", "text": text[:3000]}},
|
||||
]
|
||||
@ -579,7 +579,7 @@ class SlackChannel(BaseChannel):
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": label[:75]},
|
||||
"value": label[:75],
|
||||
"action_id": f"ask_user_{label[:50]}",
|
||||
"action_id": f"btn_{label[:50]}",
|
||||
})
|
||||
if elements:
|
||||
blocks.append({"type": "actions", "elements": elements[:25]})
|
||||
|
||||
@ -55,14 +55,6 @@ def _normalize_config_path(path: str) -> str:
|
||||
return _strip_trailing_slash(path)
|
||||
|
||||
|
||||
def _append_buttons_as_text(text: str, buttons: list[list[str]]) -> str:
|
||||
labels = [label for row in buttons for label in row if label]
|
||||
if not labels:
|
||||
return text
|
||||
fallback = "\n".join(f"{index}. {label}" for index, label in enumerate(labels, 1))
|
||||
return f"{text}\n\n{fallback}" if text else fallback
|
||||
|
||||
|
||||
class WebSocketConfig(Base):
|
||||
"""WebSocket server channel configuration.
|
||||
|
||||
@ -1468,16 +1460,11 @@ class WebSocketChannel(BaseChannel):
|
||||
await self.send_session_updated(msg.chat_id)
|
||||
return
|
||||
text = msg.content
|
||||
if msg.buttons:
|
||||
text = _append_buttons_as_text(text, msg.buttons)
|
||||
payload: dict[str, Any] = {
|
||||
"event": "message",
|
||||
"chat_id": msg.chat_id,
|
||||
"text": text,
|
||||
}
|
||||
if msg.buttons:
|
||||
payload["buttons"] = msg.buttons
|
||||
payload["button_prompt"] = msg.content
|
||||
if msg.media:
|
||||
payload["media"] = msg.media
|
||||
urls: list[dict[str, str]] = []
|
||||
|
||||
@ -11,7 +11,7 @@ Generate a personalized upgrade skill for this workspace.
|
||||
|
||||
Use `read_file` to check if `skills/update/SKILL.md` already exists in the workspace.
|
||||
|
||||
If it exists, use `ask_user` to ask: "An upgrade skill already exists. Reconfigure?" with options ["yes", "no"]. If no, stop here.
|
||||
If it exists, ask the user: "An upgrade skill already exists. Reconfigure?" Wait for the user's reply. If no, stop here.
|
||||
|
||||
## Step 2: Current Version and Install Clues
|
||||
|
||||
@ -38,9 +38,9 @@ answer or confirmation, not from inference alone. If you cannot get a clear
|
||||
answer, stop and ask the user to rerun this setup when they know how nanobot was
|
||||
installed.
|
||||
|
||||
Use `ask_user` for the questions below, one question per call. If `ask_user` is
|
||||
not available or cannot collect the answer, ask in normal chat and stop without
|
||||
writing the skill.
|
||||
Ask the user the questions below, one at a time, in your response text. Wait for
|
||||
the user's reply before proceeding to the next question. If you cannot get a clear
|
||||
answer, stop without writing the skill.
|
||||
|
||||
**Question 1 — Install method:**
|
||||
|
||||
|
||||
@ -1,241 +0,0 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.tools.ask import AskUserInterrupt, AskUserTool
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.schema import tool_parameters_schema
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_provider(chat_with_retry):
|
||||
async def chat_stream_with_retry(**kwargs):
|
||||
kwargs.pop("on_content_delta", None)
|
||||
return await chat_with_retry(**kwargs)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = GenerationSettings()
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
return provider
|
||||
|
||||
|
||||
def test_ask_user_tool_schema_and_interrupt():
|
||||
tool = AskUserTool()
|
||||
schema = tool.to_schema()["function"]
|
||||
|
||||
assert schema["name"] == "ask_user"
|
||||
assert "question" in schema["parameters"]["required"]
|
||||
assert schema["parameters"]["properties"]["options"]["type"] == "array"
|
||||
|
||||
with pytest.raises(AskUserInterrupt) as exc:
|
||||
asyncio.run(tool.execute("Continue?", options=["Yes", "No"]))
|
||||
|
||||
assert exc.value.question == "Continue?"
|
||||
assert exc.value.options == ["Yes", "No"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_pauses_on_ask_user_without_executing_later_tools():
|
||||
@tool_parameters(tool_parameters_schema(required=[]))
|
||||
class LaterTool(Tool):
|
||||
called = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "later"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Should not run after ask_user pauses the turn."
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
self.called = True
|
||||
return "later result"
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={"question": "Install this package?", "options": ["Yes", "No"]},
|
||||
),
|
||||
ToolCallRequest(id="call_later", name="later", arguments={}),
|
||||
],
|
||||
)
|
||||
|
||||
later = LaterTool()
|
||||
tools = ToolRegistry()
|
||||
tools.register(AskUserTool())
|
||||
tools.register(later)
|
||||
|
||||
result = await AgentRunner(_make_provider(chat_with_retry)).run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "continue"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=16_000,
|
||||
concurrent_tools=True,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "ask_user"
|
||||
assert result.final_content == "Install this package?"
|
||||
assert "ask_user" in result.tools_used
|
||||
assert later.called is False
|
||||
assert result.messages[-1]["role"] == "assistant"
|
||||
tool_calls = result.messages[-1]["tool_calls"]
|
||||
assert [tool_call["function"]["name"] for tool_call in tool_calls] == ["ask_user"]
|
||||
assert not any(message.get("name") == "ask_user" for message in result.messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_user_text_fallback_resumes_with_next_message(tmp_path):
|
||||
seen_messages: list[list[dict]] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
seen_messages.append(kwargs["messages"])
|
||||
if len(seen_messages) == 1:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={
|
||||
"question": "Install the optional package?",
|
||||
"options": ["Install", "Skip"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
return LLMResponse(content="Skipped install.", usage={})
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=_make_provider(chat_with_retry),
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
pass
|
||||
|
||||
async def on_stream_end(**kwargs) -> None:
|
||||
pass
|
||||
|
||||
first = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up"),
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
assert first is not None
|
||||
assert first.content == "Install the optional package?\n\n1. Install\n2. Skip"
|
||||
assert first.buttons == []
|
||||
assert "_streamed" not in first.metadata
|
||||
|
||||
session = loop.sessions.get_or_create("cli:direct")
|
||||
assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)
|
||||
assert not any(message.get("role") == "tool" and message.get("name") == "ask_user" for message in session.messages)
|
||||
|
||||
second = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="Skip")
|
||||
)
|
||||
|
||||
assert second is not None
|
||||
assert second.content == "Skipped install."
|
||||
assert any(
|
||||
message.get("role") == "tool"
|
||||
and message.get("name") == "ask_user"
|
||||
and message.get("content") == "Skip"
|
||||
for message in seen_messages[-1]
|
||||
)
|
||||
assert not any(
|
||||
message.get("role") == "user" and message.get("content") == "Skip"
|
||||
for message in session.messages
|
||||
)
|
||||
assert any(
|
||||
message.get("role") == "tool"
|
||||
and message.get("name") == "ask_user"
|
||||
and message.get("content") == "Skip"
|
||||
for message in session.messages
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_user_keeps_buttons_for_telegram(tmp_path):
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={
|
||||
"question": "Install the optional package?",
|
||||
"options": ["Install", "Skip"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=_make_provider(chat_with_retry),
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
response = await loop._process_message(
|
||||
InboundMessage(channel="telegram", sender_id="user", chat_id="123", content="set it up")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.content == "Install the optional package?"
|
||||
assert response.buttons == [["Install", "Skip"]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_user_keeps_buttons_for_websocket(tmp_path):
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={
|
||||
"question": "Install the optional package?",
|
||||
"options": ["Install", "Skip"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=_make_provider(chat_with_retry),
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
response = await loop._process_message(
|
||||
InboundMessage(channel="websocket", sender_id="user", chat_id="123", content="set it up")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.content == "Install the optional package?"
|
||||
assert response.buttons == [["Install", "Skip"]]
|
||||
@ -234,13 +234,13 @@ async def test_send_renders_buttons_on_last_message_chunk() -> None:
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Yes"},
|
||||
"value": "Yes",
|
||||
"action_id": "ask_user_Yes",
|
||||
"action_id": "btn_Yes",
|
||||
},
|
||||
{
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "No"},
|
||||
"value": "No",
|
||||
"action_id": "ask_user_No",
|
||||
"action_id": "btn_No",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@ -224,11 +224,9 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||
assert payload["event"] == "message"
|
||||
assert payload["chat_id"] == "chat-1"
|
||||
assert payload["text"] == "hello\n\n1. Yes\n2. No"
|
||||
assert payload["button_prompt"] == "hello"
|
||||
assert payload["text"] == "hello"
|
||||
assert payload["reply_to"] == "m1"
|
||||
assert payload["media"] == ["/tmp/a.png"]
|
||||
assert payload["buttons"] == [["Yes", "No"]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -405,7 +405,7 @@ def test_loader_registers_same_tools_as_old_hardcoded():
|
||||
registered = loader.load(ctx, registry)
|
||||
|
||||
expected = {
|
||||
"ask_user", "read_file", "write_file", "edit_file", "list_dir",
|
||||
"read_file", "write_file", "edit_file", "list_dir",
|
||||
"glob", "grep", "notebook_edit", "exec", "web_search", "web_fetch",
|
||||
"message", "spawn", "cron",
|
||||
}
|
||||
|
||||
@ -1,108 +0,0 @@
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { MessageSquareText } from "lucide-react";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface AskUserPromptProps {
|
||||
question: string;
|
||||
buttons: string[][];
|
||||
onAnswer: (answer: string) => void;
|
||||
}
|
||||
|
||||
export function AskUserPrompt({
|
||||
question,
|
||||
buttons,
|
||||
onAnswer,
|
||||
}: AskUserPromptProps) {
|
||||
const [customOpen, setCustomOpen] = useState(false);
|
||||
const [custom, setCustom] = useState("");
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
const options = buttons.flat().filter(Boolean);
|
||||
|
||||
useEffect(() => {
|
||||
if (customOpen) {
|
||||
inputRef.current?.focus();
|
||||
}
|
||||
}, [customOpen]);
|
||||
|
||||
const submitCustom = useCallback(() => {
|
||||
const answer = custom.trim();
|
||||
if (!answer) return;
|
||||
onAnswer(answer);
|
||||
setCustom("");
|
||||
setCustomOpen(false);
|
||||
}, [custom, onAnswer]);
|
||||
|
||||
if (options.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"mx-auto mb-2 w-full max-w-[49.5rem] rounded-[16px] border border-primary/30",
|
||||
"bg-card/95 p-3 shadow-sm backdrop-blur",
|
||||
)}
|
||||
role="group"
|
||||
aria-label="Question"
|
||||
>
|
||||
<div className="mb-2 flex items-start gap-2">
|
||||
<div className="mt-0.5 rounded-full bg-primary/10 p-1.5 text-primary">
|
||||
<MessageSquareText className="h-3.5 w-3.5" aria-hidden />
|
||||
</div>
|
||||
<p className="min-w-0 flex-1 text-[13.5px] font-medium leading-5 text-foreground">
|
||||
{question}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-1.5 sm:grid-cols-2">
|
||||
{options.map((option) => (
|
||||
<Button
|
||||
key={option}
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => onAnswer(option)}
|
||||
className="justify-start rounded-[10px] px-3 text-left"
|
||||
>
|
||||
<span className="truncate">{option}</span>
|
||||
</Button>
|
||||
))}
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => setCustomOpen((open) => !open)}
|
||||
className="justify-start rounded-[10px] px-3 text-muted-foreground"
|
||||
>
|
||||
Other...
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{customOpen ? (
|
||||
<div className="mt-2 flex gap-2">
|
||||
<textarea
|
||||
ref={inputRef}
|
||||
value={custom}
|
||||
onChange={(event) => setCustom(event.target.value)}
|
||||
onKeyDown={(event) => {
|
||||
if (event.key === "Enter" && !event.shiftKey && !event.nativeEvent.isComposing) {
|
||||
event.preventDefault();
|
||||
submitCustom();
|
||||
}
|
||||
}}
|
||||
rows={1}
|
||||
placeholder="Type your own answer..."
|
||||
className={cn(
|
||||
"min-h-9 flex-1 resize-none rounded-[10px] border border-border/70 bg-background",
|
||||
"px-3 py-2 text-[13.5px] leading-5 outline-none placeholder:text-muted-foreground",
|
||||
"focus-visible:ring-1 focus-visible:ring-primary/40",
|
||||
)}
|
||||
/>
|
||||
<Button type="button" size="sm" onClick={submitCustom} disabled={!custom.trim()}>
|
||||
Send
|
||||
</Button>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -13,7 +13,6 @@ import {
|
||||
} from "lucide-react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
|
||||
import { AskUserPrompt } from "@/components/thread/AskUserPrompt";
|
||||
import { ThreadComposer } from "@/components/thread/ThreadComposer";
|
||||
import { ThreadHeader } from "@/components/thread/ThreadHeader";
|
||||
import { StreamErrorNotice } from "@/components/thread/StreamErrorNotice";
|
||||
@ -105,21 +104,6 @@ export function ThreadShell({
|
||||
dismissStreamError,
|
||||
} = useNanobotStream(chatId, initial, hasPendingToolCalls, onTurnEnd);
|
||||
const showHeroComposer = messages.length === 0 && !loading;
|
||||
const pendingAsk = useMemo(() => {
|
||||
for (let index = messages.length - 1; index >= 0; index -= 1) {
|
||||
const message = messages[index];
|
||||
if (message.kind === "trace") continue;
|
||||
if (message.role === "user") return null;
|
||||
if (message.role === "assistant" && message.buttons?.some((row) => row.length > 0)) {
|
||||
return {
|
||||
question: message.content,
|
||||
buttons: message.buttons,
|
||||
};
|
||||
}
|
||||
if (message.role === "assistant") return null;
|
||||
}
|
||||
return null;
|
||||
}, [messages]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!chatId || loading) return;
|
||||
@ -247,13 +231,6 @@ export function ThreadShell({
|
||||
onDismiss={dismissStreamError}
|
||||
/>
|
||||
) : null}
|
||||
{pendingAsk ? (
|
||||
<AskUserPrompt
|
||||
question={pendingAsk.question}
|
||||
buttons={pendingAsk.buttons}
|
||||
onAnswer={send}
|
||||
/>
|
||||
) : null}
|
||||
{session ? (
|
||||
<ThreadComposer
|
||||
onSend={send}
|
||||
|
||||
@ -230,7 +230,7 @@ export function useNanobotStream(
|
||||
// the full turn (all tool calls + final text) is complete.
|
||||
setMessages((prev) => {
|
||||
const filtered = activeId ? prev.filter((m) => m.id !== activeId) : prev;
|
||||
const content = ev.buttons?.length ? (ev.button_prompt ?? ev.text) : ev.text;
|
||||
const content = ev.text;
|
||||
return [
|
||||
...filtered,
|
||||
{
|
||||
@ -238,7 +238,6 @@ export function useNanobotStream(
|
||||
role: "assistant",
|
||||
content,
|
||||
createdAt: Date.now(),
|
||||
...(ev.buttons && ev.buttons.length > 0 ? { buttons: ev.buttons } : {}),
|
||||
...(hasMedia ? { media } : {}),
|
||||
},
|
||||
];
|
||||
|
||||
@ -44,8 +44,6 @@ export interface UIMessage {
|
||||
images?: UIImage[];
|
||||
/** Signed or local UI-renderable media attachments. */
|
||||
media?: UIMediaAttachment[];
|
||||
/** Optional answer choices for a pending ask_user question. */
|
||||
buttons?: string[][];
|
||||
}
|
||||
|
||||
export interface ChatSummary {
|
||||
@ -141,9 +139,6 @@ export type InboundEvent =
|
||||
reply_to?: string;
|
||||
media?: string[];
|
||||
media_urls?: Array<{ url: string; name?: string }>;
|
||||
buttons?: string[][];
|
||||
/** Original prompt before the websocket text fallback appends buttons. */
|
||||
button_prompt?: string;
|
||||
/** Present when the frame is an agent breadcrumb (e.g. tool hint,
|
||||
* generic progress line) rather than a conversational reply. */
|
||||
kind?: "tool_hint" | "progress";
|
||||
|
||||
@ -809,46 +809,4 @@ describe("ThreadShell", () => {
|
||||
await waitFor(() => expect(screen.getByText("from chat b")).toBeInTheDocument());
|
||||
expect(screen.queryByText("from chat a")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders ask_user options above the composer and sends selected answers", async () => {
|
||||
const client = makeClient();
|
||||
const onNewChat = vi.fn().mockResolvedValue("chat-a");
|
||||
|
||||
render(
|
||||
wrap(
|
||||
client,
|
||||
<ThreadShell
|
||||
session={session("chat-a")}
|
||||
title="Chat chat-a"
|
||||
onToggleSidebar={() => {}}
|
||||
onGoHome={() => {}}
|
||||
onNewChat={onNewChat}
|
||||
/>,
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
client._emitChat("chat-a", {
|
||||
event: "message",
|
||||
chat_id: "chat-a",
|
||||
text: "How should I continue?",
|
||||
buttons: [["Short answer", "Detailed answer"]],
|
||||
});
|
||||
});
|
||||
|
||||
expect(screen.getByRole("group", { name: "Question" })).toHaveTextContent(
|
||||
"How should I continue?",
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: "Short answer" }));
|
||||
|
||||
expect(client.sendMessage).toHaveBeenCalledWith(
|
||||
"chat-a",
|
||||
"Short answer",
|
||||
undefined,
|
||||
);
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole("group", { name: "Question" })).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -217,29 +217,6 @@ describe("useNanobotStream", () => {
|
||||
expect(result.current.messages[0].content).toBe("long task");
|
||||
});
|
||||
|
||||
it("keeps assistant buttons on complete messages", () => {
|
||||
const fake = fakeClient();
|
||||
const { result } = renderHook(() => useNanobotStream("chat-q", EMPTY_MESSAGES), {
|
||||
wrapper: wrap(fake.client),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-q", {
|
||||
event: "message",
|
||||
chat_id: "chat-q",
|
||||
text: "How should I continue?\n\n1. Short answer\n2. Detailed answer",
|
||||
button_prompt: "How should I continue?",
|
||||
buttons: [["Short answer", "Detailed answer"]],
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.current.messages).toHaveLength(1);
|
||||
expect(result.current.messages[0].content).toBe("How should I continue?");
|
||||
expect(result.current.messages[0].buttons).toEqual([
|
||||
["Short answer", "Detailed answer"],
|
||||
]);
|
||||
});
|
||||
|
||||
it("keeps streaming alive across stream_end and completes on turn_end", () => {
|
||||
const fake = fakeClient();
|
||||
const onTurnEnd = vi.fn();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user