mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 16:42:25 +00:00
fix(config): make model preset switching atomic
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
6f78267c82
commit
c450d6fd3f
@ -293,6 +293,7 @@ class AgentLoop:
|
|||||||
provider_signature: tuple[object, ...] | None = None,
|
provider_signature: tuple[object, ...] | None = None,
|
||||||
model_presets: dict[str, ModelPresetConfig] | None = None,
|
model_presets: dict[str, ModelPresetConfig] | None = None,
|
||||||
model_preset: str | None = None,
|
model_preset: str | None = None,
|
||||||
|
model_preset_snapshot_builder: Callable[[ModelPresetConfig], ProviderSnapshot] | None = None,
|
||||||
):
|
):
|
||||||
from nanobot.config.schema import ToolsConfig
|
from nanobot.config.schema import ToolsConfig
|
||||||
|
|
||||||
@ -398,7 +399,10 @@ class AgentLoop:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
)
|
)
|
||||||
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
|
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
|
||||||
self._active_preset: str | None = model_preset if model_presets and model_preset in model_presets else None
|
self._model_preset_snapshot_builder = model_preset_snapshot_builder
|
||||||
|
self._active_preset: str | None = None
|
||||||
|
if model_preset:
|
||||||
|
self.set_model_preset(model_preset)
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
self._runtime_vars: dict[str, Any] = {}
|
self._runtime_vars: dict[str, Any] = {}
|
||||||
self._current_iteration: int = 0
|
self._current_iteration: int = 0
|
||||||
@ -418,7 +422,7 @@ class AgentLoop:
|
|||||||
allowing callers to override or extend the standard config-derived
|
allowing callers to override or extend the standard config-derived
|
||||||
parameters (e.g. ``cron_service``, ``session_manager``).
|
parameters (e.g. ``cron_service``, ``session_manager``).
|
||||||
"""
|
"""
|
||||||
from nanobot.providers.factory import make_provider
|
from nanobot.providers.factory import build_provider_snapshot, make_provider
|
||||||
|
|
||||||
if bus is None:
|
if bus is None:
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
@ -453,6 +457,7 @@ class AgentLoop:
|
|||||||
tools_config=config.tools,
|
tools_config=config.tools,
|
||||||
model_presets=model_presets,
|
model_presets=model_presets,
|
||||||
model_preset=defaults.model_preset,
|
model_preset=defaults.model_preset,
|
||||||
|
model_preset_snapshot_builder=lambda preset: build_provider_snapshot(config, preset=preset),
|
||||||
**extra,
|
**extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -465,8 +470,6 @@ class AgentLoop:
|
|||||||
provider = snapshot.provider
|
provider = snapshot.provider
|
||||||
model = snapshot.model
|
model = snapshot.model
|
||||||
context_window_tokens = snapshot.context_window_tokens
|
context_window_tokens = snapshot.context_window_tokens
|
||||||
if self.provider is provider and self.model == model:
|
|
||||||
return
|
|
||||||
old_model = self.model
|
old_model = self.model
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -498,15 +501,38 @@ class AgentLoop:
|
|||||||
|
|
||||||
@model_preset.setter
|
@model_preset.setter
|
||||||
def model_preset(self, name: str | None) -> None:
|
def model_preset(self, name: str | None) -> None:
|
||||||
"""Resolve a preset by name and apply all fields atomically."""
|
self.set_model_preset(name)
|
||||||
|
|
||||||
|
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
|
||||||
|
preset = self.model_presets[name]
|
||||||
|
if self._model_preset_snapshot_builder is not None:
|
||||||
|
return self._model_preset_snapshot_builder(preset)
|
||||||
|
self.provider.generation = preset.to_generation_settings()
|
||||||
|
return ProviderSnapshot(
|
||||||
|
provider=self.provider,
|
||||||
|
model=preset.model,
|
||||||
|
context_window_tokens=preset.context_window_tokens,
|
||||||
|
signature=(
|
||||||
|
"model_preset",
|
||||||
|
name,
|
||||||
|
preset.model,
|
||||||
|
preset.provider,
|
||||||
|
preset.max_tokens,
|
||||||
|
preset.context_window_tokens,
|
||||||
|
preset.temperature,
|
||||||
|
preset.reasoning_effort,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_model_preset(self, name: str | None) -> None:
|
||||||
|
"""Resolve a preset by name and apply all runtime model dependents."""
|
||||||
if not isinstance(name, str) or not name.strip():
|
if not isinstance(name, str) or not name.strip():
|
||||||
raise ValueError("model_preset must be a non-empty string")
|
raise ValueError("model_preset must be a non-empty string")
|
||||||
|
name = name.strip()
|
||||||
if name not in self.model_presets:
|
if name not in self.model_presets:
|
||||||
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
|
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
|
||||||
p = self.model_presets[name]
|
snapshot = self._build_model_preset_snapshot(name)
|
||||||
self.model = p.model
|
self._apply_provider_snapshot(snapshot)
|
||||||
self.context_window_tokens = p.context_window_tokens
|
|
||||||
self.provider.generation = p.to_generation_settings()
|
|
||||||
self._active_preset = name
|
self._active_preset = name
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
|
|||||||
@ -58,6 +58,13 @@ BUILTIN_COMMAND_SPECS: tuple[BuiltinCommandSpec, ...] = (
|
|||||||
"Display runtime, provider, and channel status.",
|
"Display runtime, provider, and channel status.",
|
||||||
"activity",
|
"activity",
|
||||||
),
|
),
|
||||||
|
BuiltinCommandSpec(
|
||||||
|
"/model",
|
||||||
|
"Switch model preset",
|
||||||
|
"Show or switch the active model preset.",
|
||||||
|
"brain",
|
||||||
|
"[preset]",
|
||||||
|
),
|
||||||
BuiltinCommandSpec(
|
BuiltinCommandSpec(
|
||||||
"/history",
|
"/history",
|
||||||
"Show conversation history",
|
"Show conversation history",
|
||||||
@ -192,6 +199,75 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_preset_names(names: list[str]) -> str:
|
||||||
|
return ", ".join(f"`{name}`" for name in names) if names else "(none configured)"
|
||||||
|
|
||||||
|
|
||||||
|
def _model_command_status(loop) -> str:
|
||||||
|
names = sorted(loop.model_presets)
|
||||||
|
active = loop.model_preset or "(none)"
|
||||||
|
return "\n".join([
|
||||||
|
"## Model",
|
||||||
|
f"- Current model: `{loop.model}`",
|
||||||
|
f"- Active preset: `{active}`",
|
||||||
|
f"- Available presets: {_format_preset_names(names)}",
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
async def cmd_model(ctx: CommandContext) -> OutboundMessage:
|
||||||
|
"""Show or switch model presets."""
|
||||||
|
loop = ctx.loop
|
||||||
|
args = ctx.args.strip()
|
||||||
|
metadata = {**dict(ctx.msg.metadata or {}), "render_as": "text"}
|
||||||
|
|
||||||
|
if not args:
|
||||||
|
return OutboundMessage(
|
||||||
|
channel=ctx.msg.channel,
|
||||||
|
chat_id=ctx.msg.chat_id,
|
||||||
|
content=_model_command_status(loop),
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
parts = args.split()
|
||||||
|
if len(parts) != 1:
|
||||||
|
return OutboundMessage(
|
||||||
|
channel=ctx.msg.channel,
|
||||||
|
chat_id=ctx.msg.chat_id,
|
||||||
|
content="Usage: `/model [preset]`",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
name = parts[0]
|
||||||
|
try:
|
||||||
|
loop.set_model_preset(name)
|
||||||
|
except (KeyError, ValueError) as exc:
|
||||||
|
names = sorted(loop.model_presets)
|
||||||
|
return OutboundMessage(
|
||||||
|
channel=ctx.msg.channel,
|
||||||
|
chat_id=ctx.msg.chat_id,
|
||||||
|
content=(
|
||||||
|
f"Could not switch model preset: {exc}\n\n"
|
||||||
|
f"Available presets: {_format_preset_names(names)}"
|
||||||
|
),
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_tokens = getattr(getattr(loop.provider, "generation", None), "max_tokens", None)
|
||||||
|
lines = [
|
||||||
|
f"Switched model preset to `{loop.model_preset}`.",
|
||||||
|
f"- Model: `{loop.model}`",
|
||||||
|
f"- Context window: {loop.context_window_tokens}",
|
||||||
|
]
|
||||||
|
if max_tokens is not None:
|
||||||
|
lines.append(f"- Max output tokens: {max_tokens}")
|
||||||
|
return OutboundMessage(
|
||||||
|
channel=ctx.msg.channel,
|
||||||
|
chat_id=ctx.msg.chat_id,
|
||||||
|
content="\n".join(lines),
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def cmd_dream(ctx: CommandContext) -> OutboundMessage:
|
async def cmd_dream(ctx: CommandContext) -> OutboundMessage:
|
||||||
"""Manually trigger a Dream consolidation run."""
|
"""Manually trigger a Dream consolidation run."""
|
||||||
import time
|
import time
|
||||||
@ -477,6 +553,8 @@ def register_builtin_commands(router: CommandRouter) -> None:
|
|||||||
router.priority("/status", cmd_status)
|
router.priority("/status", cmd_status)
|
||||||
router.exact("/new", cmd_new)
|
router.exact("/new", cmd_new)
|
||||||
router.exact("/status", cmd_status)
|
router.exact("/status", cmd_status)
|
||||||
|
router.exact("/model", cmd_model)
|
||||||
|
router.prefix("/model ", cmd_model)
|
||||||
router.exact("/history", cmd_history)
|
router.exact("/history", cmd_history)
|
||||||
router.prefix("/history ", cmd_history)
|
router.prefix("/history ", cmd_history)
|
||||||
router.exact("/dream", cmd_dream)
|
router.exact("/dream", cmd_dream)
|
||||||
|
|||||||
@ -283,10 +283,12 @@ class Config(BaseSettings):
|
|||||||
raise ValueError(f"model_preset {name!r} not found in model_presets")
|
raise ValueError(f"model_preset {name!r} not found in model_presets")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def resolve_preset(self) -> ModelPresetConfig:
|
def resolve_preset(self, name: str | None = None) -> ModelPresetConfig:
|
||||||
"""Return effective model params: from active preset, or individual defaults."""
|
"""Return effective model params: from active preset, or individual defaults."""
|
||||||
name = self.agents.defaults.model_preset
|
name = self.agents.defaults.model_preset if name is None else name
|
||||||
if name:
|
if name:
|
||||||
|
if name not in self.model_presets:
|
||||||
|
raise KeyError(f"model_preset {name!r} not found in model_presets")
|
||||||
return self.model_presets[name]
|
return self.model_presets[name]
|
||||||
d = self.agents.defaults
|
d = self.agents.defaults
|
||||||
return ModelPresetConfig(
|
return ModelPresetConfig(
|
||||||
@ -301,12 +303,14 @@ class Config(BaseSettings):
|
|||||||
return Path(self.agents.defaults.workspace).expanduser()
|
return Path(self.agents.defaults.workspace).expanduser()
|
||||||
|
|
||||||
def _match_provider(
|
def _match_provider(
|
||||||
self, model: str | None = None
|
self, model: str | None = None,
|
||||||
|
*,
|
||||||
|
preset: ModelPresetConfig | None = None,
|
||||||
) -> tuple["ProviderConfig | None", str | None]:
|
) -> tuple["ProviderConfig | None", str | None]:
|
||||||
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
||||||
from nanobot.providers.registry import PROVIDERS, find_by_name
|
from nanobot.providers.registry import PROVIDERS, find_by_name
|
||||||
|
|
||||||
resolved = self.resolve_preset()
|
resolved = preset or self.resolve_preset()
|
||||||
forced = resolved.provider
|
forced = resolved.provider
|
||||||
if forced != "auto":
|
if forced != "auto":
|
||||||
spec = find_by_name(forced)
|
spec = find_by_name(forced)
|
||||||
@ -366,26 +370,46 @@ class Config(BaseSettings):
|
|||||||
return p, spec.name
|
return p, spec.name
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def get_provider(self, model: str | None = None) -> ProviderConfig | None:
|
def get_provider(
|
||||||
|
self,
|
||||||
|
model: str | None = None,
|
||||||
|
*,
|
||||||
|
preset: ModelPresetConfig | None = None,
|
||||||
|
) -> ProviderConfig | None:
|
||||||
"""Get matched provider config (api_key, api_base, extra_headers). Falls back to first available."""
|
"""Get matched provider config (api_key, api_base, extra_headers). Falls back to first available."""
|
||||||
p, _ = self._match_provider(model)
|
p, _ = self._match_provider(model, preset=preset)
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def get_provider_name(self, model: str | None = None) -> str | None:
|
def get_provider_name(
|
||||||
|
self,
|
||||||
|
model: str | None = None,
|
||||||
|
*,
|
||||||
|
preset: ModelPresetConfig | None = None,
|
||||||
|
) -> str | None:
|
||||||
"""Get the registry name of the matched provider (e.g. "deepseek", "openrouter")."""
|
"""Get the registry name of the matched provider (e.g. "deepseek", "openrouter")."""
|
||||||
_, name = self._match_provider(model)
|
_, name = self._match_provider(model, preset=preset)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def get_api_key(self, model: str | None = None) -> str | None:
|
def get_api_key(
|
||||||
|
self,
|
||||||
|
model: str | None = None,
|
||||||
|
*,
|
||||||
|
preset: ModelPresetConfig | None = None,
|
||||||
|
) -> str | None:
|
||||||
"""Get API key for the given model. Falls back to first available key."""
|
"""Get API key for the given model. Falls back to first available key."""
|
||||||
p = self.get_provider(model)
|
p = self.get_provider(model, preset=preset)
|
||||||
return p.api_key if p else None
|
return p.api_key if p else None
|
||||||
|
|
||||||
def get_api_base(self, model: str | None = None) -> str | None:
|
def get_api_base(
|
||||||
|
self,
|
||||||
|
model: str | None = None,
|
||||||
|
*,
|
||||||
|
preset: ModelPresetConfig | None = None,
|
||||||
|
) -> str | None:
|
||||||
"""Get API base URL for the given model, falling back to the provider default when present."""
|
"""Get API base URL for the given model, falling back to the provider default when present."""
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
p, name = self._match_provider(model)
|
p, name = self._match_provider(model, preset=preset)
|
||||||
if p and p.api_base:
|
if p and p.api_base:
|
||||||
return p.api_base
|
return p.api_base
|
||||||
if name:
|
if name:
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config, ModelPresetConfig
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
@ -18,12 +18,26 @@ class ProviderSnapshot:
|
|||||||
signature: tuple[object, ...]
|
signature: tuple[object, ...]
|
||||||
|
|
||||||
|
|
||||||
def make_provider(config: Config) -> LLMProvider:
|
def _resolve_model_preset(
|
||||||
|
config: Config,
|
||||||
|
*,
|
||||||
|
preset_name: str | None = None,
|
||||||
|
preset: ModelPresetConfig | None = None,
|
||||||
|
) -> ModelPresetConfig:
|
||||||
|
return preset if preset is not None else config.resolve_preset(preset_name)
|
||||||
|
|
||||||
|
|
||||||
|
def make_provider(
|
||||||
|
config: Config,
|
||||||
|
*,
|
||||||
|
preset_name: str | None = None,
|
||||||
|
preset: ModelPresetConfig | None = None,
|
||||||
|
) -> LLMProvider:
|
||||||
"""Create the LLM provider implied by config."""
|
"""Create the LLM provider implied by config."""
|
||||||
resolved = config.resolve_preset()
|
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
||||||
model = resolved.model
|
model = resolved.model
|
||||||
provider_name = config.get_provider_name(model)
|
provider_name = config.get_provider_name(model, preset=resolved)
|
||||||
p = config.get_provider(model)
|
p = config.get_provider(model, preset=resolved)
|
||||||
spec = find_by_name(provider_name) if provider_name else None
|
spec = find_by_name(provider_name) if provider_name else None
|
||||||
backend = spec.backend if spec else "openai_compat"
|
backend = spec.backend if spec else "openai_compat"
|
||||||
|
|
||||||
@ -57,7 +71,7 @@ def make_provider(config: Config) -> LLMProvider:
|
|||||||
|
|
||||||
provider = AnthropicProvider(
|
provider = AnthropicProvider(
|
||||||
api_key=p.api_key if p else None,
|
api_key=p.api_key if p else None,
|
||||||
api_base=config.get_api_base(model),
|
api_base=config.get_api_base(model, preset=resolved),
|
||||||
default_model=model,
|
default_model=model,
|
||||||
extra_headers=p.extra_headers if p else None,
|
extra_headers=p.extra_headers if p else None,
|
||||||
)
|
)
|
||||||
@ -77,7 +91,7 @@ def make_provider(config: Config) -> LLMProvider:
|
|||||||
|
|
||||||
provider = OpenAICompatProvider(
|
provider = OpenAICompatProvider(
|
||||||
api_key=p.api_key if p else None,
|
api_key=p.api_key if p else None,
|
||||||
api_base=config.get_api_base(model),
|
api_base=config.get_api_base(model, preset=resolved),
|
||||||
default_model=model,
|
default_model=model,
|
||||||
extra_headers=p.extra_headers if p else None,
|
extra_headers=p.extra_headers if p else None,
|
||||||
spec=spec,
|
spec=spec,
|
||||||
@ -88,16 +102,21 @@ def make_provider(config: Config) -> LLMProvider:
|
|||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
def provider_signature(config: Config) -> tuple[object, ...]:
|
def provider_signature(
|
||||||
|
config: Config,
|
||||||
|
*,
|
||||||
|
preset_name: str | None = None,
|
||||||
|
preset: ModelPresetConfig | None = None,
|
||||||
|
) -> tuple[object, ...]:
|
||||||
"""Return the config fields that affect the primary LLM provider."""
|
"""Return the config fields that affect the primary LLM provider."""
|
||||||
resolved = config.resolve_preset()
|
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
||||||
p = config.get_provider(resolved.model)
|
p = config.get_provider(resolved.model, preset=resolved)
|
||||||
return (
|
return (
|
||||||
resolved.model,
|
resolved.model,
|
||||||
resolved.provider,
|
resolved.provider,
|
||||||
config.get_provider_name(resolved.model),
|
config.get_provider_name(resolved.model, preset=resolved),
|
||||||
config.get_api_key(resolved.model),
|
config.get_api_key(resolved.model, preset=resolved),
|
||||||
config.get_api_base(resolved.model),
|
config.get_api_base(resolved.model, preset=resolved),
|
||||||
p.extra_headers if p else None,
|
p.extra_headers if p else None,
|
||||||
p.extra_body if p else None,
|
p.extra_body if p else None,
|
||||||
getattr(p, "region", None) if p else None,
|
getattr(p, "region", None) if p else None,
|
||||||
@ -109,13 +128,18 @@ def provider_signature(config: Config) -> tuple[object, ...]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_provider_snapshot(config: Config) -> ProviderSnapshot:
|
def build_provider_snapshot(
|
||||||
resolved = config.resolve_preset()
|
config: Config,
|
||||||
|
*,
|
||||||
|
preset_name: str | None = None,
|
||||||
|
preset: ModelPresetConfig | None = None,
|
||||||
|
) -> ProviderSnapshot:
|
||||||
|
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
||||||
return ProviderSnapshot(
|
return ProviderSnapshot(
|
||||||
provider=make_provider(config),
|
provider=make_provider(config, preset=resolved),
|
||||||
model=resolved.model,
|
model=resolved.model,
|
||||||
context_window_tokens=resolved.context_window_tokens,
|
context_window_tokens=resolved.context_window_tokens,
|
||||||
signature=provider_signature(config),
|
signature=provider_signature(config, preset=resolved),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from nanobot.agent.loop import AgentLoop
|
|||||||
from nanobot.agent.tools.self import MyTool
|
from nanobot.agent.tools.self import MyTool
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.config.schema import ModelPresetConfig
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
from nanobot.providers.factory import ProviderSnapshot
|
||||||
|
|
||||||
|
|
||||||
def _provider(default_model: str, max_tokens: int = 123) -> MagicMock:
|
def _provider(default_model: str, max_tokens: int = 123) -> MagicMock:
|
||||||
@ -56,6 +57,75 @@ def test_model_preset_setter_updates_state(tmp_path) -> None:
|
|||||||
assert loop.provider.generation.temperature == 0.5
|
assert loop.provider.generation.temperature == 0.5
|
||||||
assert loop.provider.generation.max_tokens == 4096
|
assert loop.provider.generation.max_tokens == 4096
|
||||||
assert loop.provider.generation.reasoning_effort == "low"
|
assert loop.provider.generation.reasoning_effort == "low"
|
||||||
|
assert loop.subagents.model == "openai/gpt-4.1"
|
||||||
|
assert loop.consolidator.model == "openai/gpt-4.1"
|
||||||
|
assert loop.consolidator.context_window_tokens == 32_768
|
||||||
|
assert loop.consolidator.max_completion_tokens == 4096
|
||||||
|
assert loop.dream.model == "openai/gpt-4.1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
|
||||||
|
old_provider = _provider("base-model", max_tokens=123)
|
||||||
|
new_provider = _provider("anthropic/claude-opus-4-5", max_tokens=2048)
|
||||||
|
preset = ModelPresetConfig(
|
||||||
|
model="anthropic/claude-opus-4-5",
|
||||||
|
provider="anthropic",
|
||||||
|
max_tokens=2048,
|
||||||
|
context_window_tokens=200_000,
|
||||||
|
)
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=old_provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="base-model",
|
||||||
|
context_window_tokens=1000,
|
||||||
|
model_presets={"deep": preset},
|
||||||
|
model_preset_snapshot_builder=lambda _preset: ProviderSnapshot(
|
||||||
|
provider=new_provider,
|
||||||
|
model=_preset.model,
|
||||||
|
context_window_tokens=_preset.context_window_tokens,
|
||||||
|
signature=("deep", _preset.model),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.set_model_preset("deep")
|
||||||
|
|
||||||
|
assert loop.provider is new_provider
|
||||||
|
assert loop.runner.provider is new_provider
|
||||||
|
assert loop.subagents.provider is new_provider
|
||||||
|
assert loop.subagents.runner.provider is new_provider
|
||||||
|
assert loop.consolidator.provider is new_provider
|
||||||
|
assert loop.dream.provider is new_provider
|
||||||
|
assert loop.dream._runner.provider is new_provider
|
||||||
|
assert loop.model == "anthropic/claude-opus-4-5"
|
||||||
|
assert loop.context_window_tokens == 200_000
|
||||||
|
assert loop.consolidator.max_completion_tokens == 2048
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_preset_setter_failure_leaves_old_state(tmp_path) -> None:
|
||||||
|
preset = ModelPresetConfig(model="openai/gpt-4.1", max_tokens=4096)
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=_provider("base-model", max_tokens=123),
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="base-model",
|
||||||
|
context_window_tokens=1000,
|
||||||
|
model_presets={"fast": preset},
|
||||||
|
model_preset_snapshot_builder=lambda _preset: (_ for _ in ()).throw(
|
||||||
|
RuntimeError("provider unavailable")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="provider unavailable"):
|
||||||
|
loop.set_model_preset("fast")
|
||||||
|
|
||||||
|
assert loop.model_preset is None
|
||||||
|
assert loop.model == "base-model"
|
||||||
|
assert loop.subagents.model == "base-model"
|
||||||
|
assert loop.consolidator.model == "base-model"
|
||||||
|
assert loop.dream.model == "base-model"
|
||||||
|
assert loop.context_window_tokens == 1000
|
||||||
|
assert loop.consolidator.max_completion_tokens == 123
|
||||||
|
|
||||||
|
|
||||||
def test_model_preset_setter_raises_on_unknown(tmp_path) -> None:
|
def test_model_preset_setter_raises_on_unknown(tmp_path) -> None:
|
||||||
|
|||||||
137
tests/command/test_model_command.py
Normal file
137
tests/command/test_model_command.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.command.builtin import (
|
||||||
|
build_help_text,
|
||||||
|
builtin_command_palette,
|
||||||
|
cmd_model,
|
||||||
|
register_builtin_commands,
|
||||||
|
)
|
||||||
|
from nanobot.command.router import CommandContext, CommandRouter
|
||||||
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _provider(default_model: str, max_tokens: int = 123) -> MagicMock:
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = default_model
|
||||||
|
provider.generation = SimpleNamespace(
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=0.1,
|
||||||
|
reasoning_effort=None,
|
||||||
|
)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop(tmp_path) -> AgentLoop:
|
||||||
|
return AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=_provider("base-model", max_tokens=123),
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="base-model",
|
||||||
|
context_window_tokens=1000,
|
||||||
|
model_presets={
|
||||||
|
"default": ModelPresetConfig(
|
||||||
|
model="base-model",
|
||||||
|
max_tokens=123,
|
||||||
|
context_window_tokens=1000,
|
||||||
|
),
|
||||||
|
"fast": ModelPresetConfig(
|
||||||
|
model="openai/gpt-4.1",
|
||||||
|
max_tokens=4096,
|
||||||
|
context_window_tokens=32_768,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ctx(loop: AgentLoop, raw: str, args: str = "") -> CommandContext:
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content=raw)
|
||||||
|
return CommandContext(msg=msg, session=None, key=msg.session_key, raw=raw, args=args, loop=loop)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_command_lists_current_and_available_presets(tmp_path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
|
||||||
|
out = await cmd_model(_ctx(loop, "/model"))
|
||||||
|
|
||||||
|
assert "Current model: `base-model`" in out.content
|
||||||
|
assert "Active preset: `(none)`" in out.content
|
||||||
|
assert "`default`" in out.content
|
||||||
|
assert "`fast`" in out.content
|
||||||
|
assert out.metadata == {"render_as": "text"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_command_switches_preset(tmp_path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
|
||||||
|
out = await cmd_model(_ctx(loop, "/model fast", args="fast"))
|
||||||
|
|
||||||
|
assert "Switched model preset to `fast`." in out.content
|
||||||
|
assert "Model: `openai/gpt-4.1`" in out.content
|
||||||
|
assert loop.model_preset == "fast"
|
||||||
|
assert loop.model == "openai/gpt-4.1"
|
||||||
|
assert loop.subagents.model == "openai/gpt-4.1"
|
||||||
|
assert loop.consolidator.model == "openai/gpt-4.1"
|
||||||
|
assert loop.dream.model == "openai/gpt-4.1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_command_switches_back_to_default(tmp_path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
loop.set_model_preset("fast")
|
||||||
|
|
||||||
|
out = await cmd_model(_ctx(loop, "/model default", args="default"))
|
||||||
|
|
||||||
|
assert "Switched model preset to `default`." in out.content
|
||||||
|
assert loop.model_preset == "default"
|
||||||
|
assert loop.model == "base-model"
|
||||||
|
assert loop.context_window_tokens == 1000
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_command_unknown_preset_keeps_old_state(tmp_path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
|
||||||
|
out = await cmd_model(_ctx(loop, "/model missing", args="missing"))
|
||||||
|
|
||||||
|
assert "Could not switch model preset" in out.content
|
||||||
|
assert "Available presets: `default`, `fast`" in out.content
|
||||||
|
assert loop.model_preset is None
|
||||||
|
assert loop.model == "base-model"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_command_does_not_depend_on_my_allow_set(tmp_path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
assert loop.tools_config.my.allow_set is False
|
||||||
|
|
||||||
|
await cmd_model(_ctx(loop, "/model fast", args="fast"))
|
||||||
|
|
||||||
|
assert loop.model_preset == "fast"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_command_registered_as_exact_and_prefix(tmp_path) -> None:
|
||||||
|
router = CommandRouter()
|
||||||
|
register_builtin_commands(router)
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
|
||||||
|
out = await router.dispatch(_ctx(loop, "/model fast"))
|
||||||
|
|
||||||
|
assert out is not None
|
||||||
|
assert "Switched model preset" in out.content
|
||||||
|
assert loop.model_preset == "fast"
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_command_in_help_and_palette() -> None:
|
||||||
|
palette = builtin_command_palette()
|
||||||
|
|
||||||
|
assert any(item["command"] == "/model" and item["arg_hint"] == "[preset]" for item in palette)
|
||||||
|
assert "/model [preset]" in build_help_text()
|
||||||
@ -22,6 +22,7 @@ class TestIsDispatchableCommand:
|
|||||||
def test_exact_commands_match(self, router: CommandRouter) -> None:
|
def test_exact_commands_match(self, router: CommandRouter) -> None:
|
||||||
assert router.is_dispatchable_command("/new")
|
assert router.is_dispatchable_command("/new")
|
||||||
assert router.is_dispatchable_command("/help")
|
assert router.is_dispatchable_command("/help")
|
||||||
|
assert router.is_dispatchable_command("/model")
|
||||||
assert router.is_dispatchable_command("/dream")
|
assert router.is_dispatchable_command("/dream")
|
||||||
assert router.is_dispatchable_command("/dream-log")
|
assert router.is_dispatchable_command("/dream-log")
|
||||||
assert router.is_dispatchable_command("/dream-restore")
|
assert router.is_dispatchable_command("/dream-restore")
|
||||||
@ -29,6 +30,7 @@ class TestIsDispatchableCommand:
|
|||||||
def test_prefix_commands_match(self, router: CommandRouter) -> None:
|
def test_prefix_commands_match(self, router: CommandRouter) -> None:
|
||||||
assert router.is_dispatchable_command("/dream-log abc123")
|
assert router.is_dispatchable_command("/dream-log abc123")
|
||||||
assert router.is_dispatchable_command("/dream-restore def456")
|
assert router.is_dispatchable_command("/dream-restore def456")
|
||||||
|
assert router.is_dispatchable_command("/model fast")
|
||||||
|
|
||||||
def test_priority_commands_not_matched(self, router: CommandRouter) -> None:
|
def test_priority_commands_not_matched(self, router: CommandRouter) -> None:
|
||||||
# Priority commands are NOT in the dispatchable tiers — they are
|
# Priority commands are NOT in the dispatchable tiers — they are
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from nanobot.config.schema import Config, ModelPresetConfig
|
from nanobot.config.schema import Config
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_preset_returns_defaults_when_no_preset() -> None:
|
def test_resolve_preset_returns_defaults_when_no_preset() -> None:
|
||||||
@ -39,6 +39,20 @@ def test_resolve_preset_returns_active_preset() -> None:
|
|||||||
assert resolved.reasoning_effort == "low"
|
assert resolved.reasoning_effort == "low"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_preset_can_target_named_preset_without_activating() -> None:
|
||||||
|
config = Config.model_validate({
|
||||||
|
"model_presets": {
|
||||||
|
"fast": {"model": "openai/gpt-4.1", "provider": "openai"},
|
||||||
|
"deep": {"model": "anthropic/claude-opus-4-5", "provider": "anthropic"},
|
||||||
|
},
|
||||||
|
"agents": {"defaults": {"modelPreset": "fast"}},
|
||||||
|
})
|
||||||
|
|
||||||
|
resolved = config.resolve_preset("deep")
|
||||||
|
assert resolved.model == "anthropic/claude-opus-4-5"
|
||||||
|
assert resolved.provider == "anthropic"
|
||||||
|
|
||||||
|
|
||||||
def test_validator_rejects_unknown_preset() -> None:
|
def test_validator_rejects_unknown_preset() -> None:
|
||||||
import pytest
|
import pytest
|
||||||
with pytest.raises(ValueError, match="model_preset 'unknown' not found in model_presets"):
|
with pytest.raises(ValueError, match="model_preset 'unknown' not found in model_presets"):
|
||||||
@ -51,6 +65,12 @@ def test_validator_rejects_unknown_preset() -> None:
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_preset_rejects_unknown_named_preset() -> None:
|
||||||
|
import pytest
|
||||||
|
with pytest.raises(KeyError, match="model_preset 'missing' not found"):
|
||||||
|
Config().resolve_preset("missing")
|
||||||
|
|
||||||
|
|
||||||
def test_match_provider_uses_preset_model() -> None:
|
def test_match_provider_uses_preset_model() -> None:
|
||||||
config = Config.model_validate({
|
config = Config.model_validate({
|
||||||
"providers": {
|
"providers": {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user