mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 16:42:25 +00:00
fix(agent): separate preset snapshots from config reload
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
8fcb24bb7c
commit
e6103d9312
@ -293,6 +293,7 @@ class AgentLoop:
|
|||||||
provider_signature: tuple[object, ...] | None = None,
|
provider_signature: tuple[object, ...] | None = None,
|
||||||
model_presets: dict[str, ModelPresetConfig] | None = None,
|
model_presets: dict[str, ModelPresetConfig] | None = None,
|
||||||
model_preset: str | None = None,
|
model_preset: str | None = None,
|
||||||
|
preset_snapshot_loader: Callable[[str], ProviderSnapshot] | None = None,
|
||||||
):
|
):
|
||||||
from nanobot.config.schema import ToolsConfig
|
from nanobot.config.schema import ToolsConfig
|
||||||
|
|
||||||
@ -302,6 +303,7 @@ class AgentLoop:
|
|||||||
self.channels_config = channels_config
|
self.channels_config = channels_config
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self._provider_snapshot_loader = provider_snapshot_loader
|
self._provider_snapshot_loader = provider_snapshot_loader
|
||||||
|
self._preset_snapshot_loader = preset_snapshot_loader
|
||||||
self._provider_signature = provider_signature
|
self._provider_signature = provider_signature
|
||||||
self._default_selection_signature = provider_signature[:2] if provider_signature else None
|
self._default_selection_signature = provider_signature[:2] if provider_signature else None
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
@ -431,7 +433,16 @@ class AgentLoop:
|
|||||||
model = extra.pop("model", None) or resolved.model
|
model = extra.pop("model", None) or resolved.model
|
||||||
context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens
|
context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens
|
||||||
provider_snapshot_loader = extra.pop("provider_snapshot_loader", None)
|
provider_snapshot_loader = extra.pop("provider_snapshot_loader", None)
|
||||||
|
preset_snapshot_loader = extra.pop("preset_snapshot_loader", None)
|
||||||
model_presets = {**config.model_presets, "default": config.resolve_default_preset()}
|
model_presets = {**config.model_presets, "default": config.resolve_default_preset()}
|
||||||
|
if preset_snapshot_loader is None:
|
||||||
|
if provider_snapshot_loader is not None:
|
||||||
|
preset_snapshot_loader = lambda name: provider_snapshot_loader(preset_name=name)
|
||||||
|
else:
|
||||||
|
preset_snapshot_loader = lambda name: build_provider_snapshot(
|
||||||
|
config,
|
||||||
|
preset_name=name,
|
||||||
|
)
|
||||||
return cls(
|
return cls(
|
||||||
bus=bus,
|
bus=bus,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
@ -455,9 +466,8 @@ class AgentLoop:
|
|||||||
tools_config=config.tools,
|
tools_config=config.tools,
|
||||||
model_presets=model_presets,
|
model_presets=model_presets,
|
||||||
model_preset=defaults.model_preset,
|
model_preset=defaults.model_preset,
|
||||||
provider_snapshot_loader=provider_snapshot_loader or (
|
provider_snapshot_loader=provider_snapshot_loader,
|
||||||
lambda preset_name=None: build_provider_snapshot(config, preset_name=preset_name)
|
preset_snapshot_loader=preset_snapshot_loader,
|
||||||
),
|
|
||||||
**extra,
|
**extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -468,8 +478,14 @@ class AgentLoop:
|
|||||||
def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None:
|
def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None:
|
||||||
"""Notify WebUI clients that the effective runtime model changed."""
|
"""Notify WebUI clients that the effective runtime model changed."""
|
||||||
self.bus.outbound.put_nowait(OutboundMessage(
|
self.bus.outbound.put_nowait(OutboundMessage(
|
||||||
channel="websocket", chat_id="*", content="",
|
channel="websocket",
|
||||||
metadata={"_runtime_model_updated": True, "model": self.model, "model_preset": model_preset if model_preset is not None else self.model_preset},
|
chat_id="*",
|
||||||
|
content="",
|
||||||
|
metadata={
|
||||||
|
"_runtime_model_updated": True,
|
||||||
|
"model": self.model,
|
||||||
|
"model_preset": model_preset if model_preset is not None else self.model_preset,
|
||||||
|
},
|
||||||
))
|
))
|
||||||
|
|
||||||
def _apply_provider_snapshot(
|
def _apply_provider_snapshot(
|
||||||
@ -530,8 +546,8 @@ class AgentLoop:
|
|||||||
|
|
||||||
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
|
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
|
||||||
preset = self.model_presets[name]
|
preset = self.model_presets[name]
|
||||||
if self._provider_snapshot_loader is not None:
|
if self._preset_snapshot_loader is not None:
|
||||||
return self._provider_snapshot_loader(preset_name=name)
|
return self._preset_snapshot_loader(name)
|
||||||
self.provider.generation = preset.to_generation_settings()
|
self.provider.generation = preset.to_generation_settings()
|
||||||
return ProviderSnapshot(
|
return ProviderSnapshot(
|
||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
|
|||||||
@ -213,6 +213,10 @@ def _active_model_preset_name(loop) -> str:
|
|||||||
return loop.model_preset or "default"
|
return loop.model_preset or "default"
|
||||||
|
|
||||||
|
|
||||||
|
def _command_error_message(exc: Exception) -> str:
|
||||||
|
return str(exc.args[0]) if isinstance(exc, KeyError) and exc.args else str(exc)
|
||||||
|
|
||||||
|
|
||||||
def _model_command_status(loop) -> str:
|
def _model_command_status(loop) -> str:
|
||||||
names = _model_preset_names(loop)
|
names = _model_preset_names(loop)
|
||||||
active = _active_model_preset_name(loop)
|
active = _active_model_preset_name(loop)
|
||||||
@ -256,7 +260,7 @@ async def cmd_model(ctx: CommandContext) -> OutboundMessage:
|
|||||||
channel=ctx.msg.channel,
|
channel=ctx.msg.channel,
|
||||||
chat_id=ctx.msg.chat_id,
|
chat_id=ctx.msg.chat_id,
|
||||||
content=(
|
content=(
|
||||||
f"Could not switch model preset: {exc}\n\n"
|
f"Could not switch model preset: {_command_error_message(exc)}\n\n"
|
||||||
f"Available presets: {_format_preset_names(names)}"
|
f"Available presets: {_format_preset_names(names)}"
|
||||||
),
|
),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
|||||||
@ -281,6 +281,8 @@ class Config(BaseSettings):
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_model_preset(self) -> "Config":
|
def _validate_model_preset(self) -> "Config":
|
||||||
|
if "default" in self.model_presets:
|
||||||
|
raise ValueError("model_preset name 'default' is reserved for agents.defaults")
|
||||||
name = self.agents.defaults.model_preset
|
name = self.agents.defaults.model_preset
|
||||||
if name and name != "default" and name not in self.model_presets:
|
if name and name != "default" and name not in self.model_presets:
|
||||||
raise ValueError(f"model_preset {name!r} not found in model_presets")
|
raise ValueError(f"model_preset {name!r} not found in model_presets")
|
||||||
|
|||||||
@ -104,11 +104,11 @@ def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
|
|||||||
model="base-model",
|
model="base-model",
|
||||||
context_window_tokens=1000,
|
context_window_tokens=1000,
|
||||||
model_presets={"deep": preset},
|
model_presets={"deep": preset},
|
||||||
provider_snapshot_loader=lambda preset_name=None: ProviderSnapshot(
|
preset_snapshot_loader=lambda name: ProviderSnapshot(
|
||||||
provider=new_provider,
|
provider=new_provider,
|
||||||
model=preset.model,
|
model=preset.model,
|
||||||
context_window_tokens=preset.context_window_tokens,
|
context_window_tokens=preset.context_window_tokens,
|
||||||
signature=(preset_name, preset.model),
|
signature=(name, preset.model),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -135,7 +135,7 @@ def test_model_preset_setter_failure_leaves_old_state(tmp_path) -> None:
|
|||||||
model="base-model",
|
model="base-model",
|
||||||
context_window_tokens=1000,
|
context_window_tokens=1000,
|
||||||
model_presets={"fast": preset},
|
model_presets={"fast": preset},
|
||||||
provider_snapshot_loader=lambda preset_name=None: (_ for _ in ()).throw(
|
preset_snapshot_loader=lambda _name: (_ for _ in ()).throw(
|
||||||
RuntimeError("provider unavailable")
|
RuntimeError("provider unavailable")
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -175,9 +175,8 @@ def test_active_model_preset_survives_unchanged_config_refresh(tmp_path) -> None
|
|||||||
context_window_tokens=1000,
|
context_window_tokens=1000,
|
||||||
provider_signature=default_snapshot.signature,
|
provider_signature=default_snapshot.signature,
|
||||||
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
|
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
|
||||||
provider_snapshot_loader=lambda preset_name=None: (
|
provider_snapshot_loader=lambda: default_snapshot,
|
||||||
fast_snapshot if preset_name == "fast" else default_snapshot
|
preset_snapshot_loader=lambda _name: fast_snapshot,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loop.set_model_preset("fast")
|
loop.set_model_preset("fast")
|
||||||
@ -210,11 +209,10 @@ def test_config_model_refresh_clears_active_model_preset(tmp_path) -> None:
|
|||||||
workspace=tmp_path,
|
workspace=tmp_path,
|
||||||
model="base-model",
|
model="base-model",
|
||||||
context_window_tokens=1000,
|
context_window_tokens=1000,
|
||||||
provider_snapshot_loader=lambda preset_name=None: (
|
provider_snapshot_loader=lambda: webui_snapshot,
|
||||||
fast_snapshot if preset_name == "fast" else webui_snapshot
|
|
||||||
),
|
|
||||||
provider_signature=("base-model", "auto", "openai", "sk-old"),
|
provider_signature=("base-model", "auto", "openai", "sk-old"),
|
||||||
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
|
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
|
||||||
|
preset_snapshot_loader=lambda _name: fast_snapshot,
|
||||||
)
|
)
|
||||||
|
|
||||||
loop.set_model_preset("fast")
|
loop.set_model_preset("fast")
|
||||||
@ -286,17 +284,16 @@ def test_from_config_injects_default_preset(tmp_path) -> None:
|
|||||||
assert loop.model_presets["default"].model == "openai/gpt-4.1"
|
assert loop.model_presets["default"].model == "openai/gpt-4.1"
|
||||||
|
|
||||||
|
|
||||||
def test_from_config_reserves_default_for_agent_defaults(tmp_path) -> None:
|
def test_from_config_static_preset_loader_does_not_enable_hot_reload(tmp_path) -> None:
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
config = Config.model_validate({
|
config = Config.model_validate({
|
||||||
"agents": {"defaults": {"model": "openai/gpt-4.1", "workspace": str(tmp_path)}},
|
"agents": {"defaults": {"model": "openai/gpt-4.1", "workspace": str(tmp_path)}},
|
||||||
"model_presets": {
|
"model_presets": {"fast": {"model": "openai/gpt-4.1-mini"}},
|
||||||
"default": {"model": "custom-model"}
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
fake_provider = _provider("openai/gpt-4.1")
|
fake_provider = _provider("openai/gpt-4.1")
|
||||||
with patch("nanobot.providers.factory.make_provider", return_value=fake_provider):
|
with patch("nanobot.providers.factory.make_provider", return_value=fake_provider):
|
||||||
loop = AgentLoop.from_config(config)
|
loop = AgentLoop.from_config(config)
|
||||||
assert loop.model_presets["default"].model == "openai/gpt-4.1"
|
assert loop._provider_snapshot_loader is None
|
||||||
|
assert loop._preset_snapshot_loader is not None
|
||||||
|
|||||||
@ -102,6 +102,7 @@ async def test_model_command_unknown_preset_keeps_old_state(tmp_path) -> None:
|
|||||||
out = await cmd_model(_ctx(loop, "/model missing", args="missing"))
|
out = await cmd_model(_ctx(loop, "/model missing", args="missing"))
|
||||||
|
|
||||||
assert "Could not switch model preset" in out.content
|
assert "Could not switch model preset" in out.content
|
||||||
|
assert "\"model_preset" not in out.content
|
||||||
assert "Available presets: `default`, `fast`" in out.content
|
assert "Available presets: `default`, `fast`" in out.content
|
||||||
assert loop.model_preset is None
|
assert loop.model_preset is None
|
||||||
assert loop.model == "base-model"
|
assert loop.model == "base-model"
|
||||||
|
|||||||
@ -110,6 +110,17 @@ def test_model_preset_accepts_explicit_default_name() -> None:
|
|||||||
assert config.resolve_preset().model == "openai/gpt-4.1"
|
assert config.resolve_preset().model == "openai/gpt-4.1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_presets_rejects_reserved_default_name() -> None:
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="model_preset name 'default' is reserved"):
|
||||||
|
Config.model_validate({
|
||||||
|
"modelPresets": {
|
||||||
|
"default": {"model": "custom-model"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_preset_rejects_unknown_named_preset() -> None:
|
def test_resolve_preset_rejects_unknown_named_preset() -> None:
|
||||||
import pytest
|
import pytest
|
||||||
with pytest.raises(KeyError, match="model_preset 'missing' not found"):
|
with pytest.raises(KeyError, match="model_preset 'missing' not found"):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user