mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 00:22:31 +00:00
refactor(agent): inject runtime model publisher
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
6554c1f832
commit
13eede5803
@ -295,6 +295,7 @@ class AgentLoop:
|
|||||||
model_presets: dict[str, ModelPresetConfig] | None = None,
|
model_presets: dict[str, ModelPresetConfig] | None = None,
|
||||||
model_preset: str | None = None,
|
model_preset: str | None = None,
|
||||||
preset_snapshot_loader: preset_helpers.PresetSnapshotLoader | None = None,
|
preset_snapshot_loader: preset_helpers.PresetSnapshotLoader | None = None,
|
||||||
|
runtime_model_publisher: Callable[[str, str | None], None] | None = None,
|
||||||
):
|
):
|
||||||
from nanobot.config.schema import ToolsConfig
|
from nanobot.config.schema import ToolsConfig
|
||||||
|
|
||||||
@ -305,6 +306,7 @@ class AgentLoop:
|
|||||||
self.provider = provider
|
self.provider = provider
|
||||||
self._provider_snapshot_loader = provider_snapshot_loader
|
self._provider_snapshot_loader = provider_snapshot_loader
|
||||||
self._preset_snapshot_loader = preset_snapshot_loader
|
self._preset_snapshot_loader = preset_snapshot_loader
|
||||||
|
self._runtime_model_publisher = runtime_model_publisher
|
||||||
self._provider_signature = provider_signature
|
self._provider_signature = provider_signature
|
||||||
self._default_selection_signature = preset_helpers.default_selection_signature(provider_signature)
|
self._default_selection_signature = preset_helpers.default_selection_signature(provider_signature)
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
@ -404,7 +406,7 @@ class AgentLoop:
|
|||||||
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
|
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
|
||||||
self._active_preset: str | None = None
|
self._active_preset: str | None = None
|
||||||
if model_preset:
|
if model_preset:
|
||||||
self.set_model_preset(model_preset, notify=False)
|
self.set_model_preset(model_preset, publish_update=False)
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
self._runtime_vars: dict[str, Any] = {}
|
self._runtime_vars: dict[str, Any] = {}
|
||||||
self._current_iteration: int = 0
|
self._current_iteration: int = 0
|
||||||
@ -474,7 +476,7 @@ class AgentLoop:
|
|||||||
self,
|
self,
|
||||||
snapshot: ProviderSnapshot,
|
snapshot: ProviderSnapshot,
|
||||||
*,
|
*,
|
||||||
notify: bool = True,
|
publish_update: bool = True,
|
||||||
model_preset: str | None = None,
|
model_preset: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Swap model/provider for future turns without disturbing an active one."""
|
"""Swap model/provider for future turns without disturbing an active one."""
|
||||||
@ -490,13 +492,11 @@ class AgentLoop:
|
|||||||
self.consolidator.set_provider(provider, model, context_window_tokens)
|
self.consolidator.set_provider(provider, model, context_window_tokens)
|
||||||
self.dream.set_provider(provider, model)
|
self.dream.set_provider(provider, model)
|
||||||
self._provider_signature = snapshot.signature
|
self._provider_signature = snapshot.signature
|
||||||
if notify:
|
if publish_update and self._runtime_model_publisher is not None:
|
||||||
self.bus.outbound.put_nowait(
|
self._runtime_model_publisher(
|
||||||
preset_helpers.runtime_model_updated_message(
|
|
||||||
self.model,
|
self.model,
|
||||||
model_preset if model_preset is not None else self.model_preset,
|
model_preset if model_preset is not None else self.model_preset,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
||||||
|
|
||||||
def _refresh_provider_snapshot(self) -> None:
|
def _refresh_provider_snapshot(self) -> None:
|
||||||
@ -539,11 +539,11 @@ class AgentLoop:
|
|||||||
loader=self._preset_snapshot_loader,
|
loader=self._preset_snapshot_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_model_preset(self, name: str | None, *, notify: bool = True) -> None:
|
def set_model_preset(self, name: str | None, *, publish_update: bool = True) -> None:
|
||||||
"""Resolve a preset by name and apply all runtime model dependents."""
|
"""Resolve a preset by name and apply all runtime model dependents."""
|
||||||
name = preset_helpers.normalize_preset_name(name, self.model_presets)
|
name = preset_helpers.normalize_preset_name(name, self.model_presets)
|
||||||
snapshot = self._build_model_preset_snapshot(name)
|
snapshot = self._build_model_preset_snapshot(name)
|
||||||
self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name)
|
self._apply_provider_snapshot(snapshot, publish_update=publish_update, model_preset=name)
|
||||||
self._active_preset = name
|
self._active_preset = name
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
|
|||||||
@ -5,7 +5,6 @@ from __future__ import annotations
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
|
||||||
from nanobot.config.schema import ModelPresetConfig
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.providers.factory import ProviderSnapshot, build_provider_snapshot
|
from nanobot.providers.factory import ProviderSnapshot, build_provider_snapshot
|
||||||
@ -64,15 +63,3 @@ def normalize_preset_name(name: str | None, presets: dict[str, ModelPresetConfig
|
|||||||
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(presets) or '(none)'}")
|
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(presets) or '(none)'}")
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def runtime_model_updated_message(model: str, model_preset: str | None) -> OutboundMessage:
|
|
||||||
return OutboundMessage(
|
|
||||||
channel="websocket",
|
|
||||||
chat_id="*",
|
|
||||||
content="",
|
|
||||||
metadata={
|
|
||||||
"_runtime_model_updated": True,
|
|
||||||
"model": model,
|
|
||||||
"model_preset": model_preset,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|||||||
@ -155,6 +155,24 @@ def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
|||||||
return Response(status, reason, headers, body)
|
return Response(status, reason, headers, body)
|
||||||
|
|
||||||
|
|
||||||
|
def publish_runtime_model_update(
|
||||||
|
bus: MessageBus,
|
||||||
|
model: str,
|
||||||
|
model_preset: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Publish a WebUI runtime-model update onto the outbound bus."""
|
||||||
|
bus.outbound.put_nowait(OutboundMessage(
|
||||||
|
channel="websocket",
|
||||||
|
chat_id="*",
|
||||||
|
content="",
|
||||||
|
metadata={
|
||||||
|
"_runtime_model_updated": True,
|
||||||
|
"model": model,
|
||||||
|
"model_preset": model_preset,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
def _read_webui_model_name() -> str | None:
|
def _read_webui_model_name() -> str | None:
|
||||||
"""Return the resolved startup model for readonly WebUI display."""
|
"""Return the resolved startup model for readonly WebUI display."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -633,6 +633,7 @@ def _run_gateway(
|
|||||||
from nanobot.agent.tools.message import MessageTool
|
from nanobot.agent.tools.message import MessageTool
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.manager import ChannelManager
|
from nanobot.channels.manager import ChannelManager
|
||||||
|
from nanobot.channels.websocket import publish_runtime_model_update
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
from nanobot.cron.types import CronJob
|
from nanobot.cron.types import CronJob
|
||||||
from nanobot.heartbeat.service import HeartbeatService
|
from nanobot.heartbeat.service import HeartbeatService
|
||||||
@ -672,6 +673,11 @@ def _run_gateway(
|
|||||||
"aihubmix": config.providers.aihubmix,
|
"aihubmix": config.providers.aihubmix,
|
||||||
},
|
},
|
||||||
provider_snapshot_loader=load_provider_snapshot,
|
provider_snapshot_loader=load_provider_snapshot,
|
||||||
|
runtime_model_publisher=lambda model, preset: publish_runtime_model_update(
|
||||||
|
bus,
|
||||||
|
model,
|
||||||
|
preset,
|
||||||
|
),
|
||||||
provider_signature=provider_snapshot.signature,
|
provider_signature=provider_snapshot.signature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -64,28 +64,21 @@ def test_model_preset_setter_updates_state(tmp_path) -> None:
|
|||||||
assert loop.dream.model == "openai/gpt-4.1"
|
assert loop.dream.model == "openai/gpt-4.1"
|
||||||
|
|
||||||
|
|
||||||
def test_model_preset_setter_publishes_runtime_model_event(tmp_path) -> None:
|
def test_model_preset_setter_calls_runtime_model_publisher(tmp_path) -> None:
|
||||||
bus = MessageBus()
|
published: list[tuple[str, str | None]] = []
|
||||||
loop = AgentLoop(
|
loop = AgentLoop(
|
||||||
bus=bus,
|
bus=MessageBus(),
|
||||||
provider=_provider("base-model", max_tokens=123),
|
provider=_provider("base-model", max_tokens=123),
|
||||||
workspace=tmp_path,
|
workspace=tmp_path,
|
||||||
model="base-model",
|
model="base-model",
|
||||||
context_window_tokens=1000,
|
context_window_tokens=1000,
|
||||||
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
|
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
|
||||||
|
runtime_model_publisher=lambda model, preset: published.append((model, preset)),
|
||||||
)
|
)
|
||||||
|
|
||||||
loop.set_model_preset("fast")
|
loop.set_model_preset("fast")
|
||||||
|
|
||||||
event = bus.outbound.get_nowait()
|
assert published == [("openai/gpt-4.1", "fast")]
|
||||||
assert event.channel == "websocket"
|
|
||||||
assert event.chat_id == "*"
|
|
||||||
assert event.content == ""
|
|
||||||
assert event.metadata == {
|
|
||||||
"_runtime_model_updated": True,
|
|
||||||
"model": "openai/gpt-4.1",
|
|
||||||
"model_preset": "fast",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
|
def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from websockets.exceptions import ConnectionClosed
|
|||||||
from websockets.frames import Close
|
from websockets.frames import Close
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.websocket import (
|
from nanobot.channels.websocket import (
|
||||||
WebSocketChannel,
|
WebSocketChannel,
|
||||||
WebSocketConfig,
|
WebSocketConfig,
|
||||||
@ -25,6 +26,7 @@ from nanobot.channels.websocket import (
|
|||||||
_parse_inbound_payload,
|
_parse_inbound_payload,
|
||||||
_parse_query,
|
_parse_query,
|
||||||
_parse_request_path,
|
_parse_request_path,
|
||||||
|
publish_runtime_model_update,
|
||||||
)
|
)
|
||||||
from nanobot.config.loader import load_config, save_config
|
from nanobot.config.loader import load_config, save_config
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
@ -231,23 +233,13 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_broadcasts_runtime_model_updates() -> None:
|
async def test_send_broadcasts_runtime_model_updates() -> None:
|
||||||
bus = MagicMock()
|
bus = MessageBus()
|
||||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
channel._attach(mock_ws, "chat-1")
|
channel._attach(mock_ws, "chat-1")
|
||||||
|
|
||||||
await channel.send(
|
publish_runtime_model_update(bus, "openai/gpt-4.1", "fast")
|
||||||
OutboundMessage(
|
await channel.send(bus.outbound.get_nowait())
|
||||||
channel="websocket",
|
|
||||||
chat_id="*",
|
|
||||||
content="",
|
|
||||||
metadata={
|
|
||||||
"_runtime_model_updated": True,
|
|
||||||
"model": "openai/gpt-4.1",
|
|
||||||
"model_preset": "fast",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||||
assert payload["event"] == "runtime_model_updated"
|
assert payload["event"] == "runtime_model_updated"
|
||||||
@ -255,6 +247,27 @@ async def test_send_broadcasts_runtime_model_updates() -> None:
|
|||||||
assert payload["model_preset"] == "fast"
|
assert payload["model_preset"] == "fast"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_model_update_publisher_uses_websocket_outbound_event() -> None:
|
||||||
|
bus = MessageBus()
|
||||||
|
|
||||||
|
publish_runtime_model_update(
|
||||||
|
bus,
|
||||||
|
"openai/gpt-4.1",
|
||||||
|
"fast",
|
||||||
|
)
|
||||||
|
|
||||||
|
event = bus.outbound.get_nowait()
|
||||||
|
assert event.channel == "websocket"
|
||||||
|
assert event.chat_id == "*"
|
||||||
|
assert event.content == ""
|
||||||
|
assert event.metadata == {
|
||||||
|
"_runtime_model_updated": True,
|
||||||
|
"model": "openai/gpt-4.1",
|
||||||
|
"model_preset": "fast",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) -> None:
|
async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) -> None:
|
||||||
bus = MagicMock()
|
bus = MagicMock()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user