fix(config): make model preset switching atomic

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Xubin Ren 2026-05-12 07:55:01 +00:00 committed by Xubin Ren
parent 6f78267c82
commit c450d6fd3f
8 changed files with 420 additions and 39 deletions

View File

@ -293,6 +293,7 @@ class AgentLoop:
provider_signature: tuple[object, ...] | None = None,
model_presets: dict[str, ModelPresetConfig] | None = None,
model_preset: str | None = None,
model_preset_snapshot_builder: Callable[[ModelPresetConfig], ProviderSnapshot] | None = None,
):
from nanobot.config.schema import ToolsConfig
@ -398,7 +399,10 @@ class AgentLoop:
model=self.model,
)
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
self._active_preset: str | None = model_preset if model_presets and model_preset in model_presets else None
self._model_preset_snapshot_builder = model_preset_snapshot_builder
self._active_preset: str | None = None
if model_preset:
self.set_model_preset(model_preset)
self._register_default_tools()
self._runtime_vars: dict[str, Any] = {}
self._current_iteration: int = 0
@ -418,7 +422,7 @@ class AgentLoop:
allowing callers to override or extend the standard config-derived
parameters (e.g. ``cron_service``, ``session_manager``).
"""
from nanobot.providers.factory import make_provider
from nanobot.providers.factory import build_provider_snapshot, make_provider
if bus is None:
bus = MessageBus()
@ -453,6 +457,7 @@ class AgentLoop:
tools_config=config.tools,
model_presets=model_presets,
model_preset=defaults.model_preset,
model_preset_snapshot_builder=lambda preset: build_provider_snapshot(config, preset=preset),
**extra,
)
@ -465,8 +470,6 @@ class AgentLoop:
provider = snapshot.provider
model = snapshot.model
context_window_tokens = snapshot.context_window_tokens
if self.provider is provider and self.model == model:
return
old_model = self.model
self.provider = provider
self.model = model
@ -498,15 +501,38 @@ class AgentLoop:
@model_preset.setter
def model_preset(self, name: str | None) -> None:
"""Resolve a preset by name and apply all fields atomically."""
self.set_model_preset(name)
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
preset = self.model_presets[name]
if self._model_preset_snapshot_builder is not None:
return self._model_preset_snapshot_builder(preset)
self.provider.generation = preset.to_generation_settings()
return ProviderSnapshot(
provider=self.provider,
model=preset.model,
context_window_tokens=preset.context_window_tokens,
signature=(
"model_preset",
name,
preset.model,
preset.provider,
preset.max_tokens,
preset.context_window_tokens,
preset.temperature,
preset.reasoning_effort,
),
)
def set_model_preset(self, name: str | None) -> None:
"""Resolve a preset by name and apply all runtime model dependents."""
if not isinstance(name, str) or not name.strip():
raise ValueError("model_preset must be a non-empty string")
name = name.strip()
if name not in self.model_presets:
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
p = self.model_presets[name]
self.model = p.model
self.context_window_tokens = p.context_window_tokens
self.provider.generation = p.to_generation_settings()
snapshot = self._build_model_preset_snapshot(name)
self._apply_provider_snapshot(snapshot)
self._active_preset = name
def _register_default_tools(self) -> None:

View File

@ -58,6 +58,13 @@ BUILTIN_COMMAND_SPECS: tuple[BuiltinCommandSpec, ...] = (
"Display runtime, provider, and channel status.",
"activity",
),
BuiltinCommandSpec(
"/model",
"Switch model preset",
"Show or switch the active model preset.",
"brain",
"[preset]",
),
BuiltinCommandSpec(
"/history",
"Show conversation history",
@ -192,6 +199,75 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
)
def _format_preset_names(names: list[str]) -> str:
return ", ".join(f"`{name}`" for name in names) if names else "(none configured)"
def _model_command_status(loop) -> str:
names = sorted(loop.model_presets)
active = loop.model_preset or "(none)"
return "\n".join([
"## Model",
f"- Current model: `{loop.model}`",
f"- Active preset: `{active}`",
f"- Available presets: {_format_preset_names(names)}",
])
async def cmd_model(ctx: CommandContext) -> OutboundMessage:
"""Show or switch model presets."""
loop = ctx.loop
args = ctx.args.strip()
metadata = {**dict(ctx.msg.metadata or {}), "render_as": "text"}
if not args:
return OutboundMessage(
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content=_model_command_status(loop),
metadata=metadata,
)
parts = args.split()
if len(parts) != 1:
return OutboundMessage(
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content="Usage: `/model [preset]`",
metadata=metadata,
)
name = parts[0]
try:
loop.set_model_preset(name)
except (KeyError, ValueError) as exc:
names = sorted(loop.model_presets)
return OutboundMessage(
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content=(
f"Could not switch model preset: {exc}\n\n"
f"Available presets: {_format_preset_names(names)}"
),
metadata=metadata,
)
max_tokens = getattr(getattr(loop.provider, "generation", None), "max_tokens", None)
lines = [
f"Switched model preset to `{loop.model_preset}`.",
f"- Model: `{loop.model}`",
f"- Context window: {loop.context_window_tokens}",
]
if max_tokens is not None:
lines.append(f"- Max output tokens: {max_tokens}")
return OutboundMessage(
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content="\n".join(lines),
metadata=metadata,
)
async def cmd_dream(ctx: CommandContext) -> OutboundMessage:
"""Manually trigger a Dream consolidation run."""
import time
@ -477,6 +553,8 @@ def register_builtin_commands(router: CommandRouter) -> None:
router.priority("/status", cmd_status)
router.exact("/new", cmd_new)
router.exact("/status", cmd_status)
router.exact("/model", cmd_model)
router.prefix("/model ", cmd_model)
router.exact("/history", cmd_history)
router.prefix("/history ", cmd_history)
router.exact("/dream", cmd_dream)

View File

@ -283,10 +283,12 @@ class Config(BaseSettings):
raise ValueError(f"model_preset {name!r} not found in model_presets")
return self
def resolve_preset(self) -> ModelPresetConfig:
def resolve_preset(self, name: str | None = None) -> ModelPresetConfig:
"""Return effective model params: from active preset, or individual defaults."""
name = self.agents.defaults.model_preset
name = self.agents.defaults.model_preset if name is None else name
if name:
if name not in self.model_presets:
raise KeyError(f"model_preset {name!r} not found in model_presets")
return self.model_presets[name]
d = self.agents.defaults
return ModelPresetConfig(
@ -301,12 +303,14 @@ class Config(BaseSettings):
return Path(self.agents.defaults.workspace).expanduser()
def _match_provider(
self, model: str | None = None
self, model: str | None = None,
*,
preset: ModelPresetConfig | None = None,
) -> tuple["ProviderConfig | None", str | None]:
"""Match provider config and its registry name. Returns (config, spec_name)."""
from nanobot.providers.registry import PROVIDERS, find_by_name
resolved = self.resolve_preset()
resolved = preset or self.resolve_preset()
forced = resolved.provider
if forced != "auto":
spec = find_by_name(forced)
@ -366,26 +370,46 @@ class Config(BaseSettings):
return p, spec.name
return None, None
def get_provider(self, model: str | None = None) -> ProviderConfig | None:
def get_provider(
self,
model: str | None = None,
*,
preset: ModelPresetConfig | None = None,
) -> ProviderConfig | None:
"""Get matched provider config (api_key, api_base, extra_headers). Falls back to first available."""
p, _ = self._match_provider(model)
p, _ = self._match_provider(model, preset=preset)
return p
def get_provider_name(self, model: str | None = None) -> str | None:
def get_provider_name(
self,
model: str | None = None,
*,
preset: ModelPresetConfig | None = None,
) -> str | None:
"""Get the registry name of the matched provider (e.g. "deepseek", "openrouter")."""
_, name = self._match_provider(model)
_, name = self._match_provider(model, preset=preset)
return name
def get_api_key(self, model: str | None = None) -> str | None:
def get_api_key(
self,
model: str | None = None,
*,
preset: ModelPresetConfig | None = None,
) -> str | None:
"""Get API key for the given model. Falls back to first available key."""
p = self.get_provider(model)
p = self.get_provider(model, preset=preset)
return p.api_key if p else None
def get_api_base(self, model: str | None = None) -> str | None:
def get_api_base(
self,
model: str | None = None,
*,
preset: ModelPresetConfig | None = None,
) -> str | None:
"""Get API base URL for the given model, falling back to the provider default when present."""
from nanobot.providers.registry import find_by_name
p, name = self._match_provider(model)
p, name = self._match_provider(model, preset=preset)
if p and p.api_base:
return p.api_base
if name:

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from nanobot.config.schema import Config
from nanobot.config.schema import Config, ModelPresetConfig
from nanobot.providers.base import LLMProvider
from nanobot.providers.registry import find_by_name
@ -18,12 +18,26 @@ class ProviderSnapshot:
signature: tuple[object, ...]
def make_provider(config: Config) -> LLMProvider:
def _resolve_model_preset(
config: Config,
*,
preset_name: str | None = None,
preset: ModelPresetConfig | None = None,
) -> ModelPresetConfig:
return preset if preset is not None else config.resolve_preset(preset_name)
def make_provider(
config: Config,
*,
preset_name: str | None = None,
preset: ModelPresetConfig | None = None,
) -> LLMProvider:
"""Create the LLM provider implied by config."""
resolved = config.resolve_preset()
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
model = resolved.model
provider_name = config.get_provider_name(model)
p = config.get_provider(model)
provider_name = config.get_provider_name(model, preset=resolved)
p = config.get_provider(model, preset=resolved)
spec = find_by_name(provider_name) if provider_name else None
backend = spec.backend if spec else "openai_compat"
@ -57,7 +71,7 @@ def make_provider(config: Config) -> LLMProvider:
provider = AnthropicProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
api_base=config.get_api_base(model, preset=resolved),
default_model=model,
extra_headers=p.extra_headers if p else None,
)
@ -77,7 +91,7 @@ def make_provider(config: Config) -> LLMProvider:
provider = OpenAICompatProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
api_base=config.get_api_base(model, preset=resolved),
default_model=model,
extra_headers=p.extra_headers if p else None,
spec=spec,
@ -88,16 +102,21 @@ def make_provider(config: Config) -> LLMProvider:
return provider
def provider_signature(config: Config) -> tuple[object, ...]:
def provider_signature(
config: Config,
*,
preset_name: str | None = None,
preset: ModelPresetConfig | None = None,
) -> tuple[object, ...]:
"""Return the config fields that affect the primary LLM provider."""
resolved = config.resolve_preset()
p = config.get_provider(resolved.model)
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
p = config.get_provider(resolved.model, preset=resolved)
return (
resolved.model,
resolved.provider,
config.get_provider_name(resolved.model),
config.get_api_key(resolved.model),
config.get_api_base(resolved.model),
config.get_provider_name(resolved.model, preset=resolved),
config.get_api_key(resolved.model, preset=resolved),
config.get_api_base(resolved.model, preset=resolved),
p.extra_headers if p else None,
p.extra_body if p else None,
getattr(p, "region", None) if p else None,
@ -109,13 +128,18 @@ def provider_signature(config: Config) -> tuple[object, ...]:
)
def build_provider_snapshot(config: Config) -> ProviderSnapshot:
resolved = config.resolve_preset()
def build_provider_snapshot(
config: Config,
*,
preset_name: str | None = None,
preset: ModelPresetConfig | None = None,
) -> ProviderSnapshot:
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
return ProviderSnapshot(
provider=make_provider(config),
provider=make_provider(config, preset=resolved),
model=resolved.model,
context_window_tokens=resolved.context_window_tokens,
signature=provider_signature(config),
signature=provider_signature(config, preset=resolved),
)

View File

@ -7,6 +7,7 @@ from nanobot.agent.loop import AgentLoop
from nanobot.agent.tools.self import MyTool
from nanobot.bus.queue import MessageBus
from nanobot.config.schema import ModelPresetConfig
from nanobot.providers.factory import ProviderSnapshot
def _provider(default_model: str, max_tokens: int = 123) -> MagicMock:
@ -56,6 +57,75 @@ def test_model_preset_setter_updates_state(tmp_path) -> None:
assert loop.provider.generation.temperature == 0.5
assert loop.provider.generation.max_tokens == 4096
assert loop.provider.generation.reasoning_effort == "low"
assert loop.subagents.model == "openai/gpt-4.1"
assert loop.consolidator.model == "openai/gpt-4.1"
assert loop.consolidator.context_window_tokens == 32_768
assert loop.consolidator.max_completion_tokens == 4096
assert loop.dream.model == "openai/gpt-4.1"
def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
old_provider = _provider("base-model", max_tokens=123)
new_provider = _provider("anthropic/claude-opus-4-5", max_tokens=2048)
preset = ModelPresetConfig(
model="anthropic/claude-opus-4-5",
provider="anthropic",
max_tokens=2048,
context_window_tokens=200_000,
)
loop = AgentLoop(
bus=MessageBus(),
provider=old_provider,
workspace=tmp_path,
model="base-model",
context_window_tokens=1000,
model_presets={"deep": preset},
model_preset_snapshot_builder=lambda _preset: ProviderSnapshot(
provider=new_provider,
model=_preset.model,
context_window_tokens=_preset.context_window_tokens,
signature=("deep", _preset.model),
),
)
loop.set_model_preset("deep")
assert loop.provider is new_provider
assert loop.runner.provider is new_provider
assert loop.subagents.provider is new_provider
assert loop.subagents.runner.provider is new_provider
assert loop.consolidator.provider is new_provider
assert loop.dream.provider is new_provider
assert loop.dream._runner.provider is new_provider
assert loop.model == "anthropic/claude-opus-4-5"
assert loop.context_window_tokens == 200_000
assert loop.consolidator.max_completion_tokens == 2048
def test_model_preset_setter_failure_leaves_old_state(tmp_path) -> None:
preset = ModelPresetConfig(model="openai/gpt-4.1", max_tokens=4096)
loop = AgentLoop(
bus=MessageBus(),
provider=_provider("base-model", max_tokens=123),
workspace=tmp_path,
model="base-model",
context_window_tokens=1000,
model_presets={"fast": preset},
model_preset_snapshot_builder=lambda _preset: (_ for _ in ()).throw(
RuntimeError("provider unavailable")
),
)
with pytest.raises(RuntimeError, match="provider unavailable"):
loop.set_model_preset("fast")
assert loop.model_preset is None
assert loop.model == "base-model"
assert loop.subagents.model == "base-model"
assert loop.consolidator.model == "base-model"
assert loop.dream.model == "base-model"
assert loop.context_window_tokens == 1000
assert loop.consolidator.max_completion_tokens == 123
def test_model_preset_setter_raises_on_unknown(tmp_path) -> None:

View File

@ -0,0 +1,137 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.command.builtin import (
build_help_text,
builtin_command_palette,
cmd_model,
register_builtin_commands,
)
from nanobot.command.router import CommandContext, CommandRouter
from nanobot.config.schema import ModelPresetConfig
def _provider(default_model: str, max_tokens: int = 123) -> MagicMock:
provider = MagicMock()
provider.get_default_model.return_value = default_model
provider.generation = SimpleNamespace(
max_tokens=max_tokens,
temperature=0.1,
reasoning_effort=None,
)
return provider
def _make_loop(tmp_path) -> AgentLoop:
return AgentLoop(
bus=MessageBus(),
provider=_provider("base-model", max_tokens=123),
workspace=tmp_path,
model="base-model",
context_window_tokens=1000,
model_presets={
"default": ModelPresetConfig(
model="base-model",
max_tokens=123,
context_window_tokens=1000,
),
"fast": ModelPresetConfig(
model="openai/gpt-4.1",
max_tokens=4096,
context_window_tokens=32_768,
),
},
)
def _ctx(loop: AgentLoop, raw: str, args: str = "") -> CommandContext:
msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content=raw)
return CommandContext(msg=msg, session=None, key=msg.session_key, raw=raw, args=args, loop=loop)
@pytest.mark.asyncio
async def test_model_command_lists_current_and_available_presets(tmp_path) -> None:
loop = _make_loop(tmp_path)
out = await cmd_model(_ctx(loop, "/model"))
assert "Current model: `base-model`" in out.content
assert "Active preset: `(none)`" in out.content
assert "`default`" in out.content
assert "`fast`" in out.content
assert out.metadata == {"render_as": "text"}
@pytest.mark.asyncio
async def test_model_command_switches_preset(tmp_path) -> None:
loop = _make_loop(tmp_path)
out = await cmd_model(_ctx(loop, "/model fast", args="fast"))
assert "Switched model preset to `fast`." in out.content
assert "Model: `openai/gpt-4.1`" in out.content
assert loop.model_preset == "fast"
assert loop.model == "openai/gpt-4.1"
assert loop.subagents.model == "openai/gpt-4.1"
assert loop.consolidator.model == "openai/gpt-4.1"
assert loop.dream.model == "openai/gpt-4.1"
@pytest.mark.asyncio
async def test_model_command_switches_back_to_default(tmp_path) -> None:
loop = _make_loop(tmp_path)
loop.set_model_preset("fast")
out = await cmd_model(_ctx(loop, "/model default", args="default"))
assert "Switched model preset to `default`." in out.content
assert loop.model_preset == "default"
assert loop.model == "base-model"
assert loop.context_window_tokens == 1000
@pytest.mark.asyncio
async def test_model_command_unknown_preset_keeps_old_state(tmp_path) -> None:
loop = _make_loop(tmp_path)
out = await cmd_model(_ctx(loop, "/model missing", args="missing"))
assert "Could not switch model preset" in out.content
assert "Available presets: `default`, `fast`" in out.content
assert loop.model_preset is None
assert loop.model == "base-model"
@pytest.mark.asyncio
async def test_model_command_does_not_depend_on_my_allow_set(tmp_path) -> None:
loop = _make_loop(tmp_path)
assert loop.tools_config.my.allow_set is False
await cmd_model(_ctx(loop, "/model fast", args="fast"))
assert loop.model_preset == "fast"
@pytest.mark.asyncio
async def test_model_command_registered_as_exact_and_prefix(tmp_path) -> None:
router = CommandRouter()
register_builtin_commands(router)
loop = _make_loop(tmp_path)
out = await router.dispatch(_ctx(loop, "/model fast"))
assert out is not None
assert "Switched model preset" in out.content
assert loop.model_preset == "fast"
def test_model_command_in_help_and_palette() -> None:
palette = builtin_command_palette()
assert any(item["command"] == "/model" and item["arg_hint"] == "[preset]" for item in palette)
assert "/model [preset]" in build_help_text()

View File

@ -22,6 +22,7 @@ class TestIsDispatchableCommand:
def test_exact_commands_match(self, router: CommandRouter) -> None:
assert router.is_dispatchable_command("/new")
assert router.is_dispatchable_command("/help")
assert router.is_dispatchable_command("/model")
assert router.is_dispatchable_command("/dream")
assert router.is_dispatchable_command("/dream-log")
assert router.is_dispatchable_command("/dream-restore")
@ -29,6 +30,7 @@ class TestIsDispatchableCommand:
def test_prefix_commands_match(self, router: CommandRouter) -> None:
assert router.is_dispatchable_command("/dream-log abc123")
assert router.is_dispatchable_command("/dream-restore def456")
assert router.is_dispatchable_command("/model fast")
def test_priority_commands_not_matched(self, router: CommandRouter) -> None:
# Priority commands are NOT in the dispatchable tiers — they are

View File

@ -1,4 +1,4 @@
from nanobot.config.schema import Config, ModelPresetConfig
from nanobot.config.schema import Config
def test_resolve_preset_returns_defaults_when_no_preset() -> None:
@ -39,6 +39,20 @@ def test_resolve_preset_returns_active_preset() -> None:
assert resolved.reasoning_effort == "low"
def test_resolve_preset_can_target_named_preset_without_activating() -> None:
config = Config.model_validate({
"model_presets": {
"fast": {"model": "openai/gpt-4.1", "provider": "openai"},
"deep": {"model": "anthropic/claude-opus-4-5", "provider": "anthropic"},
},
"agents": {"defaults": {"modelPreset": "fast"}},
})
resolved = config.resolve_preset("deep")
assert resolved.model == "anthropic/claude-opus-4-5"
assert resolved.provider == "anthropic"
def test_validator_rejects_unknown_preset() -> None:
import pytest
with pytest.raises(ValueError, match="model_preset 'unknown' not found in model_presets"):
@ -51,6 +65,12 @@ def test_validator_rejects_unknown_preset() -> None:
})
def test_resolve_preset_rejects_unknown_named_preset() -> None:
import pytest
with pytest.raises(KeyError, match="model_preset 'missing' not found"):
Config().resolve_preset("missing")
def test_match_provider_uses_preset_model() -> None:
config = Config.model_validate({
"providers": {