From c450d6fd3fbbcaa28172e49649b13505a2a3ed49 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 12 May 2026 07:55:01 +0000 Subject: [PATCH] fix(config): make model preset switching atomic Co-authored-by: Cursor --- nanobot/agent/loop.py | 44 +++++-- nanobot/command/builtin.py | 78 ++++++++++++ nanobot/config/schema.py | 48 ++++++-- nanobot/providers/factory.py | 58 ++++++--- tests/agent/test_self_model_preset.py | 70 +++++++++++ tests/command/test_model_command.py | 137 ++++++++++++++++++++++ tests/command/test_router_dispatchable.py | 2 + tests/config/test_model_presets.py | 22 +++- 8 files changed, 420 insertions(+), 39 deletions(-) create mode 100644 tests/command/test_model_command.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 9b97ab378..d83c8bd41 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -293,6 +293,7 @@ class AgentLoop: provider_signature: tuple[object, ...] | None = None, model_presets: dict[str, ModelPresetConfig] | None = None, model_preset: str | None = None, + model_preset_snapshot_builder: Callable[[ModelPresetConfig], ProviderSnapshot] | None = None, ): from nanobot.config.schema import ToolsConfig @@ -398,7 +399,10 @@ class AgentLoop: model=self.model, ) self.model_presets: dict[str, ModelPresetConfig] = model_presets or {} - self._active_preset: str | None = model_preset if model_presets and model_preset in model_presets else None + self._model_preset_snapshot_builder = model_preset_snapshot_builder + self._active_preset: str | None = None + if model_preset: + self.set_model_preset(model_preset) self._register_default_tools() self._runtime_vars: dict[str, Any] = {} self._current_iteration: int = 0 @@ -418,7 +422,7 @@ class AgentLoop: allowing callers to override or extend the standard config-derived parameters (e.g. ``cron_service``, ``session_manager``). """ - from nanobot.providers.factory import make_provider + from nanobot.providers.factory import build_provider_snapshot, make_provider if bus is None: bus = MessageBus() @@ -453,6 +457,7 @@ class AgentLoop: tools_config=config.tools, model_presets=model_presets, model_preset=defaults.model_preset, + model_preset_snapshot_builder=lambda preset: build_provider_snapshot(config, preset=preset), **extra, ) @@ -465,8 +470,6 @@ class AgentLoop: provider = snapshot.provider model = snapshot.model context_window_tokens = snapshot.context_window_tokens - if self.provider is provider and self.model == model: - return old_model = self.model self.provider = provider self.model = model @@ -498,15 +501,38 @@ class AgentLoop: @model_preset.setter def model_preset(self, name: str | None) -> None: - """Resolve a preset by name and apply all fields atomically.""" + self.set_model_preset(name) + + def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot: + preset = self.model_presets[name] + if self._model_preset_snapshot_builder is not None: + return self._model_preset_snapshot_builder(preset) + self.provider.generation = preset.to_generation_settings() + return ProviderSnapshot( + provider=self.provider, + model=preset.model, + context_window_tokens=preset.context_window_tokens, + signature=( + "model_preset", + name, + preset.model, + preset.provider, + preset.max_tokens, + preset.context_window_tokens, + preset.temperature, + preset.reasoning_effort, + ), + ) + + def set_model_preset(self, name: str | None) -> None: + """Resolve a preset by name and apply all runtime model dependents.""" if not isinstance(name, str) or not name.strip(): raise ValueError("model_preset must be a non-empty string") + name = name.strip() if name not in self.model_presets: raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}") - p = self.model_presets[name] - self.model = p.model - self.context_window_tokens = p.context_window_tokens - self.provider.generation = p.to_generation_settings() + snapshot = self._build_model_preset_snapshot(name) + self._apply_provider_snapshot(snapshot) self._active_preset = name def _register_default_tools(self) -> None: diff --git a/nanobot/command/builtin.py b/nanobot/command/builtin.py index b71a77f91..2310be181 100644 --- a/nanobot/command/builtin.py +++ b/nanobot/command/builtin.py @@ -58,6 +58,13 @@ BUILTIN_COMMAND_SPECS: tuple[BuiltinCommandSpec, ...] = ( "Display runtime, provider, and channel status.", "activity", ), + BuiltinCommandSpec( + "/model", + "Switch model preset", + "Show or switch the active model preset.", + "brain", + "[preset]", + ), BuiltinCommandSpec( "/history", "Show conversation history", @@ -192,6 +199,75 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage: ) +def _format_preset_names(names: list[str]) -> str: + return ", ".join(f"`{name}`" for name in names) if names else "(none configured)" + + +def _model_command_status(loop) -> str: + names = sorted(loop.model_presets) + active = loop.model_preset or "(none)" + return "\n".join([ + "## Model", + f"- Current model: `{loop.model}`", + f"- Active preset: `{active}`", + f"- Available presets: {_format_preset_names(names)}", + ]) + + +async def cmd_model(ctx: CommandContext) -> OutboundMessage: + """Show or switch model presets.""" + loop = ctx.loop + args = ctx.args.strip() + metadata = {**dict(ctx.msg.metadata or {}), "render_as": "text"} + + if not args: + return OutboundMessage( + channel=ctx.msg.channel, + chat_id=ctx.msg.chat_id, + content=_model_command_status(loop), + metadata=metadata, + ) + + parts = args.split() + if len(parts) != 1: + return OutboundMessage( + channel=ctx.msg.channel, + chat_id=ctx.msg.chat_id, + content="Usage: `/model [preset]`", + metadata=metadata, + ) + + name = parts[0] + try: + loop.set_model_preset(name) + except (KeyError, ValueError) as exc: + names = sorted(loop.model_presets) + return OutboundMessage( + channel=ctx.msg.channel, + chat_id=ctx.msg.chat_id, + content=( + f"Could not switch model preset: {exc}\n\n" + f"Available presets: {_format_preset_names(names)}" + ), + metadata=metadata, + ) + + max_tokens = getattr(getattr(loop.provider, "generation", None), "max_tokens", None) + lines = [ + f"Switched model preset to `{loop.model_preset}`.", + f"- Model: `{loop.model}`", + f"- Context window: {loop.context_window_tokens}", + ] + if max_tokens is not None: + lines.append(f"- Max output tokens: {max_tokens}") + return OutboundMessage( + channel=ctx.msg.channel, + chat_id=ctx.msg.chat_id, + content="\n".join(lines), + metadata=metadata, + ) + + async def cmd_dream(ctx: CommandContext) -> OutboundMessage: """Manually trigger a Dream consolidation run.""" import time @@ -477,6 +553,8 @@ def register_builtin_commands(router: CommandRouter) -> None: router.priority("/status", cmd_status) router.exact("/new", cmd_new) router.exact("/status", cmd_status) + router.exact("/model", cmd_model) + router.prefix("/model ", cmd_model) router.exact("/history", cmd_history) router.prefix("/history ", cmd_history) router.exact("/dream", cmd_dream) diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index b688c820e..3d1bb9e0a 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -283,10 +283,12 @@ class Config(BaseSettings): raise ValueError(f"model_preset {name!r} not found in model_presets") return self - def resolve_preset(self) -> ModelPresetConfig: + def resolve_preset(self, name: str | None = None) -> ModelPresetConfig: """Return effective model params: from active preset, or individual defaults.""" - name = self.agents.defaults.model_preset + name = self.agents.defaults.model_preset if name is None else name if name: + if name not in self.model_presets: + raise KeyError(f"model_preset {name!r} not found in model_presets") return self.model_presets[name] d = self.agents.defaults return ModelPresetConfig( @@ -301,12 +303,14 @@ class Config(BaseSettings): return Path(self.agents.defaults.workspace).expanduser() def _match_provider( - self, model: str | None = None + self, model: str | None = None, + *, + preset: ModelPresetConfig | None = None, ) -> tuple["ProviderConfig | None", str | None]: """Match provider config and its registry name. Returns (config, spec_name).""" from nanobot.providers.registry import PROVIDERS, find_by_name - resolved = self.resolve_preset() + resolved = preset or self.resolve_preset() forced = resolved.provider if forced != "auto": spec = find_by_name(forced) @@ -366,26 +370,46 @@ class Config(BaseSettings): return p, spec.name return None, None - def get_provider(self, model: str | None = None) -> ProviderConfig | None: + def get_provider( + self, + model: str | None = None, + *, + preset: ModelPresetConfig | None = None, + ) -> ProviderConfig | None: """Get matched provider config (api_key, api_base, extra_headers). Falls back to first available.""" - p, _ = self._match_provider(model) + p, _ = self._match_provider(model, preset=preset) return p - def get_provider_name(self, model: str | None = None) -> str | None: + def get_provider_name( + self, + model: str | None = None, + *, + preset: ModelPresetConfig | None = None, + ) -> str | None: """Get the registry name of the matched provider (e.g. "deepseek", "openrouter").""" - _, name = self._match_provider(model) + _, name = self._match_provider(model, preset=preset) return name - def get_api_key(self, model: str | None = None) -> str | None: + def get_api_key( + self, + model: str | None = None, + *, + preset: ModelPresetConfig | None = None, + ) -> str | None: """Get API key for the given model. Falls back to first available key.""" - p = self.get_provider(model) + p = self.get_provider(model, preset=preset) return p.api_key if p else None - def get_api_base(self, model: str | None = None) -> str | None: + def get_api_base( + self, + model: str | None = None, + *, + preset: ModelPresetConfig | None = None, + ) -> str | None: """Get API base URL for the given model, falling back to the provider default when present.""" from nanobot.providers.registry import find_by_name - p, name = self._match_provider(model) + p, name = self._match_provider(model, preset=preset) if p and p.api_base: return p.api_base if name: diff --git a/nanobot/providers/factory.py b/nanobot/providers/factory.py index 1257eb3a5..6422f047f 100644 --- a/nanobot/providers/factory.py +++ b/nanobot/providers/factory.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass from pathlib import Path -from nanobot.config.schema import Config +from nanobot.config.schema import Config, ModelPresetConfig from nanobot.providers.base import LLMProvider from nanobot.providers.registry import find_by_name @@ -18,12 +18,26 @@ class ProviderSnapshot: signature: tuple[object, ...] -def make_provider(config: Config) -> LLMProvider: +def _resolve_model_preset( + config: Config, + *, + preset_name: str | None = None, + preset: ModelPresetConfig | None = None, +) -> ModelPresetConfig: + return preset if preset is not None else config.resolve_preset(preset_name) + + +def make_provider( + config: Config, + *, + preset_name: str | None = None, + preset: ModelPresetConfig | None = None, +) -> LLMProvider: """Create the LLM provider implied by config.""" - resolved = config.resolve_preset() + resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) model = resolved.model - provider_name = config.get_provider_name(model) - p = config.get_provider(model) + provider_name = config.get_provider_name(model, preset=resolved) + p = config.get_provider(model, preset=resolved) spec = find_by_name(provider_name) if provider_name else None backend = spec.backend if spec else "openai_compat" @@ -57,7 +71,7 @@ def make_provider(config: Config) -> LLMProvider: provider = AnthropicProvider( api_key=p.api_key if p else None, - api_base=config.get_api_base(model), + api_base=config.get_api_base(model, preset=resolved), default_model=model, extra_headers=p.extra_headers if p else None, ) @@ -77,7 +91,7 @@ def make_provider(config: Config) -> LLMProvider: provider = OpenAICompatProvider( api_key=p.api_key if p else None, - api_base=config.get_api_base(model), + api_base=config.get_api_base(model, preset=resolved), default_model=model, extra_headers=p.extra_headers if p else None, spec=spec, @@ -88,16 +102,21 @@ def make_provider(config: Config) -> LLMProvider: return provider -def provider_signature(config: Config) -> tuple[object, ...]: +def provider_signature( + config: Config, + *, + preset_name: str | None = None, + preset: ModelPresetConfig | None = None, +) -> tuple[object, ...]: """Return the config fields that affect the primary LLM provider.""" - resolved = config.resolve_preset() - p = config.get_provider(resolved.model) + resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) + p = config.get_provider(resolved.model, preset=resolved) return ( resolved.model, resolved.provider, - config.get_provider_name(resolved.model), - config.get_api_key(resolved.model), - config.get_api_base(resolved.model), + config.get_provider_name(resolved.model, preset=resolved), + config.get_api_key(resolved.model, preset=resolved), + config.get_api_base(resolved.model, preset=resolved), p.extra_headers if p else None, p.extra_body if p else None, getattr(p, "region", None) if p else None, @@ -109,13 +128,18 @@ def provider_signature(config: Config) -> tuple[object, ...]: ) -def build_provider_snapshot(config: Config) -> ProviderSnapshot: - resolved = config.resolve_preset() +def build_provider_snapshot( + config: Config, + *, + preset_name: str | None = None, + preset: ModelPresetConfig | None = None, +) -> ProviderSnapshot: + resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) return ProviderSnapshot( - provider=make_provider(config), + provider=make_provider(config, preset=resolved), model=resolved.model, context_window_tokens=resolved.context_window_tokens, - signature=provider_signature(config), + signature=provider_signature(config, preset=resolved), ) diff --git a/tests/agent/test_self_model_preset.py b/tests/agent/test_self_model_preset.py index fa81ab8e6..b41b3581b 100644 --- a/tests/agent/test_self_model_preset.py +++ b/tests/agent/test_self_model_preset.py @@ -7,6 +7,7 @@ from nanobot.agent.loop import AgentLoop from nanobot.agent.tools.self import MyTool from nanobot.bus.queue import MessageBus from nanobot.config.schema import ModelPresetConfig +from nanobot.providers.factory import ProviderSnapshot def _provider(default_model: str, max_tokens: int = 123) -> MagicMock: @@ -56,6 +57,75 @@ def test_model_preset_setter_updates_state(tmp_path) -> None: assert loop.provider.generation.temperature == 0.5 assert loop.provider.generation.max_tokens == 4096 assert loop.provider.generation.reasoning_effort == "low" + assert loop.subagents.model == "openai/gpt-4.1" + assert loop.consolidator.model == "openai/gpt-4.1" + assert loop.consolidator.context_window_tokens == 32_768 + assert loop.consolidator.max_completion_tokens == 4096 + assert loop.dream.model == "openai/gpt-4.1" + + +def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None: + old_provider = _provider("base-model", max_tokens=123) + new_provider = _provider("anthropic/claude-opus-4-5", max_tokens=2048) + preset = ModelPresetConfig( + model="anthropic/claude-opus-4-5", + provider="anthropic", + max_tokens=2048, + context_window_tokens=200_000, + ) + loop = AgentLoop( + bus=MessageBus(), + provider=old_provider, + workspace=tmp_path, + model="base-model", + context_window_tokens=1000, + model_presets={"deep": preset}, + model_preset_snapshot_builder=lambda _preset: ProviderSnapshot( + provider=new_provider, + model=_preset.model, + context_window_tokens=_preset.context_window_tokens, + signature=("deep", _preset.model), + ), + ) + + loop.set_model_preset("deep") + + assert loop.provider is new_provider + assert loop.runner.provider is new_provider + assert loop.subagents.provider is new_provider + assert loop.subagents.runner.provider is new_provider + assert loop.consolidator.provider is new_provider + assert loop.dream.provider is new_provider + assert loop.dream._runner.provider is new_provider + assert loop.model == "anthropic/claude-opus-4-5" + assert loop.context_window_tokens == 200_000 + assert loop.consolidator.max_completion_tokens == 2048 + + +def test_model_preset_setter_failure_leaves_old_state(tmp_path) -> None: + preset = ModelPresetConfig(model="openai/gpt-4.1", max_tokens=4096) + loop = AgentLoop( + bus=MessageBus(), + provider=_provider("base-model", max_tokens=123), + workspace=tmp_path, + model="base-model", + context_window_tokens=1000, + model_presets={"fast": preset}, + model_preset_snapshot_builder=lambda _preset: (_ for _ in ()).throw( + RuntimeError("provider unavailable") + ), + ) + + with pytest.raises(RuntimeError, match="provider unavailable"): + loop.set_model_preset("fast") + + assert loop.model_preset is None + assert loop.model == "base-model" + assert loop.subagents.model == "base-model" + assert loop.consolidator.model == "base-model" + assert loop.dream.model == "base-model" + assert loop.context_window_tokens == 1000 + assert loop.consolidator.max_completion_tokens == 123 def test_model_preset_setter_raises_on_unknown(tmp_path) -> None: diff --git a/tests/command/test_model_command.py b/tests/command/test_model_command.py new file mode 100644 index 000000000..f81fb0226 --- /dev/null +++ b/tests/command/test_model_command.py @@ -0,0 +1,137 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from nanobot.agent.loop import AgentLoop +from nanobot.bus.events import InboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.command.builtin import ( + build_help_text, + builtin_command_palette, + cmd_model, + register_builtin_commands, +) +from nanobot.command.router import CommandContext, CommandRouter +from nanobot.config.schema import ModelPresetConfig + + +def _provider(default_model: str, max_tokens: int = 123) -> MagicMock: + provider = MagicMock() + provider.get_default_model.return_value = default_model + provider.generation = SimpleNamespace( + max_tokens=max_tokens, + temperature=0.1, + reasoning_effort=None, + ) + return provider + + +def _make_loop(tmp_path) -> AgentLoop: + return AgentLoop( + bus=MessageBus(), + provider=_provider("base-model", max_tokens=123), + workspace=tmp_path, + model="base-model", + context_window_tokens=1000, + model_presets={ + "default": ModelPresetConfig( + model="base-model", + max_tokens=123, + context_window_tokens=1000, + ), + "fast": ModelPresetConfig( + model="openai/gpt-4.1", + max_tokens=4096, + context_window_tokens=32_768, + ), + }, + ) + + +def _ctx(loop: AgentLoop, raw: str, args: str = "") -> CommandContext: + msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content=raw) + return CommandContext(msg=msg, session=None, key=msg.session_key, raw=raw, args=args, loop=loop) + + +@pytest.mark.asyncio +async def test_model_command_lists_current_and_available_presets(tmp_path) -> None: + loop = _make_loop(tmp_path) + + out = await cmd_model(_ctx(loop, "/model")) + + assert "Current model: `base-model`" in out.content + assert "Active preset: `(none)`" in out.content + assert "`default`" in out.content + assert "`fast`" in out.content + assert out.metadata == {"render_as": "text"} + + +@pytest.mark.asyncio +async def test_model_command_switches_preset(tmp_path) -> None: + loop = _make_loop(tmp_path) + + out = await cmd_model(_ctx(loop, "/model fast", args="fast")) + + assert "Switched model preset to `fast`." in out.content + assert "Model: `openai/gpt-4.1`" in out.content + assert loop.model_preset == "fast" + assert loop.model == "openai/gpt-4.1" + assert loop.subagents.model == "openai/gpt-4.1" + assert loop.consolidator.model == "openai/gpt-4.1" + assert loop.dream.model == "openai/gpt-4.1" + + +@pytest.mark.asyncio +async def test_model_command_switches_back_to_default(tmp_path) -> None: + loop = _make_loop(tmp_path) + loop.set_model_preset("fast") + + out = await cmd_model(_ctx(loop, "/model default", args="default")) + + assert "Switched model preset to `default`." in out.content + assert loop.model_preset == "default" + assert loop.model == "base-model" + assert loop.context_window_tokens == 1000 + + +@pytest.mark.asyncio +async def test_model_command_unknown_preset_keeps_old_state(tmp_path) -> None: + loop = _make_loop(tmp_path) + + out = await cmd_model(_ctx(loop, "/model missing", args="missing")) + + assert "Could not switch model preset" in out.content + assert "Available presets: `default`, `fast`" in out.content + assert loop.model_preset is None + assert loop.model == "base-model" + + +@pytest.mark.asyncio +async def test_model_command_does_not_depend_on_my_allow_set(tmp_path) -> None: + loop = _make_loop(tmp_path) + assert loop.tools_config.my.allow_set is False + + await cmd_model(_ctx(loop, "/model fast", args="fast")) + + assert loop.model_preset == "fast" + + +@pytest.mark.asyncio +async def test_model_command_registered_as_exact_and_prefix(tmp_path) -> None: + router = CommandRouter() + register_builtin_commands(router) + loop = _make_loop(tmp_path) + + out = await router.dispatch(_ctx(loop, "/model fast")) + + assert out is not None + assert "Switched model preset" in out.content + assert loop.model_preset == "fast" + + +def test_model_command_in_help_and_palette() -> None: + palette = builtin_command_palette() + + assert any(item["command"] == "/model" and item["arg_hint"] == "[preset]" for item in palette) + assert "/model [preset]" in build_help_text() diff --git a/tests/command/test_router_dispatchable.py b/tests/command/test_router_dispatchable.py index 3be684072..0157f2a90 100644 --- a/tests/command/test_router_dispatchable.py +++ b/tests/command/test_router_dispatchable.py @@ -22,6 +22,7 @@ class TestIsDispatchableCommand: def test_exact_commands_match(self, router: CommandRouter) -> None: assert router.is_dispatchable_command("/new") assert router.is_dispatchable_command("/help") + assert router.is_dispatchable_command("/model") assert router.is_dispatchable_command("/dream") assert router.is_dispatchable_command("/dream-log") assert router.is_dispatchable_command("/dream-restore") @@ -29,6 +30,7 @@ class TestIsDispatchableCommand: def test_prefix_commands_match(self, router: CommandRouter) -> None: assert router.is_dispatchable_command("/dream-log abc123") assert router.is_dispatchable_command("/dream-restore def456") + assert router.is_dispatchable_command("/model fast") def test_priority_commands_not_matched(self, router: CommandRouter) -> None: # Priority commands are NOT in the dispatchable tiers — they are diff --git a/tests/config/test_model_presets.py b/tests/config/test_model_presets.py index 44713acb6..581202b7b 100644 --- a/tests/config/test_model_presets.py +++ b/tests/config/test_model_presets.py @@ -1,4 +1,4 @@ -from nanobot.config.schema import Config, ModelPresetConfig +from nanobot.config.schema import Config def test_resolve_preset_returns_defaults_when_no_preset() -> None: @@ -39,6 +39,20 @@ def test_resolve_preset_returns_active_preset() -> None: assert resolved.reasoning_effort == "low" +def test_resolve_preset_can_target_named_preset_without_activating() -> None: + config = Config.model_validate({ + "model_presets": { + "fast": {"model": "openai/gpt-4.1", "provider": "openai"}, + "deep": {"model": "anthropic/claude-opus-4-5", "provider": "anthropic"}, + }, + "agents": {"defaults": {"modelPreset": "fast"}}, + }) + + resolved = config.resolve_preset("deep") + assert resolved.model == "anthropic/claude-opus-4-5" + assert resolved.provider == "anthropic" + + def test_validator_rejects_unknown_preset() -> None: import pytest with pytest.raises(ValueError, match="model_preset 'unknown' not found in model_presets"): @@ -51,6 +65,12 @@ def test_validator_rejects_unknown_preset() -> None: }) +def test_resolve_preset_rejects_unknown_named_preset() -> None: + import pytest + with pytest.raises(KeyError, match="model_preset 'missing' not found"): + Config().resolve_preset("missing") + + def test_match_provider_uses_preset_model() -> None: config = Config.model_validate({ "providers": {