diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index e7753df51..86d4684b0 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -289,11 +289,10 @@ class AgentLoop: tools_config: ToolsConfig | None = None, image_generation_provider_config: ProviderConfig | None = None, image_generation_provider_configs: dict[str, ProviderConfig] | None = None, - provider_snapshot_loader: Callable[[], ProviderSnapshot] | None = None, + provider_snapshot_loader: Callable[..., ProviderSnapshot] | None = None, provider_signature: tuple[object, ...] | None = None, model_presets: dict[str, ModelPresetConfig] | None = None, model_preset: str | None = None, - model_preset_snapshot_builder: Callable[[str], ProviderSnapshot] | None = None, ): from nanobot.config.schema import ToolsConfig @@ -304,10 +303,7 @@ class AgentLoop: self.provider = provider self._provider_snapshot_loader = provider_snapshot_loader self._provider_signature = provider_signature - self._config_provider_signature = provider_signature - self._config_default_selection_signature = ( - provider_signature[:2] if provider_signature is not None else None - ) + self._default_selection_signature = provider_signature[:2] if provider_signature else None self.workspace = workspace self.model = model or provider.get_default_model() self.max_iterations = ( @@ -403,7 +399,6 @@ class AgentLoop: model=self.model, ) self.model_presets: dict[str, ModelPresetConfig] = model_presets or {} - self._model_preset_snapshot_builder = model_preset_snapshot_builder self._active_preset: str | None = None if model_preset: self.set_model_preset(model_preset, notify=False) @@ -435,9 +430,8 @@ class AgentLoop: resolved = config.resolve_preset() model = extra.pop("model", None) or resolved.model context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens - model_preset_snapshot_builder = extra.pop("model_preset_snapshot_builder", None) - model_presets = dict(config.model_presets) - model_presets["default"] = config.resolve_default_preset() + provider_snapshot_loader = extra.pop("provider_snapshot_loader", None) + model_presets = {**config.model_presets, "default": config.resolve_default_preset()} return cls( bus=bus, provider=provider, @@ -461,9 +455,8 @@ class AgentLoop: tools_config=config.tools, model_presets=model_presets, model_preset=defaults.model_preset, - model_preset_snapshot_builder=( - model_preset_snapshot_builder - or (lambda name: build_provider_snapshot(config, preset_name=name)) + provider_snapshot_loader=provider_snapshot_loader or ( + lambda preset_name=None: build_provider_snapshot(config, preset_name=preset_name) ), **extra, ) @@ -475,14 +468,8 @@ class AgentLoop: def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None: """Notify WebUI clients that the effective runtime model changed.""" self.bus.outbound.put_nowait(OutboundMessage( - channel="websocket", - chat_id="*", - content="", - metadata={ - "_runtime_model_updated": True, - "model": self.model, - "model_preset": model_preset if model_preset is not None else self.model_preset, - }, + channel="websocket", chat_id="*", content="", + metadata={"_runtime_model_updated": True, "model": self.model, "model_preset": model_preset if model_preset is not None else self.model_preset}, )) def _apply_provider_snapshot( @@ -517,36 +504,22 @@ class AgentLoop: except Exception: logger.exception("Failed to refresh provider config") return - if self._active_preset: - default_selection = snapshot.signature[:2] - if ( - self._config_default_selection_signature is not None - and default_selection != self._config_default_selection_signature - ): - self._active_preset = None - self._config_provider_signature = snapshot.signature - self._config_default_selection_signature = default_selection - self._apply_provider_snapshot(snapshot) - return - self._config_provider_signature = snapshot.signature - self._config_default_selection_signature = default_selection + default_selection = snapshot.signature[:2] + if self._active_preset and self._default_selection_signature in (None, default_selection): + self._default_selection_signature = default_selection try: snapshot = self._build_model_preset_snapshot(self._active_preset) except Exception: logger.exception("Failed to refresh active model preset") return - if snapshot.signature == self._provider_signature: - return - self._apply_provider_snapshot(snapshot) - return + else: + self._active_preset = None + self._default_selection_signature = default_selection if snapshot.signature == self._provider_signature: return - self._config_provider_signature = snapshot.signature - self._config_default_selection_signature = snapshot.signature[:2] + self._default_selection_signature = snapshot.signature[:2] self._apply_provider_snapshot(snapshot) - # -- model_preset property -- - @property def model_preset(self) -> str | None: return self._active_preset @@ -557,23 +530,14 @@ class AgentLoop: def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot: preset = self.model_presets[name] - if self._model_preset_snapshot_builder is not None: - return self._model_preset_snapshot_builder(name) + if self._provider_snapshot_loader is not None: + return self._provider_snapshot_loader(preset_name=name) self.provider.generation = preset.to_generation_settings() return ProviderSnapshot( provider=self.provider, model=preset.model, context_window_tokens=preset.context_window_tokens, - signature=( - "model_preset", - name, - preset.model, - preset.provider, - preset.max_tokens, - preset.context_window_tokens, - preset.temperature, - preset.reasoning_effort, - ), + signature=("model_preset", name, preset.model_dump_json()), ) def set_model_preset(self, name: str | None, *, notify: bool = True) -> None: diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 48f800cf1..da829f62e 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -672,7 +672,6 @@ def _run_gateway( "aihubmix": config.providers.aihubmix, }, provider_snapshot_loader=load_provider_snapshot, - model_preset_snapshot_builder=lambda name: load_provider_snapshot(preset_name=name), provider_signature=provider_snapshot.signature, ) diff --git a/tests/agent/test_self_model_preset.py b/tests/agent/test_self_model_preset.py index a996d75f2..587e6359c 100644 --- a/tests/agent/test_self_model_preset.py +++ b/tests/agent/test_self_model_preset.py @@ -104,11 +104,11 @@ def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None: model="base-model", context_window_tokens=1000, model_presets={"deep": preset}, - model_preset_snapshot_builder=lambda _name: ProviderSnapshot( + provider_snapshot_loader=lambda preset_name=None: ProviderSnapshot( provider=new_provider, model=preset.model, context_window_tokens=preset.context_window_tokens, - signature=("deep", preset.model), + signature=(preset_name, preset.model), ), ) @@ -135,7 +135,7 @@ def test_model_preset_setter_failure_leaves_old_state(tmp_path) -> None: model="base-model", context_window_tokens=1000, model_presets={"fast": preset}, - model_preset_snapshot_builder=lambda _name: (_ for _ in ()).throw( + provider_snapshot_loader=lambda preset_name=None: (_ for _ in ()).throw( RuntimeError("provider unavailable") ), ) @@ -173,10 +173,11 @@ def test_active_model_preset_survives_unchanged_config_refresh(tmp_path) -> None workspace=tmp_path, model="base-model", context_window_tokens=1000, - provider_snapshot_loader=lambda: default_snapshot, provider_signature=default_snapshot.signature, model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")}, - model_preset_snapshot_builder=lambda _name: fast_snapshot, + provider_snapshot_loader=lambda preset_name=None: ( + fast_snapshot if preset_name == "fast" else default_snapshot + ), ) loop.set_model_preset("fast") @@ -209,10 +210,11 @@ def test_config_model_refresh_clears_active_model_preset(tmp_path) -> None: workspace=tmp_path, model="base-model", context_window_tokens=1000, - provider_snapshot_loader=lambda: webui_snapshot, + provider_snapshot_loader=lambda preset_name=None: ( + fast_snapshot if preset_name == "fast" else webui_snapshot + ), provider_signature=("base-model", "auto", "openai", "sk-old"), model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")}, - model_preset_snapshot_builder=lambda _name: fast_snapshot, ) loop.set_model_preset("fast")