refactor(agent): move preset helpers out of loop

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Xubin Ren 2026-05-12 11:37:38 +00:00 committed by Xubin Ren
parent e6103d9312
commit 6554c1f832
2 changed files with 100 additions and 43 deletions

View File

@ -19,6 +19,7 @@ from nanobot.agent.autocompact import AutoCompact
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.memory import Consolidator, Dream
from nanobot.agent import model_presets as preset_helpers
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.ask import (
@ -293,7 +294,7 @@ class AgentLoop:
provider_signature: tuple[object, ...] | None = None,
model_presets: dict[str, ModelPresetConfig] | None = None,
model_preset: str | None = None,
preset_snapshot_loader: Callable[[str], ProviderSnapshot] | None = None,
preset_snapshot_loader: preset_helpers.PresetSnapshotLoader | None = None,
):
from nanobot.config.schema import ToolsConfig
@ -305,7 +306,7 @@ class AgentLoop:
self._provider_snapshot_loader = provider_snapshot_loader
self._preset_snapshot_loader = preset_snapshot_loader
self._provider_signature = provider_signature
self._default_selection_signature = provider_signature[:2] if provider_signature else None
self._default_selection_signature = preset_helpers.default_selection_signature(provider_signature)
self.workspace = workspace
self.model = model or provider.get_default_model()
self.max_iterations = (
@ -423,7 +424,7 @@ class AgentLoop:
allowing callers to override or extend the standard config-derived
parameters (e.g. ``cron_service``, ``session_manager``).
"""
from nanobot.providers.factory import build_provider_snapshot, make_provider
from nanobot.providers.factory import make_provider
if bus is None:
bus = MessageBus()
@ -433,16 +434,10 @@ class AgentLoop:
model = extra.pop("model", None) or resolved.model
context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens
provider_snapshot_loader = extra.pop("provider_snapshot_loader", None)
preset_snapshot_loader = extra.pop("preset_snapshot_loader", None)
model_presets = {**config.model_presets, "default": config.resolve_default_preset()}
if preset_snapshot_loader is None:
if provider_snapshot_loader is not None:
preset_snapshot_loader = lambda name: provider_snapshot_loader(preset_name=name)
else:
preset_snapshot_loader = lambda name: build_provider_snapshot(
config,
preset_name=name,
)
preset_snapshot_loader = extra.pop("preset_snapshot_loader", None) or preset_helpers.make_preset_snapshot_loader(
config,
provider_snapshot_loader,
)
return cls(
bus=bus,
provider=provider,
@ -464,7 +459,7 @@ class AgentLoop:
consolidation_ratio=defaults.consolidation_ratio,
max_messages=defaults.max_messages,
tools_config=config.tools,
model_presets=model_presets,
model_presets=preset_helpers.configured_model_presets(config),
model_preset=defaults.model_preset,
provider_snapshot_loader=provider_snapshot_loader,
preset_snapshot_loader=preset_snapshot_loader,
@ -475,19 +470,6 @@ class AgentLoop:
"""Keep subagent runtime limits aligned with mutable loop settings."""
self.subagents.max_iterations = self.max_iterations
def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None:
"""Notify WebUI clients that the effective runtime model changed."""
self.bus.outbound.put_nowait(OutboundMessage(
channel="websocket",
chat_id="*",
content="",
metadata={
"_runtime_model_updated": True,
"model": self.model,
"model_preset": model_preset if model_preset is not None else self.model_preset,
},
))
def _apply_provider_snapshot(
self,
snapshot: ProviderSnapshot,
@ -509,7 +491,12 @@ class AgentLoop:
self.dream.set_provider(provider, model)
self._provider_signature = snapshot.signature
if notify:
self._publish_runtime_model_updated(model_preset)
self.bus.outbound.put_nowait(
preset_helpers.runtime_model_updated_message(
self.model,
model_preset if model_preset is not None else self.model_preset,
)
)
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
def _refresh_provider_snapshot(self) -> None:
@ -520,7 +507,7 @@ class AgentLoop:
except Exception:
logger.exception("Failed to refresh provider config")
return
default_selection = snapshot.signature[:2]
default_selection = preset_helpers.default_selection_signature(snapshot.signature)
if self._active_preset and self._default_selection_signature in (None, default_selection):
self._default_selection_signature = default_selection
try:
@ -533,7 +520,7 @@ class AgentLoop:
self._default_selection_signature = default_selection
if snapshot.signature == self._provider_signature:
return
self._default_selection_signature = snapshot.signature[:2]
self._default_selection_signature = preset_helpers.default_selection_signature(snapshot.signature)
self._apply_provider_snapshot(snapshot)
@property
@ -545,24 +532,16 @@ class AgentLoop:
self.set_model_preset(name)
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
preset = self.model_presets[name]
if self._preset_snapshot_loader is not None:
return self._preset_snapshot_loader(name)
self.provider.generation = preset.to_generation_settings()
return ProviderSnapshot(
return preset_helpers.build_runtime_preset_snapshot(
name=name,
presets=self.model_presets,
provider=self.provider,
model=preset.model,
context_window_tokens=preset.context_window_tokens,
signature=("model_preset", name, preset.model_dump_json()),
loader=self._preset_snapshot_loader,
)
def set_model_preset(self, name: str | None, *, notify: bool = True) -> None:
"""Resolve a preset by name and apply all runtime model dependents."""
if not isinstance(name, str) or not name.strip():
raise ValueError("model_preset must be a non-empty string")
name = name.strip()
if name not in self.model_presets:
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
name = preset_helpers.normalize_preset_name(name, self.model_presets)
snapshot = self._build_model_preset_snapshot(name)
self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name)
self._active_preset = name

View File

@ -0,0 +1,78 @@
"""Helpers for runtime model preset selection."""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from nanobot.bus.events import OutboundMessage
from nanobot.config.schema import ModelPresetConfig
from nanobot.providers.base import LLMProvider
from nanobot.providers.factory import ProviderSnapshot, build_provider_snapshot
PresetSnapshotLoader = Callable[[str], ProviderSnapshot]
def default_selection_signature(signature: tuple[object, ...] | None) -> tuple[object, ...] | None:
return signature[:2] if signature else None
def configured_model_presets(config: Any) -> dict[str, ModelPresetConfig]:
return {**config.model_presets, "default": config.resolve_default_preset()}
def make_preset_snapshot_loader(
config: Any,
provider_snapshot_loader: Callable[..., ProviderSnapshot] | None,
) -> PresetSnapshotLoader:
if provider_snapshot_loader is not None:
return lambda name: provider_snapshot_loader(preset_name=name)
return lambda name: build_provider_snapshot(config, preset_name=name)
def build_static_preset_snapshot(
provider: LLMProvider,
name: str,
preset: ModelPresetConfig,
) -> ProviderSnapshot:
provider.generation = preset.to_generation_settings()
return ProviderSnapshot(
provider=provider,
model=preset.model,
context_window_tokens=preset.context_window_tokens,
signature=("model_preset", name, preset.model_dump_json()),
)
def build_runtime_preset_snapshot(
*,
name: str,
presets: dict[str, ModelPresetConfig],
provider: LLMProvider,
loader: PresetSnapshotLoader | None,
) -> ProviderSnapshot:
if loader is not None:
return loader(name)
return build_static_preset_snapshot(provider, name, presets[name])
def normalize_preset_name(name: str | None, presets: dict[str, ModelPresetConfig]) -> str:
if not isinstance(name, str) or not name.strip():
raise ValueError("model_preset must be a non-empty string")
name = name.strip()
if name not in presets:
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(presets) or '(none)'}")
return name
def runtime_model_updated_message(model: str, model_preset: str | None) -> OutboundMessage:
return OutboundMessage(
channel="websocket",
chat_id="*",
content="",
metadata={
"_runtime_model_updated": True,
"model": model,
"model_preset": model_preset,
},
)