fix(webui): broadcast runtime model updates

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Xubin Ren 2026-05-12 09:05:24 +00:00 committed by Xubin Ren
parent c92345bbb1
commit bcc4b97183
16 changed files with 152 additions and 51 deletions

View File

@ -406,7 +406,7 @@ class AgentLoop:
self._model_preset_snapshot_builder = model_preset_snapshot_builder self._model_preset_snapshot_builder = model_preset_snapshot_builder
self._active_preset: str | None = None self._active_preset: str | None = None
if model_preset: if model_preset:
self.set_model_preset(model_preset) self.set_model_preset(model_preset, notify=False)
self._register_default_tools() self._register_default_tools()
self._runtime_vars: dict[str, Any] = {} self._runtime_vars: dict[str, Any] = {}
self._current_iteration: int = 0 self._current_iteration: int = 0
@ -473,7 +473,26 @@ class AgentLoop:
"""Keep subagent runtime limits aligned with mutable loop settings.""" """Keep subagent runtime limits aligned with mutable loop settings."""
self.subagents.max_iterations = self.max_iterations self.subagents.max_iterations = self.max_iterations
def _apply_provider_snapshot(self, snapshot: ProviderSnapshot) -> None: def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None:
"""Notify WebUI clients that the effective runtime model changed."""
self.bus.outbound.put_nowait(OutboundMessage(
channel="websocket",
chat_id="*",
content="",
metadata={
"_runtime_model_updated": True,
"model": self.model,
"model_preset": model_preset if model_preset is not None else self.model_preset,
},
))
def _apply_provider_snapshot(
self,
snapshot: ProviderSnapshot,
*,
notify: bool = True,
model_preset: str | None = None,
) -> None:
"""Swap model/provider for future turns without disturbing an active one.""" """Swap model/provider for future turns without disturbing an active one."""
provider = snapshot.provider provider = snapshot.provider
model = snapshot.model model = snapshot.model
@ -487,6 +506,8 @@ class AgentLoop:
self.consolidator.set_provider(provider, model, context_window_tokens) self.consolidator.set_provider(provider, model, context_window_tokens)
self.dream.set_provider(provider, model) self.dream.set_provider(provider, model)
self._provider_signature = snapshot.signature self._provider_signature = snapshot.signature
if notify:
self._publish_runtime_model_updated(model_preset)
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model) logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
def _refresh_provider_snapshot(self) -> None: def _refresh_provider_snapshot(self) -> None:
@ -556,7 +577,7 @@ class AgentLoop:
), ),
) )
def set_model_preset(self, name: str | None) -> None: def set_model_preset(self, name: str | None, *, notify: bool = True) -> None:
"""Resolve a preset by name and apply all runtime model dependents.""" """Resolve a preset by name and apply all runtime model dependents."""
if not isinstance(name, str) or not name.strip(): if not isinstance(name, str) or not name.strip():
raise ValueError("model_preset must be a non-empty string") raise ValueError("model_preset must be a non-empty string")
@ -564,7 +585,7 @@ class AgentLoop:
if name not in self.model_presets: if name not in self.model_presets:
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}") raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
snapshot = self._build_model_preset_snapshot(name) snapshot = self._build_model_preset_snapshot(name)
self._apply_provider_snapshot(snapshot) self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name)
self._active_preset = name self._active_preset = name
def _register_default_tools(self) -> None: def _register_default_tools(self) -> None:

View File

@ -292,6 +292,13 @@ class ChannelManager:
if msg.metadata.get("_retry_wait"): if msg.metadata.get("_retry_wait"):
continue continue
if (
msg.metadata.get("_runtime_model_updated")
and msg.channel == "websocket"
and "websocket" not in self.channels
):
continue
# Coalesce consecutive _stream_delta messages for the same (channel, chat_id) # Coalesce consecutive _stream_delta messages for the same (channel, chat_id)
# to reduce API calls and improve streaming latency # to reduce API calls and improve streaming latency
if msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end"): if msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end"):

View File

@ -156,11 +156,11 @@ def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
def _read_webui_model_name() -> str | None: def _read_webui_model_name() -> str | None:
"""Return the configured default model for readonly webui display.""" """Return the resolved startup model for readonly WebUI display."""
try: try:
from nanobot.config.loader import load_config from nanobot.config.loader import load_config
model = load_config().agents.defaults.model.strip() model = load_config().resolve_preset().model.strip()
return model or None return model or None
except Exception as e: except Exception as e:
logger.debug("webui bootstrap could not load model name: {}", e) logger.debug("webui bootstrap could not load model name: {}", e)
@ -1423,6 +1423,13 @@ class WebSocketChannel(BaseChannel):
raise raise
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
if msg.metadata.get("_runtime_model_updated"):
await self.send_runtime_model_updated(
model_name=msg.metadata.get("model"),
model_preset=msg.metadata.get("model_preset"),
)
return
# Snapshot the subscriber set so ConnectionClosed cleanups mid-iteration are safe. # Snapshot the subscriber set so ConnectionClosed cleanups mid-iteration are safe.
conns = list(self._subs.get(msg.chat_id, ())) conns = list(self._subs.get(msg.chat_id, ()))
if not conns: if not conns:
@ -1471,9 +1478,6 @@ class WebSocketChannel(BaseChannel):
payload["kind"] = "tool_hint" payload["kind"] = "tool_hint"
elif msg.metadata.get("_progress"): elif msg.metadata.get("_progress"):
payload["kind"] = "progress" payload["kind"] = "progress"
webui_model_name = msg.metadata.get("_webui_model_name")
if isinstance(webui_model_name, str) and webui_model_name.strip():
payload["model_name"] = webui_model_name.strip()
raw = json.dumps(payload, ensure_ascii=False) raw = json.dumps(payload, ensure_ascii=False)
for connection in conns: for connection in conns:
await self._safe_send_to(connection, raw, label=" ") await self._safe_send_to(connection, raw, label=" ")
@ -1521,3 +1525,23 @@ class WebSocketChannel(BaseChannel):
raw = json.dumps(body, ensure_ascii=False) raw = json.dumps(body, ensure_ascii=False)
for connection in conns: for connection in conns:
await self._safe_send_to(connection, raw, label=" session_updated ") await self._safe_send_to(connection, raw, label=" session_updated ")
async def send_runtime_model_updated(
self,
*,
model_name: Any,
model_preset: Any = None,
) -> None:
"""Broadcast runtime model changes to all active WebUI clients."""
conns = list(self._conn_chats)
if not conns or not isinstance(model_name, str) or not model_name.strip():
return
body: dict[str, Any] = {
"event": "runtime_model_updated",
"model_name": model_name.strip(),
}
if isinstance(model_preset, str) and model_preset.strip():
body["model_preset"] = model_preset.strip()
raw = json.dumps(body, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" runtime_model_updated ")

View File

@ -225,7 +225,7 @@ async def cmd_model(ctx: CommandContext) -> OutboundMessage:
channel=ctx.msg.channel, channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id, chat_id=ctx.msg.chat_id,
content=_model_command_status(loop), content=_model_command_status(loop),
metadata={**metadata, "_webui_model_name": loop.model}, metadata=metadata,
) )
parts = args.split() parts = args.split()
@ -264,7 +264,7 @@ async def cmd_model(ctx: CommandContext) -> OutboundMessage:
channel=ctx.msg.channel, channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id, chat_id=ctx.msg.chat_id,
content="\n".join(lines), content="\n".join(lines),
metadata={**metadata, "_webui_model_name": loop.model}, metadata=metadata,
) )

View File

@ -64,6 +64,30 @@ def test_model_preset_setter_updates_state(tmp_path) -> None:
assert loop.dream.model == "openai/gpt-4.1" assert loop.dream.model == "openai/gpt-4.1"
def test_model_preset_setter_publishes_runtime_model_event(tmp_path) -> None:
bus = MessageBus()
loop = AgentLoop(
bus=bus,
provider=_provider("base-model", max_tokens=123),
workspace=tmp_path,
model="base-model",
context_window_tokens=1000,
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
)
loop.set_model_preset("fast")
event = bus.outbound.get_nowait()
assert event.channel == "websocket"
assert event.chat_id == "*"
assert event.content == ""
assert event.metadata == {
"_runtime_model_updated": True,
"model": "openai/gpt-4.1",
"model_preset": "fast",
}
def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None: def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
old_provider = _provider("base-model", max_tokens=123) old_provider = _provider("base-model", max_tokens=123)
new_provider = _provider("anthropic/claude-opus-4-5", max_tokens=2048) new_provider = _provider("anthropic/claude-opus-4-5", max_tokens=2048)

View File

@ -230,7 +230,7 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_includes_webui_model_name_metadata() -> None: async def test_send_broadcasts_runtime_model_updates() -> None:
bus = MagicMock() bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock() mock_ws = AsyncMock()
@ -239,14 +239,20 @@ async def test_send_includes_webui_model_name_metadata() -> None:
await channel.send( await channel.send(
OutboundMessage( OutboundMessage(
channel="websocket", channel="websocket",
chat_id="chat-1", chat_id="*",
content="switched", content="",
metadata={"_webui_model_name": "openai/gpt-4.1"}, metadata={
"_runtime_model_updated": True,
"model": "openai/gpt-4.1",
"model_preset": "fast",
},
) )
) )
payload = json.loads(mock_ws.send.call_args[0][0]) payload = json.loads(mock_ws.send.call_args[0][0])
assert payload["event"] == "runtime_model_updated"
assert payload["model_name"] == "openai/gpt-4.1" assert payload["model_name"] == "openai/gpt-4.1"
assert payload["model_preset"] == "fast"
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -64,8 +64,7 @@ async def test_model_command_lists_current_and_available_presets(tmp_path) -> No
assert "Active preset: `(none)`" in out.content assert "Active preset: `(none)`" in out.content
assert "`default`" in out.content assert "`default`" in out.content
assert "`fast`" in out.content assert "`fast`" in out.content
assert out.metadata["render_as"] == "text" assert out.metadata == {"render_as": "text"}
assert out.metadata["_webui_model_name"] == "base-model"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -76,7 +75,6 @@ async def test_model_command_switches_preset(tmp_path) -> None:
assert "Switched model preset to `fast`." in out.content assert "Switched model preset to `fast`." in out.content
assert "Model: `openai/gpt-4.1`" in out.content assert "Model: `openai/gpt-4.1`" in out.content
assert out.metadata["_webui_model_name"] == "openai/gpt-4.1"
assert loop.model_preset == "fast" assert loop.model_preset == "fast"
assert loop.model == "openai/gpt-4.1" assert loop.model == "openai/gpt-4.1"
assert loop.subagents.model == "openai/gpt-4.1" assert loop.subagents.model == "openai/gpt-4.1"
@ -92,7 +90,6 @@ async def test_model_command_switches_back_to_default(tmp_path) -> None:
out = await cmd_model(_ctx(loop, "/model default", args="default")) out = await cmd_model(_ctx(loop, "/model default", args="default"))
assert "Switched model preset to `default`." in out.content assert "Switched model preset to `default`." in out.content
assert out.metadata["_webui_model_name"] == "base-model"
assert loop.model_preset == "default" assert loop.model_preset == "default"
assert loop.model == "base-model" assert loop.model == "base-model"
assert loop.context_window_tokens == 1000 assert loop.context_window_tokens == 1000

View File

@ -355,6 +355,12 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName:
client.sendMessage(chatId, "/restart"); client.sendMessage(chatId, "/restart");
}, [activeSession?.chatId, client]); }, [activeSession?.chatId, client]);
useEffect(() => {
return client.onRuntimeModelUpdate((modelName) => {
onModelNameChange(modelName);
});
}, [client, onModelNameChange]);
useEffect(() => { useEffect(() => {
return client.onStatus((status) => { return client.onStatus((status) => {
let startedAt = 0; let startedAt = 0;
@ -492,7 +498,6 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName:
onNewChat={onNewChat} onNewChat={onNewChat}
onCreateChat={onCreateChat} onCreateChat={onCreateChat}
onTurnEnd={onTurnEnd} onTurnEnd={onTurnEnd}
onModelNameChange={onModelNameChange}
theme={theme} theme={theme}
onToggleTheme={toggle} onToggleTheme={toggle}
hideSidebarToggleOnDesktop={desktopSidebarOpen} hideSidebarToggleOnDesktop={desktopSidebarOpen}

View File

@ -32,7 +32,6 @@ interface ThreadShellProps {
onNewChat?: () => void; onNewChat?: () => void;
onCreateChat?: () => Promise<string | null>; onCreateChat?: () => Promise<string | null>;
onTurnEnd?: () => void; onTurnEnd?: () => void;
onModelNameChange?: (modelName: string | null) => void;
theme?: "light" | "dark"; theme?: "light" | "dark";
onToggleTheme?: () => void; onToggleTheme?: () => void;
hideSidebarToggleOnDesktop?: boolean; hideSidebarToggleOnDesktop?: boolean;
@ -76,7 +75,6 @@ export function ThreadShell({
onToggleSidebar, onToggleSidebar,
onCreateChat, onCreateChat,
onTurnEnd, onTurnEnd,
onModelNameChange,
theme = "light", theme = "light",
onToggleTheme = () => {}, onToggleTheme = () => {},
hideSidebarToggleOnDesktop = false, hideSidebarToggleOnDesktop = false,
@ -105,7 +103,7 @@ export function ThreadShell({
setMessages, setMessages,
streamError, streamError,
dismissStreamError, dismissStreamError,
} = useNanobotStream(chatId, initial, hasPendingToolCalls, onTurnEnd, onModelNameChange); } = useNanobotStream(chatId, initial, hasPendingToolCalls, onTurnEnd);
const showHeroComposer = messages.length === 0 && !loading; const showHeroComposer = messages.length === 0 && !loading;
const pendingAsk = useMemo(() => { const pendingAsk = useMemo(() => {
for (let index = messages.length - 1; index >= 0; index -= 1) { for (let index = messages.length - 1; index >= 0; index -= 1) {

View File

@ -44,7 +44,6 @@ export function useNanobotStream(
initialMessages: UIMessage[] = [], initialMessages: UIMessage[] = [],
hasPendingToolCalls = false, hasPendingToolCalls = false,
onTurnEnd?: () => void, onTurnEnd?: () => void,
onModelNameChange?: (modelName: string | null) => void,
): { ): {
messages: UIMessage[]; messages: UIMessage[];
isStreaming: boolean; isStreaming: boolean;
@ -182,9 +181,6 @@ export function useNanobotStream(
} }
if (ev.event === "message") { if (ev.event === "message") {
if (ev.model_name !== undefined) {
onModelNameChange?.(ev.model_name || null);
}
if ( if (
suppressStreamUntilTurnEndRef.current && suppressStreamUntilTurnEndRef.current &&
(ev.kind === "tool_hint" || ev.kind === "progress") (ev.kind === "tool_hint" || ev.kind === "progress")

View File

@ -14,6 +14,7 @@ const WS_CLOSING = 2;
type Unsubscribe = () => void; type Unsubscribe = () => void;
type EventHandler = (ev: InboundEvent) => void; type EventHandler = (ev: InboundEvent) => void;
type StatusHandler = (status: ConnectionStatus) => void; type StatusHandler = (status: ConnectionStatus) => void;
type RuntimeModelHandler = (modelName: string | null, modelPreset?: string | null) => void;
/** Structured connection-level errors surfaced to the UI. /** Structured connection-level errors surfaced to the UI.
* *
@ -58,6 +59,7 @@ export interface NanobotClientOptions {
export class NanobotClient { export class NanobotClient {
private socket: WebSocket | null = null; private socket: WebSocket | null = null;
private statusHandlers = new Set<StatusHandler>(); private statusHandlers = new Set<StatusHandler>();
private runtimeModelHandlers = new Set<RuntimeModelHandler>();
private errorHandlers = new Set<ErrorHandler>(); private errorHandlers = new Set<ErrorHandler>();
// chat_id -> handlers listening on it // chat_id -> handlers listening on it
private chatHandlers = new Map<string, Set<EventHandler>>(); private chatHandlers = new Map<string, Set<EventHandler>>();
@ -107,6 +109,13 @@ export class NanobotClient {
}; };
} }
onRuntimeModelUpdate(handler: RuntimeModelHandler): Unsubscribe {
this.runtimeModelHandlers.add(handler);
return () => {
this.runtimeModelHandlers.delete(handler);
};
}
/** Subscribe to transport-level faults (see :type:`StreamError`). */ /** Subscribe to transport-level faults (see :type:`StreamError`). */
onError(handler: ErrorHandler): Unsubscribe { onError(handler: ErrorHandler): Unsubscribe {
this.errorHandlers.add(handler); this.errorHandlers.add(handler);
@ -245,10 +254,21 @@ export class NanobotClient {
return; return;
} }
if (parsed.event === "runtime_model_updated") {
this.emitRuntimeModelUpdate(parsed.model_name || null, parsed.model_preset ?? null);
return;
}
const chatId = (parsed as { chat_id?: string }).chat_id; const chatId = (parsed as { chat_id?: string }).chat_id;
if (chatId) this.dispatch(chatId, parsed); if (chatId) this.dispatch(chatId, parsed);
} }
private emitRuntimeModelUpdate(modelName: string | null, modelPreset?: string | null): void {
for (const handler of this.runtimeModelHandlers) {
handler(modelName, modelPreset);
}
}
private dispatch(chatId: string, ev: InboundEvent): void { private dispatch(chatId: string, ev: InboundEvent): void {
const handlers = this.chatHandlers.get(chatId); const handlers = this.chatHandlers.get(chatId);
if (!handlers) return; if (!handlers) return;

View File

@ -147,8 +147,6 @@ export type InboundEvent =
/** Present when the frame is an agent breadcrumb (e.g. tool hint, /** Present when the frame is an agent breadcrumb (e.g. tool hint,
* generic progress line) rather than a conversational reply. */ * generic progress line) rather than a conversational reply. */
kind?: "tool_hint" | "progress"; kind?: "tool_hint" | "progress";
/** Runtime model name after commands like `/model fast` update it. */
model_name?: string | null;
} }
| { | {
event: "delta"; event: "delta";
@ -161,6 +159,11 @@ export type InboundEvent =
chat_id: string; chat_id: string;
stream_id?: string; stream_id?: string;
} }
| {
event: "runtime_model_updated";
model_name: string;
model_preset?: string | null;
}
| { event: "turn_end"; chat_id: string } | { event: "turn_end"; chat_id: string }
| { event: "session_updated"; chat_id: string } | { event: "session_updated"; chat_id: string }
| { event: "error"; chat_id?: string; detail?: string }; | { event: "error"; chat_id?: string; detail?: string };

View File

@ -57,6 +57,7 @@ vi.mock("@/lib/nanobot-client", () => {
defaultChatId: string | null = null; defaultChatId: string | null = null;
connect = connectSpy; connect = connectSpy;
onStatus = () => () => {}; onStatus = () => () => {};
onRuntimeModelUpdate = () => () => {};
onError = () => () => {}; onError = () => () => {};
onChat = () => () => {}; onChat = () => () => {};
sendMessage = vi.fn(); sendMessage = vi.fn();

View File

@ -89,6 +89,26 @@ describe("NanobotClient", () => {
}); });
}); });
it("dispatches runtime model updates globally", () => {
const client = new NanobotClient({
url: "ws://test",
reconnect: false,
socketFactory: (url) => new FakeSocket(url) as unknown as WebSocket,
});
const handler = vi.fn();
client.onRuntimeModelUpdate(handler);
client.connect();
lastSocket().fakeOpen();
lastSocket().fakeMessage({
event: "runtime_model_updated",
model_name: "openai/gpt-4.1",
model_preset: "fast",
});
expect(handler).toHaveBeenCalledWith("openai/gpt-4.1", "fast");
});
it("resolves newChat() via the server-assigned chat_id", async () => { it("resolves newChat() via the server-assigned chat_id", async () => {
const client = new NanobotClient({ const client = new NanobotClient({
url: "ws://test", url: "ws://test",

View File

@ -12,6 +12,7 @@ function makeClient() {
status: "open" as const, status: "open" as const,
defaultChatId: null as string | null, defaultChatId: null as string | null,
onStatus: () => () => {}, onStatus: () => () => {},
onRuntimeModelUpdate: () => () => {},
onChat: (chatId: string, handler: (ev: import("@/lib/types").InboundEvent) => void) => { onChat: (chatId: string, handler: (ev: import("@/lib/types").InboundEvent) => void) => {
let handlers = chatHandlers.get(chatId); let handlers = chatHandlers.get(chatId);
if (!handlers) { if (!handlers) {

View File

@ -134,28 +134,6 @@ describe("useNanobotStream", () => {
]); ]);
}); });
it("reports runtime model name updates from message frames", () => {
const fake = fakeClient();
const onModelNameChange = vi.fn();
renderHook(
() => useNanobotStream("chat-model", EMPTY_MESSAGES, false, undefined, onModelNameChange),
{
wrapper: wrap(fake.client),
},
);
act(() => {
fake.emit("chat-model", {
event: "message",
chat_id: "chat-model",
text: "Switched model preset to `fast`.",
model_name: "openai/gpt-4.1",
});
});
expect(onModelNameChange).toHaveBeenCalledWith("openai/gpt-4.1");
});
it("suppresses redundant stream confirmation after assistant media", () => { it("suppresses redundant stream confirmation after assistant media", () => {
const fake = fakeClient(); const fake = fakeClient();
const { result } = renderHook(() => useNanobotStream("chat-img-result", EMPTY_MESSAGES), { const { result } = renderHook(() => useNanobotStream("chat-img-result", EMPTY_MESSAGES), {