mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
fix(webui): broadcast runtime model updates
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
c92345bbb1
commit
bcc4b97183
@ -406,7 +406,7 @@ class AgentLoop:
|
||||
self._model_preset_snapshot_builder = model_preset_snapshot_builder
|
||||
self._active_preset: str | None = None
|
||||
if model_preset:
|
||||
self.set_model_preset(model_preset)
|
||||
self.set_model_preset(model_preset, notify=False)
|
||||
self._register_default_tools()
|
||||
self._runtime_vars: dict[str, Any] = {}
|
||||
self._current_iteration: int = 0
|
||||
@ -473,7 +473,26 @@ class AgentLoop:
|
||||
"""Keep subagent runtime limits aligned with mutable loop settings."""
|
||||
self.subagents.max_iterations = self.max_iterations
|
||||
|
||||
def _apply_provider_snapshot(self, snapshot: ProviderSnapshot) -> None:
|
||||
def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None:
|
||||
"""Notify WebUI clients that the effective runtime model changed."""
|
||||
self.bus.outbound.put_nowait(OutboundMessage(
|
||||
channel="websocket",
|
||||
chat_id="*",
|
||||
content="",
|
||||
metadata={
|
||||
"_runtime_model_updated": True,
|
||||
"model": self.model,
|
||||
"model_preset": model_preset if model_preset is not None else self.model_preset,
|
||||
},
|
||||
))
|
||||
|
||||
def _apply_provider_snapshot(
|
||||
self,
|
||||
snapshot: ProviderSnapshot,
|
||||
*,
|
||||
notify: bool = True,
|
||||
model_preset: str | None = None,
|
||||
) -> None:
|
||||
"""Swap model/provider for future turns without disturbing an active one."""
|
||||
provider = snapshot.provider
|
||||
model = snapshot.model
|
||||
@ -487,6 +506,8 @@ class AgentLoop:
|
||||
self.consolidator.set_provider(provider, model, context_window_tokens)
|
||||
self.dream.set_provider(provider, model)
|
||||
self._provider_signature = snapshot.signature
|
||||
if notify:
|
||||
self._publish_runtime_model_updated(model_preset)
|
||||
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
||||
|
||||
def _refresh_provider_snapshot(self) -> None:
|
||||
@ -556,7 +577,7 @@ class AgentLoop:
|
||||
),
|
||||
)
|
||||
|
||||
def set_model_preset(self, name: str | None) -> None:
|
||||
def set_model_preset(self, name: str | None, *, notify: bool = True) -> None:
|
||||
"""Resolve a preset by name and apply all runtime model dependents."""
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ValueError("model_preset must be a non-empty string")
|
||||
@ -564,7 +585,7 @@ class AgentLoop:
|
||||
if name not in self.model_presets:
|
||||
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
|
||||
snapshot = self._build_model_preset_snapshot(name)
|
||||
self._apply_provider_snapshot(snapshot)
|
||||
self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name)
|
||||
self._active_preset = name
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
|
||||
@ -292,6 +292,13 @@ class ChannelManager:
|
||||
if msg.metadata.get("_retry_wait"):
|
||||
continue
|
||||
|
||||
if (
|
||||
msg.metadata.get("_runtime_model_updated")
|
||||
and msg.channel == "websocket"
|
||||
and "websocket" not in self.channels
|
||||
):
|
||||
continue
|
||||
|
||||
# Coalesce consecutive _stream_delta messages for the same (channel, chat_id)
|
||||
# to reduce API calls and improve streaming latency
|
||||
if msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end"):
|
||||
|
||||
@ -156,11 +156,11 @@ def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||||
|
||||
|
||||
def _read_webui_model_name() -> str | None:
|
||||
"""Return the configured default model for readonly webui display."""
|
||||
"""Return the resolved startup model for readonly WebUI display."""
|
||||
try:
|
||||
from nanobot.config.loader import load_config
|
||||
|
||||
model = load_config().agents.defaults.model.strip()
|
||||
model = load_config().resolve_preset().model.strip()
|
||||
return model or None
|
||||
except Exception as e:
|
||||
logger.debug("webui bootstrap could not load model name: {}", e)
|
||||
@ -1423,6 +1423,13 @@ class WebSocketChannel(BaseChannel):
|
||||
raise
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
if msg.metadata.get("_runtime_model_updated"):
|
||||
await self.send_runtime_model_updated(
|
||||
model_name=msg.metadata.get("model"),
|
||||
model_preset=msg.metadata.get("model_preset"),
|
||||
)
|
||||
return
|
||||
|
||||
# Snapshot the subscriber set so ConnectionClosed cleanups mid-iteration are safe.
|
||||
conns = list(self._subs.get(msg.chat_id, ()))
|
||||
if not conns:
|
||||
@ -1471,9 +1478,6 @@ class WebSocketChannel(BaseChannel):
|
||||
payload["kind"] = "tool_hint"
|
||||
elif msg.metadata.get("_progress"):
|
||||
payload["kind"] = "progress"
|
||||
webui_model_name = msg.metadata.get("_webui_model_name")
|
||||
if isinstance(webui_model_name, str) and webui_model_name.strip():
|
||||
payload["model_name"] = webui_model_name.strip()
|
||||
raw = json.dumps(payload, ensure_ascii=False)
|
||||
for connection in conns:
|
||||
await self._safe_send_to(connection, raw, label=" ")
|
||||
@ -1521,3 +1525,23 @@ class WebSocketChannel(BaseChannel):
|
||||
raw = json.dumps(body, ensure_ascii=False)
|
||||
for connection in conns:
|
||||
await self._safe_send_to(connection, raw, label=" session_updated ")
|
||||
|
||||
async def send_runtime_model_updated(
|
||||
self,
|
||||
*,
|
||||
model_name: Any,
|
||||
model_preset: Any = None,
|
||||
) -> None:
|
||||
"""Broadcast runtime model changes to all active WebUI clients."""
|
||||
conns = list(self._conn_chats)
|
||||
if not conns or not isinstance(model_name, str) or not model_name.strip():
|
||||
return
|
||||
body: dict[str, Any] = {
|
||||
"event": "runtime_model_updated",
|
||||
"model_name": model_name.strip(),
|
||||
}
|
||||
if isinstance(model_preset, str) and model_preset.strip():
|
||||
body["model_preset"] = model_preset.strip()
|
||||
raw = json.dumps(body, ensure_ascii=False)
|
||||
for connection in conns:
|
||||
await self._safe_send_to(connection, raw, label=" runtime_model_updated ")
|
||||
|
||||
@ -225,7 +225,7 @@ async def cmd_model(ctx: CommandContext) -> OutboundMessage:
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content=_model_command_status(loop),
|
||||
metadata={**metadata, "_webui_model_name": loop.model},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
parts = args.split()
|
||||
@ -264,7 +264,7 @@ async def cmd_model(ctx: CommandContext) -> OutboundMessage:
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content="\n".join(lines),
|
||||
metadata={**metadata, "_webui_model_name": loop.model},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -64,6 +64,30 @@ def test_model_preset_setter_updates_state(tmp_path) -> None:
|
||||
assert loop.dream.model == "openai/gpt-4.1"
|
||||
|
||||
|
||||
def test_model_preset_setter_publishes_runtime_model_event(tmp_path) -> None:
|
||||
bus = MessageBus()
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=_provider("base-model", max_tokens=123),
|
||||
workspace=tmp_path,
|
||||
model="base-model",
|
||||
context_window_tokens=1000,
|
||||
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
|
||||
)
|
||||
|
||||
loop.set_model_preset("fast")
|
||||
|
||||
event = bus.outbound.get_nowait()
|
||||
assert event.channel == "websocket"
|
||||
assert event.chat_id == "*"
|
||||
assert event.content == ""
|
||||
assert event.metadata == {
|
||||
"_runtime_model_updated": True,
|
||||
"model": "openai/gpt-4.1",
|
||||
"model_preset": "fast",
|
||||
}
|
||||
|
||||
|
||||
def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
|
||||
old_provider = _provider("base-model", max_tokens=123)
|
||||
new_provider = _provider("anthropic/claude-opus-4-5", max_tokens=2048)
|
||||
|
||||
@ -230,7 +230,7 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_includes_webui_model_name_metadata() -> None:
|
||||
async def test_send_broadcasts_runtime_model_updates() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
@ -239,14 +239,20 @@ async def test_send_includes_webui_model_name_metadata() -> None:
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="websocket",
|
||||
chat_id="chat-1",
|
||||
content="switched",
|
||||
metadata={"_webui_model_name": "openai/gpt-4.1"},
|
||||
chat_id="*",
|
||||
content="",
|
||||
metadata={
|
||||
"_runtime_model_updated": True,
|
||||
"model": "openai/gpt-4.1",
|
||||
"model_preset": "fast",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||
assert payload["event"] == "runtime_model_updated"
|
||||
assert payload["model_name"] == "openai/gpt-4.1"
|
||||
assert payload["model_preset"] == "fast"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -64,8 +64,7 @@ async def test_model_command_lists_current_and_available_presets(tmp_path) -> No
|
||||
assert "Active preset: `(none)`" in out.content
|
||||
assert "`default`" in out.content
|
||||
assert "`fast`" in out.content
|
||||
assert out.metadata["render_as"] == "text"
|
||||
assert out.metadata["_webui_model_name"] == "base-model"
|
||||
assert out.metadata == {"render_as": "text"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -76,7 +75,6 @@ async def test_model_command_switches_preset(tmp_path) -> None:
|
||||
|
||||
assert "Switched model preset to `fast`." in out.content
|
||||
assert "Model: `openai/gpt-4.1`" in out.content
|
||||
assert out.metadata["_webui_model_name"] == "openai/gpt-4.1"
|
||||
assert loop.model_preset == "fast"
|
||||
assert loop.model == "openai/gpt-4.1"
|
||||
assert loop.subagents.model == "openai/gpt-4.1"
|
||||
@ -92,7 +90,6 @@ async def test_model_command_switches_back_to_default(tmp_path) -> None:
|
||||
out = await cmd_model(_ctx(loop, "/model default", args="default"))
|
||||
|
||||
assert "Switched model preset to `default`." in out.content
|
||||
assert out.metadata["_webui_model_name"] == "base-model"
|
||||
assert loop.model_preset == "default"
|
||||
assert loop.model == "base-model"
|
||||
assert loop.context_window_tokens == 1000
|
||||
|
||||
@ -355,6 +355,12 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName:
|
||||
client.sendMessage(chatId, "/restart");
|
||||
}, [activeSession?.chatId, client]);
|
||||
|
||||
useEffect(() => {
|
||||
return client.onRuntimeModelUpdate((modelName) => {
|
||||
onModelNameChange(modelName);
|
||||
});
|
||||
}, [client, onModelNameChange]);
|
||||
|
||||
useEffect(() => {
|
||||
return client.onStatus((status) => {
|
||||
let startedAt = 0;
|
||||
@ -492,7 +498,6 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName:
|
||||
onNewChat={onNewChat}
|
||||
onCreateChat={onCreateChat}
|
||||
onTurnEnd={onTurnEnd}
|
||||
onModelNameChange={onModelNameChange}
|
||||
theme={theme}
|
||||
onToggleTheme={toggle}
|
||||
hideSidebarToggleOnDesktop={desktopSidebarOpen}
|
||||
|
||||
@ -32,7 +32,6 @@ interface ThreadShellProps {
|
||||
onNewChat?: () => void;
|
||||
onCreateChat?: () => Promise<string | null>;
|
||||
onTurnEnd?: () => void;
|
||||
onModelNameChange?: (modelName: string | null) => void;
|
||||
theme?: "light" | "dark";
|
||||
onToggleTheme?: () => void;
|
||||
hideSidebarToggleOnDesktop?: boolean;
|
||||
@ -76,7 +75,6 @@ export function ThreadShell({
|
||||
onToggleSidebar,
|
||||
onCreateChat,
|
||||
onTurnEnd,
|
||||
onModelNameChange,
|
||||
theme = "light",
|
||||
onToggleTheme = () => {},
|
||||
hideSidebarToggleOnDesktop = false,
|
||||
@ -105,7 +103,7 @@ export function ThreadShell({
|
||||
setMessages,
|
||||
streamError,
|
||||
dismissStreamError,
|
||||
} = useNanobotStream(chatId, initial, hasPendingToolCalls, onTurnEnd, onModelNameChange);
|
||||
} = useNanobotStream(chatId, initial, hasPendingToolCalls, onTurnEnd);
|
||||
const showHeroComposer = messages.length === 0 && !loading;
|
||||
const pendingAsk = useMemo(() => {
|
||||
for (let index = messages.length - 1; index >= 0; index -= 1) {
|
||||
|
||||
@ -44,7 +44,6 @@ export function useNanobotStream(
|
||||
initialMessages: UIMessage[] = [],
|
||||
hasPendingToolCalls = false,
|
||||
onTurnEnd?: () => void,
|
||||
onModelNameChange?: (modelName: string | null) => void,
|
||||
): {
|
||||
messages: UIMessage[];
|
||||
isStreaming: boolean;
|
||||
@ -182,9 +181,6 @@ export function useNanobotStream(
|
||||
}
|
||||
|
||||
if (ev.event === "message") {
|
||||
if (ev.model_name !== undefined) {
|
||||
onModelNameChange?.(ev.model_name || null);
|
||||
}
|
||||
if (
|
||||
suppressStreamUntilTurnEndRef.current &&
|
||||
(ev.kind === "tool_hint" || ev.kind === "progress")
|
||||
|
||||
@ -14,6 +14,7 @@ const WS_CLOSING = 2;
|
||||
type Unsubscribe = () => void;
|
||||
type EventHandler = (ev: InboundEvent) => void;
|
||||
type StatusHandler = (status: ConnectionStatus) => void;
|
||||
type RuntimeModelHandler = (modelName: string | null, modelPreset?: string | null) => void;
|
||||
|
||||
/** Structured connection-level errors surfaced to the UI.
|
||||
*
|
||||
@ -58,6 +59,7 @@ export interface NanobotClientOptions {
|
||||
export class NanobotClient {
|
||||
private socket: WebSocket | null = null;
|
||||
private statusHandlers = new Set<StatusHandler>();
|
||||
private runtimeModelHandlers = new Set<RuntimeModelHandler>();
|
||||
private errorHandlers = new Set<ErrorHandler>();
|
||||
// chat_id -> handlers listening on it
|
||||
private chatHandlers = new Map<string, Set<EventHandler>>();
|
||||
@ -107,6 +109,13 @@ export class NanobotClient {
|
||||
};
|
||||
}
|
||||
|
||||
onRuntimeModelUpdate(handler: RuntimeModelHandler): Unsubscribe {
|
||||
this.runtimeModelHandlers.add(handler);
|
||||
return () => {
|
||||
this.runtimeModelHandlers.delete(handler);
|
||||
};
|
||||
}
|
||||
|
||||
/** Subscribe to transport-level faults (see :type:`StreamError`). */
|
||||
onError(handler: ErrorHandler): Unsubscribe {
|
||||
this.errorHandlers.add(handler);
|
||||
@ -245,10 +254,21 @@ export class NanobotClient {
|
||||
return;
|
||||
}
|
||||
|
||||
if (parsed.event === "runtime_model_updated") {
|
||||
this.emitRuntimeModelUpdate(parsed.model_name || null, parsed.model_preset ?? null);
|
||||
return;
|
||||
}
|
||||
|
||||
const chatId = (parsed as { chat_id?: string }).chat_id;
|
||||
if (chatId) this.dispatch(chatId, parsed);
|
||||
}
|
||||
|
||||
private emitRuntimeModelUpdate(modelName: string | null, modelPreset?: string | null): void {
|
||||
for (const handler of this.runtimeModelHandlers) {
|
||||
handler(modelName, modelPreset);
|
||||
}
|
||||
}
|
||||
|
||||
private dispatch(chatId: string, ev: InboundEvent): void {
|
||||
const handlers = this.chatHandlers.get(chatId);
|
||||
if (!handlers) return;
|
||||
|
||||
@ -147,8 +147,6 @@ export type InboundEvent =
|
||||
/** Present when the frame is an agent breadcrumb (e.g. tool hint,
|
||||
* generic progress line) rather than a conversational reply. */
|
||||
kind?: "tool_hint" | "progress";
|
||||
/** Runtime model name after commands like `/model fast` update it. */
|
||||
model_name?: string | null;
|
||||
}
|
||||
| {
|
||||
event: "delta";
|
||||
@ -161,6 +159,11 @@ export type InboundEvent =
|
||||
chat_id: string;
|
||||
stream_id?: string;
|
||||
}
|
||||
| {
|
||||
event: "runtime_model_updated";
|
||||
model_name: string;
|
||||
model_preset?: string | null;
|
||||
}
|
||||
| { event: "turn_end"; chat_id: string }
|
||||
| { event: "session_updated"; chat_id: string }
|
||||
| { event: "error"; chat_id?: string; detail?: string };
|
||||
|
||||
@ -57,6 +57,7 @@ vi.mock("@/lib/nanobot-client", () => {
|
||||
defaultChatId: string | null = null;
|
||||
connect = connectSpy;
|
||||
onStatus = () => () => {};
|
||||
onRuntimeModelUpdate = () => () => {};
|
||||
onError = () => () => {};
|
||||
onChat = () => () => {};
|
||||
sendMessage = vi.fn();
|
||||
|
||||
@ -89,6 +89,26 @@ describe("NanobotClient", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("dispatches runtime model updates globally", () => {
|
||||
const client = new NanobotClient({
|
||||
url: "ws://test",
|
||||
reconnect: false,
|
||||
socketFactory: (url) => new FakeSocket(url) as unknown as WebSocket,
|
||||
});
|
||||
const handler = vi.fn();
|
||||
client.onRuntimeModelUpdate(handler);
|
||||
client.connect();
|
||||
lastSocket().fakeOpen();
|
||||
|
||||
lastSocket().fakeMessage({
|
||||
event: "runtime_model_updated",
|
||||
model_name: "openai/gpt-4.1",
|
||||
model_preset: "fast",
|
||||
});
|
||||
|
||||
expect(handler).toHaveBeenCalledWith("openai/gpt-4.1", "fast");
|
||||
});
|
||||
|
||||
it("resolves newChat() via the server-assigned chat_id", async () => {
|
||||
const client = new NanobotClient({
|
||||
url: "ws://test",
|
||||
|
||||
@ -12,6 +12,7 @@ function makeClient() {
|
||||
status: "open" as const,
|
||||
defaultChatId: null as string | null,
|
||||
onStatus: () => () => {},
|
||||
onRuntimeModelUpdate: () => () => {},
|
||||
onChat: (chatId: string, handler: (ev: import("@/lib/types").InboundEvent) => void) => {
|
||||
let handlers = chatHandlers.get(chatId);
|
||||
if (!handlers) {
|
||||
|
||||
@ -134,28 +134,6 @@ describe("useNanobotStream", () => {
|
||||
]);
|
||||
});
|
||||
|
||||
it("reports runtime model name updates from message frames", () => {
|
||||
const fake = fakeClient();
|
||||
const onModelNameChange = vi.fn();
|
||||
renderHook(
|
||||
() => useNanobotStream("chat-model", EMPTY_MESSAGES, false, undefined, onModelNameChange),
|
||||
{
|
||||
wrapper: wrap(fake.client),
|
||||
},
|
||||
);
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-model", {
|
||||
event: "message",
|
||||
chat_id: "chat-model",
|
||||
text: "Switched model preset to `fast`.",
|
||||
model_name: "openai/gpt-4.1",
|
||||
});
|
||||
});
|
||||
|
||||
expect(onModelNameChange).toHaveBeenCalledWith("openai/gpt-4.1");
|
||||
});
|
||||
|
||||
it("suppresses redundant stream confirmation after assistant media", () => {
|
||||
const fake = fakeClient();
|
||||
const { result } = renderHook(() => useNanobotStream("chat-img-result", EMPTY_MESSAGES), {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user