diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index d83c8bd41..e44cf1c2e 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -293,7 +293,7 @@ class AgentLoop: provider_signature: tuple[object, ...] | None = None, model_presets: dict[str, ModelPresetConfig] | None = None, model_preset: str | None = None, - model_preset_snapshot_builder: Callable[[ModelPresetConfig], ProviderSnapshot] | None = None, + model_preset_snapshot_builder: Callable[[str], ProviderSnapshot] | None = None, ): from nanobot.config.schema import ToolsConfig @@ -304,6 +304,10 @@ class AgentLoop: self.provider = provider self._provider_snapshot_loader = provider_snapshot_loader self._provider_signature = provider_signature + self._config_provider_signature = provider_signature + self._config_default_selection_signature = ( + provider_signature[:2] if provider_signature is not None else None + ) self.workspace = workspace self.model = model or provider.get_default_model() self.max_iterations = ( @@ -431,6 +435,7 @@ class AgentLoop: resolved = config.resolve_preset() model = extra.pop("model", None) or resolved.model context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens + model_preset_snapshot_builder = extra.pop("model_preset_snapshot_builder", None) model_presets = dict(config.model_presets) if "default" not in model_presets: model_presets["default"] = resolved @@ -457,7 +462,10 @@ class AgentLoop: tools_config=config.tools, model_presets=model_presets, model_preset=defaults.model_preset, - model_preset_snapshot_builder=lambda preset: build_provider_snapshot(config, preset=preset), + model_preset_snapshot_builder=( + model_preset_snapshot_builder + or (lambda name: build_provider_snapshot(config, preset_name=name)) + ), **extra, ) @@ -489,8 +497,32 @@ class AgentLoop: except Exception: logger.exception("Failed to refresh provider config") return + if self._active_preset: + default_selection = snapshot.signature[:2] + if ( + self._config_default_selection_signature is not None + and default_selection != self._config_default_selection_signature + ): + self._active_preset = None + self._config_provider_signature = snapshot.signature + self._config_default_selection_signature = default_selection + self._apply_provider_snapshot(snapshot) + return + self._config_provider_signature = snapshot.signature + self._config_default_selection_signature = default_selection + try: + snapshot = self._build_model_preset_snapshot(self._active_preset) + except Exception: + logger.exception("Failed to refresh active model preset") + return + if snapshot.signature == self._provider_signature: + return + self._apply_provider_snapshot(snapshot) + return if snapshot.signature == self._provider_signature: return + self._config_provider_signature = snapshot.signature + self._config_default_selection_signature = snapshot.signature[:2] self._apply_provider_snapshot(snapshot) # -- model_preset property -- @@ -506,7 +538,7 @@ class AgentLoop: def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot: preset = self.model_presets[name] if self._model_preset_snapshot_builder is not None: - return self._model_preset_snapshot_builder(preset) + return self._model_preset_snapshot_builder(name) self.provider.generation = preset.to_generation_settings() return ProviderSnapshot( provider=self.provider, diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index da829f62e..48f800cf1 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -672,6 +672,7 @@ def _run_gateway( "aihubmix": config.providers.aihubmix, }, provider_snapshot_loader=load_provider_snapshot, + model_preset_snapshot_builder=lambda name: load_provider_snapshot(preset_name=name), provider_signature=provider_snapshot.signature, ) diff --git a/nanobot/providers/factory.py b/nanobot/providers/factory.py index 6422f047f..3473afff3 100644 --- a/nanobot/providers/factory.py +++ b/nanobot/providers/factory.py @@ -143,7 +143,14 @@ def build_provider_snapshot( ) -def load_provider_snapshot(config_path: Path | None = None) -> ProviderSnapshot: +def load_provider_snapshot( + config_path: Path | None = None, + *, + preset_name: str | None = None, +) -> ProviderSnapshot: from nanobot.config.loader import load_config, resolve_config_env_vars - return build_provider_snapshot(resolve_config_env_vars(load_config(config_path))) + return build_provider_snapshot( + resolve_config_env_vars(load_config(config_path)), + preset_name=preset_name, + ) diff --git a/tests/agent/test_self_model_preset.py b/tests/agent/test_self_model_preset.py index b41b3581b..45fa0db36 100644 --- a/tests/agent/test_self_model_preset.py +++ b/tests/agent/test_self_model_preset.py @@ -80,11 +80,11 @@ def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None: model="base-model", context_window_tokens=1000, model_presets={"deep": preset}, - model_preset_snapshot_builder=lambda _preset: ProviderSnapshot( + model_preset_snapshot_builder=lambda _name: ProviderSnapshot( provider=new_provider, - model=_preset.model, - context_window_tokens=_preset.context_window_tokens, - signature=("deep", _preset.model), + model=preset.model, + context_window_tokens=preset.context_window_tokens, + signature=("deep", preset.model), ), ) @@ -111,7 +111,7 @@ def test_model_preset_setter_failure_leaves_old_state(tmp_path) -> None: model="base-model", context_window_tokens=1000, model_presets={"fast": preset}, - model_preset_snapshot_builder=lambda _preset: (_ for _ in ()).throw( + model_preset_snapshot_builder=lambda _name: (_ for _ in ()).throw( RuntimeError("provider unavailable") ), ) @@ -128,6 +128,78 @@ def test_model_preset_setter_failure_leaves_old_state(tmp_path) -> None: assert loop.consolidator.max_completion_tokens == 123 +def test_active_model_preset_survives_unchanged_config_refresh(tmp_path) -> None: + base_provider = _provider("base-model", max_tokens=123) + fast_provider = _provider("openai/gpt-4.1", max_tokens=4096) + default_snapshot = ProviderSnapshot( + provider=base_provider, + model="base-model", + context_window_tokens=1000, + signature=("base-model", "auto", "openai", "sk-old"), + ) + fast_snapshot = ProviderSnapshot( + provider=fast_provider, + model="openai/gpt-4.1", + context_window_tokens=32_768, + signature=("openai/gpt-4.1", "auto", "openai", "sk-old"), + ) + loop = AgentLoop( + bus=MessageBus(), + provider=base_provider, + workspace=tmp_path, + model="base-model", + context_window_tokens=1000, + provider_snapshot_loader=lambda: default_snapshot, + provider_signature=default_snapshot.signature, + model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")}, + model_preset_snapshot_builder=lambda _name: fast_snapshot, + ) + + loop.set_model_preset("fast") + loop._refresh_provider_snapshot() + + assert loop.model_preset == "fast" + assert loop.provider is fast_provider + assert loop.model == "openai/gpt-4.1" + + +def test_config_model_refresh_clears_active_model_preset(tmp_path) -> None: + base_provider = _provider("base-model", max_tokens=123) + fast_provider = _provider("openai/gpt-4.1", max_tokens=4096) + webui_provider = _provider("anthropic/claude-opus-4-5", max_tokens=2048) + webui_snapshot = ProviderSnapshot( + provider=webui_provider, + model="anthropic/claude-opus-4-5", + context_window_tokens=200_000, + signature=("anthropic/claude-opus-4-5", "anthropic", "anthropic", "sk-old"), + ) + fast_snapshot = ProviderSnapshot( + provider=fast_provider, + model="openai/gpt-4.1", + context_window_tokens=32_768, + signature=("openai/gpt-4.1", "auto", "openai", "sk-old"), + ) + loop = AgentLoop( + bus=MessageBus(), + provider=base_provider, + workspace=tmp_path, + model="base-model", + context_window_tokens=1000, + provider_snapshot_loader=lambda: webui_snapshot, + provider_signature=("base-model", "auto", "openai", "sk-old"), + model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")}, + model_preset_snapshot_builder=lambda _name: fast_snapshot, + ) + + loop.set_model_preset("fast") + loop._refresh_provider_snapshot() + + assert loop.model_preset is None + assert loop.provider is webui_provider + assert loop.model == "anthropic/claude-opus-4-5" + assert loop.context_window_tokens == 200_000 + + def test_model_preset_setter_raises_on_unknown(tmp_path) -> None: loop = _make_loop(tmp_path) with pytest.raises(KeyError, match="model_preset 'missing' not found"):