mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
feat(config): add model preset support for runtime model switching
Add ModelPresetConfig schema and model_presets dictionary to config, enabling named bundles of model parameters (model, temperature, max_tokens, reasoning_effort, context_window_tokens) that can be switched atomically at runtime via the self tool.
This commit is contained in:
parent
e34b7fd086
commit
83f437a088
@ -41,7 +41,7 @@ from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.config.schema import AgentDefaults, ModelPresetConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.providers.factory import ProviderSnapshot
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
@ -217,6 +217,8 @@ class AgentLoop:
|
||||
tools_config: ToolsConfig | None = None,
|
||||
provider_snapshot_loader: Callable[[], ProviderSnapshot] | None = None,
|
||||
provider_signature: tuple[object, ...] | None = None,
|
||||
model_presets: dict[str, ModelPresetConfig] | None = None,
|
||||
model_preset: str | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig, ToolsConfig, WebToolsConfig
|
||||
|
||||
@ -255,7 +257,6 @@ class AgentLoop:
|
||||
self._start_time = time.time()
|
||||
self._last_usage: dict[str, int] = {}
|
||||
self._extra_hooks: list[AgentHook] = hooks or []
|
||||
|
||||
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
self.tools = ToolRegistry()
|
||||
@ -315,6 +316,8 @@ class AgentLoop:
|
||||
provider=provider,
|
||||
model=self.model,
|
||||
)
|
||||
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
|
||||
self._active_preset: str | None = model_preset if model_presets and model_preset in model_presets else None
|
||||
self._register_default_tools()
|
||||
if _tc.my.enable:
|
||||
self.tools.register(MyTool(loop=self, modify_allowed=_tc.my.allow_set))
|
||||
@ -357,6 +360,31 @@ class AgentLoop:
|
||||
return
|
||||
self._apply_provider_snapshot(snapshot)
|
||||
|
||||
# -- model_preset property --
|
||||
|
||||
@property
|
||||
def model_preset(self) -> str | None:
|
||||
return self._active_preset
|
||||
|
||||
@model_preset.setter
|
||||
def model_preset(self, name: str | None) -> None:
|
||||
"""Resolve a preset by name and apply all fields atomically."""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ValueError("model_preset must be a non-empty string")
|
||||
if name not in self.model_presets:
|
||||
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
|
||||
p = self.model_presets[name]
|
||||
self.model = p.model
|
||||
self.context_window_tokens = p.context_window_tokens
|
||||
self.provider.generation = GenerationSettings(
|
||||
temperature=p.temperature,
|
||||
max_tokens=p.max_tokens,
|
||||
reasoning_effort=p.reasoning_effort,
|
||||
)
|
||||
self._active_preset = name
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
allowed_dir = (
|
||||
|
||||
@ -330,6 +330,8 @@ class MyTool(Tool):
|
||||
# RESTRICTED keys
|
||||
for k in self.RESTRICTED:
|
||||
parts.append(self._format_value(getattr(loop, k, None), k))
|
||||
# model_preset (property on AgentLoop)
|
||||
parts.append(self._format_value(loop.model_preset, "model_preset"))
|
||||
# Other useful top-level keys shown in description
|
||||
for k in ("workspace", "provider_retry_mode", "max_tool_result_chars", "_current_iteration", "web_config", "exec_config", "subagents"):
|
||||
if _has_real_attr(loop, k):
|
||||
@ -386,7 +388,12 @@ class MyTool(Tool):
|
||||
value = expected(value)
|
||||
except (ValueError, TypeError):
|
||||
return f"Error: '{key}' must be {expected.__name__}, got {type(value).__name__}"
|
||||
|
||||
# --- existing restricted key logic ---
|
||||
old = getattr(self._loop, key)
|
||||
# When model is set directly, it no longer matches any preset
|
||||
if key == "model":
|
||||
self._loop._active_preset = None
|
||||
if "min" in spec and value < spec["min"]:
|
||||
return f"Error: '{key}' must be >= {spec['min']}"
|
||||
if "max" in spec and value > spec["max"]:
|
||||
@ -412,7 +419,11 @@ class MyTool(Tool):
|
||||
f"REJECTED type mismatch {key}: expects {old_t.__name__}, got {new_t.__name__}",
|
||||
)
|
||||
return f"Error: '{key}' expects {old_t.__name__}, got {new_t.__name__}"
|
||||
setattr(self._loop, key, value)
|
||||
try:
|
||||
setattr(self._loop, key, value)
|
||||
except (ValueError, KeyError) as e:
|
||||
self._audit("modify", f"REJECTED {key}: {e}")
|
||||
return f"Error: {e}"
|
||||
self._audit("modify", f"{key}: {old!r} -> {value!r}")
|
||||
return f"Set {key} = {value!r} (was {old!r})"
|
||||
if callable(value):
|
||||
|
||||
@ -160,7 +160,7 @@ def _read_webui_model_name() -> str | None:
|
||||
try:
|
||||
from nanobot.config.loader import load_config
|
||||
|
||||
model = load_config().agents.defaults.model.strip()
|
||||
model = load_config().resolve_preset().model.strip()
|
||||
return model or None
|
||||
except Exception as e:
|
||||
logger.debug("webui bootstrap could not load model name: {}", e)
|
||||
|
||||
@ -442,13 +442,75 @@ def _make_provider(config: Config):
|
||||
|
||||
Routing is driven by ``ProviderSpec.backend`` in the registry.
|
||||
"""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.factory import make_provider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
try:
|
||||
return make_provider(config)
|
||||
except ValueError as exc:
|
||||
console.print(f"[red]Error: {exc}[/red]")
|
||||
raise typer.Exit(1) from exc
|
||||
resolved = config.resolve_preset()
|
||||
model = resolved.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
|
||||
# --- validation ---
|
||||
if backend == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
||||
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
||||
console.print("Use the model field to specify the deployment name.")
|
||||
raise typer.Exit(1)
|
||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||
needs_key = not (p and p.api_key)
|
||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||
if needs_key and not exempt:
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# --- instantiation by backend ---
|
||||
if backend == "openai_codex":
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
elif backend == "azure_openai":
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key,
|
||||
api_base=p.api_base,
|
||||
default_model=model,
|
||||
)
|
||||
elif backend == "github_copilot":
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
provider = GitHubCopilotProvider(default_model=model)
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
spec=spec,
|
||||
extra_body=p.extra_body if p else None,
|
||||
)
|
||||
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=resolved.temperature,
|
||||
max_tokens=resolved.max_tokens,
|
||||
reasoning_effort=resolved.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||
@ -547,13 +609,14 @@ def serve(
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(runtime_config)
|
||||
session_manager = SessionManager(runtime_config.workspace_path)
|
||||
_resolved = runtime_config.resolve_preset()
|
||||
agent_loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=runtime_config.workspace_path,
|
||||
model=runtime_config.agents.defaults.model,
|
||||
model=_resolved.model,
|
||||
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
|
||||
context_window_tokens=_resolved.context_window_tokens,
|
||||
context_block_limit=runtime_config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
|
||||
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
|
||||
@ -571,12 +634,16 @@ def serve(
|
||||
consolidation_ratio=runtime_config.agents.defaults.consolidation_ratio,
|
||||
max_messages=runtime_config.agents.defaults.max_messages,
|
||||
tools_config=runtime_config.tools,
|
||||
model_presets=runtime_config.model_presets,
|
||||
model_preset=runtime_config.agents.defaults.model_preset,
|
||||
)
|
||||
|
||||
model_name = runtime_config.agents.defaults.model
|
||||
model_name = _resolved.model
|
||||
preset_name = runtime_config.agents.defaults.model_preset
|
||||
preset_tag = f" (preset: {preset_name})" if preset_name else ""
|
||||
console.print(f"{__logo__} Starting OpenAI-compatible API server")
|
||||
console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
|
||||
console.print(f" [cyan]Model[/cyan] : {model_name}")
|
||||
console.print(f" [cyan]Model[/cyan] : {model_name}{preset_tag}")
|
||||
console.print(" [cyan]Session[/cyan] : api:default")
|
||||
console.print(f" [cyan]Timeout[/cyan] : {timeout}s")
|
||||
if host in {"0.0.0.0", "::"}:
|
||||
@ -671,13 +738,14 @@ def _run_gateway(
|
||||
cron = CronService(cron_store_path)
|
||||
|
||||
# Create agent with cron service
|
||||
_resolved = config.resolve_preset()
|
||||
agent = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=provider_snapshot.model,
|
||||
model=_resolved.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=provider_snapshot.context_window_tokens,
|
||||
context_window_tokens=_resolved.context_window_tokens,
|
||||
web_config=config.tools.web,
|
||||
context_block_limit=config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||
@ -698,6 +766,8 @@ def _run_gateway(
|
||||
tools_config=config.tools,
|
||||
provider_snapshot_loader=load_provider_snapshot,
|
||||
provider_signature=provider_snapshot.signature,
|
||||
model_presets=config.model_presets,
|
||||
model_preset=config.agents.defaults.model_preset,
|
||||
)
|
||||
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
@ -1064,13 +1134,14 @@ def agent(
|
||||
else:
|
||||
logger.disable("nanobot")
|
||||
|
||||
_resolved = config.resolve_preset()
|
||||
agent_loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=config.agents.defaults.model,
|
||||
model=_resolved.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
context_window_tokens=_resolved.context_window_tokens,
|
||||
web_config=config.tools.web,
|
||||
context_block_limit=config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||
@ -1088,6 +1159,8 @@ def agent(
|
||||
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
||||
max_messages=config.agents.defaults.max_messages,
|
||||
tools_config=config.tools,
|
||||
model_presets=config.model_presets,
|
||||
model_preset=config.agents.defaults.model_preset,
|
||||
)
|
||||
restart_notice = consume_restart_notice_from_env()
|
||||
if restart_notice and should_show_cli_restart_notice(restart_notice, session_id):
|
||||
@ -1131,7 +1204,7 @@ def agent(
|
||||
# Interactive mode — route through bus like other channels
|
||||
from nanobot.bus.events import InboundMessage
|
||||
_init_prompt_session()
|
||||
console.print(f"{__logo__} Interactive mode [bold blue]({config.agents.defaults.model})[/bold blue] — type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit\n")
|
||||
console.print(f"{__logo__} Interactive mode [bold blue]({_resolved.model})[/bold blue] — type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit\n")
|
||||
|
||||
if ":" in session_id:
|
||||
cli_channel, cli_chat_id = session_id.split(":", 1)
|
||||
@ -1489,7 +1562,10 @@ def status():
|
||||
if config_path.exists():
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
|
||||
console.print(f"Model: {config.agents.defaults.model}")
|
||||
_resolved = config.resolve_preset()
|
||||
_preset = config.agents.defaults.model_preset
|
||||
_preset_tag = f" (preset: {_preset})" if _preset else ""
|
||||
console.print(f"Model: {_resolved.model}{_preset_tag}")
|
||||
|
||||
# Check API keys from registry
|
||||
for spec in PROVIDERS:
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@ -65,18 +65,34 @@ class DreamConfig(Base):
|
||||
return f"every {hours}h"
|
||||
|
||||
|
||||
class ModelPresetConfig(Base):
|
||||
"""A named set of model + generation parameters for quick switching."""
|
||||
|
||||
model: str
|
||||
provider: str = "auto"
|
||||
max_tokens: int = 8192
|
||||
context_window_tokens: int = 65_536
|
||||
temperature: float = 0.1
|
||||
reasoning_effort: str | None = None
|
||||
|
||||
|
||||
class AgentDefaults(Base):
|
||||
"""Default agent configuration."""
|
||||
|
||||
workspace: str = "~/.nanobot/workspace"
|
||||
model_preset: str | None = None # Active preset name — takes precedence over fields below
|
||||
# Fallback fields (used when model_preset is not set):
|
||||
model: str = "anthropic/claude-opus-4-5"
|
||||
provider: str = (
|
||||
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
||||
)
|
||||
max_tokens: int = 8192
|
||||
context_window_tokens: int = 65_536
|
||||
context_block_limit: int | None = None
|
||||
temperature: float = 0.1
|
||||
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
|
||||
# End fallback fields
|
||||
|
||||
context_block_limit: int | None = None
|
||||
max_tool_iterations: int = 200
|
||||
max_concurrent_subagents: int = Field(default=1, ge=1)
|
||||
max_tool_result_chars: int = 16_000
|
||||
@ -273,6 +289,26 @@ class Config(BaseSettings):
|
||||
api: ApiConfig = Field(default_factory=ApiConfig)
|
||||
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
|
||||
tools: ToolsConfig = Field(default_factory=ToolsConfig)
|
||||
model_presets: dict[str, ModelPresetConfig] = Field(default_factory=dict)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model_preset(self) -> "Config":
|
||||
name = self.agents.defaults.model_preset
|
||||
if name and name not in self.model_presets:
|
||||
raise ValueError(f"model_preset {name!r} not found in model_presets")
|
||||
return self
|
||||
|
||||
def resolve_preset(self) -> ModelPresetConfig:
|
||||
"""Return effective model params: from active preset, or individual defaults."""
|
||||
name = self.agents.defaults.model_preset
|
||||
if name:
|
||||
return self.model_presets[name]
|
||||
d = self.agents.defaults
|
||||
return ModelPresetConfig(
|
||||
model=d.model, provider=d.provider, max_tokens=d.max_tokens,
|
||||
context_window_tokens=d.context_window_tokens,
|
||||
temperature=d.temperature, reasoning_effort=d.reasoning_effort,
|
||||
)
|
||||
|
||||
@property
|
||||
def workspace_path(self) -> Path:
|
||||
@ -285,7 +321,7 @@ class Config(BaseSettings):
|
||||
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
||||
from nanobot.providers.registry import PROVIDERS, find_by_name
|
||||
|
||||
forced = self.agents.defaults.provider
|
||||
forced = self.resolve_preset().provider
|
||||
if forced != "auto":
|
||||
spec = find_by_name(forced)
|
||||
if spec:
|
||||
@ -293,7 +329,7 @@ class Config(BaseSettings):
|
||||
return (p, spec.name) if p else (None, None)
|
||||
return None, None
|
||||
|
||||
model_lower = (model or self.agents.defaults.model).lower()
|
||||
model_lower = (model or self.resolve_preset().model).lower()
|
||||
model_normalized = model_lower.replace("-", "_")
|
||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||
normalized_prefix = model_prefix.replace("-", "_")
|
||||
|
||||
@ -65,14 +65,15 @@ class Nanobot:
|
||||
provider = _make_provider(config)
|
||||
bus = MessageBus()
|
||||
defaults = config.agents.defaults
|
||||
_resolved = config.resolve_preset()
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=defaults.model,
|
||||
model=_resolved.model,
|
||||
max_iterations=defaults.max_tool_iterations,
|
||||
context_window_tokens=defaults.context_window_tokens,
|
||||
context_window_tokens=_resolved.context_window_tokens,
|
||||
context_block_limit=defaults.context_block_limit,
|
||||
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||
provider_retry_mode=defaults.provider_retry_mode,
|
||||
@ -87,6 +88,8 @@ class Nanobot:
|
||||
session_ttl_minutes=defaults.session_ttl_minutes,
|
||||
consolidation_ratio=defaults.consolidation_ratio,
|
||||
tools_config=config.tools,
|
||||
model_presets=config.model_presets,
|
||||
model_preset=defaults.model_preset,
|
||||
)
|
||||
return cls(loop)
|
||||
|
||||
@ -126,6 +129,64 @@ class Nanobot:
|
||||
|
||||
def _make_provider(config: Any) -> Any:
|
||||
"""Create the LLM provider from config (extracted from CLI)."""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.factory import make_provider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
return make_provider(config)
|
||||
resolved = config.resolve_preset()
|
||||
model = resolved.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
|
||||
if backend == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||
needs_key = not (p and p.api_key)
|
||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||
if needs_key and not exempt:
|
||||
raise ValueError(f"No API key configured for provider '{provider_name}'.")
|
||||
|
||||
if backend == "openai_codex":
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
elif backend == "github_copilot":
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
|
||||
provider = GitHubCopilotProvider(default_model=model)
|
||||
elif backend == "azure_openai":
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key, api_base=p.api_base, default_model=model
|
||||
)
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
spec=spec,
|
||||
extra_body=p.extra_body if p else None,
|
||||
)
|
||||
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=resolved.temperature,
|
||||
max_tokens=resolved.max_tokens,
|
||||
reasoning_effort=resolved.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
|
||||
84
tests/agent/test_self_model_preset.py
Normal file
84
tests/agent/test_self_model_preset.py
Normal file
@ -0,0 +1,84 @@
|
||||
# tests/agent/test_self_model_preset.py
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.config.schema import ModelPresetConfig, MyToolConfig, ToolsConfig
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
|
||||
|
||||
def _make_loop(presets: dict | None = None) -> tuple[AgentLoop, "MyTool"]:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = GenerationSettings(temperature=0.1, max_tokens=8192)
|
||||
loop = AgentLoop(
|
||||
bus=MagicMock(),
|
||||
provider=provider,
|
||||
workspace=Path("/tmp/test"),
|
||||
model="test-model",
|
||||
context_window_tokens=65536,
|
||||
model_presets=presets or {},
|
||||
tools_config=ToolsConfig(my=MyToolConfig(allow_set=True)),
|
||||
)
|
||||
tool = loop.tools.get("my")
|
||||
return loop, tool
|
||||
|
||||
|
||||
async def test_set_model_preset_updates_all_fields() -> None:
|
||||
presets = {
|
||||
"gpt5": ModelPresetConfig(
|
||||
model="gpt-5",
|
||||
provider="openai",
|
||||
max_tokens=16384,
|
||||
context_window_tokens=128000,
|
||||
temperature=0.2,
|
||||
),
|
||||
}
|
||||
loop, tool = _make_loop(presets)
|
||||
result = await tool.execute(action="set", key="model_preset", value="gpt5")
|
||||
|
||||
assert loop.model == "gpt-5"
|
||||
assert loop.context_window_tokens == 128000
|
||||
assert loop.provider.generation.temperature == 0.2
|
||||
assert loop.provider.generation.max_tokens == 16384
|
||||
assert loop._active_preset == "gpt5"
|
||||
|
||||
|
||||
async def test_set_model_preset_unknown_returns_error() -> None:
|
||||
loop, tool = _make_loop({})
|
||||
result = await tool.execute(action="set", key="model_preset", value="nope")
|
||||
|
||||
assert "Error" in result or "not found" in result
|
||||
|
||||
|
||||
async def test_check_model_preset_shows_current() -> None:
|
||||
presets = {"gpt5": ModelPresetConfig(model="gpt-5", provider="openai")}
|
||||
loop, tool = _make_loop(presets)
|
||||
await tool.execute(action="set", key="model_preset", value="gpt5")
|
||||
result = await tool.execute(action="check", key="model_preset")
|
||||
|
||||
assert "gpt5" in result
|
||||
|
||||
|
||||
async def test_check_model_presets_shows_available() -> None:
|
||||
presets = {
|
||||
"gpt5": ModelPresetConfig(model="gpt-5", provider="openai"),
|
||||
"ds": ModelPresetConfig(model="deepseek-chat", provider="deepseek"),
|
||||
}
|
||||
loop, tool = _make_loop(presets)
|
||||
result = await tool.execute(action="check", key="model_presets")
|
||||
|
||||
assert "gpt5" in result
|
||||
assert "ds" in result
|
||||
|
||||
|
||||
async def test_set_model_directly_clears_preset() -> None:
|
||||
presets = {"gpt5": ModelPresetConfig(model="gpt-5", provider="openai")}
|
||||
loop, tool = _make_loop(presets)
|
||||
await tool.execute(action="set", key="model_preset", value="gpt5")
|
||||
assert loop._active_preset == "gpt5"
|
||||
|
||||
await tool.execute(action="set", key="model", value="other-model")
|
||||
assert loop._active_preset is None
|
||||
assert loop.model == "other-model"
|
||||
206
tests/config/test_model_presets.py
Normal file
206
tests/config/test_model_presets.py
Normal file
@ -0,0 +1,206 @@
|
||||
from nanobot.config.schema import Config, ModelPresetConfig
|
||||
|
||||
|
||||
def test_model_preset_config_accepts_model_and_provider_separately() -> None:
|
||||
preset = ModelPresetConfig(model="gpt-5", provider="openai")
|
||||
assert preset.model == "gpt-5"
|
||||
assert preset.provider == "openai"
|
||||
|
||||
|
||||
def test_model_preset_config_defaults() -> None:
|
||||
preset = ModelPresetConfig(model="test-model")
|
||||
assert preset.provider == "auto"
|
||||
assert preset.max_tokens == 8192
|
||||
assert preset.context_window_tokens == 65_536
|
||||
assert preset.temperature == 0.1
|
||||
assert preset.reasoning_effort is None
|
||||
|
||||
|
||||
def test_model_preset_config_all_fields() -> None:
|
||||
preset = ModelPresetConfig(
|
||||
model="deepseek-r1",
|
||||
provider="deepseek",
|
||||
max_tokens=16384,
|
||||
context_window_tokens=131072,
|
||||
temperature=0.2,
|
||||
reasoning_effort="high",
|
||||
)
|
||||
assert preset.model == "deepseek-r1"
|
||||
assert preset.provider == "deepseek"
|
||||
assert preset.max_tokens == 16384
|
||||
assert preset.context_window_tokens == 131072
|
||||
assert preset.temperature == 0.2
|
||||
assert preset.reasoning_effort == "high"
|
||||
|
||||
|
||||
def test_config_accepts_model_presets_dict() -> None:
|
||||
cfg = Config(model_presets={
|
||||
"gpt5": ModelPresetConfig(model="gpt-5", provider="openai", max_tokens=16384),
|
||||
"ds": ModelPresetConfig(model="deepseek-chat", provider="deepseek"),
|
||||
})
|
||||
assert "gpt5" in cfg.model_presets
|
||||
assert cfg.model_presets["gpt5"].max_tokens == 16384
|
||||
assert cfg.model_presets["ds"].model == "deepseek-chat"
|
||||
|
||||
|
||||
def test_resolve_preset_returns_preset_values() -> None:
|
||||
cfg = Config.model_validate({
|
||||
"model_presets": {
|
||||
"gpt5": {
|
||||
"model": "gpt-5",
|
||||
"provider": "openai",
|
||||
"max_tokens": 16384,
|
||||
"context_window_tokens": 128000,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
},
|
||||
"agents": {"defaults": {"model_preset": "gpt5"}},
|
||||
})
|
||||
r = cfg.resolve_preset()
|
||||
assert r.model == "gpt-5"
|
||||
assert r.provider == "openai"
|
||||
assert r.max_tokens == 16384
|
||||
assert r.context_window_tokens == 128000
|
||||
assert r.temperature == 0.2
|
||||
|
||||
|
||||
def test_resolve_preset_ignores_old_config_fields() -> None:
|
||||
"""Preset wins completely — old config remnants are ignored."""
|
||||
cfg = Config.model_validate({
|
||||
"model_presets": {
|
||||
"gpt5": {
|
||||
"model": "gpt-5",
|
||||
"provider": "openai",
|
||||
"max_tokens": 16384,
|
||||
"context_window_tokens": 128000,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model_preset": "gpt5",
|
||||
"model": "old-model",
|
||||
"temperature": 0.5,
|
||||
},
|
||||
},
|
||||
})
|
||||
r = cfg.resolve_preset()
|
||||
assert r.model == "gpt-5"
|
||||
assert r.temperature == 0.2
|
||||
assert r.max_tokens == 16384
|
||||
|
||||
|
||||
def test_preset_not_found_raises_error() -> None:
|
||||
import pytest
|
||||
with pytest.raises(Exception, match="model_preset.*not found"):
|
||||
Config.model_validate({
|
||||
"model_presets": {},
|
||||
"agents": {"defaults": {"model_preset": "nonexistent"}},
|
||||
})
|
||||
|
||||
|
||||
def test_resolve_preset_without_preset_returns_defaults() -> None:
|
||||
"""Backward compat: no preset → resolve_preset returns individual field values."""
|
||||
cfg = Config.model_validate({
|
||||
"agents": {"defaults": {"model": "deepseek-chat"}},
|
||||
})
|
||||
r = cfg.resolve_preset()
|
||||
assert r.model == "deepseek-chat"
|
||||
assert r.max_tokens == 8192
|
||||
|
||||
|
||||
def test_agent_loop_stores_model_presets() -> None:
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
|
||||
presets = {
|
||||
"gpt5": ModelPresetConfig(model="gpt-5", provider="openai"),
|
||||
}
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test"
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MagicMock(),
|
||||
provider=provider,
|
||||
workspace=Path("/tmp/test"),
|
||||
model_presets=presets,
|
||||
)
|
||||
assert loop.model_presets == presets
|
||||
|
||||
|
||||
def test_resolve_preset_with_reasoning_effort() -> None:
|
||||
cfg = Config.model_validate({
|
||||
"model_presets": {
|
||||
"ds-r1": {
|
||||
"model": "deepseek-r1",
|
||||
"provider": "deepseek",
|
||||
"reasoning_effort": "high",
|
||||
},
|
||||
},
|
||||
"agents": {"defaults": {"model_preset": "ds-r1"}},
|
||||
})
|
||||
assert cfg.resolve_preset().reasoning_effort == "high"
|
||||
|
||||
|
||||
def test_preset_routes_to_correct_provider() -> None:
|
||||
"""resolve_preset + _match_provider uses the preset's model+provider."""
|
||||
cfg = Config.model_validate({
|
||||
"model_presets": {
|
||||
"ds": {"model": "deepseek-chat", "provider": "deepseek"},
|
||||
},
|
||||
"providers": {"deepseek": {"api_key": "test-key"}},
|
||||
"agents": {"defaults": {"model_preset": "ds"}},
|
||||
})
|
||||
provider_name = cfg.get_provider_name()
|
||||
assert provider_name == "deepseek"
|
||||
|
||||
|
||||
def test_preset_with_auto_provider_uses_keyword_matching() -> None:
|
||||
cfg = Config.model_validate({
|
||||
"model_presets": {
|
||||
"auto-ds": {"model": "deepseek-chat", "provider": "auto"},
|
||||
},
|
||||
"providers": {"deepseek": {"api_key": "test-key"}},
|
||||
"agents": {"defaults": {"model_preset": "auto-ds"}},
|
||||
})
|
||||
provider_name = cfg.get_provider_name()
|
||||
assert provider_name == "deepseek"
|
||||
|
||||
|
||||
def test_backward_compat_no_preset() -> None:
|
||||
"""Existing configs without model_presets work exactly as before."""
|
||||
cfg = Config.model_validate({
|
||||
"providers": {"anthropic": {"api_key": "test-key"}},
|
||||
"agents": {"defaults": {"model": "anthropic/claude-opus-4-5"}},
|
||||
})
|
||||
assert cfg.resolve_preset().model == "anthropic/claude-opus-4-5"
|
||||
assert cfg.agents.defaults.model_preset is None
|
||||
assert cfg.get_provider_name() == "anthropic"
|
||||
|
||||
|
||||
def test_resolve_preset_overrides_all_model_fields() -> None:
|
||||
"""When model_preset is set, resolve_preset returns preset values, not individual fields."""
|
||||
cfg = Config.model_validate({
|
||||
"model_presets": {
|
||||
"gpt5": {"model": "gpt-5", "provider": "openai", "max_tokens": 16384},
|
||||
},
|
||||
"providers": {"openai": {"api_key": "test-key"}},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model_preset": "gpt5",
|
||||
"model": "legacy-model",
|
||||
"max_tokens": 4096,
|
||||
},
|
||||
},
|
||||
})
|
||||
r = cfg.resolve_preset()
|
||||
assert r.model == "gpt-5"
|
||||
assert r.provider == "openai"
|
||||
assert r.max_tokens == 16384
|
||||
|
||||
|
||||
def test_empty_model_presets_dict_is_harmless() -> None:
|
||||
cfg = Config.model_validate({"model_presets": {}})
|
||||
assert cfg.resolve_preset().model == "anthropic/claude-opus-4-5"
|
||||
Loading…
x
Reference in New Issue
Block a user