mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
refactor(agent): inject runtime model publisher
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
6554c1f832
commit
13eede5803
@ -295,6 +295,7 @@ class AgentLoop:
|
||||
model_presets: dict[str, ModelPresetConfig] | None = None,
|
||||
model_preset: str | None = None,
|
||||
preset_snapshot_loader: preset_helpers.PresetSnapshotLoader | None = None,
|
||||
runtime_model_publisher: Callable[[str, str | None], None] | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ToolsConfig
|
||||
|
||||
@ -305,6 +306,7 @@ class AgentLoop:
|
||||
self.provider = provider
|
||||
self._provider_snapshot_loader = provider_snapshot_loader
|
||||
self._preset_snapshot_loader = preset_snapshot_loader
|
||||
self._runtime_model_publisher = runtime_model_publisher
|
||||
self._provider_signature = provider_signature
|
||||
self._default_selection_signature = preset_helpers.default_selection_signature(provider_signature)
|
||||
self.workspace = workspace
|
||||
@ -404,7 +406,7 @@ class AgentLoop:
|
||||
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
|
||||
self._active_preset: str | None = None
|
||||
if model_preset:
|
||||
self.set_model_preset(model_preset, notify=False)
|
||||
self.set_model_preset(model_preset, publish_update=False)
|
||||
self._register_default_tools()
|
||||
self._runtime_vars: dict[str, Any] = {}
|
||||
self._current_iteration: int = 0
|
||||
@ -474,7 +476,7 @@ class AgentLoop:
|
||||
self,
|
||||
snapshot: ProviderSnapshot,
|
||||
*,
|
||||
notify: bool = True,
|
||||
publish_update: bool = True,
|
||||
model_preset: str | None = None,
|
||||
) -> None:
|
||||
"""Swap model/provider for future turns without disturbing an active one."""
|
||||
@ -490,12 +492,10 @@ class AgentLoop:
|
||||
self.consolidator.set_provider(provider, model, context_window_tokens)
|
||||
self.dream.set_provider(provider, model)
|
||||
self._provider_signature = snapshot.signature
|
||||
if notify:
|
||||
self.bus.outbound.put_nowait(
|
||||
preset_helpers.runtime_model_updated_message(
|
||||
self.model,
|
||||
model_preset if model_preset is not None else self.model_preset,
|
||||
)
|
||||
if publish_update and self._runtime_model_publisher is not None:
|
||||
self._runtime_model_publisher(
|
||||
self.model,
|
||||
model_preset if model_preset is not None else self.model_preset,
|
||||
)
|
||||
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
||||
|
||||
@ -539,11 +539,11 @@ class AgentLoop:
|
||||
loader=self._preset_snapshot_loader,
|
||||
)
|
||||
|
||||
def set_model_preset(self, name: str | None, *, notify: bool = True) -> None:
|
||||
def set_model_preset(self, name: str | None, *, publish_update: bool = True) -> None:
|
||||
"""Resolve a preset by name and apply all runtime model dependents."""
|
||||
name = preset_helpers.normalize_preset_name(name, self.model_presets)
|
||||
snapshot = self._build_model_preset_snapshot(name)
|
||||
self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name)
|
||||
self._apply_provider_snapshot(snapshot, publish_update=publish_update, model_preset=name)
|
||||
self._active_preset = name
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
|
||||
@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.config.schema import ModelPresetConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.providers.factory import ProviderSnapshot, build_provider_snapshot
|
||||
@ -64,15 +63,3 @@ def normalize_preset_name(name: str | None, presets: dict[str, ModelPresetConfig
|
||||
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(presets) or '(none)'}")
|
||||
return name
|
||||
|
||||
|
||||
def runtime_model_updated_message(model: str, model_preset: str | None) -> OutboundMessage:
|
||||
return OutboundMessage(
|
||||
channel="websocket",
|
||||
chat_id="*",
|
||||
content="",
|
||||
metadata={
|
||||
"_runtime_model_updated": True,
|
||||
"model": model,
|
||||
"model_preset": model_preset,
|
||||
},
|
||||
)
|
||||
|
||||
@ -155,6 +155,24 @@ def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||||
return Response(status, reason, headers, body)
|
||||
|
||||
|
||||
def publish_runtime_model_update(
|
||||
bus: MessageBus,
|
||||
model: str,
|
||||
model_preset: str | None,
|
||||
) -> None:
|
||||
"""Publish a WebUI runtime-model update onto the outbound bus."""
|
||||
bus.outbound.put_nowait(OutboundMessage(
|
||||
channel="websocket",
|
||||
chat_id="*",
|
||||
content="",
|
||||
metadata={
|
||||
"_runtime_model_updated": True,
|
||||
"model": model,
|
||||
"model_preset": model_preset,
|
||||
},
|
||||
))
|
||||
|
||||
|
||||
def _read_webui_model_name() -> str | None:
|
||||
"""Return the resolved startup model for readonly WebUI display."""
|
||||
try:
|
||||
|
||||
@ -633,6 +633,7 @@ def _run_gateway(
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.channels.websocket import publish_runtime_model_update
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronJob
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
@ -672,6 +673,11 @@ def _run_gateway(
|
||||
"aihubmix": config.providers.aihubmix,
|
||||
},
|
||||
provider_snapshot_loader=load_provider_snapshot,
|
||||
runtime_model_publisher=lambda model, preset: publish_runtime_model_update(
|
||||
bus,
|
||||
model,
|
||||
preset,
|
||||
),
|
||||
provider_signature=provider_snapshot.signature,
|
||||
)
|
||||
|
||||
|
||||
@ -64,28 +64,21 @@ def test_model_preset_setter_updates_state(tmp_path) -> None:
|
||||
assert loop.dream.model == "openai/gpt-4.1"
|
||||
|
||||
|
||||
def test_model_preset_setter_publishes_runtime_model_event(tmp_path) -> None:
|
||||
bus = MessageBus()
|
||||
def test_model_preset_setter_calls_runtime_model_publisher(tmp_path) -> None:
|
||||
published: list[tuple[str, str | None]] = []
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
bus=MessageBus(),
|
||||
provider=_provider("base-model", max_tokens=123),
|
||||
workspace=tmp_path,
|
||||
model="base-model",
|
||||
context_window_tokens=1000,
|
||||
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
|
||||
runtime_model_publisher=lambda model, preset: published.append((model, preset)),
|
||||
)
|
||||
|
||||
loop.set_model_preset("fast")
|
||||
|
||||
event = bus.outbound.get_nowait()
|
||||
assert event.channel == "websocket"
|
||||
assert event.chat_id == "*"
|
||||
assert event.content == ""
|
||||
assert event.metadata == {
|
||||
"_runtime_model_updated": True,
|
||||
"model": "openai/gpt-4.1",
|
||||
"model_preset": "fast",
|
||||
}
|
||||
assert published == [("openai/gpt-4.1", "fast")]
|
||||
|
||||
|
||||
def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
|
||||
|
||||
@ -14,6 +14,7 @@ from websockets.exceptions import ConnectionClosed
|
||||
from websockets.frames import Close
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.websocket import (
|
||||
WebSocketChannel,
|
||||
WebSocketConfig,
|
||||
@ -25,6 +26,7 @@ from nanobot.channels.websocket import (
|
||||
_parse_inbound_payload,
|
||||
_parse_query,
|
||||
_parse_request_path,
|
||||
publish_runtime_model_update,
|
||||
)
|
||||
from nanobot.config.loader import load_config, save_config
|
||||
from nanobot.config.schema import Config
|
||||
@ -231,23 +233,13 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_broadcasts_runtime_model_updates() -> None:
|
||||
bus = MagicMock()
|
||||
bus = MessageBus()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="websocket",
|
||||
chat_id="*",
|
||||
content="",
|
||||
metadata={
|
||||
"_runtime_model_updated": True,
|
||||
"model": "openai/gpt-4.1",
|
||||
"model_preset": "fast",
|
||||
},
|
||||
)
|
||||
)
|
||||
publish_runtime_model_update(bus, "openai/gpt-4.1", "fast")
|
||||
await channel.send(bus.outbound.get_nowait())
|
||||
|
||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||
assert payload["event"] == "runtime_model_updated"
|
||||
@ -255,6 +247,27 @@ async def test_send_broadcasts_runtime_model_updates() -> None:
|
||||
assert payload["model_preset"] == "fast"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_model_update_publisher_uses_websocket_outbound_event() -> None:
|
||||
bus = MessageBus()
|
||||
|
||||
publish_runtime_model_update(
|
||||
bus,
|
||||
"openai/gpt-4.1",
|
||||
"fast",
|
||||
)
|
||||
|
||||
event = bus.outbound.get_nowait()
|
||||
assert event.channel == "websocket"
|
||||
assert event.chat_id == "*"
|
||||
assert event.content == ""
|
||||
assert event.metadata == {
|
||||
"_runtime_model_updated": True,
|
||||
"model": "openai/gpt-4.1",
|
||||
"model_preset": "fast",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) -> None:
|
||||
bus = MagicMock()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user