refactor(agent): trim model preset runtime wiring

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Xubin Ren 2026-05-12 11:20:08 +00:00 committed by Xubin Ren
parent 70b8daaee6
commit 8fcb24bb7c
3 changed files with 27 additions and 62 deletions

View File

@ -289,11 +289,10 @@ class AgentLoop:
tools_config: ToolsConfig | None = None, tools_config: ToolsConfig | None = None,
image_generation_provider_config: ProviderConfig | None = None, image_generation_provider_config: ProviderConfig | None = None,
image_generation_provider_configs: dict[str, ProviderConfig] | None = None, image_generation_provider_configs: dict[str, ProviderConfig] | None = None,
provider_snapshot_loader: Callable[[], ProviderSnapshot] | None = None, provider_snapshot_loader: Callable[..., ProviderSnapshot] | None = None,
provider_signature: tuple[object, ...] | None = None, provider_signature: tuple[object, ...] | None = None,
model_presets: dict[str, ModelPresetConfig] | None = None, model_presets: dict[str, ModelPresetConfig] | None = None,
model_preset: str | None = None, model_preset: str | None = None,
model_preset_snapshot_builder: Callable[[str], ProviderSnapshot] | None = None,
): ):
from nanobot.config.schema import ToolsConfig from nanobot.config.schema import ToolsConfig
@ -304,10 +303,7 @@ class AgentLoop:
self.provider = provider self.provider = provider
self._provider_snapshot_loader = provider_snapshot_loader self._provider_snapshot_loader = provider_snapshot_loader
self._provider_signature = provider_signature self._provider_signature = provider_signature
self._config_provider_signature = provider_signature self._default_selection_signature = provider_signature[:2] if provider_signature else None
self._config_default_selection_signature = (
provider_signature[:2] if provider_signature is not None else None
)
self.workspace = workspace self.workspace = workspace
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.max_iterations = ( self.max_iterations = (
@ -403,7 +399,6 @@ class AgentLoop:
model=self.model, model=self.model,
) )
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {} self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
self._model_preset_snapshot_builder = model_preset_snapshot_builder
self._active_preset: str | None = None self._active_preset: str | None = None
if model_preset: if model_preset:
self.set_model_preset(model_preset, notify=False) self.set_model_preset(model_preset, notify=False)
@ -435,9 +430,8 @@ class AgentLoop:
resolved = config.resolve_preset() resolved = config.resolve_preset()
model = extra.pop("model", None) or resolved.model model = extra.pop("model", None) or resolved.model
context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens
model_preset_snapshot_builder = extra.pop("model_preset_snapshot_builder", None) provider_snapshot_loader = extra.pop("provider_snapshot_loader", None)
model_presets = dict(config.model_presets) model_presets = {**config.model_presets, "default": config.resolve_default_preset()}
model_presets["default"] = config.resolve_default_preset()
return cls( return cls(
bus=bus, bus=bus,
provider=provider, provider=provider,
@ -461,9 +455,8 @@ class AgentLoop:
tools_config=config.tools, tools_config=config.tools,
model_presets=model_presets, model_presets=model_presets,
model_preset=defaults.model_preset, model_preset=defaults.model_preset,
model_preset_snapshot_builder=( provider_snapshot_loader=provider_snapshot_loader or (
model_preset_snapshot_builder lambda preset_name=None: build_provider_snapshot(config, preset_name=preset_name)
or (lambda name: build_provider_snapshot(config, preset_name=name))
), ),
**extra, **extra,
) )
@ -475,14 +468,8 @@ class AgentLoop:
def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None: def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None:
"""Notify WebUI clients that the effective runtime model changed.""" """Notify WebUI clients that the effective runtime model changed."""
self.bus.outbound.put_nowait(OutboundMessage( self.bus.outbound.put_nowait(OutboundMessage(
channel="websocket", channel="websocket", chat_id="*", content="",
chat_id="*", metadata={"_runtime_model_updated": True, "model": self.model, "model_preset": model_preset if model_preset is not None else self.model_preset},
content="",
metadata={
"_runtime_model_updated": True,
"model": self.model,
"model_preset": model_preset if model_preset is not None else self.model_preset,
},
)) ))
def _apply_provider_snapshot( def _apply_provider_snapshot(
@ -517,36 +504,22 @@ class AgentLoop:
except Exception: except Exception:
logger.exception("Failed to refresh provider config") logger.exception("Failed to refresh provider config")
return return
if self._active_preset: default_selection = snapshot.signature[:2]
default_selection = snapshot.signature[:2] if self._active_preset and self._default_selection_signature in (None, default_selection):
if ( self._default_selection_signature = default_selection
self._config_default_selection_signature is not None
and default_selection != self._config_default_selection_signature
):
self._active_preset = None
self._config_provider_signature = snapshot.signature
self._config_default_selection_signature = default_selection
self._apply_provider_snapshot(snapshot)
return
self._config_provider_signature = snapshot.signature
self._config_default_selection_signature = default_selection
try: try:
snapshot = self._build_model_preset_snapshot(self._active_preset) snapshot = self._build_model_preset_snapshot(self._active_preset)
except Exception: except Exception:
logger.exception("Failed to refresh active model preset") logger.exception("Failed to refresh active model preset")
return return
if snapshot.signature == self._provider_signature: else:
return self._active_preset = None
self._apply_provider_snapshot(snapshot) self._default_selection_signature = default_selection
return
if snapshot.signature == self._provider_signature: if snapshot.signature == self._provider_signature:
return return
self._config_provider_signature = snapshot.signature self._default_selection_signature = snapshot.signature[:2]
self._config_default_selection_signature = snapshot.signature[:2]
self._apply_provider_snapshot(snapshot) self._apply_provider_snapshot(snapshot)
# -- model_preset property --
@property @property
def model_preset(self) -> str | None: def model_preset(self) -> str | None:
return self._active_preset return self._active_preset
@ -557,23 +530,14 @@ class AgentLoop:
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot: def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
preset = self.model_presets[name] preset = self.model_presets[name]
if self._model_preset_snapshot_builder is not None: if self._provider_snapshot_loader is not None:
return self._model_preset_snapshot_builder(name) return self._provider_snapshot_loader(preset_name=name)
self.provider.generation = preset.to_generation_settings() self.provider.generation = preset.to_generation_settings()
return ProviderSnapshot( return ProviderSnapshot(
provider=self.provider, provider=self.provider,
model=preset.model, model=preset.model,
context_window_tokens=preset.context_window_tokens, context_window_tokens=preset.context_window_tokens,
signature=( signature=("model_preset", name, preset.model_dump_json()),
"model_preset",
name,
preset.model,
preset.provider,
preset.max_tokens,
preset.context_window_tokens,
preset.temperature,
preset.reasoning_effort,
),
) )
def set_model_preset(self, name: str | None, *, notify: bool = True) -> None: def set_model_preset(self, name: str | None, *, notify: bool = True) -> None:

View File

@ -672,7 +672,6 @@ def _run_gateway(
"aihubmix": config.providers.aihubmix, "aihubmix": config.providers.aihubmix,
}, },
provider_snapshot_loader=load_provider_snapshot, provider_snapshot_loader=load_provider_snapshot,
model_preset_snapshot_builder=lambda name: load_provider_snapshot(preset_name=name),
provider_signature=provider_snapshot.signature, provider_signature=provider_snapshot.signature,
) )

View File

@ -104,11 +104,11 @@ def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
model="base-model", model="base-model",
context_window_tokens=1000, context_window_tokens=1000,
model_presets={"deep": preset}, model_presets={"deep": preset},
model_preset_snapshot_builder=lambda _name: ProviderSnapshot( provider_snapshot_loader=lambda preset_name=None: ProviderSnapshot(
provider=new_provider, provider=new_provider,
model=preset.model, model=preset.model,
context_window_tokens=preset.context_window_tokens, context_window_tokens=preset.context_window_tokens,
signature=("deep", preset.model), signature=(preset_name, preset.model),
), ),
) )
@ -135,7 +135,7 @@ def test_model_preset_setter_failure_leaves_old_state(tmp_path) -> None:
model="base-model", model="base-model",
context_window_tokens=1000, context_window_tokens=1000,
model_presets={"fast": preset}, model_presets={"fast": preset},
model_preset_snapshot_builder=lambda _name: (_ for _ in ()).throw( provider_snapshot_loader=lambda preset_name=None: (_ for _ in ()).throw(
RuntimeError("provider unavailable") RuntimeError("provider unavailable")
), ),
) )
@ -173,10 +173,11 @@ def test_active_model_preset_survives_unchanged_config_refresh(tmp_path) -> None
workspace=tmp_path, workspace=tmp_path,
model="base-model", model="base-model",
context_window_tokens=1000, context_window_tokens=1000,
provider_snapshot_loader=lambda: default_snapshot,
provider_signature=default_snapshot.signature, provider_signature=default_snapshot.signature,
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")}, model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
model_preset_snapshot_builder=lambda _name: fast_snapshot, provider_snapshot_loader=lambda preset_name=None: (
fast_snapshot if preset_name == "fast" else default_snapshot
),
) )
loop.set_model_preset("fast") loop.set_model_preset("fast")
@ -209,10 +210,11 @@ def test_config_model_refresh_clears_active_model_preset(tmp_path) -> None:
workspace=tmp_path, workspace=tmp_path,
model="base-model", model="base-model",
context_window_tokens=1000, context_window_tokens=1000,
provider_snapshot_loader=lambda: webui_snapshot, provider_snapshot_loader=lambda preset_name=None: (
fast_snapshot if preset_name == "fast" else webui_snapshot
),
provider_signature=("base-model", "auto", "openai", "sk-old"), provider_signature=("base-model", "auto", "openai", "sk-old"),
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")}, model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
model_preset_snapshot_builder=lambda _name: fast_snapshot,
) )
loop.set_model_preset("fast") loop.set_model_preset("fast")