mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 08:32:25 +00:00
refactor(agent): move preset helpers out of loop
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
e6103d9312
commit
6554c1f832
@ -19,6 +19,7 @@ from nanobot.agent.autocompact import AutoCompact
|
|||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||||
from nanobot.agent.memory import Consolidator, Dream
|
from nanobot.agent.memory import Consolidator, Dream
|
||||||
|
from nanobot.agent import model_presets as preset_helpers
|
||||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
from nanobot.agent.tools.ask import (
|
from nanobot.agent.tools.ask import (
|
||||||
@ -293,7 +294,7 @@ class AgentLoop:
|
|||||||
provider_signature: tuple[object, ...] | None = None,
|
provider_signature: tuple[object, ...] | None = None,
|
||||||
model_presets: dict[str, ModelPresetConfig] | None = None,
|
model_presets: dict[str, ModelPresetConfig] | None = None,
|
||||||
model_preset: str | None = None,
|
model_preset: str | None = None,
|
||||||
preset_snapshot_loader: Callable[[str], ProviderSnapshot] | None = None,
|
preset_snapshot_loader: preset_helpers.PresetSnapshotLoader | None = None,
|
||||||
):
|
):
|
||||||
from nanobot.config.schema import ToolsConfig
|
from nanobot.config.schema import ToolsConfig
|
||||||
|
|
||||||
@ -305,7 +306,7 @@ class AgentLoop:
|
|||||||
self._provider_snapshot_loader = provider_snapshot_loader
|
self._provider_snapshot_loader = provider_snapshot_loader
|
||||||
self._preset_snapshot_loader = preset_snapshot_loader
|
self._preset_snapshot_loader = preset_snapshot_loader
|
||||||
self._provider_signature = provider_signature
|
self._provider_signature = provider_signature
|
||||||
self._default_selection_signature = provider_signature[:2] if provider_signature else None
|
self._default_selection_signature = preset_helpers.default_selection_signature(provider_signature)
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.max_iterations = (
|
self.max_iterations = (
|
||||||
@ -423,7 +424,7 @@ class AgentLoop:
|
|||||||
allowing callers to override or extend the standard config-derived
|
allowing callers to override or extend the standard config-derived
|
||||||
parameters (e.g. ``cron_service``, ``session_manager``).
|
parameters (e.g. ``cron_service``, ``session_manager``).
|
||||||
"""
|
"""
|
||||||
from nanobot.providers.factory import build_provider_snapshot, make_provider
|
from nanobot.providers.factory import make_provider
|
||||||
|
|
||||||
if bus is None:
|
if bus is None:
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
@ -433,15 +434,9 @@ class AgentLoop:
|
|||||||
model = extra.pop("model", None) or resolved.model
|
model = extra.pop("model", None) or resolved.model
|
||||||
context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens
|
context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens
|
||||||
provider_snapshot_loader = extra.pop("provider_snapshot_loader", None)
|
provider_snapshot_loader = extra.pop("provider_snapshot_loader", None)
|
||||||
preset_snapshot_loader = extra.pop("preset_snapshot_loader", None)
|
preset_snapshot_loader = extra.pop("preset_snapshot_loader", None) or preset_helpers.make_preset_snapshot_loader(
|
||||||
model_presets = {**config.model_presets, "default": config.resolve_default_preset()}
|
|
||||||
if preset_snapshot_loader is None:
|
|
||||||
if provider_snapshot_loader is not None:
|
|
||||||
preset_snapshot_loader = lambda name: provider_snapshot_loader(preset_name=name)
|
|
||||||
else:
|
|
||||||
preset_snapshot_loader = lambda name: build_provider_snapshot(
|
|
||||||
config,
|
config,
|
||||||
preset_name=name,
|
provider_snapshot_loader,
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
bus=bus,
|
bus=bus,
|
||||||
@ -464,7 +459,7 @@ class AgentLoop:
|
|||||||
consolidation_ratio=defaults.consolidation_ratio,
|
consolidation_ratio=defaults.consolidation_ratio,
|
||||||
max_messages=defaults.max_messages,
|
max_messages=defaults.max_messages,
|
||||||
tools_config=config.tools,
|
tools_config=config.tools,
|
||||||
model_presets=model_presets,
|
model_presets=preset_helpers.configured_model_presets(config),
|
||||||
model_preset=defaults.model_preset,
|
model_preset=defaults.model_preset,
|
||||||
provider_snapshot_loader=provider_snapshot_loader,
|
provider_snapshot_loader=provider_snapshot_loader,
|
||||||
preset_snapshot_loader=preset_snapshot_loader,
|
preset_snapshot_loader=preset_snapshot_loader,
|
||||||
@ -475,19 +470,6 @@ class AgentLoop:
|
|||||||
"""Keep subagent runtime limits aligned with mutable loop settings."""
|
"""Keep subagent runtime limits aligned with mutable loop settings."""
|
||||||
self.subagents.max_iterations = self.max_iterations
|
self.subagents.max_iterations = self.max_iterations
|
||||||
|
|
||||||
def _publish_runtime_model_updated(self, model_preset: str | None = None) -> None:
|
|
||||||
"""Notify WebUI clients that the effective runtime model changed."""
|
|
||||||
self.bus.outbound.put_nowait(OutboundMessage(
|
|
||||||
channel="websocket",
|
|
||||||
chat_id="*",
|
|
||||||
content="",
|
|
||||||
metadata={
|
|
||||||
"_runtime_model_updated": True,
|
|
||||||
"model": self.model,
|
|
||||||
"model_preset": model_preset if model_preset is not None else self.model_preset,
|
|
||||||
},
|
|
||||||
))
|
|
||||||
|
|
||||||
def _apply_provider_snapshot(
|
def _apply_provider_snapshot(
|
||||||
self,
|
self,
|
||||||
snapshot: ProviderSnapshot,
|
snapshot: ProviderSnapshot,
|
||||||
@ -509,7 +491,12 @@ class AgentLoop:
|
|||||||
self.dream.set_provider(provider, model)
|
self.dream.set_provider(provider, model)
|
||||||
self._provider_signature = snapshot.signature
|
self._provider_signature = snapshot.signature
|
||||||
if notify:
|
if notify:
|
||||||
self._publish_runtime_model_updated(model_preset)
|
self.bus.outbound.put_nowait(
|
||||||
|
preset_helpers.runtime_model_updated_message(
|
||||||
|
self.model,
|
||||||
|
model_preset if model_preset is not None else self.model_preset,
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
||||||
|
|
||||||
def _refresh_provider_snapshot(self) -> None:
|
def _refresh_provider_snapshot(self) -> None:
|
||||||
@ -520,7 +507,7 @@ class AgentLoop:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to refresh provider config")
|
logger.exception("Failed to refresh provider config")
|
||||||
return
|
return
|
||||||
default_selection = snapshot.signature[:2]
|
default_selection = preset_helpers.default_selection_signature(snapshot.signature)
|
||||||
if self._active_preset and self._default_selection_signature in (None, default_selection):
|
if self._active_preset and self._default_selection_signature in (None, default_selection):
|
||||||
self._default_selection_signature = default_selection
|
self._default_selection_signature = default_selection
|
||||||
try:
|
try:
|
||||||
@ -533,7 +520,7 @@ class AgentLoop:
|
|||||||
self._default_selection_signature = default_selection
|
self._default_selection_signature = default_selection
|
||||||
if snapshot.signature == self._provider_signature:
|
if snapshot.signature == self._provider_signature:
|
||||||
return
|
return
|
||||||
self._default_selection_signature = snapshot.signature[:2]
|
self._default_selection_signature = preset_helpers.default_selection_signature(snapshot.signature)
|
||||||
self._apply_provider_snapshot(snapshot)
|
self._apply_provider_snapshot(snapshot)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -545,24 +532,16 @@ class AgentLoop:
|
|||||||
self.set_model_preset(name)
|
self.set_model_preset(name)
|
||||||
|
|
||||||
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
|
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
|
||||||
preset = self.model_presets[name]
|
return preset_helpers.build_runtime_preset_snapshot(
|
||||||
if self._preset_snapshot_loader is not None:
|
name=name,
|
||||||
return self._preset_snapshot_loader(name)
|
presets=self.model_presets,
|
||||||
self.provider.generation = preset.to_generation_settings()
|
|
||||||
return ProviderSnapshot(
|
|
||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
model=preset.model,
|
loader=self._preset_snapshot_loader,
|
||||||
context_window_tokens=preset.context_window_tokens,
|
|
||||||
signature=("model_preset", name, preset.model_dump_json()),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_model_preset(self, name: str | None, *, notify: bool = True) -> None:
|
def set_model_preset(self, name: str | None, *, notify: bool = True) -> None:
|
||||||
"""Resolve a preset by name and apply all runtime model dependents."""
|
"""Resolve a preset by name and apply all runtime model dependents."""
|
||||||
if not isinstance(name, str) or not name.strip():
|
name = preset_helpers.normalize_preset_name(name, self.model_presets)
|
||||||
raise ValueError("model_preset must be a non-empty string")
|
|
||||||
name = name.strip()
|
|
||||||
if name not in self.model_presets:
|
|
||||||
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
|
|
||||||
snapshot = self._build_model_preset_snapshot(name)
|
snapshot = self._build_model_preset_snapshot(name)
|
||||||
self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name)
|
self._apply_provider_snapshot(snapshot, notify=notify, model_preset=name)
|
||||||
self._active_preset = name
|
self._active_preset = name
|
||||||
|
|||||||
78
nanobot/agent/model_presets.py
Normal file
78
nanobot/agent/model_presets.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
"""Helpers for runtime model preset selection."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
from nanobot.providers.base import LLMProvider
|
||||||
|
from nanobot.providers.factory import ProviderSnapshot, build_provider_snapshot
|
||||||
|
|
||||||
|
PresetSnapshotLoader = Callable[[str], ProviderSnapshot]
|
||||||
|
|
||||||
|
|
||||||
|
def default_selection_signature(signature: tuple[object, ...] | None) -> tuple[object, ...] | None:
|
||||||
|
return signature[:2] if signature else None
|
||||||
|
|
||||||
|
|
||||||
|
def configured_model_presets(config: Any) -> dict[str, ModelPresetConfig]:
|
||||||
|
return {**config.model_presets, "default": config.resolve_default_preset()}
|
||||||
|
|
||||||
|
|
||||||
|
def make_preset_snapshot_loader(
|
||||||
|
config: Any,
|
||||||
|
provider_snapshot_loader: Callable[..., ProviderSnapshot] | None,
|
||||||
|
) -> PresetSnapshotLoader:
|
||||||
|
if provider_snapshot_loader is not None:
|
||||||
|
return lambda name: provider_snapshot_loader(preset_name=name)
|
||||||
|
return lambda name: build_provider_snapshot(config, preset_name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def build_static_preset_snapshot(
|
||||||
|
provider: LLMProvider,
|
||||||
|
name: str,
|
||||||
|
preset: ModelPresetConfig,
|
||||||
|
) -> ProviderSnapshot:
|
||||||
|
provider.generation = preset.to_generation_settings()
|
||||||
|
return ProviderSnapshot(
|
||||||
|
provider=provider,
|
||||||
|
model=preset.model,
|
||||||
|
context_window_tokens=preset.context_window_tokens,
|
||||||
|
signature=("model_preset", name, preset.model_dump_json()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_runtime_preset_snapshot(
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
presets: dict[str, ModelPresetConfig],
|
||||||
|
provider: LLMProvider,
|
||||||
|
loader: PresetSnapshotLoader | None,
|
||||||
|
) -> ProviderSnapshot:
|
||||||
|
if loader is not None:
|
||||||
|
return loader(name)
|
||||||
|
return build_static_preset_snapshot(provider, name, presets[name])
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_preset_name(name: str | None, presets: dict[str, ModelPresetConfig]) -> str:
|
||||||
|
if not isinstance(name, str) or not name.strip():
|
||||||
|
raise ValueError("model_preset must be a non-empty string")
|
||||||
|
name = name.strip()
|
||||||
|
if name not in presets:
|
||||||
|
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(presets) or '(none)'}")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def runtime_model_updated_message(model: str, model_preset: str | None) -> OutboundMessage:
|
||||||
|
return OutboundMessage(
|
||||||
|
channel="websocket",
|
||||||
|
chat_id="*",
|
||||||
|
content="",
|
||||||
|
metadata={
|
||||||
|
"_runtime_model_updated": True,
|
||||||
|
"model": model,
|
||||||
|
"model_preset": model_preset,
|
||||||
|
},
|
||||||
|
)
|
||||||
Loading…
x
Reference in New Issue
Block a user