diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index a40928741..daebb22d2 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -19,6 +19,7 @@ from nanobot.agent.autocompact import AutoCompact from nanobot.agent.context import ContextBuilder from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.memory import Consolidator, Dream +from nanobot.agent import model_presets as preset_helpers from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec from nanobot.agent.subagent import SubagentManager from nanobot.agent.tools.ask import ( @@ -293,7 +294,7 @@ class AgentLoop: provider_signature: tuple[object, ...] | None = None, model_presets: dict[str, ModelPresetConfig] | None = None, model_preset: str | None = None, - preset_snapshot_loader: Callable[[str], ProviderSnapshot] | None = None, + preset_snapshot_loader: preset_helpers.PresetSnapshotLoader | None = None, ): from nanobot.config.schema import ToolsConfig @@ -305,7 +306,7 @@ class AgentLoop: self._provider_snapshot_loader = provider_snapshot_loader self._preset_snapshot_loader = preset_snapshot_loader self._provider_signature = provider_signature - self._default_selection_signature = provider_signature[:2] if provider_signature else None + self._default_selection_signature = preset_helpers.default_selection_signature(provider_signature) self.workspace = workspace self.model = model or provider.get_default_model() self.max_iterations = ( @@ -423,7 +424,7 @@ class AgentLoop: allowing callers to override or extend the standard config-derived parameters (e.g. ``cron_service``, ``session_manager``). """ - from nanobot.providers.factory import build_provider_snapshot, make_provider + from nanobot.providers.factory import make_provider if bus is None: bus = MessageBus() @@ -433,16 +434,10 @@ class AgentLoop: model = extra.pop("model", None) or resolved.model context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens provider_snapshot_loader = extra.pop("provider_snapshot_loader", None) - preset_snapshot_loader = extra.pop("preset_snapshot_loader", None) - model_presets = {**config.model_presets, "default": config.resolve_default_preset()} - if preset_snapshot_loader is None: - if provider_snapshot_loader is not None: - preset_snapshot_loader = lambda name: provider_snapshot_loader(preset_name=name) - else: - preset_snapshot_loader = lambda name: build_provider_snapshot( - config, - preset_name=name, - ) + preset_snapshot_loader = extra.pop("preset_snapshot_loader", None) or preset_helpers.make_preset_snapshot_loader( + config, + provider_snapshot_loader, + ) return cls( bus=bus, provider=provider, @@ -464,7 +459,7 @@ class AgentLoop: consolidation_ratio=defaults.consolidation_ratio, max_messages=defaults.max_messages, tools_config=config.tools, - model_presets=model_presets, + model_presets=preset_helpers.configured_model_presets(config), model_preset=defaults.model_preset, provider_snapshot_loader=provider_snapshot_loader, preset_snapshot_loader=preset_snapshot_loader, @@ -475,19 +470,6 @@ class AgentLoop: """Keep subagent runtime limits aligned with mutable loop settings.""" self.subagents.max_iterations = self.max_iterations - def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None: - """Notify WebUI clients that the effective runtime model changed.""" - self.bus.outbound.put_nowait(OutboundMessage( - channel="websocket", - chat_id="*", - content="", - metadata={ - "_runtime_model_updated": True, - "model": self.model, - "model_preset": model_preset if model_preset is not None else self.model_preset, - }, - )) - def _apply_provider_snapshot( self, snapshot: ProviderSnapshot, @@ -509,7 +491,12 @@ class AgentLoop: self.dream.set_provider(provider, model) self._provider_signature = snapshot.signature if notify: - self._publish_runtime_model_updated(model_preset) + self.bus.outbound.put_nowait( + preset_helpers.runtime_model_updated_message( + self.model, + model_preset if model_preset is not None else self.model_preset, + ) + ) logger.info("Runtime model switched for next turn: {} -> {}", old_model, model) def _refresh_provider_snapshot(self) -> None: @@ -520,7 +507,7 @@ class AgentLoop: except Exception: logger.exception("Failed to refresh provider config") return - default_selection = snapshot.signature[:2] + default_selection = preset_helpers.default_selection_signature(snapshot.signature) if self._active_preset and self._default_selection_signature in (None, default_selection): self._default_selection_signature = default_selection try: @@ -533,7 +520,7 @@ class AgentLoop: self._default_selection_signature = default_selection if snapshot.signature == self._provider_signature: return - self._default_selection_signature = snapshot.signature[:2] + self._default_selection_signature = preset_helpers.default_selection_signature(snapshot.signature) self._apply_provider_snapshot(snapshot) @property @@ -545,24 +532,16 @@ class AgentLoop: self.set_model_preset(name) def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot: - preset = self.model_presets[name] - if self._preset_snapshot_loader is not None: - return self._preset_snapshot_loader(name) - self.provider.generation = preset.to_generation_settings() - return ProviderSnapshot( + return preset_helpers.build_runtime_preset_snapshot( + name=name, + presets=self.model_presets, provider=self.provider, - model=preset.model, - context_window_tokens=preset.context_window_tokens, - signature=("model_preset", name, preset.model_dump_json()), + loader=self._preset_snapshot_loader, ) def set_model_preset(self, name: str | None, *, notify: bool = True) -> None: """Resolve a preset by name and apply all runtime model dependents.""" - if not isinstance(name, str) or not name.strip(): - raise ValueError("model_preset must be a non-empty string") - name = name.strip() - if name not in self.model_presets: - raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}") + name = preset_helpers.normalize_preset_name(name, self.model_presets) snapshot = self._build_model_preset_snapshot(name) self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name) self._active_preset = name diff --git a/nanobot/agent/model_presets.py b/nanobot/agent/model_presets.py new file mode 100644 index 000000000..a95959857 --- /dev/null +++ b/nanobot/agent/model_presets.py @@ -0,0 +1,78 @@ +"""Helpers for runtime model preset selection.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from nanobot.bus.events import OutboundMessage +from nanobot.config.schema import ModelPresetConfig +from nanobot.providers.base import LLMProvider +from nanobot.providers.factory import ProviderSnapshot, build_provider_snapshot + +PresetSnapshotLoader = Callable[[str], ProviderSnapshot] + + +def default_selection_signature(signature: tuple[object, ...] | None) -> tuple[object, ...] | None: + return signature[:2] if signature else None + + +def configured_model_presets(config: Any) -> dict[str, ModelPresetConfig]: + return {**config.model_presets, "default": config.resolve_default_preset()} + + +def make_preset_snapshot_loader( + config: Any, + provider_snapshot_loader: Callable[..., ProviderSnapshot] | None, +) -> PresetSnapshotLoader: + if provider_snapshot_loader is not None: + return lambda name: provider_snapshot_loader(preset_name=name) + return lambda name: build_provider_snapshot(config, preset_name=name) + + +def build_static_preset_snapshot( + provider: LLMProvider, + name: str, + preset: ModelPresetConfig, +) -> ProviderSnapshot: + provider.generation = preset.to_generation_settings() + return ProviderSnapshot( + provider=provider, + model=preset.model, + context_window_tokens=preset.context_window_tokens, + signature=("model_preset", name, preset.model_dump_json()), + ) + + +def build_runtime_preset_snapshot( + *, + name: str, + presets: dict[str, ModelPresetConfig], + provider: LLMProvider, + loader: PresetSnapshotLoader | None, +) -> ProviderSnapshot: + if loader is not None: + return loader(name) + return build_static_preset_snapshot(provider, name, presets[name]) + + +def normalize_preset_name(name: str | None, presets: dict[str, ModelPresetConfig]) -> str: + if not isinstance(name, str) or not name.strip(): + raise ValueError("model_preset must be a non-empty string") + name = name.strip() + if name not in presets: + raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(presets) or '(none)'}") + return name + + +def runtime_model_updated_message(model: str, model_preset: str | None) -> OutboundMessage: + return OutboundMessage( + channel="websocket", + chat_id="*", + content="", + metadata={ + "_runtime_model_updated": True, + "model": model, + "model_preset": model_preset, + }, + )