refactor(agent): remove ask_user tool

The ask_user tool used AskUserInterrupt(BaseException) for mid-turn
blocking, creating heavy coupling across runner, loop, and session
management. The model now asks questions naturally in response text,
the turn ends normally, and the user's next message starts a new turn
with session history providing continuity.

Removed:
- nanobot/agent/tools/ask.py (tool, interrupt, helpers)
- tests/agent/test_ask_user.py
- webui/src/components/thread/AskUserPrompt.tsx
- AskUserInterrupt handling in runner.py
- Dual-path message building in loop.py
- Pending ask detection via history scanning
- button_prompt/buttons emission in WebSocket channel
- ask_user references in Slack channel docstrings

Preserved (MessageTool uses these independently):
- OutboundMessage.buttons field
- Channel button rendering (Telegram, Slack, WebSocket)
This commit is contained in:
chengyongru 2026-05-12 18:36:03 +08:00 committed by Xubin Ren
parent 07f9ab580a
commit 9e15925cf4
16 changed files with 24 additions and 673 deletions

View File

@ -22,12 +22,6 @@ from nanobot.agent.memory import Consolidator, Dream
from nanobot.agent import model_presets as preset_helpers from nanobot.agent import model_presets as preset_helpers
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.ask import (
ask_user_options_from_messages,
ask_user_outbound,
ask_user_tool_result_messages,
pending_ask_user_id,
)
from nanobot.agent.tools.file_state import FileStateStore, bind_file_states, reset_file_states from nanobot.agent.tools.file_state import FileStateStore, bind_file_states, reset_file_states
from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.registry import ToolRegistry
@ -693,7 +687,6 @@ class AgentLoop:
self, self,
msg: InboundMessage, msg: InboundMessage,
session: Session, session: Session,
pending_ask_id: str | None,
) -> bool: ) -> bool:
"""Persist the triggering user message before the turn starts. """Persist the triggering user message before the turn starts.
@ -701,7 +694,7 @@ class AgentLoop:
""" """
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p] media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
has_text = isinstance(msg.content, str) and msg.content.strip() has_text = isinstance(msg.content, str) and msg.content.strip()
if not pending_ask_id and (has_text or media_paths): if has_text or media_paths:
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {} extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
text = msg.content if isinstance(msg.content, str) else "" text = msg.content if isinstance(msg.content, str) else ""
session.add_message("user", text, **extra) session.add_message("user", text, **extra)
@ -715,21 +708,9 @@ class AgentLoop:
msg: InboundMessage, msg: InboundMessage,
session: Session, session: Session,
history: list[dict[str, Any]], history: list[dict[str, Any]],
pending_ask_id: str | None,
pending_summary: str | None, pending_summary: str | None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Build the initial message list for the LLM turn.""" """Build the initial message list for the LLM turn."""
if pending_ask_id:
system_prompt = self.context.build_system_prompt(
channel=msg.channel,
session_summary=pending_summary,
)
return ask_user_tool_result_messages(
system_prompt,
history,
pending_ask_id,
image_generation_prompt(msg.content, msg.metadata),
)
return self.context.build_messages( return self.context.build_messages(
history=history, history=history,
current_message=image_generation_prompt(msg.content, msg.metadata), current_message=image_generation_prompt(msg.content, msg.metadata),
@ -1237,12 +1218,7 @@ class AgentLoop:
replay_max_messages=self._max_messages, replay_max_messages=self._max_messages,
) )
) )
options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [] content = final_content or "Background task completed."
content, buttons = ask_user_outbound(
final_content or "Background task completed.",
options,
channel,
)
outbound_metadata: dict[str, Any] = {} outbound_metadata: dict[str, Any] = {}
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2: if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]} outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
@ -1252,7 +1228,6 @@ class AgentLoop:
channel=channel, channel=channel,
chat_id=chat_id, chat_id=chat_id,
content=content, content=content,
buttons=buttons,
metadata=outbound_metadata, metadata=outbound_metadata,
) )
@ -1365,21 +1340,15 @@ class AgentLoop:
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
meta = dict(msg.metadata or {}) meta = dict(msg.metadata or {})
content, buttons = ask_user_outbound( if on_stream is not None and stop_reason not in {"error", "tool_error"}:
final_content,
ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [],
msg.channel,
)
if on_stream is not None and stop_reason not in {"ask_user", "error", "tool_error"}:
meta["_streamed"] = True meta["_streamed"] = True
return OutboundMessage( return OutboundMessage(
channel=msg.channel, channel=msg.channel,
chat_id=msg.chat_id, chat_id=msg.chat_id,
content=content, content=final_content,
media=generated_media, media=generated_media,
metadata=meta, metadata=meta,
buttons=buttons,
) )
async def _state_restore(self, ctx: TurnContext) -> TurnState: async def _state_restore(self, ctx: TurnContext) -> TurnState:
@ -1446,12 +1415,11 @@ class AgentLoop:
} }
ctx.history = ctx.session.get_history(**_hist_kwargs) ctx.history = ctx.session.get_history(**_hist_kwargs)
pending_ask_id = pending_ask_user_id(ctx.history)
ctx.initial_messages = self._build_initial_messages( ctx.initial_messages = self._build_initial_messages(
ctx.msg, ctx.session, ctx.history, pending_ask_id, ctx.pending_summary ctx.msg, ctx.session, ctx.history, ctx.pending_summary
) )
ctx.user_persisted_early = self._persist_user_message_early( ctx.user_persisted_early = self._persist_user_message_early(
ctx.msg, ctx.session, pending_ask_id ctx.msg, ctx.session
) )
if ctx.on_progress is None: if ctx.on_progress is None:

View File

@ -13,7 +13,6 @@ from typing import Any
from loguru import logger from loguru import logger
from nanobot.agent.hook import AgentHook, AgentHookContext from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.tools.ask import AskUserInterrupt
from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.utils.helpers import ( from nanobot.utils.helpers import (
@ -283,22 +282,18 @@ class AgentRunner:
self._accumulate_usage(usage, raw_usage) self._accumulate_usage(usage, raw_usage)
if response.should_execute_tools: if response.should_execute_tools:
tool_calls = list(response.tool_calls) context.tool_calls = list(response.tool_calls)
ask_index = next((i for i, tc in enumerate(tool_calls) if tc.name == "ask_user"), None)
if ask_index is not None:
tool_calls = tool_calls[: ask_index + 1]
context.tool_calls = list(tool_calls)
if hook.wants_streaming(): if hook.wants_streaming():
await hook.on_stream_end(context, resuming=True) await hook.on_stream_end(context, resuming=True)
assistant_message = build_assistant_message( assistant_message = build_assistant_message(
response.content or "", response.content or "",
tool_calls=[tc.to_openai_tool_call() for tc in tool_calls], tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks, thinking_blocks=response.thinking_blocks,
) )
messages.append(assistant_message) messages.append(assistant_message)
tools_used.extend(tc.name for tc in tool_calls) tools_used.extend(tc.name for tc in response.tool_calls)
await self._emit_checkpoint( await self._emit_checkpoint(
spec, spec,
{ {
@ -307,7 +302,7 @@ class AgentRunner:
"model": spec.model, "model": spec.model,
"assistant_message": assistant_message, "assistant_message": assistant_message,
"completed_tool_results": [], "completed_tool_results": [],
"pending_tool_calls": [tc.to_openai_tool_call() for tc in tool_calls], "pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
}, },
) )
@ -315,7 +310,7 @@ class AgentRunner:
results, new_events, fatal_error = await self._execute_tools( results, new_events, fatal_error = await self._execute_tools(
spec, spec,
tool_calls, response.tool_calls,
external_lookup_counts, external_lookup_counts,
workspace_violation_counts, workspace_violation_counts,
) )
@ -323,9 +318,7 @@ class AgentRunner:
context.tool_results = list(results) context.tool_results = list(results)
context.tool_events = list(new_events) context.tool_events = list(new_events)
completed_tool_results: list[dict[str, Any]] = [] completed_tool_results: list[dict[str, Any]] = []
for tool_call, result in zip(tool_calls, results): for tool_call, result in zip(response.tool_calls, results):
if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user":
continue
tool_message = { tool_message = {
"role": "tool", "role": "tool",
"tool_call_id": tool_call.id, "tool_call_id": tool_call.id,
@ -340,15 +333,6 @@ class AgentRunner:
messages.append(tool_message) messages.append(tool_message)
completed_tool_results.append(tool_message) completed_tool_results.append(tool_message)
if fatal_error is not None: if fatal_error is not None:
if isinstance(fatal_error, AskUserInterrupt):
final_content = fatal_error.question
stop_reason = "ask_user"
context.final_content = final_content
context.stop_reason = stop_reason
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)
await hook.after_iteration(context)
break
error = f"Error: {type(fatal_error).__name__}: {fatal_error}" error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
final_content = error final_content = error
stop_reason = "tool_error" stop_reason = "tool_error"
@ -724,10 +708,6 @@ class AgentRunner:
) )
tool_results.append(result) tool_results.append(result)
batch_results.append(result) batch_results.append(result)
if isinstance(result[2], AskUserInterrupt):
break
if any(isinstance(error, AskUserInterrupt) for _, _, error in batch_results):
break
results: list[Any] = [] results: list[Any] = []
events: list[dict[str, str]] = [] events: list[dict[str, str]] = []
@ -799,9 +779,6 @@ class AgentRunner:
"status": "error", "status": "error",
"detail": str(exc), "detail": str(exc),
} }
if isinstance(exc, AskUserInterrupt):
event["status"] = "waiting"
return "", event, exc
payload = f"Error: {type(exc).__name__}: {exc}" payload = f"Error: {type(exc).__name__}: {exc}"
handled = self._classify_violation( handled = self._classify_violation(
raw_text=str(exc), raw_text=str(exc),

View File

@ -1,136 +0,0 @@
"""Tool for pausing a turn until the user answers."""
import json
from typing import Any
from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
STRUCTURED_BUTTON_CHANNELS = frozenset({"telegram", "websocket"})
class AskUserInterrupt(BaseException):
"""Internal signal: the runner should stop and wait for user input."""
def __init__(self, question: str, options: list[str] | None = None) -> None:
self.question = question
self.options = [str(option) for option in (options or []) if str(option)]
super().__init__(question)
@tool_parameters(
tool_parameters_schema(
question=StringSchema(
"The question to ask before continuing. Use this only when the task needs the user's answer."
),
options=ArraySchema(
StringSchema("A possible answer label"),
description="Optional choices. The user may still reply with free text.",
),
required=["question"],
)
)
class AskUserTool(Tool):
"""Ask the user a blocking question."""
@property
def name(self) -> str:
return "ask_user"
@property
def description(self) -> str:
return (
"Pause and ask the user a question when their answer is required to continue. "
"Use options for likely answers; the user's reply, typed or selected, is returned as the tool result. "
"For non-blocking notifications or buttons, use the message tool instead."
)
@property
def exclusive(self) -> bool:
return True
async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any:
raise AskUserInterrupt(question=question, options=options)
def _tool_call_name(tool_call: dict[str, Any]) -> str:
function = tool_call.get("function")
if isinstance(function, dict) and isinstance(function.get("name"), str):
return function["name"]
name = tool_call.get("name")
return name if isinstance(name, str) else ""
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
function = tool_call.get("function")
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return {}
return parsed if isinstance(parsed, dict) else {}
return {}
def pending_ask_user_id(history: list[dict[str, Any]]) -> str | None:
pending: dict[str, str] = {}
for message in history:
if message.get("role") == "assistant":
for tool_call in message.get("tool_calls") or []:
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
pending[tool_call["id"]] = _tool_call_name(tool_call)
elif message.get("role") == "tool":
tool_call_id = message.get("tool_call_id")
if isinstance(tool_call_id, str):
pending.pop(tool_call_id, None)
for tool_call_id, name in reversed(pending.items()):
if name == "ask_user":
return tool_call_id
return None
def ask_user_tool_result_messages(
system_prompt: str,
history: list[dict[str, Any]],
tool_call_id: str,
content: str,
) -> list[dict[str, Any]]:
return [
{"role": "system", "content": system_prompt},
*history,
{
"role": "tool",
"tool_call_id": tool_call_id,
"name": "ask_user",
"content": content,
},
]
def ask_user_options_from_messages(messages: list[dict[str, Any]]) -> list[str]:
for message in reversed(messages):
if message.get("role") != "assistant":
continue
for tool_call in reversed(message.get("tool_calls") or []):
if not isinstance(tool_call, dict) or _tool_call_name(tool_call) != "ask_user":
continue
options = _tool_call_arguments(tool_call).get("options")
if isinstance(options, list):
return [str(option) for option in options if isinstance(option, str)]
return []
def ask_user_outbound(
content: str | None,
options: list[str],
channel: str,
) -> tuple[str | None, list[list[str]]]:
if not options:
return content, []
if channel in STRUCTURED_BUTTON_CHANNELS:
return content, [options]
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
return f"{content}\n\n{option_text}" if content else option_text, []

View File

@ -471,7 +471,7 @@ class SlackChannel(BaseChannel):
return preview.startswith(_HTML_DOWNLOAD_PREFIXES) return preview.startswith(_HTML_DOWNLOAD_PREFIXES)
async def _on_block_action(self, client: SocketModeClient, req: SocketModeRequest) -> None: async def _on_block_action(self, client: SocketModeClient, req: SocketModeRequest) -> None:
"""Handle button clicks from ask_user blocks.""" """Handle button clicks from inline action buttons."""
await client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id)) await client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id))
payload = req.payload or {} payload = req.payload or {}
actions = payload.get("actions") or [] actions = payload.get("actions") or []
@ -568,7 +568,7 @@ class SlackChannel(BaseChannel):
@staticmethod @staticmethod
def _build_button_blocks(text: str, buttons: list[list[str]]) -> list[dict[str, Any]]: def _build_button_blocks(text: str, buttons: list[list[str]]) -> list[dict[str, Any]]:
"""Build Slack Block Kit blocks with action buttons for ask_user choices.""" """Build Slack Block Kit blocks with action buttons."""
blocks: list[dict[str, Any]] = [ blocks: list[dict[str, Any]] = [
{"type": "section", "text": {"type": "mrkdwn", "text": text[:3000]}}, {"type": "section", "text": {"type": "mrkdwn", "text": text[:3000]}},
] ]
@ -579,7 +579,7 @@ class SlackChannel(BaseChannel):
"type": "button", "type": "button",
"text": {"type": "plain_text", "text": label[:75]}, "text": {"type": "plain_text", "text": label[:75]},
"value": label[:75], "value": label[:75],
"action_id": f"ask_user_{label[:50]}", "action_id": f"btn_{label[:50]}",
}) })
if elements: if elements:
blocks.append({"type": "actions", "elements": elements[:25]}) blocks.append({"type": "actions", "elements": elements[:25]})

View File

@ -55,14 +55,6 @@ def _normalize_config_path(path: str) -> str:
return _strip_trailing_slash(path) return _strip_trailing_slash(path)
def _append_buttons_as_text(text: str, buttons: list[list[str]]) -> str:
labels = [label for row in buttons for label in row if label]
if not labels:
return text
fallback = "\n".join(f"{index}. {label}" for index, label in enumerate(labels, 1))
return f"{text}\n\n{fallback}" if text else fallback
class WebSocketConfig(Base): class WebSocketConfig(Base):
"""WebSocket server channel configuration. """WebSocket server channel configuration.
@ -1468,16 +1460,11 @@ class WebSocketChannel(BaseChannel):
await self.send_session_updated(msg.chat_id) await self.send_session_updated(msg.chat_id)
return return
text = msg.content text = msg.content
if msg.buttons:
text = _append_buttons_as_text(text, msg.buttons)
payload: dict[str, Any] = { payload: dict[str, Any] = {
"event": "message", "event": "message",
"chat_id": msg.chat_id, "chat_id": msg.chat_id,
"text": text, "text": text,
} }
if msg.buttons:
payload["buttons"] = msg.buttons
payload["button_prompt"] = msg.content
if msg.media: if msg.media:
payload["media"] = msg.media payload["media"] = msg.media
urls: list[dict[str, str]] = [] urls: list[dict[str, str]] = []

View File

@ -11,7 +11,7 @@ Generate a personalized upgrade skill for this workspace.
Use `read_file` to check if `skills/update/SKILL.md` already exists in the workspace. Use `read_file` to check if `skills/update/SKILL.md` already exists in the workspace.
If it exists, use `ask_user` to ask: "An upgrade skill already exists. Reconfigure?" with options ["yes", "no"]. If no, stop here. If it exists, ask the user: "An upgrade skill already exists. Reconfigure?" Wait for the user's reply. If no, stop here.
## Step 2: Current Version and Install Clues ## Step 2: Current Version and Install Clues
@ -38,9 +38,9 @@ answer or confirmation, not from inference alone. If you cannot get a clear
answer, stop and ask the user to rerun this setup when they know how nanobot was answer, stop and ask the user to rerun this setup when they know how nanobot was
installed. installed.
Use `ask_user` for the questions below, one question per call. If `ask_user` is Ask the user the questions below, one at a time, in your response text. Wait for
not available or cannot collect the answer, ask in normal chat and stop without the user's reply before proceeding to the next question. If you cannot get a clear
writing the skill. answer, stop without writing the skill.
**Question 1 — Install method:** **Question 1 — Install method:**

View File

@ -1,241 +0,0 @@
import asyncio
from unittest.mock import MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.agent.runner import AgentRunner, AgentRunSpec
from nanobot.agent.tools.ask import AskUserInterrupt, AskUserTool
from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.schema import tool_parameters_schema
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequest
def _make_provider(chat_with_retry):
async def chat_stream_with_retry(**kwargs):
kwargs.pop("on_content_delta", None)
return await chat_with_retry(**kwargs)
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.generation = GenerationSettings()
provider.chat_with_retry = chat_with_retry
provider.chat_stream_with_retry = chat_stream_with_retry
return provider
def test_ask_user_tool_schema_and_interrupt():
tool = AskUserTool()
schema = tool.to_schema()["function"]
assert schema["name"] == "ask_user"
assert "question" in schema["parameters"]["required"]
assert schema["parameters"]["properties"]["options"]["type"] == "array"
with pytest.raises(AskUserInterrupt) as exc:
asyncio.run(tool.execute("Continue?", options=["Yes", "No"]))
assert exc.value.question == "Continue?"
assert exc.value.options == ["Yes", "No"]
@pytest.mark.asyncio
async def test_runner_pauses_on_ask_user_without_executing_later_tools():
@tool_parameters(tool_parameters_schema(required=[]))
class LaterTool(Tool):
called = False
@property
def name(self) -> str:
return "later"
@property
def description(self) -> str:
return "Should not run after ask_user pauses the turn."
async def execute(self, **kwargs):
self.called = True
return "later result"
async def chat_with_retry(**kwargs):
return LLMResponse(
content="",
finish_reason="tool_calls",
tool_calls=[
ToolCallRequest(
id="call_ask",
name="ask_user",
arguments={"question": "Install this package?", "options": ["Yes", "No"]},
),
ToolCallRequest(id="call_later", name="later", arguments={}),
],
)
later = LaterTool()
tools = ToolRegistry()
tools.register(AskUserTool())
tools.register(later)
result = await AgentRunner(_make_provider(chat_with_retry)).run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "continue"}],
tools=tools,
model="test-model",
max_iterations=3,
max_tool_result_chars=16_000,
concurrent_tools=True,
))
assert result.stop_reason == "ask_user"
assert result.final_content == "Install this package?"
assert "ask_user" in result.tools_used
assert later.called is False
assert result.messages[-1]["role"] == "assistant"
tool_calls = result.messages[-1]["tool_calls"]
assert [tool_call["function"]["name"] for tool_call in tool_calls] == ["ask_user"]
assert not any(message.get("name") == "ask_user" for message in result.messages)
@pytest.mark.asyncio
async def test_ask_user_text_fallback_resumes_with_next_message(tmp_path):
seen_messages: list[list[dict]] = []
async def chat_with_retry(**kwargs):
seen_messages.append(kwargs["messages"])
if len(seen_messages) == 1:
return LLMResponse(
content="",
finish_reason="tool_calls",
tool_calls=[
ToolCallRequest(
id="call_ask",
name="ask_user",
arguments={
"question": "Install the optional package?",
"options": ["Install", "Skip"],
},
)
],
)
return LLMResponse(content="Skipped install.", usage={})
loop = AgentLoop(
bus=MessageBus(),
provider=_make_provider(chat_with_retry),
workspace=tmp_path,
model="test-model",
)
async def on_stream(delta: str) -> None:
pass
async def on_stream_end(**kwargs) -> None:
pass
first = await loop._process_message(
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up"),
on_stream=on_stream,
on_stream_end=on_stream_end,
)
assert first is not None
assert first.content == "Install the optional package?\n\n1. Install\n2. Skip"
assert first.buttons == []
assert "_streamed" not in first.metadata
session = loop.sessions.get_or_create("cli:direct")
assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)
assert not any(message.get("role") == "tool" and message.get("name") == "ask_user" for message in session.messages)
second = await loop._process_message(
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="Skip")
)
assert second is not None
assert second.content == "Skipped install."
assert any(
message.get("role") == "tool"
and message.get("name") == "ask_user"
and message.get("content") == "Skip"
for message in seen_messages[-1]
)
assert not any(
message.get("role") == "user" and message.get("content") == "Skip"
for message in session.messages
)
assert any(
message.get("role") == "tool"
and message.get("name") == "ask_user"
and message.get("content") == "Skip"
for message in session.messages
)
@pytest.mark.asyncio
async def test_ask_user_keeps_buttons_for_telegram(tmp_path):
async def chat_with_retry(**kwargs):
return LLMResponse(
content="",
finish_reason="tool_calls",
tool_calls=[
ToolCallRequest(
id="call_ask",
name="ask_user",
arguments={
"question": "Install the optional package?",
"options": ["Install", "Skip"],
},
)
],
)
loop = AgentLoop(
bus=MessageBus(),
provider=_make_provider(chat_with_retry),
workspace=tmp_path,
model="test-model",
)
response = await loop._process_message(
InboundMessage(channel="telegram", sender_id="user", chat_id="123", content="set it up")
)
assert response is not None
assert response.content == "Install the optional package?"
assert response.buttons == [["Install", "Skip"]]
@pytest.mark.asyncio
async def test_ask_user_keeps_buttons_for_websocket(tmp_path):
async def chat_with_retry(**kwargs):
return LLMResponse(
content="",
finish_reason="tool_calls",
tool_calls=[
ToolCallRequest(
id="call_ask",
name="ask_user",
arguments={
"question": "Install the optional package?",
"options": ["Install", "Skip"],
},
)
],
)
loop = AgentLoop(
bus=MessageBus(),
provider=_make_provider(chat_with_retry),
workspace=tmp_path,
model="test-model",
)
response = await loop._process_message(
InboundMessage(channel="websocket", sender_id="user", chat_id="123", content="set it up")
)
assert response is not None
assert response.content == "Install the optional package?"
assert response.buttons == [["Install", "Skip"]]

View File

@ -234,13 +234,13 @@ async def test_send_renders_buttons_on_last_message_chunk() -> None:
"type": "button", "type": "button",
"text": {"type": "plain_text", "text": "Yes"}, "text": {"type": "plain_text", "text": "Yes"},
"value": "Yes", "value": "Yes",
"action_id": "ask_user_Yes", "action_id": "btn_Yes",
}, },
{ {
"type": "button", "type": "button",
"text": {"type": "plain_text", "text": "No"}, "text": {"type": "plain_text", "text": "No"},
"value": "No", "value": "No",
"action_id": "ask_user_No", "action_id": "btn_No",
}, },
], ],
} }

View File

@ -224,11 +224,9 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
payload = json.loads(mock_ws.send.call_args[0][0]) payload = json.loads(mock_ws.send.call_args[0][0])
assert payload["event"] == "message" assert payload["event"] == "message"
assert payload["chat_id"] == "chat-1" assert payload["chat_id"] == "chat-1"
assert payload["text"] == "hello\n\n1. Yes\n2. No" assert payload["text"] == "hello"
assert payload["button_prompt"] == "hello"
assert payload["reply_to"] == "m1" assert payload["reply_to"] == "m1"
assert payload["media"] == ["/tmp/a.png"] assert payload["media"] == ["/tmp/a.png"]
assert payload["buttons"] == [["Yes", "No"]]
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -405,7 +405,7 @@ def test_loader_registers_same_tools_as_old_hardcoded():
registered = loader.load(ctx, registry) registered = loader.load(ctx, registry)
expected = { expected = {
"ask_user", "read_file", "write_file", "edit_file", "list_dir", "read_file", "write_file", "edit_file", "list_dir",
"glob", "grep", "notebook_edit", "exec", "web_search", "web_fetch", "glob", "grep", "notebook_edit", "exec", "web_search", "web_fetch",
"message", "spawn", "cron", "message", "spawn", "cron",
} }

View File

@ -1,108 +0,0 @@
import { useCallback, useEffect, useRef, useState } from "react";
import { MessageSquareText } from "lucide-react";
import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils";
interface AskUserPromptProps {
question: string;
buttons: string[][];
onAnswer: (answer: string) => void;
}
export function AskUserPrompt({
question,
buttons,
onAnswer,
}: AskUserPromptProps) {
const [customOpen, setCustomOpen] = useState(false);
const [custom, setCustom] = useState("");
const inputRef = useRef<HTMLTextAreaElement>(null);
const options = buttons.flat().filter(Boolean);
useEffect(() => {
if (customOpen) {
inputRef.current?.focus();
}
}, [customOpen]);
const submitCustom = useCallback(() => {
const answer = custom.trim();
if (!answer) return;
onAnswer(answer);
setCustom("");
setCustomOpen(false);
}, [custom, onAnswer]);
if (options.length === 0) return null;
return (
<div
className={cn(
"mx-auto mb-2 w-full max-w-[49.5rem] rounded-[16px] border border-primary/30",
"bg-card/95 p-3 shadow-sm backdrop-blur",
)}
role="group"
aria-label="Question"
>
<div className="mb-2 flex items-start gap-2">
<div className="mt-0.5 rounded-full bg-primary/10 p-1.5 text-primary">
<MessageSquareText className="h-3.5 w-3.5" aria-hidden />
</div>
<p className="min-w-0 flex-1 text-[13.5px] font-medium leading-5 text-foreground">
{question}
</p>
</div>
<div className="grid gap-1.5 sm:grid-cols-2">
{options.map((option) => (
<Button
key={option}
type="button"
variant="outline"
size="sm"
onClick={() => onAnswer(option)}
className="justify-start rounded-[10px] px-3 text-left"
>
<span className="truncate">{option}</span>
</Button>
))}
<Button
type="button"
variant="ghost"
size="sm"
onClick={() => setCustomOpen((open) => !open)}
className="justify-start rounded-[10px] px-3 text-muted-foreground"
>
Other...
</Button>
</div>
{customOpen ? (
<div className="mt-2 flex gap-2">
<textarea
ref={inputRef}
value={custom}
onChange={(event) => setCustom(event.target.value)}
onKeyDown={(event) => {
if (event.key === "Enter" && !event.shiftKey && !event.nativeEvent.isComposing) {
event.preventDefault();
submitCustom();
}
}}
rows={1}
placeholder="Type your own answer..."
className={cn(
"min-h-9 flex-1 resize-none rounded-[10px] border border-border/70 bg-background",
"px-3 py-2 text-[13.5px] leading-5 outline-none placeholder:text-muted-foreground",
"focus-visible:ring-1 focus-visible:ring-primary/40",
)}
/>
<Button type="button" size="sm" onClick={submitCustom} disabled={!custom.trim()}>
Send
</Button>
</div>
) : null}
</div>
);
}

View File

@ -13,7 +13,6 @@ import {
} from "lucide-react"; } from "lucide-react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { AskUserPrompt } from "@/components/thread/AskUserPrompt";
import { ThreadComposer } from "@/components/thread/ThreadComposer"; import { ThreadComposer } from "@/components/thread/ThreadComposer";
import { ThreadHeader } from "@/components/thread/ThreadHeader"; import { ThreadHeader } from "@/components/thread/ThreadHeader";
import { StreamErrorNotice } from "@/components/thread/StreamErrorNotice"; import { StreamErrorNotice } from "@/components/thread/StreamErrorNotice";
@ -105,21 +104,6 @@ export function ThreadShell({
dismissStreamError, dismissStreamError,
} = useNanobotStream(chatId, initial, hasPendingToolCalls, onTurnEnd); } = useNanobotStream(chatId, initial, hasPendingToolCalls, onTurnEnd);
const showHeroComposer = messages.length === 0 && !loading; const showHeroComposer = messages.length === 0 && !loading;
const pendingAsk = useMemo(() => {
for (let index = messages.length - 1; index >= 0; index -= 1) {
const message = messages[index];
if (message.kind === "trace") continue;
if (message.role === "user") return null;
if (message.role === "assistant" && message.buttons?.some((row) => row.length > 0)) {
return {
question: message.content,
buttons: message.buttons,
};
}
if (message.role === "assistant") return null;
}
return null;
}, [messages]);
useEffect(() => { useEffect(() => {
if (!chatId || loading) return; if (!chatId || loading) return;
@ -247,13 +231,6 @@ export function ThreadShell({
onDismiss={dismissStreamError} onDismiss={dismissStreamError}
/> />
) : null} ) : null}
{pendingAsk ? (
<AskUserPrompt
question={pendingAsk.question}
buttons={pendingAsk.buttons}
onAnswer={send}
/>
) : null}
{session ? ( {session ? (
<ThreadComposer <ThreadComposer
onSend={send} onSend={send}

View File

@ -230,7 +230,7 @@ export function useNanobotStream(
// the full turn (all tool calls + final text) is complete. // the full turn (all tool calls + final text) is complete.
setMessages((prev) => { setMessages((prev) => {
const filtered = activeId ? prev.filter((m) => m.id !== activeId) : prev; const filtered = activeId ? prev.filter((m) => m.id !== activeId) : prev;
const content = ev.buttons?.length ? (ev.button_prompt ?? ev.text) : ev.text; const content = ev.text;
return [ return [
...filtered, ...filtered,
{ {
@ -238,7 +238,6 @@ export function useNanobotStream(
role: "assistant", role: "assistant",
content, content,
createdAt: Date.now(), createdAt: Date.now(),
...(ev.buttons && ev.buttons.length > 0 ? { buttons: ev.buttons } : {}),
...(hasMedia ? { media } : {}), ...(hasMedia ? { media } : {}),
}, },
]; ];

View File

@ -44,8 +44,6 @@ export interface UIMessage {
images?: UIImage[]; images?: UIImage[];
/** Signed or local UI-renderable media attachments. */ /** Signed or local UI-renderable media attachments. */
media?: UIMediaAttachment[]; media?: UIMediaAttachment[];
/** Optional answer choices for a pending ask_user question. */
buttons?: string[][];
} }
export interface ChatSummary { export interface ChatSummary {
@ -141,9 +139,6 @@ export type InboundEvent =
reply_to?: string; reply_to?: string;
media?: string[]; media?: string[];
media_urls?: Array<{ url: string; name?: string }>; media_urls?: Array<{ url: string; name?: string }>;
buttons?: string[][];
/** Original prompt before the websocket text fallback appends buttons. */
button_prompt?: string;
/** Present when the frame is an agent breadcrumb (e.g. tool hint, /** Present when the frame is an agent breadcrumb (e.g. tool hint,
* generic progress line) rather than a conversational reply. */ * generic progress line) rather than a conversational reply. */
kind?: "tool_hint" | "progress"; kind?: "tool_hint" | "progress";

View File

@ -809,46 +809,4 @@ describe("ThreadShell", () => {
await waitFor(() => expect(screen.getByText("from chat b")).toBeInTheDocument()); await waitFor(() => expect(screen.getByText("from chat b")).toBeInTheDocument());
expect(screen.queryByText("from chat a")).not.toBeInTheDocument(); expect(screen.queryByText("from chat a")).not.toBeInTheDocument();
}); });
it("renders ask_user options above the composer and sends selected answers", async () => {
const client = makeClient();
const onNewChat = vi.fn().mockResolvedValue("chat-a");
render(
wrap(
client,
<ThreadShell
session={session("chat-a")}
title="Chat chat-a"
onToggleSidebar={() => {}}
onGoHome={() => {}}
onNewChat={onNewChat}
/>,
),
);
await act(async () => {
client._emitChat("chat-a", {
event: "message",
chat_id: "chat-a",
text: "How should I continue?",
buttons: [["Short answer", "Detailed answer"]],
});
});
expect(screen.getByRole("group", { name: "Question" })).toHaveTextContent(
"How should I continue?",
);
fireEvent.click(screen.getByRole("button", { name: "Short answer" }));
expect(client.sendMessage).toHaveBeenCalledWith(
"chat-a",
"Short answer",
undefined,
);
await waitFor(() => {
expect(screen.queryByRole("group", { name: "Question" })).not.toBeInTheDocument();
});
});
}); });

View File

@ -217,29 +217,6 @@ describe("useNanobotStream", () => {
expect(result.current.messages[0].content).toBe("long task"); expect(result.current.messages[0].content).toBe("long task");
}); });
it("keeps assistant buttons on complete messages", () => {
const fake = fakeClient();
const { result } = renderHook(() => useNanobotStream("chat-q", EMPTY_MESSAGES), {
wrapper: wrap(fake.client),
});
act(() => {
fake.emit("chat-q", {
event: "message",
chat_id: "chat-q",
text: "How should I continue?\n\n1. Short answer\n2. Detailed answer",
button_prompt: "How should I continue?",
buttons: [["Short answer", "Detailed answer"]],
});
});
expect(result.current.messages).toHaveLength(1);
expect(result.current.messages[0].content).toBe("How should I continue?");
expect(result.current.messages[0].buttons).toEqual([
["Short answer", "Detailed answer"],
]);
});
it("keeps streaming alive across stream_end and completes on turn_end", () => { it("keeps streaming alive across stream_end and completes on turn_end", () => {
const fake = fakeClient(); const fake = fakeClient();
const onTurnEnd = vi.fn(); const onTurnEnd = vi.fn();