mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 08:32:25 +00:00
fix(agent): separate preset snapshots from config reload
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
8fcb24bb7c
commit
e6103d9312
@ -293,6 +293,7 @@ class AgentLoop:
|
||||
provider_signature: tuple[object, ...] | None = None,
|
||||
model_presets: dict[str, ModelPresetConfig] | None = None,
|
||||
model_preset: str | None = None,
|
||||
preset_snapshot_loader: Callable[[str], ProviderSnapshot] | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ToolsConfig
|
||||
|
||||
@ -302,6 +303,7 @@ class AgentLoop:
|
||||
self.channels_config = channels_config
|
||||
self.provider = provider
|
||||
self._provider_snapshot_loader = provider_snapshot_loader
|
||||
self._preset_snapshot_loader = preset_snapshot_loader
|
||||
self._provider_signature = provider_signature
|
||||
self._default_selection_signature = provider_signature[:2] if provider_signature else None
|
||||
self.workspace = workspace
|
||||
@ -431,7 +433,16 @@ class AgentLoop:
|
||||
model = extra.pop("model", None) or resolved.model
|
||||
context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens
|
||||
provider_snapshot_loader = extra.pop("provider_snapshot_loader", None)
|
||||
preset_snapshot_loader = extra.pop("preset_snapshot_loader", None)
|
||||
model_presets = {**config.model_presets, "default": config.resolve_default_preset()}
|
||||
if preset_snapshot_loader is None:
|
||||
if provider_snapshot_loader is not None:
|
||||
preset_snapshot_loader = lambda name: provider_snapshot_loader(preset_name=name)
|
||||
else:
|
||||
preset_snapshot_loader = lambda name: build_provider_snapshot(
|
||||
config,
|
||||
preset_name=name,
|
||||
)
|
||||
return cls(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
@ -455,9 +466,8 @@ class AgentLoop:
|
||||
tools_config=config.tools,
|
||||
model_presets=model_presets,
|
||||
model_preset=defaults.model_preset,
|
||||
provider_snapshot_loader=provider_snapshot_loader or (
|
||||
lambda preset_name=None: build_provider_snapshot(config, preset_name=preset_name)
|
||||
),
|
||||
provider_snapshot_loader=provider_snapshot_loader,
|
||||
preset_snapshot_loader=preset_snapshot_loader,
|
||||
**extra,
|
||||
)
|
||||
|
||||
@ -468,8 +478,14 @@ class AgentLoop:
|
||||
def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None:
|
||||
"""Notify WebUI clients that the effective runtime model changed."""
|
||||
self.bus.outbound.put_nowait(OutboundMessage(
|
||||
channel="websocket", chat_id="*", content="",
|
||||
metadata={"_runtime_model_updated": True, "model": self.model, "model_preset": model_preset if model_preset is not None else self.model_preset},
|
||||
channel="websocket",
|
||||
chat_id="*",
|
||||
content="",
|
||||
metadata={
|
||||
"_runtime_model_updated": True,
|
||||
"model": self.model,
|
||||
"model_preset": model_preset if model_preset is not None else self.model_preset,
|
||||
},
|
||||
))
|
||||
|
||||
def _apply_provider_snapshot(
|
||||
@ -530,8 +546,8 @@ class AgentLoop:
|
||||
|
||||
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
|
||||
preset = self.model_presets[name]
|
||||
if self._provider_snapshot_loader is not None:
|
||||
return self._provider_snapshot_loader(preset_name=name)
|
||||
if self._preset_snapshot_loader is not None:
|
||||
return self._preset_snapshot_loader(name)
|
||||
self.provider.generation = preset.to_generation_settings()
|
||||
return ProviderSnapshot(
|
||||
provider=self.provider,
|
||||
|
||||
@ -213,6 +213,10 @@ def _active_model_preset_name(loop) -> str:
|
||||
return loop.model_preset or "default"
|
||||
|
||||
|
||||
def _command_error_message(exc: Exception) -> str:
|
||||
return str(exc.args[0]) if isinstance(exc, KeyError) and exc.args else str(exc)
|
||||
|
||||
|
||||
def _model_command_status(loop) -> str:
|
||||
names = _model_preset_names(loop)
|
||||
active = _active_model_preset_name(loop)
|
||||
@ -256,7 +260,7 @@ async def cmd_model(ctx: CommandContext) -> OutboundMessage:
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content=(
|
||||
f"Could not switch model preset: {exc}\n\n"
|
||||
f"Could not switch model preset: {_command_error_message(exc)}\n\n"
|
||||
f"Available presets: {_format_preset_names(names)}"
|
||||
),
|
||||
metadata=metadata,
|
||||
|
||||
@ -281,6 +281,8 @@ class Config(BaseSettings):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model_preset(self) -> "Config":
|
||||
if "default" in self.model_presets:
|
||||
raise ValueError("model_preset name 'default' is reserved for agents.defaults")
|
||||
name = self.agents.defaults.model_preset
|
||||
if name and name != "default" and name not in self.model_presets:
|
||||
raise ValueError(f"model_preset {name!r} not found in model_presets")
|
||||
|
||||
@ -104,11 +104,11 @@ def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
|
||||
model="base-model",
|
||||
context_window_tokens=1000,
|
||||
model_presets={"deep": preset},
|
||||
provider_snapshot_loader=lambda preset_name=None: ProviderSnapshot(
|
||||
preset_snapshot_loader=lambda name: ProviderSnapshot(
|
||||
provider=new_provider,
|
||||
model=preset.model,
|
||||
context_window_tokens=preset.context_window_tokens,
|
||||
signature=(preset_name, preset.model),
|
||||
signature=(name, preset.model),
|
||||
),
|
||||
)
|
||||
|
||||
@ -135,7 +135,7 @@ def test_model_preset_setter_failure_leaves_old_state(tmp_path) -> None:
|
||||
model="base-model",
|
||||
context_window_tokens=1000,
|
||||
model_presets={"fast": preset},
|
||||
provider_snapshot_loader=lambda preset_name=None: (_ for _ in ()).throw(
|
||||
preset_snapshot_loader=lambda _name: (_ for _ in ()).throw(
|
||||
RuntimeError("provider unavailable")
|
||||
),
|
||||
)
|
||||
@ -175,9 +175,8 @@ def test_active_model_preset_survives_unchanged_config_refresh(tmp_path) -> None
|
||||
context_window_tokens=1000,
|
||||
provider_signature=default_snapshot.signature,
|
||||
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
|
||||
provider_snapshot_loader=lambda preset_name=None: (
|
||||
fast_snapshot if preset_name == "fast" else default_snapshot
|
||||
),
|
||||
provider_snapshot_loader=lambda: default_snapshot,
|
||||
preset_snapshot_loader=lambda _name: fast_snapshot,
|
||||
)
|
||||
|
||||
loop.set_model_preset("fast")
|
||||
@ -210,11 +209,10 @@ def test_config_model_refresh_clears_active_model_preset(tmp_path) -> None:
|
||||
workspace=tmp_path,
|
||||
model="base-model",
|
||||
context_window_tokens=1000,
|
||||
provider_snapshot_loader=lambda preset_name=None: (
|
||||
fast_snapshot if preset_name == "fast" else webui_snapshot
|
||||
),
|
||||
provider_snapshot_loader=lambda: webui_snapshot,
|
||||
provider_signature=("base-model", "auto", "openai", "sk-old"),
|
||||
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
|
||||
preset_snapshot_loader=lambda _name: fast_snapshot,
|
||||
)
|
||||
|
||||
loop.set_model_preset("fast")
|
||||
@ -286,17 +284,16 @@ def test_from_config_injects_default_preset(tmp_path) -> None:
|
||||
assert loop.model_presets["default"].model == "openai/gpt-4.1"
|
||||
|
||||
|
||||
def test_from_config_reserves_default_for_agent_defaults(tmp_path) -> None:
|
||||
def test_from_config_static_preset_loader_does_not_enable_hot_reload(tmp_path) -> None:
|
||||
from unittest.mock import patch
|
||||
|
||||
from nanobot.config.schema import Config
|
||||
config = Config.model_validate({
|
||||
"agents": {"defaults": {"model": "openai/gpt-4.1", "workspace": str(tmp_path)}},
|
||||
"model_presets": {
|
||||
"default": {"model": "custom-model"}
|
||||
},
|
||||
"model_presets": {"fast": {"model": "openai/gpt-4.1-mini"}},
|
||||
})
|
||||
fake_provider = _provider("openai/gpt-4.1")
|
||||
with patch("nanobot.providers.factory.make_provider", return_value=fake_provider):
|
||||
loop = AgentLoop.from_config(config)
|
||||
assert loop.model_presets["default"].model == "openai/gpt-4.1"
|
||||
assert loop._provider_snapshot_loader is None
|
||||
assert loop._preset_snapshot_loader is not None
|
||||
|
||||
@ -102,6 +102,7 @@ async def test_model_command_unknown_preset_keeps_old_state(tmp_path) -> None:
|
||||
out = await cmd_model(_ctx(loop, "/model missing", args="missing"))
|
||||
|
||||
assert "Could not switch model preset" in out.content
|
||||
assert "\"model_preset" not in out.content
|
||||
assert "Available presets: `default`, `fast`" in out.content
|
||||
assert loop.model_preset is None
|
||||
assert loop.model == "base-model"
|
||||
|
||||
@ -110,6 +110,17 @@ def test_model_preset_accepts_explicit_default_name() -> None:
|
||||
assert config.resolve_preset().model == "openai/gpt-4.1"
|
||||
|
||||
|
||||
def test_model_presets_rejects_reserved_default_name() -> None:
|
||||
import pytest
|
||||
|
||||
with pytest.raises(ValueError, match="model_preset name 'default' is reserved"):
|
||||
Config.model_validate({
|
||||
"modelPresets": {
|
||||
"default": {"model": "custom-model"},
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
def test_resolve_preset_rejects_unknown_named_preset() -> None:
|
||||
import pytest
|
||||
with pytest.raises(KeyError, match="model_preset 'missing' not found"):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user