refactor(agent): inject runtime model publisher

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Xubin Ren 2026-05-12 11:51:45 +00:00 committed by Xubin Ren
parent 6554c1f832
commit 13eede5803
6 changed files with 65 additions and 48 deletions

View File

@ -295,6 +295,7 @@ class AgentLoop:
model_presets: dict[str, ModelPresetConfig] | None = None, model_presets: dict[str, ModelPresetConfig] | None = None,
model_preset: str | None = None, model_preset: str | None = None,
preset_snapshot_loader: preset_helpers.PresetSnapshotLoader | None = None, preset_snapshot_loader: preset_helpers.PresetSnapshotLoader | None = None,
runtime_model_publisher: Callable[[str, str | None], None] | None = None,
): ):
from nanobot.config.schema import ToolsConfig from nanobot.config.schema import ToolsConfig
@ -305,6 +306,7 @@ class AgentLoop:
self.provider = provider self.provider = provider
self._provider_snapshot_loader = provider_snapshot_loader self._provider_snapshot_loader = provider_snapshot_loader
self._preset_snapshot_loader = preset_snapshot_loader self._preset_snapshot_loader = preset_snapshot_loader
self._runtime_model_publisher = runtime_model_publisher
self._provider_signature = provider_signature self._provider_signature = provider_signature
self._default_selection_signature = preset_helpers.default_selection_signature(provider_signature) self._default_selection_signature = preset_helpers.default_selection_signature(provider_signature)
self.workspace = workspace self.workspace = workspace
@ -404,7 +406,7 @@ class AgentLoop:
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {} self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
self._active_preset: str | None = None self._active_preset: str | None = None
if model_preset: if model_preset:
self.set_model_preset(model_preset, notify=False) self.set_model_preset(model_preset, publish_update=False)
self._register_default_tools() self._register_default_tools()
self._runtime_vars: dict[str, Any] = {} self._runtime_vars: dict[str, Any] = {}
self._current_iteration: int = 0 self._current_iteration: int = 0
@ -474,7 +476,7 @@ class AgentLoop:
self, self,
snapshot: ProviderSnapshot, snapshot: ProviderSnapshot,
*, *,
notify: bool = True, publish_update: bool = True,
model_preset: str | None = None, model_preset: str | None = None,
) -> None: ) -> None:
"""Swap model/provider for future turns without disturbing an active one.""" """Swap model/provider for future turns without disturbing an active one."""
@ -490,13 +492,11 @@ class AgentLoop:
self.consolidator.set_provider(provider, model, context_window_tokens) self.consolidator.set_provider(provider, model, context_window_tokens)
self.dream.set_provider(provider, model) self.dream.set_provider(provider, model)
self._provider_signature = snapshot.signature self._provider_signature = snapshot.signature
if notify: if publish_update and self._runtime_model_publisher is not None:
self.bus.outbound.put_nowait( self._runtime_model_publisher(
preset_helpers.runtime_model_updated_message(
self.model, self.model,
model_preset if model_preset is not None else self.model_preset, model_preset if model_preset is not None else self.model_preset,
) )
)
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model) logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
def _refresh_provider_snapshot(self) -> None: def _refresh_provider_snapshot(self) -> None:
@ -539,11 +539,11 @@ class AgentLoop:
loader=self._preset_snapshot_loader, loader=self._preset_snapshot_loader,
) )
def set_model_preset(self, name: str | None, *, notify: bool = True) -> None: def set_model_preset(self, name: str | None, *, publish_update: bool = True) -> None:
"""Resolve a preset by name and apply all runtime model dependents.""" """Resolve a preset by name and apply all runtime model dependents."""
name = preset_helpers.normalize_preset_name(name, self.model_presets) name = preset_helpers.normalize_preset_name(name, self.model_presets)
snapshot = self._build_model_preset_snapshot(name) snapshot = self._build_model_preset_snapshot(name)
self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name) self._apply_provider_snapshot(snapshot, publish_update=publish_update, model_preset=name)
self._active_preset = name self._active_preset = name
def _register_default_tools(self) -> None: def _register_default_tools(self) -> None:

View File

@ -5,7 +5,6 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
from nanobot.bus.events import OutboundMessage
from nanobot.config.schema import ModelPresetConfig from nanobot.config.schema import ModelPresetConfig
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.providers.factory import ProviderSnapshot, build_provider_snapshot from nanobot.providers.factory import ProviderSnapshot, build_provider_snapshot
@ -64,15 +63,3 @@ def normalize_preset_name(name: str | None, presets: dict[str, ModelPresetConfig
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(presets) or '(none)'}") raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(presets) or '(none)'}")
return name return name
def runtime_model_updated_message(model: str, model_preset: str | None) -> OutboundMessage:
return OutboundMessage(
channel="websocket",
chat_id="*",
content="",
metadata={
"_runtime_model_updated": True,
"model": model,
"model_preset": model_preset,
},
)

View File

@ -155,6 +155,24 @@ def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
return Response(status, reason, headers, body) return Response(status, reason, headers, body)
def publish_runtime_model_update(
bus: MessageBus,
model: str,
model_preset: str | None,
) -> None:
"""Publish a WebUI runtime-model update onto the outbound bus."""
bus.outbound.put_nowait(OutboundMessage(
channel="websocket",
chat_id="*",
content="",
metadata={
"_runtime_model_updated": True,
"model": model,
"model_preset": model_preset,
},
))
def _read_webui_model_name() -> str | None: def _read_webui_model_name() -> str | None:
"""Return the resolved startup model for readonly WebUI display.""" """Return the resolved startup model for readonly WebUI display."""
try: try:

View File

@ -633,6 +633,7 @@ def _run_gateway(
from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.message import MessageTool
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.manager import ChannelManager from nanobot.channels.manager import ChannelManager
from nanobot.channels.websocket import publish_runtime_model_update
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService from nanobot.heartbeat.service import HeartbeatService
@ -672,6 +673,11 @@ def _run_gateway(
"aihubmix": config.providers.aihubmix, "aihubmix": config.providers.aihubmix,
}, },
provider_snapshot_loader=load_provider_snapshot, provider_snapshot_loader=load_provider_snapshot,
runtime_model_publisher=lambda model, preset: publish_runtime_model_update(
bus,
model,
preset,
),
provider_signature=provider_snapshot.signature, provider_signature=provider_snapshot.signature,
) )

View File

@ -64,28 +64,21 @@ def test_model_preset_setter_updates_state(tmp_path) -> None:
assert loop.dream.model == "openai/gpt-4.1" assert loop.dream.model == "openai/gpt-4.1"
def test_model_preset_setter_publishes_runtime_model_event(tmp_path) -> None: def test_model_preset_setter_calls_runtime_model_publisher(tmp_path) -> None:
bus = MessageBus() published: list[tuple[str, str | None]] = []
loop = AgentLoop( loop = AgentLoop(
bus=bus, bus=MessageBus(),
provider=_provider("base-model", max_tokens=123), provider=_provider("base-model", max_tokens=123),
workspace=tmp_path, workspace=tmp_path,
model="base-model", model="base-model",
context_window_tokens=1000, context_window_tokens=1000,
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")}, model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
runtime_model_publisher=lambda model, preset: published.append((model, preset)),
) )
loop.set_model_preset("fast") loop.set_model_preset("fast")
event = bus.outbound.get_nowait() assert published == [("openai/gpt-4.1", "fast")]
assert event.channel == "websocket"
assert event.chat_id == "*"
assert event.content == ""
assert event.metadata == {
"_runtime_model_updated": True,
"model": "openai/gpt-4.1",
"model_preset": "fast",
}
def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None: def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:

View File

@ -14,6 +14,7 @@ from websockets.exceptions import ConnectionClosed
from websockets.frames import Close from websockets.frames import Close
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.websocket import ( from nanobot.channels.websocket import (
WebSocketChannel, WebSocketChannel,
WebSocketConfig, WebSocketConfig,
@ -25,6 +26,7 @@ from nanobot.channels.websocket import (
_parse_inbound_payload, _parse_inbound_payload,
_parse_query, _parse_query,
_parse_request_path, _parse_request_path,
publish_runtime_model_update,
) )
from nanobot.config.loader import load_config, save_config from nanobot.config.loader import load_config, save_config
from nanobot.config.schema import Config from nanobot.config.schema import Config
@ -231,23 +233,13 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_broadcasts_runtime_model_updates() -> None: async def test_send_broadcasts_runtime_model_updates() -> None:
bus = MagicMock() bus = MessageBus()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock() mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1") channel._attach(mock_ws, "chat-1")
await channel.send( publish_runtime_model_update(bus, "openai/gpt-4.1", "fast")
OutboundMessage( await channel.send(bus.outbound.get_nowait())
channel="websocket",
chat_id="*",
content="",
metadata={
"_runtime_model_updated": True,
"model": "openai/gpt-4.1",
"model_preset": "fast",
},
)
)
payload = json.loads(mock_ws.send.call_args[0][0]) payload = json.loads(mock_ws.send.call_args[0][0])
assert payload["event"] == "runtime_model_updated" assert payload["event"] == "runtime_model_updated"
@ -255,6 +247,27 @@ async def test_send_broadcasts_runtime_model_updates() -> None:
assert payload["model_preset"] == "fast" assert payload["model_preset"] == "fast"
@pytest.mark.asyncio
async def test_runtime_model_update_publisher_uses_websocket_outbound_event() -> None:
bus = MessageBus()
publish_runtime_model_update(
bus,
"openai/gpt-4.1",
"fast",
)
event = bus.outbound.get_nowait()
assert event.channel == "websocket"
assert event.chat_id == "*"
assert event.content == ""
assert event.metadata == {
"_runtime_model_updated": True,
"model": "openai/gpt-4.1",
"model_preset": "fast",
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) -> None: async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) -> None:
bus = MagicMock() bus = MagicMock()