refactor(agent): move preset helpers out of loop

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Xubin Ren 2026-05-12 11:37:38 +00:00 committed by Xubin Ren
parent e6103d9312
commit 6554c1f832
2 changed files with 100 additions and 43 deletions

View File

@ -19,6 +19,7 @@ from nanobot.agent.autocompact import AutoCompact
from nanobot.agent.context import ContextBuilder from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.memory import Consolidator, Dream from nanobot.agent.memory import Consolidator, Dream
from nanobot.agent import model_presets as preset_helpers
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.ask import ( from nanobot.agent.tools.ask import (
@ -293,7 +294,7 @@ class AgentLoop:
provider_signature: tuple[object, ...] | None = None, provider_signature: tuple[object, ...] | None = None,
model_presets: dict[str, ModelPresetConfig] | None = None, model_presets: dict[str, ModelPresetConfig] | None = None,
model_preset: str | None = None, model_preset: str | None = None,
preset_snapshot_loader: Callable[[str], ProviderSnapshot] | None = None, preset_snapshot_loader: preset_helpers.PresetSnapshotLoader | None = None,
): ):
from nanobot.config.schema import ToolsConfig from nanobot.config.schema import ToolsConfig
@ -305,7 +306,7 @@ class AgentLoop:
self._provider_snapshot_loader = provider_snapshot_loader self._provider_snapshot_loader = provider_snapshot_loader
self._preset_snapshot_loader = preset_snapshot_loader self._preset_snapshot_loader = preset_snapshot_loader
self._provider_signature = provider_signature self._provider_signature = provider_signature
self._default_selection_signature = provider_signature[:2] if provider_signature else None self._default_selection_signature = preset_helpers.default_selection_signature(provider_signature)
self.workspace = workspace self.workspace = workspace
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.max_iterations = ( self.max_iterations = (
@ -423,7 +424,7 @@ class AgentLoop:
allowing callers to override or extend the standard config-derived allowing callers to override or extend the standard config-derived
parameters (e.g. ``cron_service``, ``session_manager``). parameters (e.g. ``cron_service``, ``session_manager``).
""" """
from nanobot.providers.factory import build_provider_snapshot, make_provider from nanobot.providers.factory import make_provider
if bus is None: if bus is None:
bus = MessageBus() bus = MessageBus()
@ -433,15 +434,9 @@ class AgentLoop:
model = extra.pop("model", None) or resolved.model model = extra.pop("model", None) or resolved.model
context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens
provider_snapshot_loader = extra.pop("provider_snapshot_loader", None) provider_snapshot_loader = extra.pop("provider_snapshot_loader", None)
preset_snapshot_loader = extra.pop("preset_snapshot_loader", None) preset_snapshot_loader = extra.pop("preset_snapshot_loader", None) or preset_helpers.make_preset_snapshot_loader(
model_presets = {**config.model_presets, "default": config.resolve_default_preset()}
if preset_snapshot_loader is None:
if provider_snapshot_loader is not None:
preset_snapshot_loader = lambda name: provider_snapshot_loader(preset_name=name)
else:
preset_snapshot_loader = lambda name: build_provider_snapshot(
config, config,
preset_name=name, provider_snapshot_loader,
) )
return cls( return cls(
bus=bus, bus=bus,
@ -464,7 +459,7 @@ class AgentLoop:
consolidation_ratio=defaults.consolidation_ratio, consolidation_ratio=defaults.consolidation_ratio,
max_messages=defaults.max_messages, max_messages=defaults.max_messages,
tools_config=config.tools, tools_config=config.tools,
model_presets=model_presets, model_presets=preset_helpers.configured_model_presets(config),
model_preset=defaults.model_preset, model_preset=defaults.model_preset,
provider_snapshot_loader=provider_snapshot_loader, provider_snapshot_loader=provider_snapshot_loader,
preset_snapshot_loader=preset_snapshot_loader, preset_snapshot_loader=preset_snapshot_loader,
@ -475,19 +470,6 @@ class AgentLoop:
"""Keep subagent runtime limits aligned with mutable loop settings.""" """Keep subagent runtime limits aligned with mutable loop settings."""
self.subagents.max_iterations = self.max_iterations self.subagents.max_iterations = self.max_iterations
def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None:
"""Notify WebUI clients that the effective runtime model changed."""
self.bus.outbound.put_nowait(OutboundMessage(
channel="websocket",
chat_id="*",
content="",
metadata={
"_runtime_model_updated": True,
"model": self.model,
"model_preset": model_preset if model_preset is not None else self.model_preset,
},
))
def _apply_provider_snapshot( def _apply_provider_snapshot(
self, self,
snapshot: ProviderSnapshot, snapshot: ProviderSnapshot,
@ -509,7 +491,12 @@ class AgentLoop:
self.dream.set_provider(provider, model) self.dream.set_provider(provider, model)
self._provider_signature = snapshot.signature self._provider_signature = snapshot.signature
if notify: if notify:
self._publish_runtime_model_updated(model_preset) self.bus.outbound.put_nowait(
preset_helpers.runtime_model_updated_message(
self.model,
model_preset if model_preset is not None else self.model_preset,
)
)
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model) logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
def _refresh_provider_snapshot(self) -> None: def _refresh_provider_snapshot(self) -> None:
@ -520,7 +507,7 @@ class AgentLoop:
except Exception: except Exception:
logger.exception("Failed to refresh provider config") logger.exception("Failed to refresh provider config")
return return
default_selection = snapshot.signature[:2] default_selection = preset_helpers.default_selection_signature(snapshot.signature)
if self._active_preset and self._default_selection_signature in (None, default_selection): if self._active_preset and self._default_selection_signature in (None, default_selection):
self._default_selection_signature = default_selection self._default_selection_signature = default_selection
try: try:
@ -533,7 +520,7 @@ class AgentLoop:
self._default_selection_signature = default_selection self._default_selection_signature = default_selection
if snapshot.signature == self._provider_signature: if snapshot.signature == self._provider_signature:
return return
self._default_selection_signature = snapshot.signature[:2] self._default_selection_signature = preset_helpers.default_selection_signature(snapshot.signature)
self._apply_provider_snapshot(snapshot) self._apply_provider_snapshot(snapshot)
@property @property
@ -545,24 +532,16 @@ class AgentLoop:
self.set_model_preset(name) self.set_model_preset(name)
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot: def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
preset = self.model_presets[name] return preset_helpers.build_runtime_preset_snapshot(
if self._preset_snapshot_loader is not None: name=name,
return self._preset_snapshot_loader(name) presets=self.model_presets,
self.provider.generation = preset.to_generation_settings()
return ProviderSnapshot(
provider=self.provider, provider=self.provider,
model=preset.model, loader=self._preset_snapshot_loader,
context_window_tokens=preset.context_window_tokens,
signature=("model_preset", name, preset.model_dump_json()),
) )
def set_model_preset(self, name: str | None, *, notify: bool = True) -> None: def set_model_preset(self, name: str | None, *, notify: bool = True) -> None:
"""Resolve a preset by name and apply all runtime model dependents.""" """Resolve a preset by name and apply all runtime model dependents."""
if not isinstance(name, str) or not name.strip(): name = preset_helpers.normalize_preset_name(name, self.model_presets)
raise ValueError("model_preset must be a non-empty string")
name = name.strip()
if name not in self.model_presets:
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
snapshot = self._build_model_preset_snapshot(name) snapshot = self._build_model_preset_snapshot(name)
self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name) self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name)
self._active_preset = name self._active_preset = name

View File

@ -0,0 +1,78 @@
"""Helpers for runtime model preset selection."""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from nanobot.bus.events import OutboundMessage
from nanobot.config.schema import ModelPresetConfig
from nanobot.providers.base import LLMProvider
from nanobot.providers.factory import ProviderSnapshot, build_provider_snapshot
PresetSnapshotLoader = Callable[[str], ProviderSnapshot]
def default_selection_signature(signature: tuple[object, ...] | None) -> tuple[object, ...] | None:
return signature[:2] if signature else None
def configured_model_presets(config: Any) -> dict[str, ModelPresetConfig]:
return {**config.model_presets, "default": config.resolve_default_preset()}
def make_preset_snapshot_loader(
config: Any,
provider_snapshot_loader: Callable[..., ProviderSnapshot] | None,
) -> PresetSnapshotLoader:
if provider_snapshot_loader is not None:
return lambda name: provider_snapshot_loader(preset_name=name)
return lambda name: build_provider_snapshot(config, preset_name=name)
def build_static_preset_snapshot(
provider: LLMProvider,
name: str,
preset: ModelPresetConfig,
) -> ProviderSnapshot:
provider.generation = preset.to_generation_settings()
return ProviderSnapshot(
provider=provider,
model=preset.model,
context_window_tokens=preset.context_window_tokens,
signature=("model_preset", name, preset.model_dump_json()),
)
def build_runtime_preset_snapshot(
*,
name: str,
presets: dict[str, ModelPresetConfig],
provider: LLMProvider,
loader: PresetSnapshotLoader | None,
) -> ProviderSnapshot:
if loader is not None:
return loader(name)
return build_static_preset_snapshot(provider, name, presets[name])
def normalize_preset_name(name: str | None, presets: dict[str, ModelPresetConfig]) -> str:
if not isinstance(name, str) or not name.strip():
raise ValueError("model_preset must be a non-empty string")
name = name.strip()
if name not in presets:
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(presets) or '(none)'}")
return name
def runtime_model_updated_message(model: str, model_preset: str | None) -> OutboundMessage:
return OutboundMessage(
channel="websocket",
chat_id="*",
content="",
metadata={
"_runtime_model_updated": True,
"model": model,
"model_preset": model_preset,
},
)