mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
fix(config): make model preset switching atomic
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
6f78267c82
commit
c450d6fd3f
@ -293,6 +293,7 @@ class AgentLoop:
|
||||
provider_signature: tuple[object, ...] | None = None,
|
||||
model_presets: dict[str, ModelPresetConfig] | None = None,
|
||||
model_preset: str | None = None,
|
||||
model_preset_snapshot_builder: Callable[[ModelPresetConfig], ProviderSnapshot] | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ToolsConfig
|
||||
|
||||
@ -398,7 +399,10 @@ class AgentLoop:
|
||||
model=self.model,
|
||||
)
|
||||
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
|
||||
self._active_preset: str | None = model_preset if model_presets and model_preset in model_presets else None
|
||||
self._model_preset_snapshot_builder = model_preset_snapshot_builder
|
||||
self._active_preset: str | None = None
|
||||
if model_preset:
|
||||
self.set_model_preset(model_preset)
|
||||
self._register_default_tools()
|
||||
self._runtime_vars: dict[str, Any] = {}
|
||||
self._current_iteration: int = 0
|
||||
@ -418,7 +422,7 @@ class AgentLoop:
|
||||
allowing callers to override or extend the standard config-derived
|
||||
parameters (e.g. ``cron_service``, ``session_manager``).
|
||||
"""
|
||||
from nanobot.providers.factory import make_provider
|
||||
from nanobot.providers.factory import build_provider_snapshot, make_provider
|
||||
|
||||
if bus is None:
|
||||
bus = MessageBus()
|
||||
@ -453,6 +457,7 @@ class AgentLoop:
|
||||
tools_config=config.tools,
|
||||
model_presets=model_presets,
|
||||
model_preset=defaults.model_preset,
|
||||
model_preset_snapshot_builder=lambda preset: build_provider_snapshot(config, preset=preset),
|
||||
**extra,
|
||||
)
|
||||
|
||||
@ -465,8 +470,6 @@ class AgentLoop:
|
||||
provider = snapshot.provider
|
||||
model = snapshot.model
|
||||
context_window_tokens = snapshot.context_window_tokens
|
||||
if self.provider is provider and self.model == model:
|
||||
return
|
||||
old_model = self.model
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
@ -498,15 +501,38 @@ class AgentLoop:
|
||||
|
||||
@model_preset.setter
|
||||
def model_preset(self, name: str | None) -> None:
|
||||
"""Resolve a preset by name and apply all fields atomically."""
|
||||
self.set_model_preset(name)
|
||||
|
||||
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
|
||||
preset = self.model_presets[name]
|
||||
if self._model_preset_snapshot_builder is not None:
|
||||
return self._model_preset_snapshot_builder(preset)
|
||||
self.provider.generation = preset.to_generation_settings()
|
||||
return ProviderSnapshot(
|
||||
provider=self.provider,
|
||||
model=preset.model,
|
||||
context_window_tokens=preset.context_window_tokens,
|
||||
signature=(
|
||||
"model_preset",
|
||||
name,
|
||||
preset.model,
|
||||
preset.provider,
|
||||
preset.max_tokens,
|
||||
preset.context_window_tokens,
|
||||
preset.temperature,
|
||||
preset.reasoning_effort,
|
||||
),
|
||||
)
|
||||
|
||||
def set_model_preset(self, name: str | None) -> None:
|
||||
"""Resolve a preset by name and apply all runtime model dependents."""
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ValueError("model_preset must be a non-empty string")
|
||||
name = name.strip()
|
||||
if name not in self.model_presets:
|
||||
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
|
||||
p = self.model_presets[name]
|
||||
self.model = p.model
|
||||
self.context_window_tokens = p.context_window_tokens
|
||||
self.provider.generation = p.to_generation_settings()
|
||||
snapshot = self._build_model_preset_snapshot(name)
|
||||
self._apply_provider_snapshot(snapshot)
|
||||
self._active_preset = name
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
|
||||
@ -58,6 +58,13 @@ BUILTIN_COMMAND_SPECS: tuple[BuiltinCommandSpec, ...] = (
|
||||
"Display runtime, provider, and channel status.",
|
||||
"activity",
|
||||
),
|
||||
BuiltinCommandSpec(
|
||||
"/model",
|
||||
"Switch model preset",
|
||||
"Show or switch the active model preset.",
|
||||
"brain",
|
||||
"[preset]",
|
||||
),
|
||||
BuiltinCommandSpec(
|
||||
"/history",
|
||||
"Show conversation history",
|
||||
@ -192,6 +199,75 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
|
||||
)
|
||||
|
||||
|
||||
def _format_preset_names(names: list[str]) -> str:
|
||||
return ", ".join(f"`{name}`" for name in names) if names else "(none configured)"
|
||||
|
||||
|
||||
def _model_command_status(loop) -> str:
|
||||
names = sorted(loop.model_presets)
|
||||
active = loop.model_preset or "(none)"
|
||||
return "\n".join([
|
||||
"## Model",
|
||||
f"- Current model: `{loop.model}`",
|
||||
f"- Active preset: `{active}`",
|
||||
f"- Available presets: {_format_preset_names(names)}",
|
||||
])
|
||||
|
||||
|
||||
async def cmd_model(ctx: CommandContext) -> OutboundMessage:
|
||||
"""Show or switch model presets."""
|
||||
loop = ctx.loop
|
||||
args = ctx.args.strip()
|
||||
metadata = {**dict(ctx.msg.metadata or {}), "render_as": "text"}
|
||||
|
||||
if not args:
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content=_model_command_status(loop),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
parts = args.split()
|
||||
if len(parts) != 1:
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content="Usage: `/model [preset]`",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
name = parts[0]
|
||||
try:
|
||||
loop.set_model_preset(name)
|
||||
except (KeyError, ValueError) as exc:
|
||||
names = sorted(loop.model_presets)
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content=(
|
||||
f"Could not switch model preset: {exc}\n\n"
|
||||
f"Available presets: {_format_preset_names(names)}"
|
||||
),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
max_tokens = getattr(getattr(loop.provider, "generation", None), "max_tokens", None)
|
||||
lines = [
|
||||
f"Switched model preset to `{loop.model_preset}`.",
|
||||
f"- Model: `{loop.model}`",
|
||||
f"- Context window: {loop.context_window_tokens}",
|
||||
]
|
||||
if max_tokens is not None:
|
||||
lines.append(f"- Max output tokens: {max_tokens}")
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content="\n".join(lines),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def cmd_dream(ctx: CommandContext) -> OutboundMessage:
|
||||
"""Manually trigger a Dream consolidation run."""
|
||||
import time
|
||||
@ -477,6 +553,8 @@ def register_builtin_commands(router: CommandRouter) -> None:
|
||||
router.priority("/status", cmd_status)
|
||||
router.exact("/new", cmd_new)
|
||||
router.exact("/status", cmd_status)
|
||||
router.exact("/model", cmd_model)
|
||||
router.prefix("/model ", cmd_model)
|
||||
router.exact("/history", cmd_history)
|
||||
router.prefix("/history ", cmd_history)
|
||||
router.exact("/dream", cmd_dream)
|
||||
|
||||
@ -283,10 +283,12 @@ class Config(BaseSettings):
|
||||
raise ValueError(f"model_preset {name!r} not found in model_presets")
|
||||
return self
|
||||
|
||||
def resolve_preset(self) -> ModelPresetConfig:
|
||||
def resolve_preset(self, name: str | None = None) -> ModelPresetConfig:
|
||||
"""Return effective model params: from active preset, or individual defaults."""
|
||||
name = self.agents.defaults.model_preset
|
||||
name = self.agents.defaults.model_preset if name is None else name
|
||||
if name:
|
||||
if name not in self.model_presets:
|
||||
raise KeyError(f"model_preset {name!r} not found in model_presets")
|
||||
return self.model_presets[name]
|
||||
d = self.agents.defaults
|
||||
return ModelPresetConfig(
|
||||
@ -301,12 +303,14 @@ class Config(BaseSettings):
|
||||
return Path(self.agents.defaults.workspace).expanduser()
|
||||
|
||||
def _match_provider(
|
||||
self, model: str | None = None
|
||||
self, model: str | None = None,
|
||||
*,
|
||||
preset: ModelPresetConfig | None = None,
|
||||
) -> tuple["ProviderConfig | None", str | None]:
|
||||
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
||||
from nanobot.providers.registry import PROVIDERS, find_by_name
|
||||
|
||||
resolved = self.resolve_preset()
|
||||
resolved = preset or self.resolve_preset()
|
||||
forced = resolved.provider
|
||||
if forced != "auto":
|
||||
spec = find_by_name(forced)
|
||||
@ -366,26 +370,46 @@ class Config(BaseSettings):
|
||||
return p, spec.name
|
||||
return None, None
|
||||
|
||||
def get_provider(self, model: str | None = None) -> ProviderConfig | None:
|
||||
def get_provider(
|
||||
self,
|
||||
model: str | None = None,
|
||||
*,
|
||||
preset: ModelPresetConfig | None = None,
|
||||
) -> ProviderConfig | None:
|
||||
"""Get matched provider config (api_key, api_base, extra_headers). Falls back to first available."""
|
||||
p, _ = self._match_provider(model)
|
||||
p, _ = self._match_provider(model, preset=preset)
|
||||
return p
|
||||
|
||||
def get_provider_name(self, model: str | None = None) -> str | None:
|
||||
def get_provider_name(
|
||||
self,
|
||||
model: str | None = None,
|
||||
*,
|
||||
preset: ModelPresetConfig | None = None,
|
||||
) -> str | None:
|
||||
"""Get the registry name of the matched provider (e.g. "deepseek", "openrouter")."""
|
||||
_, name = self._match_provider(model)
|
||||
_, name = self._match_provider(model, preset=preset)
|
||||
return name
|
||||
|
||||
def get_api_key(self, model: str | None = None) -> str | None:
|
||||
def get_api_key(
|
||||
self,
|
||||
model: str | None = None,
|
||||
*,
|
||||
preset: ModelPresetConfig | None = None,
|
||||
) -> str | None:
|
||||
"""Get API key for the given model. Falls back to first available key."""
|
||||
p = self.get_provider(model)
|
||||
p = self.get_provider(model, preset=preset)
|
||||
return p.api_key if p else None
|
||||
|
||||
def get_api_base(self, model: str | None = None) -> str | None:
|
||||
def get_api_base(
|
||||
self,
|
||||
model: str | None = None,
|
||||
*,
|
||||
preset: ModelPresetConfig | None = None,
|
||||
) -> str | None:
|
||||
"""Get API base URL for the given model, falling back to the provider default when present."""
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
p, name = self._match_provider(model)
|
||||
p, name = self._match_provider(model, preset=preset)
|
||||
if p and p.api_base:
|
||||
return p.api_base
|
||||
if name:
|
||||
|
||||
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.config.schema import Config, ModelPresetConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
@ -18,12 +18,26 @@ class ProviderSnapshot:
|
||||
signature: tuple[object, ...]
|
||||
|
||||
|
||||
def make_provider(config: Config) -> LLMProvider:
|
||||
def _resolve_model_preset(
|
||||
config: Config,
|
||||
*,
|
||||
preset_name: str | None = None,
|
||||
preset: ModelPresetConfig | None = None,
|
||||
) -> ModelPresetConfig:
|
||||
return preset if preset is not None else config.resolve_preset(preset_name)
|
||||
|
||||
|
||||
def make_provider(
|
||||
config: Config,
|
||||
*,
|
||||
preset_name: str | None = None,
|
||||
preset: ModelPresetConfig | None = None,
|
||||
) -> LLMProvider:
|
||||
"""Create the LLM provider implied by config."""
|
||||
resolved = config.resolve_preset()
|
||||
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
||||
model = resolved.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
provider_name = config.get_provider_name(model, preset=resolved)
|
||||
p = config.get_provider(model, preset=resolved)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
|
||||
@ -57,7 +71,7 @@ def make_provider(config: Config) -> LLMProvider:
|
||||
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
api_base=config.get_api_base(model, preset=resolved),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
)
|
||||
@ -77,7 +91,7 @@ def make_provider(config: Config) -> LLMProvider:
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
api_base=config.get_api_base(model, preset=resolved),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
spec=spec,
|
||||
@ -88,16 +102,21 @@ def make_provider(config: Config) -> LLMProvider:
|
||||
return provider
|
||||
|
||||
|
||||
def provider_signature(config: Config) -> tuple[object, ...]:
|
||||
def provider_signature(
|
||||
config: Config,
|
||||
*,
|
||||
preset_name: str | None = None,
|
||||
preset: ModelPresetConfig | None = None,
|
||||
) -> tuple[object, ...]:
|
||||
"""Return the config fields that affect the primary LLM provider."""
|
||||
resolved = config.resolve_preset()
|
||||
p = config.get_provider(resolved.model)
|
||||
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
||||
p = config.get_provider(resolved.model, preset=resolved)
|
||||
return (
|
||||
resolved.model,
|
||||
resolved.provider,
|
||||
config.get_provider_name(resolved.model),
|
||||
config.get_api_key(resolved.model),
|
||||
config.get_api_base(resolved.model),
|
||||
config.get_provider_name(resolved.model, preset=resolved),
|
||||
config.get_api_key(resolved.model, preset=resolved),
|
||||
config.get_api_base(resolved.model, preset=resolved),
|
||||
p.extra_headers if p else None,
|
||||
p.extra_body if p else None,
|
||||
getattr(p, "region", None) if p else None,
|
||||
@ -109,13 +128,18 @@ def provider_signature(config: Config) -> tuple[object, ...]:
|
||||
)
|
||||
|
||||
|
||||
def build_provider_snapshot(config: Config) -> ProviderSnapshot:
|
||||
resolved = config.resolve_preset()
|
||||
def build_provider_snapshot(
|
||||
config: Config,
|
||||
*,
|
||||
preset_name: str | None = None,
|
||||
preset: ModelPresetConfig | None = None,
|
||||
) -> ProviderSnapshot:
|
||||
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
||||
return ProviderSnapshot(
|
||||
provider=make_provider(config),
|
||||
provider=make_provider(config, preset=resolved),
|
||||
model=resolved.model,
|
||||
context_window_tokens=resolved.context_window_tokens,
|
||||
signature=provider_signature(config),
|
||||
signature=provider_signature(config, preset=resolved),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.tools.self import MyTool
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import ModelPresetConfig
|
||||
from nanobot.providers.factory import ProviderSnapshot
|
||||
|
||||
|
||||
def _provider(default_model: str, max_tokens: int = 123) -> MagicMock:
|
||||
@ -56,6 +57,75 @@ def test_model_preset_setter_updates_state(tmp_path) -> None:
|
||||
assert loop.provider.generation.temperature == 0.5
|
||||
assert loop.provider.generation.max_tokens == 4096
|
||||
assert loop.provider.generation.reasoning_effort == "low"
|
||||
assert loop.subagents.model == "openai/gpt-4.1"
|
||||
assert loop.consolidator.model == "openai/gpt-4.1"
|
||||
assert loop.consolidator.context_window_tokens == 32_768
|
||||
assert loop.consolidator.max_completion_tokens == 4096
|
||||
assert loop.dream.model == "openai/gpt-4.1"
|
||||
|
||||
|
||||
def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
|
||||
old_provider = _provider("base-model", max_tokens=123)
|
||||
new_provider = _provider("anthropic/claude-opus-4-5", max_tokens=2048)
|
||||
preset = ModelPresetConfig(
|
||||
model="anthropic/claude-opus-4-5",
|
||||
provider="anthropic",
|
||||
max_tokens=2048,
|
||||
context_window_tokens=200_000,
|
||||
)
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=old_provider,
|
||||
workspace=tmp_path,
|
||||
model="base-model",
|
||||
context_window_tokens=1000,
|
||||
model_presets={"deep": preset},
|
||||
model_preset_snapshot_builder=lambda _preset: ProviderSnapshot(
|
||||
provider=new_provider,
|
||||
model=_preset.model,
|
||||
context_window_tokens=_preset.context_window_tokens,
|
||||
signature=("deep", _preset.model),
|
||||
),
|
||||
)
|
||||
|
||||
loop.set_model_preset("deep")
|
||||
|
||||
assert loop.provider is new_provider
|
||||
assert loop.runner.provider is new_provider
|
||||
assert loop.subagents.provider is new_provider
|
||||
assert loop.subagents.runner.provider is new_provider
|
||||
assert loop.consolidator.provider is new_provider
|
||||
assert loop.dream.provider is new_provider
|
||||
assert loop.dream._runner.provider is new_provider
|
||||
assert loop.model == "anthropic/claude-opus-4-5"
|
||||
assert loop.context_window_tokens == 200_000
|
||||
assert loop.consolidator.max_completion_tokens == 2048
|
||||
|
||||
|
||||
def test_model_preset_setter_failure_leaves_old_state(tmp_path) -> None:
|
||||
preset = ModelPresetConfig(model="openai/gpt-4.1", max_tokens=4096)
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=_provider("base-model", max_tokens=123),
|
||||
workspace=tmp_path,
|
||||
model="base-model",
|
||||
context_window_tokens=1000,
|
||||
model_presets={"fast": preset},
|
||||
model_preset_snapshot_builder=lambda _preset: (_ for _ in ()).throw(
|
||||
RuntimeError("provider unavailable")
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="provider unavailable"):
|
||||
loop.set_model_preset("fast")
|
||||
|
||||
assert loop.model_preset is None
|
||||
assert loop.model == "base-model"
|
||||
assert loop.subagents.model == "base-model"
|
||||
assert loop.consolidator.model == "base-model"
|
||||
assert loop.dream.model == "base-model"
|
||||
assert loop.context_window_tokens == 1000
|
||||
assert loop.consolidator.max_completion_tokens == 123
|
||||
|
||||
|
||||
def test_model_preset_setter_raises_on_unknown(tmp_path) -> None:
|
||||
|
||||
137
tests/command/test_model_command.py
Normal file
137
tests/command/test_model_command.py
Normal file
@ -0,0 +1,137 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.command.builtin import (
|
||||
build_help_text,
|
||||
builtin_command_palette,
|
||||
cmd_model,
|
||||
register_builtin_commands,
|
||||
)
|
||||
from nanobot.command.router import CommandContext, CommandRouter
|
||||
from nanobot.config.schema import ModelPresetConfig
|
||||
|
||||
|
||||
def _provider(default_model: str, max_tokens: int = 123) -> MagicMock:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = default_model
|
||||
provider.generation = SimpleNamespace(
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.1,
|
||||
reasoning_effort=None,
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
def _make_loop(tmp_path) -> AgentLoop:
|
||||
return AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=_provider("base-model", max_tokens=123),
|
||||
workspace=tmp_path,
|
||||
model="base-model",
|
||||
context_window_tokens=1000,
|
||||
model_presets={
|
||||
"default": ModelPresetConfig(
|
||||
model="base-model",
|
||||
max_tokens=123,
|
||||
context_window_tokens=1000,
|
||||
),
|
||||
"fast": ModelPresetConfig(
|
||||
model="openai/gpt-4.1",
|
||||
max_tokens=4096,
|
||||
context_window_tokens=32_768,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _ctx(loop: AgentLoop, raw: str, args: str = "") -> CommandContext:
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content=raw)
|
||||
return CommandContext(msg=msg, session=None, key=msg.session_key, raw=raw, args=args, loop=loop)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_command_lists_current_and_available_presets(tmp_path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
|
||||
out = await cmd_model(_ctx(loop, "/model"))
|
||||
|
||||
assert "Current model: `base-model`" in out.content
|
||||
assert "Active preset: `(none)`" in out.content
|
||||
assert "`default`" in out.content
|
||||
assert "`fast`" in out.content
|
||||
assert out.metadata == {"render_as": "text"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_command_switches_preset(tmp_path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
|
||||
out = await cmd_model(_ctx(loop, "/model fast", args="fast"))
|
||||
|
||||
assert "Switched model preset to `fast`." in out.content
|
||||
assert "Model: `openai/gpt-4.1`" in out.content
|
||||
assert loop.model_preset == "fast"
|
||||
assert loop.model == "openai/gpt-4.1"
|
||||
assert loop.subagents.model == "openai/gpt-4.1"
|
||||
assert loop.consolidator.model == "openai/gpt-4.1"
|
||||
assert loop.dream.model == "openai/gpt-4.1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_command_switches_back_to_default(tmp_path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.set_model_preset("fast")
|
||||
|
||||
out = await cmd_model(_ctx(loop, "/model default", args="default"))
|
||||
|
||||
assert "Switched model preset to `default`." in out.content
|
||||
assert loop.model_preset == "default"
|
||||
assert loop.model == "base-model"
|
||||
assert loop.context_window_tokens == 1000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_command_unknown_preset_keeps_old_state(tmp_path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
|
||||
out = await cmd_model(_ctx(loop, "/model missing", args="missing"))
|
||||
|
||||
assert "Could not switch model preset" in out.content
|
||||
assert "Available presets: `default`, `fast`" in out.content
|
||||
assert loop.model_preset is None
|
||||
assert loop.model == "base-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_command_does_not_depend_on_my_allow_set(tmp_path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
assert loop.tools_config.my.allow_set is False
|
||||
|
||||
await cmd_model(_ctx(loop, "/model fast", args="fast"))
|
||||
|
||||
assert loop.model_preset == "fast"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_command_registered_as_exact_and_prefix(tmp_path) -> None:
|
||||
router = CommandRouter()
|
||||
register_builtin_commands(router)
|
||||
loop = _make_loop(tmp_path)
|
||||
|
||||
out = await router.dispatch(_ctx(loop, "/model fast"))
|
||||
|
||||
assert out is not None
|
||||
assert "Switched model preset" in out.content
|
||||
assert loop.model_preset == "fast"
|
||||
|
||||
|
||||
def test_model_command_in_help_and_palette() -> None:
|
||||
palette = builtin_command_palette()
|
||||
|
||||
assert any(item["command"] == "/model" and item["arg_hint"] == "[preset]" for item in palette)
|
||||
assert "/model [preset]" in build_help_text()
|
||||
@ -22,6 +22,7 @@ class TestIsDispatchableCommand:
|
||||
def test_exact_commands_match(self, router: CommandRouter) -> None:
|
||||
assert router.is_dispatchable_command("/new")
|
||||
assert router.is_dispatchable_command("/help")
|
||||
assert router.is_dispatchable_command("/model")
|
||||
assert router.is_dispatchable_command("/dream")
|
||||
assert router.is_dispatchable_command("/dream-log")
|
||||
assert router.is_dispatchable_command("/dream-restore")
|
||||
@ -29,6 +30,7 @@ class TestIsDispatchableCommand:
|
||||
def test_prefix_commands_match(self, router: CommandRouter) -> None:
|
||||
assert router.is_dispatchable_command("/dream-log abc123")
|
||||
assert router.is_dispatchable_command("/dream-restore def456")
|
||||
assert router.is_dispatchable_command("/model fast")
|
||||
|
||||
def test_priority_commands_not_matched(self, router: CommandRouter) -> None:
|
||||
# Priority commands are NOT in the dispatchable tiers — they are
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from nanobot.config.schema import Config, ModelPresetConfig
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
def test_resolve_preset_returns_defaults_when_no_preset() -> None:
|
||||
@ -39,6 +39,20 @@ def test_resolve_preset_returns_active_preset() -> None:
|
||||
assert resolved.reasoning_effort == "low"
|
||||
|
||||
|
||||
def test_resolve_preset_can_target_named_preset_without_activating() -> None:
|
||||
config = Config.model_validate({
|
||||
"model_presets": {
|
||||
"fast": {"model": "openai/gpt-4.1", "provider": "openai"},
|
||||
"deep": {"model": "anthropic/claude-opus-4-5", "provider": "anthropic"},
|
||||
},
|
||||
"agents": {"defaults": {"modelPreset": "fast"}},
|
||||
})
|
||||
|
||||
resolved = config.resolve_preset("deep")
|
||||
assert resolved.model == "anthropic/claude-opus-4-5"
|
||||
assert resolved.provider == "anthropic"
|
||||
|
||||
|
||||
def test_validator_rejects_unknown_preset() -> None:
|
||||
import pytest
|
||||
with pytest.raises(ValueError, match="model_preset 'unknown' not found in model_presets"):
|
||||
@ -51,6 +65,12 @@ def test_validator_rejects_unknown_preset() -> None:
|
||||
})
|
||||
|
||||
|
||||
def test_resolve_preset_rejects_unknown_named_preset() -> None:
|
||||
import pytest
|
||||
with pytest.raises(KeyError, match="model_preset 'missing' not found"):
|
||||
Config().resolve_preset("missing")
|
||||
|
||||
|
||||
def test_match_provider_uses_preset_model() -> None:
|
||||
config = Config.model_validate({
|
||||
"providers": {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user