mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
refactor(agent): move preset helpers out of loop
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
e6103d9312
commit
6554c1f832
@ -19,6 +19,7 @@ from nanobot.agent.autocompact import AutoCompact
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.memory import Consolidator, Dream
|
||||
from nanobot.agent import model_presets as preset_helpers
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.ask import (
|
||||
@ -293,7 +294,7 @@ class AgentLoop:
|
||||
provider_signature: tuple[object, ...] | None = None,
|
||||
model_presets: dict[str, ModelPresetConfig] | None = None,
|
||||
model_preset: str | None = None,
|
||||
preset_snapshot_loader: Callable[[str], ProviderSnapshot] | None = None,
|
||||
preset_snapshot_loader: preset_helpers.PresetSnapshotLoader | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ToolsConfig
|
||||
|
||||
@ -305,7 +306,7 @@ class AgentLoop:
|
||||
self._provider_snapshot_loader = provider_snapshot_loader
|
||||
self._preset_snapshot_loader = preset_snapshot_loader
|
||||
self._provider_signature = provider_signature
|
||||
self._default_selection_signature = provider_signature[:2] if provider_signature else None
|
||||
self._default_selection_signature = preset_helpers.default_selection_signature(provider_signature)
|
||||
self.workspace = workspace
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = (
|
||||
@ -423,7 +424,7 @@ class AgentLoop:
|
||||
allowing callers to override or extend the standard config-derived
|
||||
parameters (e.g. ``cron_service``, ``session_manager``).
|
||||
"""
|
||||
from nanobot.providers.factory import build_provider_snapshot, make_provider
|
||||
from nanobot.providers.factory import make_provider
|
||||
|
||||
if bus is None:
|
||||
bus = MessageBus()
|
||||
@ -433,16 +434,10 @@ class AgentLoop:
|
||||
model = extra.pop("model", None) or resolved.model
|
||||
context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens
|
||||
provider_snapshot_loader = extra.pop("provider_snapshot_loader", None)
|
||||
preset_snapshot_loader = extra.pop("preset_snapshot_loader", None)
|
||||
model_presets = {**config.model_presets, "default": config.resolve_default_preset()}
|
||||
if preset_snapshot_loader is None:
|
||||
if provider_snapshot_loader is not None:
|
||||
preset_snapshot_loader = lambda name: provider_snapshot_loader(preset_name=name)
|
||||
else:
|
||||
preset_snapshot_loader = lambda name: build_provider_snapshot(
|
||||
config,
|
||||
preset_name=name,
|
||||
)
|
||||
preset_snapshot_loader = extra.pop("preset_snapshot_loader", None) or preset_helpers.make_preset_snapshot_loader(
|
||||
config,
|
||||
provider_snapshot_loader,
|
||||
)
|
||||
return cls(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
@ -464,7 +459,7 @@ class AgentLoop:
|
||||
consolidation_ratio=defaults.consolidation_ratio,
|
||||
max_messages=defaults.max_messages,
|
||||
tools_config=config.tools,
|
||||
model_presets=model_presets,
|
||||
model_presets=preset_helpers.configured_model_presets(config),
|
||||
model_preset=defaults.model_preset,
|
||||
provider_snapshot_loader=provider_snapshot_loader,
|
||||
preset_snapshot_loader=preset_snapshot_loader,
|
||||
@ -475,19 +470,6 @@ class AgentLoop:
|
||||
"""Keep subagent runtime limits aligned with mutable loop settings."""
|
||||
self.subagents.max_iterations = self.max_iterations
|
||||
|
||||
def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None:
|
||||
"""Notify WebUI clients that the effective runtime model changed."""
|
||||
self.bus.outbound.put_nowait(OutboundMessage(
|
||||
channel="websocket",
|
||||
chat_id="*",
|
||||
content="",
|
||||
metadata={
|
||||
"_runtime_model_updated": True,
|
||||
"model": self.model,
|
||||
"model_preset": model_preset if model_preset is not None else self.model_preset,
|
||||
},
|
||||
))
|
||||
|
||||
def _apply_provider_snapshot(
|
||||
self,
|
||||
snapshot: ProviderSnapshot,
|
||||
@ -509,7 +491,12 @@ class AgentLoop:
|
||||
self.dream.set_provider(provider, model)
|
||||
self._provider_signature = snapshot.signature
|
||||
if notify:
|
||||
self._publish_runtime_model_updated(model_preset)
|
||||
self.bus.outbound.put_nowait(
|
||||
preset_helpers.runtime_model_updated_message(
|
||||
self.model,
|
||||
model_preset if model_preset is not None else self.model_preset,
|
||||
)
|
||||
)
|
||||
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
||||
|
||||
def _refresh_provider_snapshot(self) -> None:
|
||||
@ -520,7 +507,7 @@ class AgentLoop:
|
||||
except Exception:
|
||||
logger.exception("Failed to refresh provider config")
|
||||
return
|
||||
default_selection = snapshot.signature[:2]
|
||||
default_selection = preset_helpers.default_selection_signature(snapshot.signature)
|
||||
if self._active_preset and self._default_selection_signature in (None, default_selection):
|
||||
self._default_selection_signature = default_selection
|
||||
try:
|
||||
@ -533,7 +520,7 @@ class AgentLoop:
|
||||
self._default_selection_signature = default_selection
|
||||
if snapshot.signature == self._provider_signature:
|
||||
return
|
||||
self._default_selection_signature = snapshot.signature[:2]
|
||||
self._default_selection_signature = preset_helpers.default_selection_signature(snapshot.signature)
|
||||
self._apply_provider_snapshot(snapshot)
|
||||
|
||||
@property
|
||||
@ -545,24 +532,16 @@ class AgentLoop:
|
||||
self.set_model_preset(name)
|
||||
|
||||
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
|
||||
preset = self.model_presets[name]
|
||||
if self._preset_snapshot_loader is not None:
|
||||
return self._preset_snapshot_loader(name)
|
||||
self.provider.generation = preset.to_generation_settings()
|
||||
return ProviderSnapshot(
|
||||
return preset_helpers.build_runtime_preset_snapshot(
|
||||
name=name,
|
||||
presets=self.model_presets,
|
||||
provider=self.provider,
|
||||
model=preset.model,
|
||||
context_window_tokens=preset.context_window_tokens,
|
||||
signature=("model_preset", name, preset.model_dump_json()),
|
||||
loader=self._preset_snapshot_loader,
|
||||
)
|
||||
|
||||
def set_model_preset(self, name: str | None, *, notify: bool = True) -> None:
|
||||
"""Resolve a preset by name and apply all runtime model dependents."""
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ValueError("model_preset must be a non-empty string")
|
||||
name = name.strip()
|
||||
if name not in self.model_presets:
|
||||
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
|
||||
name = preset_helpers.normalize_preset_name(name, self.model_presets)
|
||||
snapshot = self._build_model_preset_snapshot(name)
|
||||
self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name)
|
||||
self._active_preset = name
|
||||
|
||||
78
nanobot/agent/model_presets.py
Normal file
78
nanobot/agent/model_presets.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""Helpers for runtime model preset selection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.config.schema import ModelPresetConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.providers.factory import ProviderSnapshot, build_provider_snapshot
|
||||
|
||||
PresetSnapshotLoader = Callable[[str], ProviderSnapshot]
|
||||
|
||||
|
||||
def default_selection_signature(signature: tuple[object, ...] | None) -> tuple[object, ...] | None:
|
||||
return signature[:2] if signature else None
|
||||
|
||||
|
||||
def configured_model_presets(config: Any) -> dict[str, ModelPresetConfig]:
|
||||
return {**config.model_presets, "default": config.resolve_default_preset()}
|
||||
|
||||
|
||||
def make_preset_snapshot_loader(
|
||||
config: Any,
|
||||
provider_snapshot_loader: Callable[..., ProviderSnapshot] | None,
|
||||
) -> PresetSnapshotLoader:
|
||||
if provider_snapshot_loader is not None:
|
||||
return lambda name: provider_snapshot_loader(preset_name=name)
|
||||
return lambda name: build_provider_snapshot(config, preset_name=name)
|
||||
|
||||
|
||||
def build_static_preset_snapshot(
|
||||
provider: LLMProvider,
|
||||
name: str,
|
||||
preset: ModelPresetConfig,
|
||||
) -> ProviderSnapshot:
|
||||
provider.generation = preset.to_generation_settings()
|
||||
return ProviderSnapshot(
|
||||
provider=provider,
|
||||
model=preset.model,
|
||||
context_window_tokens=preset.context_window_tokens,
|
||||
signature=("model_preset", name, preset.model_dump_json()),
|
||||
)
|
||||
|
||||
|
||||
def build_runtime_preset_snapshot(
|
||||
*,
|
||||
name: str,
|
||||
presets: dict[str, ModelPresetConfig],
|
||||
provider: LLMProvider,
|
||||
loader: PresetSnapshotLoader | None,
|
||||
) -> ProviderSnapshot:
|
||||
if loader is not None:
|
||||
return loader(name)
|
||||
return build_static_preset_snapshot(provider, name, presets[name])
|
||||
|
||||
|
||||
def normalize_preset_name(name: str | None, presets: dict[str, ModelPresetConfig]) -> str:
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ValueError("model_preset must be a non-empty string")
|
||||
name = name.strip()
|
||||
if name not in presets:
|
||||
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(presets) or '(none)'}")
|
||||
return name
|
||||
|
||||
|
||||
def runtime_model_updated_message(model: str, model_preset: str | None) -> OutboundMessage:
|
||||
return OutboundMessage(
|
||||
channel="websocket",
|
||||
chat_id="*",
|
||||
content="",
|
||||
metadata={
|
||||
"_runtime_model_updated": True,
|
||||
"model": model,
|
||||
"model_preset": model_preset,
|
||||
},
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user