mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-10 03:36:02 +00:00
feat(webui): render ask_user choices
Made-with: Cursor
This commit is contained in:
parent
403ce23d22
commit
a58d9fd357
@ -6,7 +6,7 @@ from typing import Any
|
|||||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||||
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
||||||
|
|
||||||
BUTTON_CHANNELS = frozenset({"telegram"})
|
STRUCTURED_BUTTON_CHANNELS = frozenset({"telegram", "websocket"})
|
||||||
|
|
||||||
|
|
||||||
class AskUserInterrupt(BaseException):
|
class AskUserInterrupt(BaseException):
|
||||||
@ -130,7 +130,7 @@ def ask_user_outbound(
|
|||||||
) -> tuple[str | None, list[list[str]]]:
|
) -> tuple[str | None, list[list[str]]]:
|
||||||
if not options:
|
if not options:
|
||||||
return content, []
|
return content, []
|
||||||
if channel in BUTTON_CHANNELS:
|
if channel in STRUCTURED_BUTTON_CHANNELS:
|
||||||
return content, [options]
|
return content, [options]
|
||||||
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
|
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
|
||||||
return f"{content}\n\n{option_text}" if content else option_text, []
|
return f"{content}\n\n{option_text}" if content else option_text, []
|
||||||
|
|||||||
@ -54,6 +54,14 @@ def _normalize_config_path(path: str) -> str:
|
|||||||
return _strip_trailing_slash(path)
|
return _strip_trailing_slash(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _append_buttons_as_text(text: str, buttons: list[list[str]]) -> str:
|
||||||
|
labels = [label for row in buttons for label in row if label]
|
||||||
|
if not labels:
|
||||||
|
return text
|
||||||
|
fallback = "\n".join(f"{index}. {label}" for index, label in enumerate(labels, 1))
|
||||||
|
return f"{text}\n\n{fallback}" if text else fallback
|
||||||
|
|
||||||
|
|
||||||
class WebSocketConfig(Base):
|
class WebSocketConfig(Base):
|
||||||
"""WebSocket server channel configuration.
|
"""WebSocket server channel configuration.
|
||||||
|
|
||||||
@ -1146,11 +1154,17 @@ class WebSocketChannel(BaseChannel):
|
|||||||
if not conns:
|
if not conns:
|
||||||
logger.warning("websocket: no active subscribers for chat_id={}", msg.chat_id)
|
logger.warning("websocket: no active subscribers for chat_id={}", msg.chat_id)
|
||||||
return
|
return
|
||||||
|
text = msg.content
|
||||||
|
if msg.buttons:
|
||||||
|
text = _append_buttons_as_text(text, msg.buttons)
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"chat_id": msg.chat_id,
|
"chat_id": msg.chat_id,
|
||||||
"text": msg.content,
|
"text": text,
|
||||||
}
|
}
|
||||||
|
if msg.buttons:
|
||||||
|
payload["buttons"] = msg.buttons
|
||||||
|
payload["button_prompt"] = msg.content
|
||||||
if msg.media:
|
if msg.media:
|
||||||
payload["media"] = msg.media
|
payload["media"] = msg.media
|
||||||
urls: list[dict[str, str]] = []
|
urls: list[dict[str, str]] = []
|
||||||
|
|||||||
@ -205,3 +205,37 @@ async def test_ask_user_keeps_buttons_for_telegram(tmp_path):
|
|||||||
assert response is not None
|
assert response is not None
|
||||||
assert response.content == "Install the optional package?"
|
assert response.content == "Install the optional package?"
|
||||||
assert response.buttons == [["Install", "Skip"]]
|
assert response.buttons == [["Install", "Skip"]]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ask_user_keeps_buttons_for_websocket(tmp_path):
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_ask",
|
||||||
|
name="ask_user",
|
||||||
|
arguments={
|
||||||
|
"question": "Install the optional package?",
|
||||||
|
"options": ["Install", "Skip"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=_make_provider(chat_with_retry),
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await loop._process_message(
|
||||||
|
InboundMessage(channel="websocket", sender_id="user", chat_id="123", content="set it up")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert response.content == "Install the optional package?"
|
||||||
|
assert response.buttons == [["Install", "Skip"]]
|
||||||
|
|||||||
@ -178,6 +178,7 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
|||||||
content="hello",
|
content="hello",
|
||||||
reply_to="m1",
|
reply_to="m1",
|
||||||
media=["/tmp/a.png"],
|
media=["/tmp/a.png"],
|
||||||
|
buttons=[["Yes", "No"]],
|
||||||
)
|
)
|
||||||
await channel.send(msg)
|
await channel.send(msg)
|
||||||
|
|
||||||
@ -185,9 +186,11 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
|||||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||||
assert payload["event"] == "message"
|
assert payload["event"] == "message"
|
||||||
assert payload["chat_id"] == "chat-1"
|
assert payload["chat_id"] == "chat-1"
|
||||||
assert payload["text"] == "hello"
|
assert payload["text"] == "hello\n\n1. Yes\n2. No"
|
||||||
|
assert payload["button_prompt"] == "hello"
|
||||||
assert payload["reply_to"] == "m1"
|
assert payload["reply_to"] == "m1"
|
||||||
assert payload["media"] == ["/tmp/a.png"]
|
assert payload["media"] == ["/tmp/a.png"]
|
||||||
|
assert payload["buttons"] == [["Yes", "No"]]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
108
webui/src/components/thread/AskUserPrompt.tsx
Normal file
108
webui/src/components/thread/AskUserPrompt.tsx
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
import { MessageSquareText } from "lucide-react";
|
||||||
|
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
interface AskUserPromptProps {
|
||||||
|
question: string;
|
||||||
|
buttons: string[][];
|
||||||
|
onAnswer: (answer: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function AskUserPrompt({
|
||||||
|
question,
|
||||||
|
buttons,
|
||||||
|
onAnswer,
|
||||||
|
}: AskUserPromptProps) {
|
||||||
|
const [customOpen, setCustomOpen] = useState(false);
|
||||||
|
const [custom, setCustom] = useState("");
|
||||||
|
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
const options = buttons.flat().filter(Boolean);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (customOpen) {
|
||||||
|
inputRef.current?.focus();
|
||||||
|
}
|
||||||
|
}, [customOpen]);
|
||||||
|
|
||||||
|
const submitCustom = useCallback(() => {
|
||||||
|
const answer = custom.trim();
|
||||||
|
if (!answer) return;
|
||||||
|
onAnswer(answer);
|
||||||
|
setCustom("");
|
||||||
|
setCustomOpen(false);
|
||||||
|
}, [custom, onAnswer]);
|
||||||
|
|
||||||
|
if (options.length === 0) return null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"mx-auto mb-2 w-full max-w-[49.5rem] rounded-[16px] border border-primary/30",
|
||||||
|
"bg-card/95 p-3 shadow-sm backdrop-blur",
|
||||||
|
)}
|
||||||
|
role="group"
|
||||||
|
aria-label="Question"
|
||||||
|
>
|
||||||
|
<div className="mb-2 flex items-start gap-2">
|
||||||
|
<div className="mt-0.5 rounded-full bg-primary/10 p-1.5 text-primary">
|
||||||
|
<MessageSquareText className="h-3.5 w-3.5" aria-hidden />
|
||||||
|
</div>
|
||||||
|
<p className="min-w-0 flex-1 text-sm font-medium leading-5 text-foreground">
|
||||||
|
{question}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="grid gap-1.5 sm:grid-cols-2">
|
||||||
|
{options.map((option) => (
|
||||||
|
<Button
|
||||||
|
key={option}
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
onClick={() => onAnswer(option)}
|
||||||
|
className="justify-start rounded-[10px] px-3 text-left"
|
||||||
|
>
|
||||||
|
<span className="truncate">{option}</span>
|
||||||
|
</Button>
|
||||||
|
))}
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={() => setCustomOpen((open) => !open)}
|
||||||
|
className="justify-start rounded-[10px] px-3 text-muted-foreground"
|
||||||
|
>
|
||||||
|
Other...
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{customOpen ? (
|
||||||
|
<div className="mt-2 flex gap-2">
|
||||||
|
<textarea
|
||||||
|
ref={inputRef}
|
||||||
|
value={custom}
|
||||||
|
onChange={(event) => setCustom(event.target.value)}
|
||||||
|
onKeyDown={(event) => {
|
||||||
|
if (event.key === "Enter" && !event.shiftKey && !event.nativeEvent.isComposing) {
|
||||||
|
event.preventDefault();
|
||||||
|
submitCustom();
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
rows={1}
|
||||||
|
placeholder="Type your own answer..."
|
||||||
|
className={cn(
|
||||||
|
"min-h-9 flex-1 resize-none rounded-[10px] border border-border/70 bg-background",
|
||||||
|
"px-3 py-2 text-sm leading-5 outline-none placeholder:text-muted-foreground",
|
||||||
|
"focus-visible:ring-1 focus-visible:ring-primary/40",
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<Button type="button" size="sm" onClick={submitCustom} disabled={!custom.trim()}>
|
||||||
|
Send
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
) : null}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
|
|
||||||
|
import { AskUserPrompt } from "@/components/thread/AskUserPrompt";
|
||||||
import { ThreadComposer } from "@/components/thread/ThreadComposer";
|
import { ThreadComposer } from "@/components/thread/ThreadComposer";
|
||||||
import { ThreadHeader } from "@/components/thread/ThreadHeader";
|
import { ThreadHeader } from "@/components/thread/ThreadHeader";
|
||||||
import { StreamErrorNotice } from "@/components/thread/StreamErrorNotice";
|
import { StreamErrorNotice } from "@/components/thread/StreamErrorNotice";
|
||||||
@ -57,6 +58,21 @@ export function ThreadShell({
|
|||||||
dismissStreamError,
|
dismissStreamError,
|
||||||
} = useNanobotStream(chatId, initial);
|
} = useNanobotStream(chatId, initial);
|
||||||
const showHeroComposer = messages.length === 0 && !loading;
|
const showHeroComposer = messages.length === 0 && !loading;
|
||||||
|
const pendingAsk = useMemo(() => {
|
||||||
|
for (let index = messages.length - 1; index >= 0; index -= 1) {
|
||||||
|
const message = messages[index];
|
||||||
|
if (message.kind === "trace") continue;
|
||||||
|
if (message.role === "user") return null;
|
||||||
|
if (message.role === "assistant" && message.buttons?.some((row) => row.length > 0)) {
|
||||||
|
return {
|
||||||
|
question: message.content,
|
||||||
|
buttons: message.buttons,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (message.role === "assistant") return null;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}, [messages]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!chatId || loading) return;
|
if (!chatId || loading) return;
|
||||||
@ -152,6 +168,13 @@ export function ThreadShell({
|
|||||||
onDismiss={dismissStreamError}
|
onDismiss={dismissStreamError}
|
||||||
/>
|
/>
|
||||||
) : null}
|
) : null}
|
||||||
|
{pendingAsk ? (
|
||||||
|
<AskUserPrompt
|
||||||
|
question={pendingAsk.question}
|
||||||
|
buttons={pendingAsk.buttons}
|
||||||
|
onAnswer={send}
|
||||||
|
/>
|
||||||
|
) : null}
|
||||||
{session ? (
|
{session ? (
|
||||||
<ThreadComposer
|
<ThreadComposer
|
||||||
onSend={send}
|
onSend={send}
|
||||||
|
|||||||
@ -160,13 +160,15 @@ export function useNanobotStream(
|
|||||||
setIsStreaming(false);
|
setIsStreaming(false);
|
||||||
setMessages((prev) => {
|
setMessages((prev) => {
|
||||||
const filtered = activeId ? prev.filter((m) => m.id !== activeId) : prev;
|
const filtered = activeId ? prev.filter((m) => m.id !== activeId) : prev;
|
||||||
|
const content = ev.buttons?.length ? (ev.button_prompt ?? ev.text) : ev.text;
|
||||||
return [
|
return [
|
||||||
...filtered,
|
...filtered,
|
||||||
{
|
{
|
||||||
id: crypto.randomUUID(),
|
id: crypto.randomUUID(),
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: ev.text,
|
content,
|
||||||
createdAt: Date.now(),
|
createdAt: Date.now(),
|
||||||
|
...(ev.buttons && ev.buttons.length > 0 ? { buttons: ev.buttons } : {}),
|
||||||
...(media && media.length > 0 ? { media } : {}),
|
...(media && media.length > 0 ? { media } : {}),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|||||||
@ -44,6 +44,8 @@ export interface UIMessage {
|
|||||||
images?: UIImage[];
|
images?: UIImage[];
|
||||||
/** Signed or local UI-renderable media attachments. */
|
/** Signed or local UI-renderable media attachments. */
|
||||||
media?: UIMediaAttachment[];
|
media?: UIMediaAttachment[];
|
||||||
|
/** Optional answer choices for a pending ask_user question. */
|
||||||
|
buttons?: string[][];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatSummary {
|
export interface ChatSummary {
|
||||||
@ -82,6 +84,9 @@ export type InboundEvent =
|
|||||||
reply_to?: string;
|
reply_to?: string;
|
||||||
media?: string[];
|
media?: string[];
|
||||||
media_urls?: Array<{ url: string; name?: string }>;
|
media_urls?: Array<{ url: string; name?: string }>;
|
||||||
|
buttons?: string[][];
|
||||||
|
/** Original prompt before the websocket text fallback appends buttons. */
|
||||||
|
button_prompt?: string;
|
||||||
/** Present when the frame is an agent breadcrumb (e.g. tool hint,
|
/** Present when the frame is an agent breadcrumb (e.g. tool hint,
|
||||||
* generic progress line) rather than a conversational reply. */
|
* generic progress line) rather than a conversational reply. */
|
||||||
kind?: "tool_hint" | "progress";
|
kind?: "tool_hint" | "progress";
|
||||||
|
|||||||
@ -7,11 +7,22 @@ import { ClientProvider } from "@/providers/ClientProvider";
|
|||||||
|
|
||||||
function makeClient() {
|
function makeClient() {
|
||||||
const errorHandlers = new Set<(err: { kind: string }) => void>();
|
const errorHandlers = new Set<(err: { kind: string }) => void>();
|
||||||
|
const chatHandlers = new Map<string, Set<(ev: import("@/lib/types").InboundEvent) => void>>();
|
||||||
return {
|
return {
|
||||||
status: "open" as const,
|
status: "open" as const,
|
||||||
defaultChatId: null as string | null,
|
defaultChatId: null as string | null,
|
||||||
onStatus: () => () => {},
|
onStatus: () => () => {},
|
||||||
onChat: () => () => {},
|
onChat: (chatId: string, handler: (ev: import("@/lib/types").InboundEvent) => void) => {
|
||||||
|
let handlers = chatHandlers.get(chatId);
|
||||||
|
if (!handlers) {
|
||||||
|
handlers = new Set();
|
||||||
|
chatHandlers.set(chatId, handlers);
|
||||||
|
}
|
||||||
|
handlers.add(handler);
|
||||||
|
return () => {
|
||||||
|
handlers?.delete(handler);
|
||||||
|
};
|
||||||
|
},
|
||||||
onError: (handler: (err: { kind: string }) => void) => {
|
onError: (handler: (err: { kind: string }) => void) => {
|
||||||
errorHandlers.add(handler);
|
errorHandlers.add(handler);
|
||||||
return () => {
|
return () => {
|
||||||
@ -21,6 +32,9 @@ function makeClient() {
|
|||||||
_emitError(err: { kind: string }) {
|
_emitError(err: { kind: string }) {
|
||||||
for (const h of errorHandlers) h(err);
|
for (const h of errorHandlers) h(err);
|
||||||
},
|
},
|
||||||
|
_emitChat(chatId: string, ev: import("@/lib/types").InboundEvent) {
|
||||||
|
for (const h of chatHandlers.get(chatId) ?? []) h(ev);
|
||||||
|
},
|
||||||
sendMessage: vi.fn(),
|
sendMessage: vi.fn(),
|
||||||
newChat: vi.fn(),
|
newChat: vi.fn(),
|
||||||
attach: vi.fn(),
|
attach: vi.fn(),
|
||||||
@ -411,4 +425,46 @@ describe("ThreadShell", () => {
|
|||||||
await waitFor(() => expect(screen.getByText("from chat b")).toBeInTheDocument());
|
await waitFor(() => expect(screen.getByText("from chat b")).toBeInTheDocument());
|
||||||
expect(screen.queryByText("from chat a")).not.toBeInTheDocument();
|
expect(screen.queryByText("from chat a")).not.toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("renders ask_user options above the composer and sends selected answers", async () => {
|
||||||
|
const client = makeClient();
|
||||||
|
const onNewChat = vi.fn().mockResolvedValue("chat-a");
|
||||||
|
|
||||||
|
render(
|
||||||
|
wrap(
|
||||||
|
client,
|
||||||
|
<ThreadShell
|
||||||
|
session={session("chat-a")}
|
||||||
|
title="Chat chat-a"
|
||||||
|
onToggleSidebar={() => {}}
|
||||||
|
onGoHome={() => {}}
|
||||||
|
onNewChat={onNewChat}
|
||||||
|
/>,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
client._emitChat("chat-a", {
|
||||||
|
event: "message",
|
||||||
|
chat_id: "chat-a",
|
||||||
|
text: "How should I continue?",
|
||||||
|
buttons: [["Short answer", "Detailed answer"]],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByRole("group", { name: "Question" })).toHaveTextContent(
|
||||||
|
"How should I continue?",
|
||||||
|
);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByRole("button", { name: "Short answer" }));
|
||||||
|
|
||||||
|
expect(client.sendMessage).toHaveBeenCalledWith(
|
||||||
|
"chat-a",
|
||||||
|
"Short answer",
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.queryByRole("group", { name: "Question" })).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -113,4 +113,27 @@ describe("useNanobotStream", () => {
|
|||||||
{ kind: "video", url: "/api/media/sig/payload", name: "demo.mp4" },
|
{ kind: "video", url: "/api/media/sig/payload", name: "demo.mp4" },
|
||||||
]);
|
]);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("keeps assistant buttons on complete messages", () => {
|
||||||
|
const fake = fakeClient();
|
||||||
|
const { result } = renderHook(() => useNanobotStream("chat-q", []), {
|
||||||
|
wrapper: wrap(fake.client),
|
||||||
|
});
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
fake.emit("chat-q", {
|
||||||
|
event: "message",
|
||||||
|
chat_id: "chat-q",
|
||||||
|
text: "How should I continue?\n\n1. Short answer\n2. Detailed answer",
|
||||||
|
button_prompt: "How should I continue?",
|
||||||
|
buttons: [["Short answer", "Detailed answer"]],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.current.messages).toHaveLength(1);
|
||||||
|
expect(result.current.messages[0].content).toBe("How should I continue?");
|
||||||
|
expect(result.current.messages[0].buttons).toEqual([
|
||||||
|
["Short answer", "Detailed answer"],
|
||||||
|
]);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user