fix(webui): broadcast runtime model updates

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Xubin Ren 2026-05-12 09:05:24 +00:00 committed by Xubin Ren
parent c92345bbb1
commit bcc4b97183
16 changed files with 152 additions and 51 deletions

View File

@ -406,7 +406,7 @@ class AgentLoop:
self._model_preset_snapshot_builder = model_preset_snapshot_builder
self._active_preset: str | None = None
if model_preset:
self.set_model_preset(model_preset)
self.set_model_preset(model_preset, notify=False)
self._register_default_tools()
self._runtime_vars: dict[str, Any] = {}
self._current_iteration: int = 0
@ -473,7 +473,26 @@ class AgentLoop:
"""Keep subagent runtime limits aligned with mutable loop settings."""
self.subagents.max_iterations = self.max_iterations
def _apply_provider_snapshot(self, snapshot: ProviderSnapshot) -> None:
def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None:
"""Notify WebUI clients that the effective runtime model changed."""
self.bus.outbound.put_nowait(OutboundMessage(
channel="websocket",
chat_id="*",
content="",
metadata={
"_runtime_model_updated": True,
"model": self.model,
"model_preset": model_preset if model_preset is not None else self.model_preset,
},
))
def _apply_provider_snapshot(
self,
snapshot: ProviderSnapshot,
*,
notify: bool = True,
model_preset: str | None = None,
) -> None:
"""Swap model/provider for future turns without disturbing an active one."""
provider = snapshot.provider
model = snapshot.model
@ -487,6 +506,8 @@ class AgentLoop:
self.consolidator.set_provider(provider, model, context_window_tokens)
self.dream.set_provider(provider, model)
self._provider_signature = snapshot.signature
if notify:
self._publish_runtime_model_updated(model_preset)
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
def _refresh_provider_snapshot(self) -> None:
@ -556,7 +577,7 @@ class AgentLoop:
),
)
def set_model_preset(self, name: str | None) -> None:
def set_model_preset(self, name: str | None, *, notify: bool = True) -> None:
"""Resolve a preset by name and apply all runtime model dependents."""
if not isinstance(name, str) or not name.strip():
raise ValueError("model_preset must be a non-empty string")
@ -564,7 +585,7 @@ class AgentLoop:
if name not in self.model_presets:
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
snapshot = self._build_model_preset_snapshot(name)
self._apply_provider_snapshot(snapshot)
self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name)
self._active_preset = name
def _register_default_tools(self) -> None:

View File

@ -292,6 +292,13 @@ class ChannelManager:
if msg.metadata.get("_retry_wait"):
continue
if (
msg.metadata.get("_runtime_model_updated")
and msg.channel == "websocket"
and "websocket" not in self.channels
):
continue
# Coalesce consecutive _stream_delta messages for the same (channel, chat_id)
# to reduce API calls and improve streaming latency
if msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end"):

View File

@ -156,11 +156,11 @@ def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
def _read_webui_model_name() -> str | None:
"""Return the configured default model for readonly webui display."""
"""Return the resolved startup model for readonly WebUI display."""
try:
from nanobot.config.loader import load_config
model = load_config().agents.defaults.model.strip()
model = load_config().resolve_preset().model.strip()
return model or None
except Exception as e:
logger.debug("webui bootstrap could not load model name: {}", e)
@ -1423,6 +1423,13 @@ class WebSocketChannel(BaseChannel):
raise
async def send(self, msg: OutboundMessage) -> None:
if msg.metadata.get("_runtime_model_updated"):
await self.send_runtime_model_updated(
model_name=msg.metadata.get("model"),
model_preset=msg.metadata.get("model_preset"),
)
return
# Snapshot the subscriber set so ConnectionClosed cleanups mid-iteration are safe.
conns = list(self._subs.get(msg.chat_id, ()))
if not conns:
@ -1471,9 +1478,6 @@ class WebSocketChannel(BaseChannel):
payload["kind"] = "tool_hint"
elif msg.metadata.get("_progress"):
payload["kind"] = "progress"
webui_model_name = msg.metadata.get("_webui_model_name")
if isinstance(webui_model_name, str) and webui_model_name.strip():
payload["model_name"] = webui_model_name.strip()
raw = json.dumps(payload, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" ")
@ -1521,3 +1525,23 @@ class WebSocketChannel(BaseChannel):
raw = json.dumps(body, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" session_updated ")
async def send_runtime_model_updated(
self,
*,
model_name: Any,
model_preset: Any = None,
) -> None:
"""Broadcast runtime model changes to all active WebUI clients."""
conns = list(self._conn_chats)
if not conns or not isinstance(model_name, str) or not model_name.strip():
return
body: dict[str, Any] = {
"event": "runtime_model_updated",
"model_name": model_name.strip(),
}
if isinstance(model_preset, str) and model_preset.strip():
body["model_preset"] = model_preset.strip()
raw = json.dumps(body, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" runtime_model_updated ")

View File

@ -225,7 +225,7 @@ async def cmd_model(ctx: CommandContext) -> OutboundMessage:
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content=_model_command_status(loop),
metadata={**metadata, "_webui_model_name": loop.model},
metadata=metadata,
)
parts = args.split()
@ -264,7 +264,7 @@ async def cmd_model(ctx: CommandContext) -> OutboundMessage:
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content="\n".join(lines),
metadata={**metadata, "_webui_model_name": loop.model},
metadata=metadata,
)

View File

@ -64,6 +64,30 @@ def test_model_preset_setter_updates_state(tmp_path) -> None:
assert loop.dream.model == "openai/gpt-4.1"
def test_model_preset_setter_publishes_runtime_model_event(tmp_path) -> None:
bus = MessageBus()
loop = AgentLoop(
bus=bus,
provider=_provider("base-model", max_tokens=123),
workspace=tmp_path,
model="base-model",
context_window_tokens=1000,
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
)
loop.set_model_preset("fast")
event = bus.outbound.get_nowait()
assert event.channel == "websocket"
assert event.chat_id == "*"
assert event.content == ""
assert event.metadata == {
"_runtime_model_updated": True,
"model": "openai/gpt-4.1",
"model_preset": "fast",
}
def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
old_provider = _provider("base-model", max_tokens=123)
new_provider = _provider("anthropic/claude-opus-4-5", max_tokens=2048)

View File

@ -230,7 +230,7 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
@pytest.mark.asyncio
async def test_send_includes_webui_model_name_metadata() -> None:
async def test_send_broadcasts_runtime_model_updates() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock()
@ -239,14 +239,20 @@ async def test_send_includes_webui_model_name_metadata() -> None:
await channel.send(
OutboundMessage(
channel="websocket",
chat_id="chat-1",
content="switched",
metadata={"_webui_model_name": "openai/gpt-4.1"},
chat_id="*",
content="",
metadata={
"_runtime_model_updated": True,
"model": "openai/gpt-4.1",
"model_preset": "fast",
},
)
)
payload = json.loads(mock_ws.send.call_args[0][0])
assert payload["event"] == "runtime_model_updated"
assert payload["model_name"] == "openai/gpt-4.1"
assert payload["model_preset"] == "fast"
@pytest.mark.asyncio

View File

@ -64,8 +64,7 @@ async def test_model_command_lists_current_and_available_presets(tmp_path) -> No
assert "Active preset: `(none)`" in out.content
assert "`default`" in out.content
assert "`fast`" in out.content
assert out.metadata["render_as"] == "text"
assert out.metadata["_webui_model_name"] == "base-model"
assert out.metadata == {"render_as": "text"}
@pytest.mark.asyncio
@ -76,7 +75,6 @@ async def test_model_command_switches_preset(tmp_path) -> None:
assert "Switched model preset to `fast`." in out.content
assert "Model: `openai/gpt-4.1`" in out.content
assert out.metadata["_webui_model_name"] == "openai/gpt-4.1"
assert loop.model_preset == "fast"
assert loop.model == "openai/gpt-4.1"
assert loop.subagents.model == "openai/gpt-4.1"
@ -92,7 +90,6 @@ async def test_model_command_switches_back_to_default(tmp_path) -> None:
out = await cmd_model(_ctx(loop, "/model default", args="default"))
assert "Switched model preset to `default`." in out.content
assert out.metadata["_webui_model_name"] == "base-model"
assert loop.model_preset == "default"
assert loop.model == "base-model"
assert loop.context_window_tokens == 1000

View File

@ -355,6 +355,12 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName:
client.sendMessage(chatId, "/restart");
}, [activeSession?.chatId, client]);
useEffect(() => {
return client.onRuntimeModelUpdate((modelName) => {
onModelNameChange(modelName);
});
}, [client, onModelNameChange]);
useEffect(() => {
return client.onStatus((status) => {
let startedAt = 0;
@ -492,7 +498,6 @@ function Shell({ onModelNameChange, onLogout }: { onModelNameChange: (modelName:
onNewChat={onNewChat}
onCreateChat={onCreateChat}
onTurnEnd={onTurnEnd}
onModelNameChange={onModelNameChange}
theme={theme}
onToggleTheme={toggle}
hideSidebarToggleOnDesktop={desktopSidebarOpen}

View File

@ -32,7 +32,6 @@ interface ThreadShellProps {
onNewChat?: () => void;
onCreateChat?: () => Promise<string | null>;
onTurnEnd?: () => void;
onModelNameChange?: (modelName: string | null) => void;
theme?: "light" | "dark";
onToggleTheme?: () => void;
hideSidebarToggleOnDesktop?: boolean;
@ -76,7 +75,6 @@ export function ThreadShell({
onToggleSidebar,
onCreateChat,
onTurnEnd,
onModelNameChange,
theme = "light",
onToggleTheme = () => {},
hideSidebarToggleOnDesktop = false,
@ -105,7 +103,7 @@ export function ThreadShell({
setMessages,
streamError,
dismissStreamError,
} = useNanobotStream(chatId, initial, hasPendingToolCalls, onTurnEnd, onModelNameChange);
} = useNanobotStream(chatId, initial, hasPendingToolCalls, onTurnEnd);
const showHeroComposer = messages.length === 0 && !loading;
const pendingAsk = useMemo(() => {
for (let index = messages.length - 1; index >= 0; index -= 1) {

View File

@ -44,7 +44,6 @@ export function useNanobotStream(
initialMessages: UIMessage[] = [],
hasPendingToolCalls = false,
onTurnEnd?: () => void,
onModelNameChange?: (modelName: string | null) => void,
): {
messages: UIMessage[];
isStreaming: boolean;
@ -182,9 +181,6 @@ export function useNanobotStream(
}
if (ev.event === "message") {
if (ev.model_name !== undefined) {
onModelNameChange?.(ev.model_name || null);
}
if (
suppressStreamUntilTurnEndRef.current &&
(ev.kind === "tool_hint" || ev.kind === "progress")

View File

@ -14,6 +14,7 @@ const WS_CLOSING = 2;
type Unsubscribe = () => void;
type EventHandler = (ev: InboundEvent) => void;
type StatusHandler = (status: ConnectionStatus) => void;
type RuntimeModelHandler = (modelName: string | null, modelPreset?: string | null) => void;
/** Structured connection-level errors surfaced to the UI.
*
@ -58,6 +59,7 @@ export interface NanobotClientOptions {
export class NanobotClient {
private socket: WebSocket | null = null;
private statusHandlers = new Set<StatusHandler>();
private runtimeModelHandlers = new Set<RuntimeModelHandler>();
private errorHandlers = new Set<ErrorHandler>();
// chat_id -> handlers listening on it
private chatHandlers = new Map<string, Set<EventHandler>>();
@ -107,6 +109,13 @@ export class NanobotClient {
};
}
onRuntimeModelUpdate(handler: RuntimeModelHandler): Unsubscribe {
this.runtimeModelHandlers.add(handler);
return () => {
this.runtimeModelHandlers.delete(handler);
};
}
/** Subscribe to transport-level faults (see :type:`StreamError`). */
onError(handler: ErrorHandler): Unsubscribe {
this.errorHandlers.add(handler);
@ -245,10 +254,21 @@ export class NanobotClient {
return;
}
if (parsed.event === "runtime_model_updated") {
this.emitRuntimeModelUpdate(parsed.model_name || null, parsed.model_preset ?? null);
return;
}
const chatId = (parsed as { chat_id?: string }).chat_id;
if (chatId) this.dispatch(chatId, parsed);
}
private emitRuntimeModelUpdate(modelName: string | null, modelPreset?: string | null): void {
for (const handler of this.runtimeModelHandlers) {
handler(modelName, modelPreset);
}
}
private dispatch(chatId: string, ev: InboundEvent): void {
const handlers = this.chatHandlers.get(chatId);
if (!handlers) return;

View File

@ -147,8 +147,6 @@ export type InboundEvent =
/** Present when the frame is an agent breadcrumb (e.g. tool hint,
* generic progress line) rather than a conversational reply. */
kind?: "tool_hint" | "progress";
/** Runtime model name after commands like `/model fast` update it. */
model_name?: string | null;
}
| {
event: "delta";
@ -161,6 +159,11 @@ export type InboundEvent =
chat_id: string;
stream_id?: string;
}
| {
event: "runtime_model_updated";
model_name: string;
model_preset?: string | null;
}
| { event: "turn_end"; chat_id: string }
| { event: "session_updated"; chat_id: string }
| { event: "error"; chat_id?: string; detail?: string };

View File

@ -57,6 +57,7 @@ vi.mock("@/lib/nanobot-client", () => {
defaultChatId: string | null = null;
connect = connectSpy;
onStatus = () => () => {};
onRuntimeModelUpdate = () => () => {};
onError = () => () => {};
onChat = () => () => {};
sendMessage = vi.fn();

View File

@ -89,6 +89,26 @@ describe("NanobotClient", () => {
});
});
it("dispatches runtime model updates globally", () => {
const client = new NanobotClient({
url: "ws://test",
reconnect: false,
socketFactory: (url) => new FakeSocket(url) as unknown as WebSocket,
});
const handler = vi.fn();
client.onRuntimeModelUpdate(handler);
client.connect();
lastSocket().fakeOpen();
lastSocket().fakeMessage({
event: "runtime_model_updated",
model_name: "openai/gpt-4.1",
model_preset: "fast",
});
expect(handler).toHaveBeenCalledWith("openai/gpt-4.1", "fast");
});
it("resolves newChat() via the server-assigned chat_id", async () => {
const client = new NanobotClient({
url: "ws://test",

View File

@ -12,6 +12,7 @@ function makeClient() {
status: "open" as const,
defaultChatId: null as string | null,
onStatus: () => () => {},
onRuntimeModelUpdate: () => () => {},
onChat: (chatId: string, handler: (ev: import("@/lib/types").InboundEvent) => void) => {
let handlers = chatHandlers.get(chatId);
if (!handlers) {

View File

@ -134,28 +134,6 @@ describe("useNanobotStream", () => {
]);
});
it("reports runtime model name updates from message frames", () => {
const fake = fakeClient();
const onModelNameChange = vi.fn();
renderHook(
() => useNanobotStream("chat-model", EMPTY_MESSAGES, false, undefined, onModelNameChange),
{
wrapper: wrap(fake.client),
},
);
act(() => {
fake.emit("chat-model", {
event: "message",
chat_id: "chat-model",
text: "Switched model preset to `fast`.",
model_name: "openai/gpt-4.1",
});
});
expect(onModelNameChange).toHaveBeenCalledWith("openai/gpt-4.1");
});
it("suppresses redundant stream confirmation after assistant media", () => {
const fake = fakeClient();
const { result } = renderHook(() => useNanobotStream("chat-img-result", EMPTY_MESSAGES), {