diff --git a/README.md b/README.md index 99c5ea2c4..cac4d0274 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,7 @@ - **Ultra-lightweight**: stable long-running agent behavior with a small, readable core. - **Research-ready**: the codebase is intentionally simple enough to study, modify, and extend. - **Practical**: chat channels, API, memory, MCP, and deployment paths are already built in. +- **Runtime model switching**: define [model presets](docs/configuration.md#model-presets) and switch between cheap/fast and powerful models mid-conversation — no restart required. - **Hackable**: you can start fast, then go deeper through repo docs instead of a monolithic landing page. ## 📦 Install diff --git a/docs/configuration.md b/docs/configuration.md index 01d55c20b..c957169ff 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -656,6 +656,146 @@ That's it! Environment variables, model routing, config matching, and `nanobot s +## Agent Settings + +### Model Presets + +Model presets let you define **named bundles** of model + generation parameters and switch between them instantly — no restart required. + +> [!NOTE] +> Config fields in `config.json` use **camelCase** (`modelPreset`, `contextWindowTokens`). +> The [`my` tool](./my-tool.md) uses **snake_case** (`model_preset`, `context_window_tokens`). +> Both refer to the same thing — just different naming conventions for config vs. runtime API. + +**Why use presets?** +- Switch between a cheap/fast model and a powerful model mid-conversation. +- Share the same config across different tasks without manually editing `model`, `provider`, `temperature`, etc. +- Runtime switching via the [`my` tool](./my-tool.md). + +> [!TIP] +> The easiest way to set up presets and fallback models is through the interactive wizard: +> ```bash +> nanobot onboard --wizard +> ``` +> Choose **"[M] Model Presets"** to create, edit, or delete presets interactively. + +**Configuration example:** + +```json +{ + "modelPresets": { + "fast": { + "model": "gpt-4.1-mini", + "provider": "openai", + "maxTokens": 4096, + "contextWindowTokens": 128000, + "temperature": 0.3 + }, + "deep": { + "model": "claude-opus-4-7", + "provider": "anthropic", + "maxTokens": 8192, + "contextWindowTokens": 200000, + "temperature": 0.1, + "reasoningEffort": "high" + } + }, + "agents": { + "defaults": { + "modelPreset": "fast" + } + } +} +``` + +**Preset fields:** + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `model` | string | *(required)* | Model identifier, e.g. `anthropic/claude-opus-4-7` or `gpt-4.1` | +| `provider` | string | `"auto"` | Provider name or `"auto"` to infer from the model string | +| `maxTokens` | integer | `8192` | Max completion tokens per turn | +| `contextWindowTokens` | integer | `65536` | Context window size for token budgeting | +| `temperature` | float | `0.1` | Sampling temperature | +| `reasoningEffort` | string or null | `null` | Thinking mode: `low`, `medium`, `high`, `adaptive` | + +**How it works:** +- When `modelPreset` is set, the preset **completely overrides** all model-specific fields in `agents.defaults`. +- When `modelPreset` is omitted, nanobot automatically creates an implicit `"default"` preset from your existing `agents.defaults.model`, `provider`, `temperature`, etc. — **zero migration required** for existing configs. + +**Runtime switching** (requires `tools.my.allowSet: true`): + +```text +my(action="set", key="model_preset", value="deep") +``` + +This atomically swaps the model, provider, generation parameters, and context window for the next turn. + +If the preset name does not exist, the agent receives an error such as `model_preset 'unknown' not found. Available: fast, deep`. + +> [!NOTE] +> Directly modifying `model` or `contextWindowTokens` via `my(action="set", key="model", ...)` still works, but it automatically clears the active preset because the live state no longer matches the preset bundle. Use `model_preset` for atomic switches instead. + +See [`my-tool.md`](./my-tool.md) for more runtime examples. + +--- + +### Fallback Models + +When the primary model returns a transient error (rate limit, server overload, quota exhausted), nanobot can automatically fail over to a chain of backup models. + +**Configuration example:** + +```json +{ + "agents": { + "defaults": { + "modelPreset": "fast", + "fallbackModels": ["deep", "backup"] + } + } +} +``` + +**How it works:** +1. nanobot tries the primary model first (the one from the active preset). +2. The provider retries transient errors internally (e.g. 3 attempts with exponential backoff for 503/429). +3. Only after the provider's own retries are exhausted and the final response still has `finish_reason == "error"` with a retryable error kind, nanobot moves to the next candidate in `fallbackModels`. +4. Each candidate must be a preset name defined in `modelPresets`. The preset's full config (model, provider, generation params) is used. +5. If all candidates are exhausted, the final error is returned to the user. + +**Failover triggers on:** +- `server_error` (503, 502, 500) +- `rate_limit` (429) +- `insufficient_quota` / `quota_exhausted` (429) + +**Failover does NOT trigger on:** +- Authentication errors (401) — rotating to another model with the same key won't help +- Invalid request errors (400) — the request itself is malformed + +> [!TIP] +> Fallback models must reference preset names defined in `modelPresets`. Define a preset for each fallback model you want to use: `["cheap-preset", "backup", "emergency"]`. + +--- + +### Other Agent Defaults + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `agents.defaults.model` | string | `"anthropic/claude-opus-4-5"` | Default model when no preset is active | +| `agents.defaults.provider` | string | `"auto"` | Default provider when no preset is active | +| `agents.defaults.maxTokens` | integer | `8192` | Max completion tokens when no preset is active | +| `agents.defaults.temperature` | float | `0.1` | Sampling temperature when no preset is active | +| `agents.defaults.reasoningEffort` | string or null | `null` | Thinking mode when no preset is active | +| `agents.defaults.maxToolIterations` | integer | `200` | Max tool calls per conversation turn | +| `agents.defaults.maxToolResultChars` | integer | `16000` | Max characters per tool result | +| `agents.defaults.providerRetryMode` | string | `"standard"` | `"standard"` or `"persistent"` — how aggressively to retry provider-level errors | +| `agents.defaults.timezone` | string | `"UTC"` | IANA timezone for runtime context | +| `agents.defaults.unifiedSession` | boolean | `false` | Share one session across all channels | +| `agents.defaults.sessionTtlMinutes` | integer | `0` | Auto-compact idle threshold (0 = disabled) | +| `agents.defaults.maxMessages` | integer | `120` | Max messages to replay from session history | +| `agents.defaults.consolidationRatio` | float | `0.5` | Target ratio retained after context compression | + ## Channel Settings Global settings that apply to all channels. Configure under the `channels` section in `~/.nanobot/config.json`: diff --git a/docs/my-tool.md b/docs/my-tool.md index bc22ed5a9..00c491dfd 100644 --- a/docs/my-tool.md +++ b/docs/my-tool.md @@ -12,6 +12,11 @@ My tool fills this gap. With it, the agent can: - **Adapt on the fly**: Complex task? Expand the context window. Simple chat? Switch to a faster model. - **Remember across turns**: Store notes in your scratchpad that persist into the next conversation turn. +> [!NOTE] +> This tool uses **snake_case** keys (`model_preset`, `context_window_tokens`). +> The matching config fields in `config.json` are **camelCase** (`modelPreset`, `contextWindowTokens`). +> See [`configuration.md`](./configuration.md#model-presets) for how to define presets in your config. + ## Configuration Enabled by default (read-only mode). The agent can check its state but not set it. @@ -39,8 +44,7 @@ Without parameters, returns a key config overview: ```text my(action="check") # → max_iterations: 40 -# context_window_tokens: 65536 -# model: 'anthropic/claude-sonnet-4-20250514' +# model_preset: 'fast' # workspace: PosixPath('/tmp/workspace') # provider_retry_mode: 'standard' # max_tool_result_chars: 16000 @@ -55,8 +59,13 @@ With a key parameter, drill into a specific config: my(action="check", key="_last_usage.prompt_tokens") # → How many prompt tokens I've used so far -my(action="check", key="model") -# → What model I'm currently running on +my(action="check", key="model_preset") +# → Current active preset name (e.g. 'fast') + +my(action="check", key="model_presets") +# → Lists all preset names and their models, e.g.: +# fast → gpt-4.1-mini (openai) +# deep → claude-opus-4-7 (anthropic) my(action="check", key="web_config.enable") # → Whether web search is enabled @@ -66,7 +75,7 @@ my(action="check", key="web_config.enable") | Scenario | How | |----------|-----| -| "What model are you using?" | `check("model")` | +| "What model are you using?" | `check("model_preset")` | | "How many more tool calls can you make?" | `check("max_iterations")` minus `check("_current_iteration")` | | "How many tokens has this conversation used?" | `check("_last_usage")` — cumulative across all turns | | "Where is your working directory?" | `check("workspace")` | @@ -83,8 +92,11 @@ Changes take effect immediately, no restart required. my(action="set", key="max_iterations", value=80) # → Bump iteration limit from 40 to 80 -my(action="set", key="model", value="fast-model") -# → Switch to a faster model +my(action="set", key="model_preset", value="fast") +# → Switch to the 'fast' preset (model, provider, temperature, etc. all at once) +# +# If the preset name does not exist: +# → Error: model_preset 'unknown' not found. Available: fast, deep my(action="set", key="context_window_tokens", value=131072) # → Expand context window for long documents @@ -101,15 +113,17 @@ my(action="set", key="task_complexity", value="high") ### Protected parameters -These parameters have type and range validation — invalid values are rejected: +These parameters have validation — invalid values are rejected: -| Parameter | Type | Range | Purpose | -|-----------|------|-------|---------| +| Parameter | Type | Range / Constraint | Purpose | +|-----------|------|-------------------|---------| | `max_iterations` | int | 1–100 | Max tool calls per conversation turn | -| `context_window_tokens` | int | 4,096–1,000,000 | Context window size | -| `model` | str | non-empty | LLM model to use | +| `model_preset` | str | must exist in `model_presets` | Switch to a named preset bundle | -Other parameters (e.g. `workspace`, `provider_retry_mode`, `max_tool_result_chars`) can be set freely, as long as the value is JSON-safe. +Other parameters (e.g. `model`, `context_window_tokens`, `workspace`, `provider_retry_mode`, `max_tool_result_chars`) can be set freely, as long as the value is JSON-safe. + +> [!NOTE] +> Setting `model` or `context_window_tokens` directly automatically clears the active `model_preset`, because the live state no longer matches the preset bundle. Use `model_preset` for atomic switches instead. --- @@ -125,8 +139,8 @@ Agent: This codebase is large, let me expand my context window to handle it. ### "Simple question, don't waste compute" ```text -Agent: This is a straightforward question, let me switch to a faster model. -→ my(action="set", key="model", value="fast-model") +Agent: This is a straightforward question, let me switch to the fast preset. +→ my(action="set", key="model_preset", value="fast") ``` ### "Remember user preferences across turns" diff --git a/docs/quick-start.md b/docs/quick-start.md index 7112ba8ca..3852b2d85 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -95,6 +95,8 @@ Configure these **two parts** in your config (other options have defaults). } ``` +*Want to switch models mid-conversation?* Define [`modelPresets`](./configuration.md#model-presets) and switch instantly with `my(action="set", key="model_preset", value="fast")`. + **3. Chat** ```bash diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index a001f7037..5c5cacca5 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -197,6 +197,50 @@ class AgentLoop: _RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint" _PENDING_USER_TURN_KEY = "pending_user_turn" + @classmethod + def from_config( + cls, + config: Any, + bus: MessageBus | None = None, + **extra: Any, + ) -> AgentLoop: + """Create an AgentLoop from config with the common parameter set.""" + from nanobot.providers.factory import build_provider_for_preset, make_provider_factory + + if bus is None: + bus = MessageBus() + defaults = config.agents.defaults + resolved_preset = config.resolve_preset() + provider = build_provider_for_preset(config, resolved_preset) + return cls( + bus=bus, + provider=provider, + workspace=config.workspace_path, + model=resolved_preset.model, + max_iterations=defaults.max_tool_iterations, + context_window_tokens=resolved_preset.context_window_tokens, + context_block_limit=defaults.context_block_limit, + max_tool_result_chars=defaults.max_tool_result_chars, + provider_retry_mode=defaults.provider_retry_mode, + fallback_presets=defaults.fallback_presets, + provider_factory=make_provider_factory(config), + web_config=config.tools.web, + exec_config=config.tools.exec, + restrict_to_workspace=config.tools.restrict_to_workspace, + mcp_servers=config.tools.mcp_servers, + channels_config=config.channels, + timezone=defaults.timezone, + unified_session=defaults.unified_session, + disabled_skills=defaults.disabled_skills, + session_ttl_minutes=defaults.session_ttl_minutes, + consolidation_ratio=defaults.consolidation_ratio, + max_messages=defaults.max_messages, + tools_config=config.tools, + model_presets=config.model_presets, + model_preset=defaults.model_preset, + **extra, + ) + def __init__( self, bus: MessageBus, @@ -209,8 +253,8 @@ class AgentLoop: max_tool_result_chars: int | None = None, provider_retry_mode: str = "standard", tool_hint_max_length: int | None = None, - fallback_models: list[str] | None = None, - provider_factory: Any | None = None, + fallback_presets: list[str] | None = None, + provider_factory: Callable[[str], LLMProvider] | None = None, web_config: WebToolsConfig | None = None, exec_config: ExecToolConfig | None = None, cron_service: CronService | None = None, @@ -239,7 +283,12 @@ class AgentLoop: defaults = AgentDefaults() self.bus = bus self.channels_config = channels_config - self.provider = provider + self.provider_factory = provider_factory + self.fallback_presets = fallback_presets or [] + wrapped_provider = self._wrap_with_failover( + provider, model or provider.get_default_model() + ) + self.provider = wrapped_provider self._provider_snapshot_loader = provider_snapshot_loader self._provider_signature = provider_signature self.workspace = workspace @@ -263,7 +312,6 @@ class AgentLoop: tool_hint_max_length if tool_hint_max_length is not None else defaults.tool_hint_max_length ) - self.fallback_models = fallback_models or [] self.web_config = web_config or WebToolsConfig() self.exec_config = exec_config or ExecToolConfig() self.tools_config = _tc @@ -278,15 +326,16 @@ class AgentLoop: self._start_time = time.time() self._last_usage: dict[str, int] = {} self._extra_hooks: list[AgentHook] = hooks or [] + self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills) self.sessions = session_manager or SessionManager(workspace) self.tools = ToolRegistry() # One file-read/write tracker per logical session. The tool registry is # shared by this loop, so tools resolve the active state via contextvars. self._file_state_store = FileStateStore() - self.runner = AgentRunner(provider, provider_factory=provider_factory) + self.runner = AgentRunner(wrapped_provider) self.subagents = SubagentManager( - provider=provider, + provider=wrapped_provider, workspace=workspace, bus=bus, model=self.model, @@ -318,13 +367,13 @@ class AgentLoop: ) self.consolidator = Consolidator( store=self.context.memory, - provider=provider, + provider=wrapped_provider, model=self.model, sessions=self.sessions, context_window_tokens=self.context_window_tokens, build_messages=self.context.build_messages, get_tool_definitions=self.tools.get_definitions, - max_completion_tokens=provider.generation.max_tokens, + max_completion_tokens=wrapped_provider.generation.max_tokens, consolidation_ratio=consolidation_ratio, ) self.auto_compact = AutoCompact( @@ -334,11 +383,13 @@ class AgentLoop: ) self.dream = Dream( store=self.context.memory, - provider=provider, + provider=wrapped_provider, model=self.model, ) self.model_presets: dict[str, ModelPresetConfig] = model_presets or {} - self._active_preset: str | None = model_preset if model_presets and model_preset in model_presets else None + self._active_preset: str | None = ( + model_preset if model_preset in self.model_presets else None + ) self._register_default_tools() if _tc.my.enable: self.tools.register(MyTool(loop=self, modify_allowed=_tc.my.allow_set)) @@ -351,6 +402,38 @@ class AgentLoop: """Keep subagent runtime limits aligned with mutable loop settings.""" self.subagents.max_iterations = self.max_iterations + def _wrap_with_failover(self, provider: LLMProvider, model: str) -> LLMProvider: + """Wrap provider with failover router when fallback_presets are configured.""" + if not self.fallback_presets or not self.provider_factory: + return provider + from nanobot.providers.failover import ModelRouter + + if isinstance(provider, ModelRouter): + return provider + + return ModelRouter( + primary_provider=provider, + primary_model=model, + fallback_presets=self.fallback_presets, + provider_factory=self.provider_factory, + ) + + def _apply_provider_state( + self, + provider: LLMProvider, + model: str, + context_window_tokens: int, + ) -> None: + """Push provider/model/context_window to all LLM-consuming subsystems.""" + self.provider = provider + # Bypass property setters so internal updates don't clear _active_preset. + object.__setattr__(self, "_model", model) + object.__setattr__(self, "_context_window_tokens", context_window_tokens) + self.runner.provider = provider + self.subagents.set_provider(provider, model) + self.consolidator.set_provider(provider, model, context_window_tokens) + self.dream.set_provider(provider, model) + def _apply_provider_snapshot(self, snapshot: ProviderSnapshot) -> None: """Swap model/provider for future turns without disturbing an active one.""" provider = snapshot.provider @@ -359,14 +442,13 @@ class AgentLoop: if self.provider is provider and self.model == model: return old_model = self.model - self.provider = provider - self.model = model - self.context_window_tokens = context_window_tokens - self.runner.provider = provider - self.subagents.set_provider(provider, model) - self.consolidator.set_provider(provider, model, context_window_tokens) - self.dream.set_provider(provider, model) + provider = self._wrap_with_failover(provider, model) + self._apply_provider_state(provider, model, context_window_tokens) self._provider_signature = snapshot.signature + if self._active_preset: + preset = self.model_presets.get(self._active_preset) + if preset and preset.model != model: + self._active_preset = None logger.info("Runtime model switched for next turn: {} -> {}", old_model, model) def _refresh_provider_snapshot(self) -> None: @@ -381,6 +463,28 @@ class AgentLoop: return self._apply_provider_snapshot(snapshot) + # -- model / context_window_tokens properties with preset invalidation -- + + @property + def model(self) -> str: + return self._model + + @model.setter + def model(self, value: str) -> None: + self._model = value + if hasattr(self, "_active_preset"): + self._active_preset = None + + @property + def context_window_tokens(self) -> int: + return self._context_window_tokens + + @context_window_tokens.setter + def context_window_tokens(self, value: int) -> None: + self._context_window_tokens = value + if hasattr(self, "_active_preset"): + self._active_preset = None + # -- model_preset property -- @property @@ -388,22 +492,27 @@ class AgentLoop: return self._active_preset @model_preset.setter - def model_preset(self, name: str | None) -> None: - """Resolve a preset by name and apply all fields atomically.""" - from nanobot.providers.base import GenerationSettings - + def model_preset(self, name: str) -> None: + """Resolve a preset by name and apply all fields.""" if not isinstance(name, str) or not name.strip(): raise ValueError("model_preset must be a non-empty string") if name not in self.model_presets: - raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}") + raise KeyError( + f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}" + ) + if self.provider_factory is None: + raise ValueError("provider_factory is not configured; cannot switch model preset") + p = self.model_presets[name] - self.model = p.model - self.context_window_tokens = p.context_window_tokens - self.provider.generation = GenerationSettings( - temperature=p.temperature, - max_tokens=p.max_tokens, - reasoning_effort=p.reasoning_effort, - ) + new_provider = self._wrap_with_failover(self.provider_factory(name), p.model) + + # Preserve dream model_override if it differs from the current loop model. + old_dream_model = self.dream.model + dream_had_override = old_dream_model != self.model + + self._apply_provider_state(new_provider, p.model, p.context_window_tokens) + if dream_had_override: + self.dream.model = old_dream_model self._active_preset = name def _register_default_tools(self) -> None: @@ -710,7 +819,6 @@ class AgentLoop: context_window_tokens=self.context_window_tokens, context_block_limit=self.context_block_limit, provider_retry_mode=self.provider_retry_mode, - fallback_models=self.fallback_models, progress_callback=on_progress, stream_progress_deltas=on_stream is not None, retry_wait_callback=on_retry_wait, diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 532b4da0d..7fe92ad51 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -16,7 +16,6 @@ from nanobot.agent.hook import AgentHook, AgentHookContext from nanobot.agent.tools.ask import AskUserInterrupt from nanobot.agent.tools.registry import ToolRegistry from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest -from nanobot.providers.failover import ModelCandidate, ModelRouter from nanobot.utils.helpers import ( build_assistant_message, estimate_message_tokens, @@ -76,7 +75,6 @@ class AgentRunSpec: context_window_tokens: int | None = None context_block_limit: int | None = None provider_retry_mode: str = "standard" - fallback_models: list[str] = field(default_factory=list) progress_callback: Any | None = None stream_progress_deltas: bool = True retry_wait_callback: Any | None = None @@ -99,21 +97,11 @@ class AgentRunResult: had_injections: bool = False -ProviderFactory = Any # Callable[[str], LLMProvider] — avoids circular import - - class AgentRunner: """Run a tool-capable LLM loop without product-layer concerns.""" - def __init__( - self, - provider: LLMProvider, - *, - provider_factory: ProviderFactory | None = None, - ): + def __init__(self, provider: LLMProvider): self.provider = provider - self._provider_factory = provider_factory - self._fallback_providers: dict[str, LLMProvider] = {} @staticmethod def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]: @@ -606,9 +594,12 @@ class AgentRunner: messages: list[dict[str, Any]], hook: AgentHook, context: AgentHookContext, - ) -> LLMResponse: + ): timeout_s: float | None = spec.llm_timeout_s if timeout_s is None: + # Default to a finite timeout to avoid per-session lock starvation when an LLM + # request hangs indefinitely (e.g. gateway/network stall). + # Set NANOBOT_LLM_TIMEOUT_S=0 to disable. raw = os.environ.get("NANOBOT_LLM_TIMEOUT_S", "300").strip() try: timeout_s = float(raw) @@ -622,30 +613,12 @@ class AgentRunner: messages, tools=spec.tools.get_definitions(), ) - provider: LLMProvider = self.provider - request_timeout = timeout_s - if spec.fallback_models: - provider = self._build_model_router(spec, timeout_s) - # ModelRouter applies the same timeout per candidate, preserving - # fallback on primary timeouts instead of timing out the whole chain. - request_timeout = None - return await self._call_provider(provider, kwargs, hook, context, spec, request_timeout) - - async def _call_provider( - self, - provider: LLMProvider, - kwargs: dict[str, Any], - hook: AgentHook, - context: AgentHookContext, - spec: AgentRunSpec, - timeout_s: float | None = None, - ) -> LLMResponse: wants_streaming = hook.wants_streaming() wants_progress_streaming = ( not wants_streaming and spec.stream_progress_deltas and spec.progress_callback is not None - and getattr(provider, "supports_progress_deltas", False) is True + and getattr(self.provider, "supports_progress_deltas", False) is True ) if wants_streaming: @@ -654,7 +627,7 @@ class AgentRunner: context.streamed_content = True await hook.on_stream(context, delta) - coro = provider.chat_stream_with_retry( + coro = self.provider.chat_stream_with_retry( **kwargs, on_content_delta=_stream, ) @@ -673,12 +646,12 @@ class AgentRunner: context.streamed_content = True await spec.progress_callback(incremental) - coro = provider.chat_stream_with_retry( + coro = self.provider.chat_stream_with_retry( **kwargs, on_content_delta=_stream_progress, ) else: - coro = provider.chat_with_retry(**kwargs) + coro = self.provider.chat_with_retry(**kwargs) if timeout_s is None: return await coro @@ -691,41 +664,6 @@ class AgentRunner: error_kind="timeout", ) - def _resolve_fallback_provider(self, model: str) -> tuple[LLMProvider, str]: - """Return (provider, actual_model_name) for a fallback model. - - When a provider_factory is available (and the model string may be a - preset name), the factory resolves the actual model; otherwise the - primary provider is reused with the raw model string. - """ - if model in self._fallback_providers: - p = self._fallback_providers[model] - return p, p.get_default_model() - if self._provider_factory: - provider = self._provider_factory(model) - self._fallback_providers[model] = provider - return provider, provider.get_default_model() - return self.provider, model - - def _build_model_router( - self, - spec: AgentRunSpec, - timeout_s: float | None, - ) -> ModelRouter: - candidates = [ - ModelCandidate( - label=model, - resolver=lambda m=model: self._resolve_fallback_provider(m), - ) - for model in spec.fallback_models - ] - return ModelRouter( - primary_provider=self.provider, - primary_model=spec.model, - fallback_candidates=candidates, - per_candidate_timeout_s=timeout_s, - ) - async def _request_finalization_retry( self, spec: AgentRunSpec, diff --git a/nanobot/agent/tools/self.py b/nanobot/agent/tools/self.py index 86a6f934a..6b40299ca 100644 --- a/nanobot/agent/tools/self.py +++ b/nanobot/agent/tools/self.py @@ -76,8 +76,6 @@ class MyTool(Tool): RESTRICTED: dict[str, dict[str, Any]] = { "max_iterations": {"type": int, "min": 1, "max": 100}, - "context_window_tokens": {"type": int, "min": 4096, "max": 1_000_000}, - "model": {"type": str, "min_len": 1}, } _MAX_RUNTIME_KEYS = 64 @@ -118,13 +116,14 @@ class MyTool(Tool): "Scratchpad keys persist across turns but not restarts.\n" "Key values: _current_iteration (current progress), " "max_iterations - _current_iteration = remaining iterations.\n" + "Use 'model_preset' to switch the active model preset.\n" "Note: web_config and exec_config are readable but read-only.\n" "\n" "When to use:\n" "- User asks about your model, settings, or token usage → check that key.\n" "- A tool fails or behaves unexpectedly → check the related config to diagnose.\n" "- User asks you to remember a preference for this session → set to store it in your scratchpad.\n" - "- About to start a large task → check context_window_tokens and max_iterations first." + "- About to start a large task → check max_iterations and model_preset first." ) if not self._modify_allowed: base += "\nREAD-ONLY MODE: set is disabled." @@ -132,7 +131,7 @@ class MyTool(Tool): base += ( "\nIMPORTANT: Before setting state, predict the potential impact. " "If the operation could cause crashes or instability " - "(e.g. changing model), warn the user first." + "(e.g. changing model_preset), warn the user first." ) return base @@ -148,7 +147,7 @@ class MyTool(Tool): }, "key": { "type": "string", - "description": "Dot-path for check/set. Examples: 'max_iterations', 'workspace', 'provider_retry_mode'. " + "description": "Dot-path for check/set. Examples: 'max_iterations', 'model_preset', 'provider_retry_mode'. " "For check without key, shows all config values.", }, "value": {"description": "New value (for set). Type must match target (int for max_iterations/context_window_tokens, str for model)."}, @@ -391,9 +390,6 @@ class MyTool(Tool): # --- existing restricted key logic --- old = getattr(self._loop, key) - # When model is set directly, it no longer matches any preset - if key == "model": - self._loop._active_preset = None if "min" in spec and value < spec["min"]: return f"Error: '{key}' must be >= {spec['min']}" if "max" in spec and value > spec["max"]: @@ -419,9 +415,12 @@ class MyTool(Tool): f"REJECTED type mismatch {key}: expects {old_t.__name__}, got {new_t.__name__}", ) return f"Error: '{key}' expects {old_t.__name__}, got {new_t.__name__}" + # When a model-specific field is set directly, it no longer matches any preset + if key in ("model", "context_window_tokens"): + self._loop._active_preset = None try: setattr(self._loop, key, value) - except (ValueError, KeyError) as e: + except (AttributeError, TypeError, ValueError, KeyError) as e: self._audit("modify", f"REJECTED {key}: {e}") return f"Error: {e}" self._audit("modify", f"{key}: {old!r} -> {value!r}") diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 34abb8fc5..171769459 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -48,6 +48,7 @@ from rich.table import Table from rich.text import Text from nanobot import __logo__, __version__ +from nanobot.agent.loop import AgentLoop class SafeFileHistory(FileHistory): @@ -437,104 +438,6 @@ def _onboard_plugins(config_path: Path) -> None: json.dump(data, f, indent=2, ensure_ascii=False) -def _make_provider(config: Config): - """Create the appropriate LLM provider from config. - - Routing is driven by ``ProviderSpec.backend`` in the registry. - """ - from nanobot.providers.base import GenerationSettings - from nanobot.providers.factory import make_provider - from nanobot.providers.registry import find_by_name - - resolved = config.resolve_preset() - model = resolved.model - provider_name = config.get_provider_name(model) - p = config.get_provider(model) - spec = find_by_name(provider_name) if provider_name else None - backend = spec.backend if spec else "openai_compat" - - # --- validation --- - if backend == "azure_openai": - if not p or not p.api_key or not p.api_base: - console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]") - console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section") - console.print("Use the model field to specify the deployment name.") - raise typer.Exit(1) - elif backend == "openai_compat" and not model.startswith("bedrock/"): - needs_key = not (p and p.api_key) - exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct) - if needs_key and not exempt: - console.print("[red]Error: No API key configured.[/red]") - console.print("Set one in ~/.nanobot/config.json under providers section") - raise typer.Exit(1) - - # --- instantiation by backend --- - if backend == "openai_codex": - from nanobot.providers.openai_codex_provider import OpenAICodexProvider - - provider = OpenAICodexProvider(default_model=model) - elif backend == "azure_openai": - from nanobot.providers.azure_openai_provider import AzureOpenAIProvider - - provider = AzureOpenAIProvider( - api_key=p.api_key, - api_base=p.api_base, - default_model=model, - ) - elif backend == "github_copilot": - from nanobot.providers.github_copilot_provider import GitHubCopilotProvider - provider = GitHubCopilotProvider(default_model=model) - elif backend == "anthropic": - from nanobot.providers.anthropic_provider import AnthropicProvider - - provider = AnthropicProvider( - api_key=p.api_key if p else None, - api_base=config.get_api_base(model), - default_model=model, - extra_headers=p.extra_headers if p else None, - ) - else: - from nanobot.providers.openai_compat_provider import OpenAICompatProvider - - provider = OpenAICompatProvider( - api_key=p.api_key if p else None, - api_base=config.get_api_base(model), - default_model=model, - extra_headers=p.extra_headers if p else None, - spec=spec, - extra_body=p.extra_body if p else None, - ) - - provider.generation = GenerationSettings( - temperature=resolved.temperature, - max_tokens=resolved.max_tokens, - reasoning_effort=resolved.reasoning_effort, - ) - return provider - - -def _make_cli_provider_factory(config: Config): - """Build a cached factory for fallback model providers (CLI side). - - Supports preset names: if a model string matches a preset, the preset's - full config is used for provider creation. - """ - from nanobot.nanobot import _make_provider_for_model - - cache: dict[str, Any] = {} - presets = getattr(config, "model_presets", {}) or {} - - def factory(model_or_preset: str): - preset = presets.get(model_or_preset) - actual_model = preset.model if preset else model_or_preset - key = actual_model - if key not in cache: - cache[key] = _make_provider_for_model(config, actual_model, preset=preset) - return cache[key] - - return factory - - def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config: """Load config and optionally override the active workspace.""" from nanobot.config.loader import load_config, resolve_config_env_vars, set_config_path @@ -612,8 +515,6 @@ def serve( raise typer.Exit(1) from loguru import logger - - from nanobot.agent.loop import AgentLoop from nanobot.api.server import create_app from nanobot.bus.queue import MessageBus from nanobot.session.manager import SessionManager @@ -630,45 +531,19 @@ def serve( timeout = timeout if timeout is not None else api_cfg.timeout sync_workspace_templates(runtime_config.workspace_path) bus = MessageBus() - provider = _make_provider(runtime_config) defaults = runtime_config.agents.defaults - pf = _make_cli_provider_factory(runtime_config) if defaults.fallback_models else None session_manager = SessionManager(runtime_config.workspace_path) - _resolved = runtime_config.resolve_preset() - agent_loop = AgentLoop( - bus=bus, - provider=provider, - workspace=runtime_config.workspace_path, - model=_resolved.model, - max_iterations=defaults.max_tool_iterations, - context_window_tokens=_resolved.context_window_tokens, - context_block_limit=defaults.context_block_limit, - max_tool_result_chars=defaults.max_tool_result_chars, - provider_retry_mode=defaults.provider_retry_mode, - fallback_models=defaults.fallback_models, - provider_factory=pf, - web_config=runtime_config.tools.web, - exec_config=runtime_config.tools.exec, - restrict_to_workspace=runtime_config.tools.restrict_to_workspace, + resolved_preset = runtime_config.resolve_preset() + agent_loop = AgentLoop.from_config( + runtime_config, bus, session_manager=session_manager, - mcp_servers=runtime_config.tools.mcp_servers, - channels_config=runtime_config.channels, - timezone=runtime_config.agents.defaults.timezone, - unified_session=runtime_config.agents.defaults.unified_session, - disabled_skills=runtime_config.agents.defaults.disabled_skills, - session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes, - consolidation_ratio=runtime_config.agents.defaults.consolidation_ratio, - max_messages=runtime_config.agents.defaults.max_messages, - tools_config=runtime_config.tools, image_generation_provider_configs={ "openrouter": runtime_config.providers.openrouter, "aihubmix": runtime_config.providers.aihubmix, }, - model_presets=runtime_config.model_presets, - model_preset=runtime_config.agents.defaults.model_preset, ) - model_name = _resolved.model + model_name = resolved_preset.model preset_name = defaults.model_preset preset_tag = f" (preset: {preset_name})" if preset_name else "" console.print(f"{__logo__} Starting OpenAI-compatible API server") @@ -735,7 +610,6 @@ def _run_gateway( open_browser_url: str | None = None, ) -> None: """Shared gateway runtime; ``open_browser_url`` opens a tab once channels are up.""" - from nanobot.agent.loop import AgentLoop from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.message import MessageTool from nanobot.bus.queue import MessageBus @@ -751,9 +625,6 @@ def _run_gateway( console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...") sync_workspace_templates(config.workspace_path) bus = MessageBus() - provider = _make_provider(config) - gw_defaults = config.agents.defaults - gw_pf = _make_cli_provider_factory(config) if gw_defaults.fallback_models else None try: provider_snapshot = build_provider_snapshot(config) except ValueError as exc: @@ -770,41 +641,16 @@ def _run_gateway( cron = CronService(cron_store_path) # Create agent with cron service - _resolved = config.resolve_preset() - agent = AgentLoop( - bus=bus, - provider=provider, - workspace=config.workspace_path, - model=_resolved.model, - max_iterations=gw_defaults.max_tool_iterations, - context_window_tokens=_resolved.context_window_tokens, - web_config=config.tools.web, - context_block_limit=gw_defaults.context_block_limit, - max_tool_result_chars=gw_defaults.max_tool_result_chars, - provider_retry_mode=gw_defaults.provider_retry_mode, - fallback_models=gw_defaults.fallback_models, - provider_factory=gw_pf, - exec_config=config.tools.exec, + agent = AgentLoop.from_config( + config, bus, cron_service=cron, - restrict_to_workspace=config.tools.restrict_to_workspace, session_manager=session_manager, - mcp_servers=config.tools.mcp_servers, - channels_config=config.channels, - timezone=config.agents.defaults.timezone, - unified_session=config.agents.defaults.unified_session, - disabled_skills=config.agents.defaults.disabled_skills, - session_ttl_minutes=config.agents.defaults.session_ttl_minutes, - consolidation_ratio=config.agents.defaults.consolidation_ratio, - max_messages=config.agents.defaults.max_messages, - tools_config=config.tools, image_generation_provider_configs={ "openrouter": config.providers.openrouter, "aihubmix": config.providers.aihubmix, }, provider_snapshot_loader=load_provider_snapshot, provider_signature=provider_snapshot.signature, - model_presets=config.model_presets, - model_preset=config.agents.defaults.model_preset, ) from nanobot.agent.loop import UNIFIED_SESSION_KEY @@ -908,7 +754,7 @@ def _run_gateway( if job.payload.deliver and job.payload.to and response: should_notify = await evaluate_response( - response, reminder_note, provider, agent.model, + response, reminder_note, agent.provider, agent.model, ) if should_notify: await _deliver_to_channel( @@ -998,7 +844,7 @@ def _run_gateway( hb_cfg = config.gateway.heartbeat heartbeat = HeartbeatService( workspace=config.workspace_path, - provider=provider, + provider=agent.provider, model=agent.model, on_execute=on_heartbeat_execute, on_notify=on_heartbeat_notify, @@ -1151,7 +997,6 @@ def agent( """Interact with the agent directly.""" from loguru import logger - from nanobot.agent.loop import AgentLoop from nanobot.bus.queue import MessageBus from nanobot.cron.service import CronService @@ -1159,10 +1004,6 @@ def agent( sync_workspace_templates(config.workspace_path) bus = MessageBus() - provider = _make_provider(config) - chat_defaults = config.agents.defaults - chat_pf = _make_cli_provider_factory(config) if chat_defaults.fallback_models else None - # Preserve existing single-workspace installs, but keep custom workspaces clean. if is_default_workspace(config.workspace_path): _migrate_cron_store(config) @@ -1176,34 +1017,10 @@ def agent( else: logger.disable("nanobot") - _resolved = config.resolve_preset() - agent_loop = AgentLoop( - bus=bus, - provider=provider, - workspace=config.workspace_path, - model=_resolved.model, - max_iterations=chat_defaults.max_tool_iterations, - context_window_tokens=_resolved.context_window_tokens, - web_config=config.tools.web, - context_block_limit=chat_defaults.context_block_limit, - max_tool_result_chars=chat_defaults.max_tool_result_chars, - provider_retry_mode=chat_defaults.provider_retry_mode, - fallback_models=chat_defaults.fallback_models, - provider_factory=chat_pf, - exec_config=config.tools.exec, + resolved_preset = config.resolve_preset() + agent_loop = AgentLoop.from_config( + config, bus, cron_service=cron, - restrict_to_workspace=config.tools.restrict_to_workspace, - mcp_servers=config.tools.mcp_servers, - channels_config=config.channels, - timezone=config.agents.defaults.timezone, - unified_session=config.agents.defaults.unified_session, - disabled_skills=config.agents.defaults.disabled_skills, - session_ttl_minutes=config.agents.defaults.session_ttl_minutes, - consolidation_ratio=config.agents.defaults.consolidation_ratio, - max_messages=config.agents.defaults.max_messages, - tools_config=config.tools, - model_presets=config.model_presets, - model_preset=config.agents.defaults.model_preset, ) restart_notice = consume_restart_notice_from_env() if restart_notice and should_show_cli_restart_notice(restart_notice, session_id): @@ -1247,7 +1064,7 @@ def agent( # Interactive mode — route through bus like other channels from nanobot.bus.events import InboundMessage _init_prompt_session() - console.print(f"{__logo__} Interactive mode [bold blue]({_resolved.model})[/bold blue] — type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit\n") + console.print(f"{__logo__} Interactive mode [bold blue]({resolved_preset.model})[/bold blue] — type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit\n") if ":" in session_id: cli_channel, cli_chat_id = session_id.split(":", 1) @@ -1605,10 +1422,10 @@ def status(): if config_path.exists(): from nanobot.providers.registry import PROVIDERS - _resolved = config.resolve_preset() - _preset = config.agents.defaults.model_preset - _preset_tag = f" (preset: {_preset})" if _preset else "" - console.print(f"Model: {_resolved.model}{_preset_tag}") + resolved_preset = config.resolve_preset() + preset = config.agents.defaults.model_preset + preset_tag = f" (preset: {preset})" if preset else "" + console.print(f"Model: {resolved_preset.model}{preset_tag}") # Check API keys from registry for spec in PROVIDERS: diff --git a/nanobot/cli/onboard.py b/nanobot/cli/onboard.py index 13b2a978a..b3044b448 100644 --- a/nanobot/cli/onboard.py +++ b/nanobot/cli/onboard.py @@ -22,7 +22,7 @@ from nanobot.cli.models import ( get_model_suggestions, ) from nanobot.config.loader import get_config_path, load_config -from nanobot.config.schema import Config +from nanobot.config.schema import Config, ModelPresetConfig console = Console() @@ -49,6 +49,16 @@ _SELECT_FIELD_HINTS: dict[str, tuple[list[str], str]] = { _BACK_PRESSED = object() # Sentinel value for back navigation +# Cache of model-preset names populated at runtime so that field handlers can +# offer existing presets as choices (e.g. AgentDefaults.model_preset). +# +# Lifecycle: populated by _sync_preset_cache(config), which must be called +# after every config mutation that changes model_presets (add, delete, edit). +# Cleared between tests via _MODEL_PRESET_CACHE.clear(). In long-running +# processes (gateway) the cache is refreshed each time the preset management +# screen is entered, so staleness is bounded by user interaction. +_MODEL_PRESET_CACHE: set[str] = set() + def _get_questionary(): """Return questionary or raise a clear error when wizard deps are unavailable.""" @@ -588,9 +598,100 @@ def _handle_context_window_field( setattr(working_model, field_name, new_value) +def _handle_model_preset_field( + working_model: BaseModel, field_name: str, field_display: str, current_value: Any +) -> None: + """Handle the 'model_preset' field with a list of existing presets.""" + # model_preset lives on AgentDefaults, but the preset list is on Config. + # We can't easily access Config here, so we read from the global config + # via a module-level cache set by _configure_model_presets / run_onboard. + preset_names = sorted(_MODEL_PRESET_CACHE) + choices = ["(clear/unset)"] + preset_names + default_choice = str(current_value) if current_value else "(clear/unset)" + new_value = _select_with_back(field_display, choices, default=default_choice) + if new_value is _BACK_PRESSED: + return + if new_value == "(clear/unset)": + setattr(working_model, field_name, None) + elif new_value is not None: + setattr(working_model, field_name, new_value) + + +def _handle_provider_field( + working_model: BaseModel, field_name: str, field_display: str, current_value: Any +) -> None: + """Handle the 'provider' field with a list of registered providers.""" + provider_names = sorted(_get_provider_names().keys()) + choices = ["auto"] + provider_names + default_choice = str(current_value) if current_value else "auto" + new_value = _select_with_back(field_display, choices, default=default_choice) + if new_value is _BACK_PRESSED: + return + if new_value is not None: + setattr(working_model, field_name, new_value) + + +def _handle_fallback_presets_field( + working_model: BaseModel, field_name: str, field_display: str, current_value: Any +) -> None: + """Handle the 'fallback_presets' field with preset-aware multi-select.""" + items: list[str] = list(current_value) if isinstance(current_value, list) else [] + preset_names = sorted(_MODEL_PRESET_CACHE) + + while True: + console.clear() + console.print(f"[bold]{field_display}[/bold]") + if items: + for idx, item in enumerate(items, 1): + console.print(f" {idx}. {item}") + else: + console.print(" [dim](empty)[/dim]") + console.print() + + choices = ["[+] Add preset"] + if items: + choices.append("[-] Remove last") + choices.append("[X] Clear all") + choices.append("[Done]") + choices.append("<- Back") + + answer = _get_questionary().select( + "Manage fallback chain:", + choices=choices, + qmark=">", + ).ask() + + if answer is None or answer == "<- Back": + return + if answer == "[Done]": + setattr(working_model, field_name, items) + return + if answer == "[+] Add preset": + if not preset_names: + console.print("[yellow]! No presets defined yet.[/yellow]") + _get_questionary().press_any_key_to_continue().ask() + continue + add_choices = [p for p in preset_names if p not in items] + if not add_choices: + console.print("[yellow]! All presets already added.[/yellow]") + _get_questionary().press_any_key_to_continue().ask() + continue + picked = _select_with_back("Select preset:", add_choices) + if picked is _BACK_PRESSED or picked is None: + continue + items.append(picked) + elif answer == "[-] Remove last" and items: + items.pop() + elif answer == "[X] Clear all" and items: + items.clear() + + _FIELD_HANDLERS: dict[str, Any] = { "model": _handle_model_field, "context_window_tokens": _handle_context_window_field, + "model_preset": _handle_model_preset_field, + "provider": _handle_provider_field, + "fallback_presets": _handle_fallback_presets_field, } @@ -757,6 +858,113 @@ def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None console.print("[dim](i) Could not auto-fill context window (model not in database)[/dim]") +# --- Model Preset Configuration --- + + +def _sync_preset_cache(config: Config) -> None: + """Synchronise the module-level preset name cache from config.""" + _MODEL_PRESET_CACHE.clear() + _MODEL_PRESET_CACHE.update(config.model_presets.keys()) + + +def _configure_model_presets(config: Config) -> None: + """Configure model presets (CRUD).""" + _sync_preset_cache(config) + + def get_preset_choices() -> list[str]: + choices: list[str] = [] + for name, preset in config.model_presets.items(): + choices.append(f"{name} ({preset.model})") + choices.append("[+] Add new preset") + choices.append("<- Back") + return choices + + last_preset_name: str | None = None + while True: + try: + console.clear() + _show_section_header( + "Model Presets", + "Create, edit or delete named model presets for quick switching", + ) + choices = get_preset_choices() + default_choice = None + if last_preset_name: + for c in choices: + if c.startswith(last_preset_name + " ("): + default_choice = c + break + answer = _select_with_back( + "Select preset:", choices, default=default_choice + ) + + if answer is _BACK_PRESSED or answer is None or answer == "<- Back": + break + + assert isinstance(answer, str) + + if answer == "[+] Add new preset": + name_input = _get_questionary().text( + "Preset name:", + validate=lambda t: True if t and t.strip() else "Name cannot be empty", + ).ask() + if not name_input: + continue + name = name_input.strip() + if name in config.model_presets: + console.print(f"[yellow]! Preset '{name}' already exists[/yellow]") + _pause() + continue + new_preset = ModelPresetConfig(model="") + updated = _configure_pydantic_model(new_preset, f"New Preset: {name}") + if updated is not None: + config.model_presets[name] = updated + _sync_preset_cache(config) + last_preset_name = name + continue + + # Editing / deleting an existing preset + # Extract preset name from "name (model)" format + preset_name = answer.split(" (", 1)[0] + preset = config.model_presets.get(preset_name) + if preset is None: + continue + + last_preset_name = preset_name + + choices = ["Edit", "Cancel"] + if preset_name != "default": + choices.insert(1, "Delete") + action = _select_with_back( + f"Preset: {preset_name}", + choices, + default="Edit", + ) + if action is _BACK_PRESSED or action == "Cancel" or action is None: + continue + + if action == "Delete": + confirm = _get_questionary().confirm( + f"Delete preset '{preset_name}'?", + default=False, + ).ask() + if confirm: + del config.model_presets[preset_name] + _sync_preset_cache(config) + last_preset_name = None + continue + + if action == "Edit": + updated = _configure_pydantic_model(preset, f"Edit Preset: {preset_name}") + if updated is not None: + config.model_presets[preset_name] = updated + _sync_preset_cache(config) + + except KeyboardInterrupt: + console.print("\n[dim]Returning to main menu...[/dim]") + break + + # --- Provider Configuration --- @@ -1043,6 +1251,12 @@ def _show_summary(config: Config) -> None: channel_rows.append((display, status)) _print_summary_panel(channel_rows, "Chat Channels") + # Model Presets + preset_rows = [] + for name, preset in config.model_presets.items(): + preset_rows.append((name, f"{preset.model} (ctx={preset.context_window_tokens})")) + _print_summary_panel(preset_rows, "Model Presets") + # Settings sections for title, model in [ ("Agent Settings", config.agents.defaults), @@ -1112,6 +1326,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult: original_config = base_config.model_copy(deep=True) config = base_config.model_copy(deep=True) + _sync_preset_cache(config) last_main_choice: str | None = None while True: @@ -1123,6 +1338,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult: "What would you like to configure?", choices=[ "[P] LLM Provider", + "[M] Model Presets", "[C] Chat Channel", "[H] Channel Common", "[A] Agent Settings", @@ -1149,6 +1365,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult: _menu_dispatch = { "[P] LLM Provider": lambda: _configure_providers(config), + "[M] Model Presets": lambda: _configure_model_presets(config), "[C] Chat Channel": lambda: _configure_channels(config), "[H] Channel Common": lambda: _configure_general_settings(config, "Channel Common"), "[A] Agent Settings": lambda: _configure_general_settings(config, "Agent Settings"), diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index f0c93374e..e5c6f404a 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -104,8 +104,9 @@ class AgentDefaults(Base): validation_alias=AliasChoices("toolHintMaxLength"), serialization_alias="toolHintMaxLength", ) # Max characters for tool hint display (e.g. "$ cd …/project && npm test") - reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode - fallback_models: list[str] = Field(default_factory=list) + fallback_presets: list[str] = Field( + default_factory=list + ) # Ordered fallback chain. Each item must be a preset name defined in model_presets. timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York" unified_session: bool = False # Share one session across all channels (single-user multi-device) disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"]) @@ -306,24 +307,52 @@ class Config(BaseSettings): model_presets: dict[str, ModelPresetConfig] = Field(default_factory=dict) @model_validator(mode="after") - def _validate_model_preset(self) -> "Config": - name = self.agents.defaults.model_preset - if name and name not in self.model_presets: - raise ValueError(f"model_preset {name!r} not found in model_presets") + def _sync_and_validate_preset(self) -> "Config": + """Expose agents.defaults model fields as the implicit 'default' preset + and validate the active preset reference. + + This guarantees that ``model_presets`` is never empty and that legacy + configs (which only set ``agents.defaults.model`` etc.) continue to work + without explicitly declaring a preset. + """ + self._refresh_default_preset() + defaults = self.agents.defaults + if defaults.model_preset is None: + defaults.model_preset = "default" + if defaults.model_preset not in self.model_presets: + raise ValueError(f"model_preset {defaults.model_preset!r} not found in model_presets") + for fb in defaults.fallback_presets: + if fb not in self.model_presets: + raise ValueError(f"fallback_presets entry {fb!r} not found in model_presets") return self - def resolve_preset(self) -> ModelPresetConfig: - """Return effective model params: from active preset, or individual defaults.""" - name = self.agents.defaults.model_preset - if name: - return self.model_presets[name] + def _refresh_default_preset(self) -> None: + """Rebuild the implicit 'default' preset from current agents.defaults. + + Called inside ``_sync_and_validate_preset`` (model validator) and + ``resolve_preset()`` so that runtime mutations (e.g. tests directly + setting ``defaults.model``) are reflected. + """ d = self.agents.defaults - return ModelPresetConfig( - model=d.model, provider=d.provider, max_tokens=d.max_tokens, + self.model_presets["default"] = ModelPresetConfig( + model=d.model, + provider=d.provider, + max_tokens=d.max_tokens, context_window_tokens=d.context_window_tokens, - temperature=d.temperature, reasoning_effort=d.reasoning_effort, + temperature=d.temperature, + reasoning_effort=d.reasoning_effort, ) + def resolve_preset(self) -> ModelPresetConfig: + """Return the active preset. + + The implicit ``"default"`` preset is rebuilt from current defaults every + time so that runtime mutations (e.g. tests setting ``defaults.model``) + are always reflected. + """ + self._refresh_default_preset() + return self.model_presets[self.agents.defaults.model_preset] + @property def workspace_path(self) -> Path: """Get expanded workspace path.""" @@ -335,15 +364,16 @@ class Config(BaseSettings): """Match provider config and its registry name. Returns (config, spec_name).""" from nanobot.providers.registry import PROVIDERS, find_by_name - forced = self.resolve_preset().provider + resolved = self.resolve_preset() + forced = resolved.provider if forced != "auto": spec = find_by_name(forced) if spec: - p = getattr(self.providers, spec.name, None) - return (p, spec.name) if p else (None, None) + provider_cfg = getattr(self.providers, spec.name, None) + return (provider_cfg, spec.name) if provider_cfg else (None, None) return None, None - model_lower = (model or self.resolve_preset().model).lower() + model_lower = (model or resolved.model).lower() model_normalized = model_lower.replace("-", "_") model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else "" normalized_prefix = model_prefix.replace("-", "_") diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index 1eaef6636..bfedb7611 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -8,7 +8,6 @@ from typing import Any from nanobot.agent.hook import AgentHook, SDKCaptureHook from nanobot.agent.loop import AgentLoop -from nanobot.bus.queue import MessageBus @dataclass(slots=True) @@ -62,41 +61,12 @@ class Nanobot: Path(workspace).expanduser().resolve() ) - provider = _make_provider(config) - bus = MessageBus() - defaults = config.agents.defaults - _resolved = config.resolve_preset() - pf = _make_provider_factory(config) if defaults.fallback_models else None - - loop = AgentLoop( - bus=bus, - provider=provider, - workspace=config.workspace_path, - model=_resolved.model, - max_iterations=defaults.max_tool_iterations, - context_window_tokens=_resolved.context_window_tokens, - context_block_limit=defaults.context_block_limit, - max_tool_result_chars=defaults.max_tool_result_chars, - provider_retry_mode=defaults.provider_retry_mode, - tool_hint_max_length=defaults.tool_hint_max_length, - fallback_models=defaults.fallback_models, - provider_factory=pf, - web_config=config.tools.web, - exec_config=config.tools.exec, - restrict_to_workspace=config.tools.restrict_to_workspace, - mcp_servers=config.tools.mcp_servers, - timezone=defaults.timezone, - unified_session=defaults.unified_session, - disabled_skills=defaults.disabled_skills, - session_ttl_minutes=defaults.session_ttl_minutes, - consolidation_ratio=defaults.consolidation_ratio, - tools_config=config.tools, + loop = AgentLoop.from_config( + config, image_generation_provider_configs={ "openrouter": config.providers.openrouter, "aihubmix": config.providers.aihubmix, }, - model_presets=config.model_presets, - model_preset=defaults.model_preset, ) return cls(loop) @@ -134,99 +104,3 @@ class Nanobot: ) -def _make_provider_for_model( - config: Any, - model: str, - *, - preset: Any | None = None, -) -> Any: - """Create an LLM provider instance for a specific model string. - - When *preset* is given, its generation settings (temperature, max_tokens, - reasoning_effort) override the active preset defaults. - """ - from nanobot.providers.base import GenerationSettings - from nanobot.providers.factory import make_provider - from nanobot.providers.registry import find_by_name - - gen_src = preset or config.resolve_preset() - provider_name = config.get_provider_name(model) - p = config.get_provider(model) - spec = find_by_name(provider_name) if provider_name else None - backend = spec.backend if spec else "openai_compat" - - if backend == "azure_openai": - if not p or not p.api_key or not p.api_base: - raise ValueError("Azure OpenAI requires api_key and api_base in config.") - elif backend == "openai_compat" and not model.startswith("bedrock/"): - needs_key = not (p and p.api_key) - exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct) - if needs_key and not exempt: - raise ValueError(f"No API key configured for provider '{provider_name}'.") - - if backend == "openai_codex": - from nanobot.providers.openai_codex_provider import OpenAICodexProvider - - provider = OpenAICodexProvider(default_model=model) - elif backend == "github_copilot": - from nanobot.providers.github_copilot_provider import GitHubCopilotProvider - - provider = GitHubCopilotProvider(default_model=model) - elif backend == "azure_openai": - from nanobot.providers.azure_openai_provider import AzureOpenAIProvider - - provider = AzureOpenAIProvider( - api_key=p.api_key, api_base=p.api_base, default_model=model - ) - elif backend == "anthropic": - from nanobot.providers.anthropic_provider import AnthropicProvider - - provider = AnthropicProvider( - api_key=p.api_key if p else None, - api_base=config.get_api_base(model), - default_model=model, - extra_headers=p.extra_headers if p else None, - ) - else: - from nanobot.providers.openai_compat_provider import OpenAICompatProvider - - provider = OpenAICompatProvider( - api_key=p.api_key if p else None, - api_base=config.get_api_base(model), - default_model=model, - extra_headers=p.extra_headers if p else None, - spec=spec, - extra_body=p.extra_body if p else None, - ) - - provider.generation = GenerationSettings( - temperature=gen_src.temperature, - max_tokens=gen_src.max_tokens, - reasoning_effort=gen_src.reasoning_effort, - ) - return provider - - -def _make_provider(config: Any) -> Any: - """Create the LLM provider for the primary model from config.""" - return _make_provider_for_model(config, config.resolve_preset().model) - - -def _make_provider_factory(config: Any): - """Build a cached factory that creates providers for arbitrary model strings. - - If a model string matches a preset name in ``config.model_presets``, the - preset's full config (model, temperature, max_tokens, …) is used. - """ - cache: dict[str, Any] = {} - presets = getattr(config, "model_presets", {}) or {} - - def factory(model_or_preset: str): - preset = presets.get(model_or_preset) - actual_model = preset.model if preset else model_or_preset - key = actual_model - if key not in cache: - cache[key] = _make_provider_for_model(config, actual_model, preset=preset) - return cache[key] - - return factory diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 1d598f20a..18b06d36e 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -137,7 +137,9 @@ class LLMProvider(ABC): "insufficient_quota", "insufficient quota", "quota exceeded", + "quota_exceeded", "quota exhausted", + "quota_exhausted", "billing hard limit", "billing_hard_limit_reached", "billing not active", diff --git a/nanobot/providers/factory.py b/nanobot/providers/factory.py index d71390940..4013cef41 100644 --- a/nanobot/providers/factory.py +++ b/nanobot/providers/factory.py @@ -4,11 +4,16 @@ from __future__ import annotations from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING from nanobot.config.schema import Config from nanobot.providers.base import GenerationSettings, LLMProvider from nanobot.providers.registry import find_by_name +if TYPE_CHECKING: + from nanobot.config.schema import ModelPresetConfig, ProviderConfig + from nanobot.providers.registry import ProviderSpec + @dataclass(frozen=True) class ProviderSnapshot: @@ -18,22 +23,62 @@ class ProviderSnapshot: signature: tuple[object, ...] -def make_provider(config: Config) -> LLMProvider: - """Create the LLM provider implied by config.""" - model = config.agents.defaults.model - provider_name = config.get_provider_name(model) - p = config.get_provider(model) - spec = find_by_name(provider_name) if provider_name else None +@dataclass(frozen=True) +class _ProviderInfo: + """Resolved metadata needed to build and validate an LLM provider.""" + + name: str | None + cfg: ProviderConfig | None + spec: ProviderSpec | None + api_base: str | None + backend: str + + +def _resolve_provider_info( + config: Config, + model: str, + preset: ModelPresetConfig, +) -> _ProviderInfo: + """Derive provider name, config, spec and api_base from preset or auto-detection.""" + if preset.provider != "auto": + name = preset.provider + cfg = getattr(config.providers, name, None) + spec = find_by_name(name) + api_base = ( + cfg.api_base + if cfg and cfg.api_base + else (spec.default_api_base if spec and spec.default_api_base else None) + ) + else: + name = config.get_provider_name(model) + cfg = config.get_provider(model) + spec = find_by_name(name) if name else None + api_base = config.get_api_base(model) + backend = spec.backend if spec else "openai_compat" + return _ProviderInfo(name=name, cfg=cfg, spec=spec, api_base=api_base, backend=backend) + + +def _validate_provider(info: _ProviderInfo, model: str) -> None: + """Ensure credentials / endpoints are present before instantiation.""" + cfg = info.cfg + backend = info.backend + name = info.name if backend == "azure_openai": - if not p or not p.api_key or not p.api_base: + if not cfg or not cfg.api_key or not cfg.api_base: raise ValueError("Azure OpenAI requires api_key and api_base in config.") elif backend == "openai_compat" and not model.startswith("bedrock/"): - needs_key = not (p and p.api_key) - exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct) + needs_key = not (cfg and cfg.api_key) + exempt = info.spec and (info.spec.is_oauth or info.spec.is_local or info.spec.is_direct) if needs_key and not exempt: - raise ValueError(f"No API key configured for provider '{provider_name}'.") + raise ValueError(f"No API key configured for provider '{name}'.") + + +def _create_provider(model: str, info: _ProviderInfo) -> LLMProvider: + """Instantiate the concrete provider class for *backend*.""" + cfg = info.cfg + backend = info.backend if backend == "openai_codex": from nanobot.providers.openai_codex_provider import OpenAICodexProvider @@ -43,8 +88,8 @@ def make_provider(config: Config) -> LLMProvider: from nanobot.providers.azure_openai_provider import AzureOpenAIProvider provider = AzureOpenAIProvider( - api_key=p.api_key, - api_base=p.api_base, + api_key=cfg.api_key if cfg else None, + api_base=info.api_base, default_model=model, ) elif backend == "github_copilot": @@ -55,70 +100,103 @@ def make_provider(config: Config) -> LLMProvider: from nanobot.providers.anthropic_provider import AnthropicProvider provider = AnthropicProvider( - api_key=p.api_key if p else None, - api_base=config.get_api_base(model), + api_key=cfg.api_key if cfg else None, + api_base=info.api_base, default_model=model, - extra_headers=p.extra_headers if p else None, + extra_headers=cfg.extra_headers if cfg else None, ) elif backend == "bedrock": from nanobot.providers.bedrock_provider import BedrockProvider provider = BedrockProvider( - api_key=p.api_key if p else None, - api_base=p.api_base if p else None, + api_key=cfg.api_key if cfg else None, + api_base=info.api_base if cfg else None, default_model=model, - region=getattr(p, "region", None) if p else None, - profile=getattr(p, "profile", None) if p else None, - extra_body=p.extra_body if p else None, + region=getattr(cfg, "region", None) if cfg else None, + profile=getattr(cfg, "profile", None) if cfg else None, + extra_body=cfg.extra_body if cfg else None, ) else: from nanobot.providers.openai_compat_provider import OpenAICompatProvider provider = OpenAICompatProvider( - api_key=p.api_key if p else None, - api_base=config.get_api_base(model), + api_key=cfg.api_key if cfg else None, + api_base=info.api_base, default_model=model, - extra_headers=p.extra_headers if p else None, - spec=spec, - extra_body=p.extra_body if p else None, + extra_headers=cfg.extra_headers if cfg else None, + spec=info.spec, + extra_body=cfg.extra_body if cfg else None, ) - - defaults = config.agents.defaults - provider.generation = GenerationSettings( - temperature=defaults.temperature, - max_tokens=defaults.max_tokens, - reasoning_effort=defaults.reasoning_effort, - ) return provider +def _apply_generation(provider: LLMProvider, preset: ModelPresetConfig) -> None: + provider.generation = GenerationSettings( + temperature=preset.temperature, + max_tokens=preset.max_tokens, + reasoning_effort=preset.reasoning_effort, + ) + + +def build_provider_for_preset(config: Config, preset: ModelPresetConfig) -> LLMProvider: + """Create an LLM provider from a full *preset* (model + provider + generation).""" + info = _resolve_provider_info(config, preset.model, preset) + _validate_provider(info, preset.model) + provider = _create_provider(preset.model, info) + _apply_generation(provider, preset) + return provider + + +def make_provider(config: Config) -> LLMProvider: + """Create the LLM provider implied by config (legacy entrypoint).""" + resolved = config.resolve_preset() + return build_provider_for_preset(config, resolved) + + +def make_provider_factory(config: Config): + """Build a cached factory that creates providers for preset names. + + The factory looks up *preset_name* in ``config.model_presets`` and builds + the provider from the preset's full configuration. + """ + cache: dict[str, LLMProvider] = {} + presets = config.model_presets + + def factory(preset_name: str) -> LLMProvider: + preset = presets.get(preset_name) + if preset is None: + raise ValueError(f"Preset {preset_name!r} not found in model_presets") + if preset_name not in cache: + cache[preset_name] = build_provider_for_preset(config, preset) + return cache[preset_name] + + return factory + + def provider_signature(config: Config) -> tuple[object, ...]: """Return the config fields that affect the primary LLM provider.""" - model = config.agents.defaults.model + resolved = config.resolve_preset() defaults = config.agents.defaults - p = config.get_provider(model) return ( - model, - defaults.provider, - config.get_provider_name(model), - config.get_api_key(model), - config.get_api_base(model), - p.extra_headers if p else None, - p.extra_body if p else None, - getattr(p, "region", None) if p else None, - getattr(p, "profile", None) if p else None, - defaults.max_tokens, - defaults.temperature, - defaults.reasoning_effort, - defaults.context_window_tokens, + resolved.model, + resolved.provider, + config.get_provider_name(resolved.model), + config.get_api_key(resolved.model), + config.get_api_base(resolved.model), + resolved.max_tokens, + resolved.temperature, + resolved.reasoning_effort, + resolved.context_window_tokens, + tuple(defaults.fallback_presets), ) def build_provider_snapshot(config: Config) -> ProviderSnapshot: + resolved = config.resolve_preset() return ProviderSnapshot( provider=make_provider(config), - model=config.agents.defaults.model, - context_window_tokens=config.agents.defaults.context_window_tokens, + model=resolved.model, + context_window_tokens=resolved.context_window_tokens, signature=provider_signature(config), ) diff --git a/nanobot/providers/failover.py b/nanobot/providers/failover.py index 660652fd4..b629d3526 100644 --- a/nanobot/providers/failover.py +++ b/nanobot/providers/failover.py @@ -4,7 +4,6 @@ from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable -from dataclasses import dataclass from typing import Any from loguru import logger @@ -12,68 +11,16 @@ from loguru import logger from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse -@dataclass(frozen=True) -class ModelCandidate: - """A lazily resolved model/provider candidate.""" - - label: str - resolver: Callable[[], tuple[LLMProvider, str]] - - class ModelRouter(LLMProvider): """Try fallback model candidates for eligible transient final errors.""" - supports_progress_deltas = False - - _BLOCKED_STATUS_CODES = frozenset({400, 401, 403, 404, 422}) - _QUOTA_MARKERS = ( - "insufficient_quota", - "insufficient quota", - "quota exceeded", - "quota_exceeded", - "quota exhausted", - "quota_exhausted", - "billing hard limit", - "billing_hard_limit_reached", - "billing not active", - "insufficient balance", - "insufficient_balance", - "credit balance too low", - "payment required", - "out of credits", - "out of quota", - "exceeded your current quota", - ) - _NON_FAILOVER_MARKERS = ( - "context length", - "context_length", - "maximum context", - "max context", - "token budget", - "too many tokens", - "schema", - "invalid request", - "invalid_request", - "invalid parameter", - "invalid_parameter", - "unsupported", - "unauthorized", - "authentication", - "permission", - "forbidden", - "refusal", - "content policy", - "content_filter", - "policy violation", - "safety", - ) - def __init__( self, *, primary_provider: LLMProvider, primary_model: str, - fallback_candidates: list[ModelCandidate], + fallback_presets: list[str], + provider_factory: Callable[[str], LLMProvider] | None = None, per_candidate_timeout_s: float | None = None, ) -> None: super().__init__( @@ -82,7 +29,9 @@ class ModelRouter(LLMProvider): ) self.primary_provider = primary_provider self.primary_model = primary_model - self.fallback_candidates = list(fallback_candidates) + self.fallback_presets = list(fallback_presets) + self._provider_factory = provider_factory + self._provider_cache: dict[str, LLMProvider] = {} self.per_candidate_timeout_s = per_candidate_timeout_s self.generation = getattr(primary_provider, "generation", GenerationSettings()) @@ -90,41 +39,46 @@ class ModelRouter(LLMProvider): return self.primary_model async def chat(self, **kwargs: Any) -> LLMResponse: - return await self.primary_provider.chat(**kwargs) + async def call(provider: LLMProvider, candidate_model: str, _unused_delta: Any) -> LLMResponse: + return await provider.chat(**{**kwargs, "model": candidate_model}) + return await self._route(call) async def chat_stream(self, **kwargs: Any) -> LLMResponse: - return await self.primary_provider.chat_stream(**kwargs) + async def call(provider: LLMProvider, candidate_model: str, content_delta: Any) -> LLMResponse: + return await provider.chat_stream( + **{**kwargs, "model": candidate_model, "on_content_delta": content_delta} + ) + return await self._route(call, on_content_delta=kwargs.get("on_content_delta")) - @classmethod - def _is_quota_error(cls, response: LLMResponse) -> bool: - tokens = { - cls._normalize_error_token(response.error_type), - cls._normalize_error_token(response.error_code), - } - if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in tokens if token): - return True - content = (response.content or "").lower() - return any(marker in content for marker in cls._QUOTA_MARKERS) - - @classmethod - def _is_blocked_error(cls, response: LLMResponse) -> bool: - status = response.error_status_code - if status is not None and int(status) in cls._BLOCKED_STATUS_CODES: - return True - if response.finish_reason in {"refusal", "content_filter"}: - return True - content = (response.content or "").lower() - return any(marker in content for marker in cls._NON_FAILOVER_MARKERS) + @property + def supports_progress_deltas(self) -> bool: # type: ignore[override] + return getattr(self.primary_provider, "supports_progress_deltas", False) @classmethod def _should_failover(cls, response: LLMResponse) -> bool: if response.finish_reason != "error": return False - if cls._is_blocked_error(response): + if response.error_should_retry is False: return False - if cls._is_quota_error(response): + if response.error_kind == "configuration": return False - return cls._is_transient_response(response) + return True + + def _resolve(self, model: str) -> tuple[LLMProvider, str]: + """Return (provider, actual_model_name) for a preset name. + + Caches results so factory is only invoked once per unique name. + """ + if model in self._provider_cache: + cached_provider = self._provider_cache[model] + return cached_provider, cached_provider.get_default_model() + if self._provider_factory is None: + raise ValueError( + f"Cannot resolve fallback model {model!r}: no provider_factory configured" + ) + provider = self._provider_factory(model) + self._provider_cache[model] = provider + return provider, provider.get_default_model() async def _with_timeout(self, coro: Awaitable[LLMResponse]) -> LLMResponse: timeout_s = self.per_candidate_timeout_s @@ -140,137 +94,90 @@ class ModelRouter(LLMProvider): ) @staticmethod - def _resolver_error(candidate: ModelCandidate, exc: Exception) -> LLMResponse: - logger.warning("Failed to resolve fallback model {}: {}", candidate.label, exc) + def _resolver_error(label: str, exc: Exception) -> LLMResponse: + logger.warning("Failed to resolve fallback model {}: {}", label, exc) return LLMResponse( - content=f"Error configuring fallback model {candidate.label}: {exc}", + content=f"Error configuring fallback model {label}: {exc}", finish_reason="error", error_kind="configuration", error_should_retry=False, ) - def _candidate_chain(self) -> list[ModelCandidate]: - return [ - ModelCandidate( - label=self.primary_model, - resolver=lambda: (self.primary_provider, self.primary_model), - ), - *self.fallback_candidates, - ] - async def _route( self, call: Callable[[LLMProvider, str, Callable[[str], Awaitable[None]] | None], Awaitable[LLMResponse]], *, on_content_delta: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: - last_response: LLMResponse | None = None - chain = self._candidate_chain() - for index, candidate in enumerate(chain): + """Try primary then each fallback candidate, lazily resolving providers.""" + + async def _try_one(label: str, provider: LLMProvider, model: str) -> LLMResponse: try: - provider, model = candidate.resolver() + return await self._with_timeout(call(provider, model, on_content_delta)) except asyncio.CancelledError: raise except Exception as exc: - response = self._resolver_error(candidate, exc) - else: - response = await self._with_timeout(call(provider, model, on_content_delta)) + return self._resolver_error(label, exc) + # Primary + response = await _try_one("primary", self.primary_provider, self.primary_model) + if response.finish_reason != "error": + return response + if not self._should_failover(response): + return response + + # Fallbacks + for name in self.fallback_presets: + try: + provider, model = self._resolve(name) + except Exception as exc: + logger.warning("Failed to resolve fallback model {}: {}", name, exc) + return self._resolver_error(name, exc) + + response = await _try_one(name, provider, model) if response.finish_reason != "error": - if index > 0: - logger.info("LLM failover selected model={}", candidate.label) + logger.info("LLM failover selected model={}", name) return response - - last_response = response if not self._should_failover(response): return response - if index + 1 >= len(chain): - logger.warning("LLM failover exhausted after model={}", candidate.label) - return response - logger.warning( - "LLM failover model={} next_model={} status={} kind={}", - candidate.label, - chain[index + 1].label, - response.error_status_code, - response.error_kind or response.error_type or response.error_code or "unknown", - ) - return last_response or LLMResponse( - content="No available fallback model candidate.", - finish_reason="error", - error_kind="configuration", - error_should_retry=False, - ) + logger.warning("LLM failover exhausted after all candidates") + return response - async def chat_with_retry( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - model: str | None = None, - max_tokens: object = LLMProvider._SENTINEL, - temperature: object = LLMProvider._SENTINEL, - reasoning_effort: object = LLMProvider._SENTINEL, - tool_choice: str | dict[str, Any] | None = None, - retry_mode: str = "standard", - on_retry_wait: Callable[[str], Awaitable[None]] | None = None, - ) -> LLMResponse: + async def chat_with_retry(self, **kwargs: Any) -> LLMResponse: async def call( - provider: LLMProvider, - candidate_model: str, - _delta: Callable[[str], Awaitable[None]] | None, + provider: LLMProvider, candidate_model: str, _unused_delta: Any ) -> LLMResponse: return await provider.chat_with_retry( - messages=messages, - tools=tools, - model=candidate_model, - max_tokens=max_tokens, - temperature=temperature, - reasoning_effort=reasoning_effort, - tool_choice=tool_choice, - retry_mode=retry_mode, - on_retry_wait=on_retry_wait, + **{**kwargs, "model": candidate_model} ) - return await self._route(call) - async def chat_stream_with_retry( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - model: str | None = None, - max_tokens: object = LLMProvider._SENTINEL, - temperature: object = LLMProvider._SENTINEL, - reasoning_effort: object = LLMProvider._SENTINEL, - tool_choice: str | dict[str, Any] | None = None, - on_content_delta: Callable[[str], Awaitable[None]] | None = None, - retry_mode: str = "standard", - on_retry_wait: Callable[[str], Awaitable[None]] | None = None, - ) -> LLMResponse: + async def chat_stream_with_retry(self, **kwargs: Any) -> LLMResponse: + on_content_delta = kwargs.pop("on_content_delta", None) + async def call( provider: LLMProvider, candidate_model: str, - external_delta: Callable[[str], Awaitable[None]] | None, + content_delta: Callable[[str], Awaitable[None]] | None, ) -> LLMResponse: buffered: list[str] = [] async def buffer_delta(delta: str) -> None: buffered.append(delta) + kwargs["on_content_delta"] = buffer_delta if content_delta else None response = await provider.chat_stream_with_retry( - messages=messages, - tools=tools, - model=candidate_model, - max_tokens=max_tokens, - temperature=temperature, - reasoning_effort=reasoning_effort, - tool_choice=tool_choice, - on_content_delta=buffer_delta if external_delta else None, - retry_mode=retry_mode, - on_retry_wait=on_retry_wait, + **{**kwargs, "model": candidate_model} ) - if response.finish_reason != "error" and external_delta: - for delta in buffered: - await external_delta(delta) + if response.finish_reason != "error" and content_delta: + try: + for delta in buffered: + await content_delta(delta) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Failover delta callback failed for model={}", candidate_model) return response return await self._route(call, on_content_delta=on_content_delta) diff --git a/tests/agent/test_onboard_logic.py b/tests/agent/test_onboard_logic.py index f192cacee..4e9ad5e64 100644 --- a/tests/agent/test_onboard_logic.py +++ b/tests/agent/test_onboard_logic.py @@ -8,7 +8,6 @@ from pathlib import Path from types import SimpleNamespace from typing import Any, cast -import pytest from pydantic import BaseModel, Field from nanobot.cli import onboard as onboard_wizard @@ -636,8 +635,8 @@ class TestValidateFieldConstraint: def test_real_send_max_retries_field(self): """Validate against the actual ChannelsConfig.send_max_retries field.""" - from nanobot.config.schema import ChannelsConfig from nanobot.cli.onboard import _validate_field_constraint + from nanobot.config.schema import ChannelsConfig field_info = ChannelsConfig.model_fields["send_max_retries"] assert _validate_field_constraint(3, field_info) is None @@ -829,12 +828,11 @@ class TestMainMenuUpdate: def test_main_menu_dispatch_includes_channel_common(self): """Main menu dispatch should route [H] to Channel Common.""" - from nanobot.cli.onboard import run_onboard # We verify by checking the dispatch table is set up correctly # The menu items are defined inline in run_onboard, so we test # that _configure_general_settings handles the new sections. - from nanobot.cli.onboard import _SETTINGS_SECTIONS, _SETTINGS_GETTER, _SETTINGS_SETTER + from nanobot.cli.onboard import _SETTINGS_GETTER, _SETTINGS_SECTIONS, _SETTINGS_SETTER assert "Channel Common" in _SETTINGS_SECTIONS assert "Channel Common" in _SETTINGS_GETTER @@ -842,7 +840,7 @@ class TestMainMenuUpdate: def test_main_menu_dispatch_includes_api_server(self): """Main menu dispatch should route [I] to API Server.""" - from nanobot.cli.onboard import _SETTINGS_SECTIONS, _SETTINGS_GETTER, _SETTINGS_SETTER + from nanobot.cli.onboard import _SETTINGS_GETTER, _SETTINGS_SECTIONS, _SETTINGS_SETTER assert "API Server" in _SETTINGS_SECTIONS assert "API Server" in _SETTINGS_GETTER @@ -1074,3 +1072,346 @@ class TestConfigurePydanticModelEmptyString: result = _configure_pydantic_model(model, "Test") assert result is not None assert result.api_key == "" + + +class TestModelPresetWizard: + """Tests for model preset CRUD in the onboard wizard.""" + + def test_sync_preset_cache(self): + """_sync_preset_cache should populate the module-level cache.""" + from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _sync_preset_cache + from nanobot.config.schema import ModelPresetConfig + + config = Config() + config.model_presets = { + "fast": ModelPresetConfig(model="gpt-4.1-mini"), + "power": ModelPresetConfig(model="gpt-4.1"), + } + _sync_preset_cache(config) + assert _MODEL_PRESET_CACHE == {"fast", "power"} + + def test_model_preset_add(self, monkeypatch): + """_configure_model_presets should add a new preset.""" + from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _configure_model_presets + from nanobot.config.schema import ModelPresetConfig + + config = Config() + _MODEL_PRESET_CACHE.clear() + + responses = iter([ + "[+] Add new preset", + "my-preset", + "<- Back", + ]) + + class FakePrompt: + def __init__(self, response): + self.response = response + def ask(self): + if isinstance(self.response, BaseException): + raise self.response + return self.response + + def fake_select(*_args, **_kwargs): + return FakePrompt(next(responses)) + + def fake_text(*_args, **_kwargs): + return FakePrompt(next(responses)) + + def fake_configure(*_model, **_kwargs): + return ModelPresetConfig(model="gpt-test", temperature=0.5) + + # _select_with_back returns a string/sentinel directly (not a prompt object) + def fake_select_with_back(*_args, **_kwargs): + return next(responses) + + monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select_with_back) + monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select, text=fake_text)) + monkeypatch.setattr(onboard_wizard, "_configure_pydantic_model", fake_configure) + monkeypatch.setattr(onboard_wizard, "_show_section_header", lambda *a, **kw: None) + monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None)) + + _configure_model_presets(config) + + assert "my-preset" in config.model_presets + assert config.model_presets["my-preset"].model == "gpt-test" + assert config.model_presets["my-preset"].temperature == 0.5 + + def test_model_preset_delete(self, monkeypatch): + """_configure_model_presets should delete an existing preset.""" + from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _configure_model_presets + from nanobot.config.schema import ModelPresetConfig + + config = Config() + config.model_presets = {"old": ModelPresetConfig(model="x")} + _MODEL_PRESET_CACHE.clear() + _MODEL_PRESET_CACHE.add("old") + + responses = iter([ + "old (x)", + "Delete", + True, + "<- Back", + ]) + + class FakePrompt: + def __init__(self, response): + self.response = response + def ask(self): + if isinstance(self.response, BaseException): + raise self.response + return self.response + + def fake_select(*_args, **_kwargs): + return FakePrompt(next(responses)) + + def fake_confirm(*_args, **_kwargs): + return FakePrompt(next(responses)) + + def fake_select_with_back(*_args, **_kwargs): + return next(responses) + + monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select_with_back) + monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select, confirm=fake_confirm)) + monkeypatch.setattr(onboard_wizard, "_show_section_header", lambda *a, **kw: None) + monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None)) + + _configure_model_presets(config) + + assert "old" not in config.model_presets + assert "old" not in _MODEL_PRESET_CACHE + + def test_model_preset_field_handler(self, monkeypatch): + """_handle_model_preset_field should set a preset name from choices.""" + from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_model_preset_field + from nanobot.config.schema import AgentDefaults + + _MODEL_PRESET_CACHE.clear() + _MODEL_PRESET_CACHE.update({"fast", "power"}) + + monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: "fast") + + defaults = AgentDefaults() + _handle_model_preset_field(defaults, "model_preset", "Model Preset", None) + assert defaults.model_preset == "fast" + + def test_model_preset_field_handler_clear(self, monkeypatch): + """_handle_model_preset_field should clear preset when (clear/unset) chosen.""" + from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_model_preset_field + from nanobot.config.schema import AgentDefaults + + _MODEL_PRESET_CACHE.clear() + _MODEL_PRESET_CACHE.add("fast") + + monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: "(clear/unset)") + + defaults = AgentDefaults(model_preset="fast") + _handle_model_preset_field(defaults, "model_preset", "Model Preset", "fast") + assert defaults.model_preset is None + + def test_main_menu_dispatch_includes_model_presets(self): + """run_onboard dispatch should route [M] to Model Presets.""" + from nanobot.cli.onboard import _configure_model_presets + + # The function should be importable and callable + assert callable(_configure_model_presets) + + def test_run_onboard_model_presets_edit(self, monkeypatch): + """run_onboard should handle [M] Model Presets correctly.""" + initial_config = Config() + + responses = iter([ + "[M] Model Presets", + KeyboardInterrupt(), + "[S] Save and Exit", + ]) + + class FakePrompt: + def __init__(self, response): + self.response = response + def ask(self): + if isinstance(self.response, BaseException): + raise self.response + return self.response + + def fake_select(*_args, **_kwargs): + return FakePrompt(next(responses)) + + preset_mutated = {"n": 0} + + def fake_configure_model_presets(config): + preset_mutated["n"] += 1 + # Mutate config so unsaved changes are detected + from nanobot.config.schema import ModelPresetConfig + config.model_presets["test"] = ModelPresetConfig(model="x") + + monkeypatch.setattr(onboard_wizard, "_show_main_menu_header", lambda: None) + monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select)) + monkeypatch.setattr(onboard_wizard, "_configure_model_presets", fake_configure_model_presets) + + result = run_onboard(initial_config=initial_config) + + assert result.should_save is True + assert preset_mutated["n"] == 1 + + def test_summary_shows_model_presets(self, monkeypatch): + """_show_summary should include model presets panel.""" + from nanobot.cli.onboard import _show_summary + from nanobot.config.schema import ModelPresetConfig + + config = Config() + config.model_presets = { + "fast": ModelPresetConfig(model="gpt-4.1-mini"), + } + + panels = [] + + def fake_print_summary(rows, title): + panels.append(title) + + monkeypatch.setattr(onboard_wizard, "_print_summary_panel", fake_print_summary) + monkeypatch.setattr(onboard_wizard, "_get_provider_names", lambda: {}) + monkeypatch.setattr(onboard_wizard, "_get_channel_names", lambda: {}) + monkeypatch.setattr(onboard_wizard, "_pause", lambda: None) + monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(print=lambda *a, **kw: None)) + + _show_summary(config) + + assert "Model Presets" in panels + + def test_provider_field_handler(self, monkeypatch): + """_handle_provider_field should set a provider from the registry list.""" + from nanobot.cli.onboard import _handle_provider_field + from nanobot.config.schema import ModelPresetConfig + + monkeypatch.setattr( + onboard_wizard, "_get_provider_names", lambda: {"moonshot": "Moonshot", "openai": "OpenAI"} + ) + monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: "moonshot") + + preset = ModelPresetConfig(model="x") + _handle_provider_field(preset, "provider", "Provider", "auto") + assert preset.provider == "moonshot" + + def test_provider_field_handler_back_pressed(self, monkeypatch): + """_handle_provider_field should not modify value when back is pressed.""" + from nanobot.cli.onboard import _BACK_PRESSED, _handle_provider_field + from nanobot.config.schema import ModelPresetConfig + + monkeypatch.setattr( + onboard_wizard, "_get_provider_names", lambda: {"moonshot": "Moonshot"} + ) + monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: _BACK_PRESSED) + + preset = ModelPresetConfig(model="x", provider="auto") + _handle_provider_field(preset, "provider", "Provider", "auto") + assert preset.provider == "auto" + + def test_fallback_presets_add_preset_and_done(self, monkeypatch): + """_handle_fallback_presets_field should add a preset and save on Done.""" + from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_fallback_presets_field + from nanobot.config.schema import AgentDefaults + + _MODEL_PRESET_CACHE.clear() + _MODEL_PRESET_CACHE.update({"fast", "power"}) + + responses = iter(["[+] Add preset", "[Done]"]) + + class FakePrompt: + def __init__(self, response): + self.response = response + def ask(self): + if isinstance(self.response, BaseException): + raise self.response + return self.response + + def fake_select(*_args, **_kwargs): + return FakePrompt(next(responses)) + + monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: "fast") + monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select)) + monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None, print=lambda *a, **kw: None)) + + defaults = AgentDefaults() + _handle_fallback_presets_field(defaults, "fallback_presets", "Fallback Presets", []) + assert defaults.fallback_presets == ["fast"] + + def test_fallback_presets_back_preserves_existing(self, monkeypatch): + """_handle_fallback_presets_field should not modify value on Back.""" + from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_fallback_presets_field + from nanobot.config.schema import AgentDefaults + + _MODEL_PRESET_CACHE.clear() + _MODEL_PRESET_CACHE.add("fast") + + class FakePrompt: + def __init__(self, response): + self.response = response + def ask(self): + if isinstance(self.response, BaseException): + raise self.response + return self.response + + def fake_select(*_args, **_kwargs): + return FakePrompt("<- Back") + + monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select)) + monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None, print=lambda *a, **kw: None)) + + defaults = AgentDefaults(fallback_presets=["existing"]) + _handle_fallback_presets_field(defaults, "fallback_presets", "Fallback Presets", ["existing"]) + assert defaults.fallback_presets == ["existing"] + + def test_fallback_presets_remove_last(self, monkeypatch): + """_handle_fallback_presets_field should remove last item.""" + from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_fallback_presets_field + from nanobot.config.schema import AgentDefaults + + _MODEL_PRESET_CACHE.clear() + + responses = iter(["[-] Remove last", "[Done]"]) + + class FakePrompt: + def __init__(self, response): + self.response = response + def ask(self): + if isinstance(self.response, BaseException): + raise self.response + return self.response + + def fake_select(*_args, **_kwargs): + return FakePrompt(next(responses)) + + monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select)) + monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None, print=lambda *a, **kw: None)) + + defaults = AgentDefaults(fallback_presets=["a", "b"]) + _handle_fallback_presets_field(defaults, "fallback_presets", "Fallback Presets", ["a", "b"]) + assert defaults.fallback_presets == ["a"] + + def test_fallback_presets_no_presets_shows_warning(self, monkeypatch): + """_handle_fallback_presets_field should warn when no presets exist.""" + from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_fallback_presets_field + from nanobot.config.schema import AgentDefaults + + _MODEL_PRESET_CACHE.clear() + + responses = iter(["[+] Add preset", "[Done]"]) + + class FakePrompt: + def __init__(self, response): + self.response = response + def ask(self): + if isinstance(self.response, BaseException): + raise self.response + return self.response + + def fake_select(*_args, **_kwargs): + return FakePrompt(next(responses)) + + monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select, press_any_key_to_continue=lambda: FakePrompt(None))) + monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None, print=lambda *a, **kw: None)) + + defaults = AgentDefaults() + _handle_fallback_presets_field(defaults, "fallback_presets", "Fallback Presets", []) + assert defaults.fallback_presets == [] diff --git a/tests/agent/test_runner_fallback.py b/tests/agent/test_runner_fallback.py deleted file mode 100644 index 3ade85041..000000000 --- a/tests/agent/test_runner_fallback.py +++ /dev/null @@ -1,269 +0,0 @@ -"""Tests for the provider fallback models feature in AgentRunner.""" - -from __future__ import annotations - -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from nanobot.agent.hook import AgentHook, AgentHookContext -from nanobot.agent.runner import AgentRunner, AgentRunSpec -from nanobot.providers.base import LLMResponse - - -def _make_tools(): - tools = MagicMock() - tools.get_definitions.return_value = [] - tools.execute = AsyncMock(return_value="ok") - return tools - - -def _make_provider(*, model_response: LLMResponse | None = None): - p = MagicMock() - if model_response is not None: - p.chat_with_retry = AsyncMock(return_value=model_response) - return p - - -def _transient_error(content: str = "server unavailable") -> LLMResponse: - return LLMResponse(content=content, finish_reason="error", error_status_code=503) - - -def _base_spec(**overrides) -> AgentRunSpec: - defaults = dict( - initial_messages=[ - {"role": "system", "content": "sys"}, - {"role": "user", "content": "hello"}, - ], - tools=_make_tools(), - model="primary-model", - max_iterations=1, - max_tool_result_chars=8000, - ) - defaults.update(overrides) - return AgentRunSpec(**defaults) - - -@pytest.mark.asyncio -async def test_no_fallback_when_primary_succeeds(): - """Primary succeeds -> fallback list never consulted.""" - ok = LLMResponse(content="done", tool_calls=[], usage={}) - provider = _make_provider(model_response=ok) - factory = MagicMock() - - runner = AgentRunner(provider, provider_factory=factory) - result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"])) - - assert result.final_content == "done" - factory.assert_not_called() - - -@pytest.mark.asyncio -async def test_fallback_triggered_on_primary_error(): - """Primary fails -> first fallback succeeds.""" - err = _transient_error() - ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={}) - - primary = _make_provider(model_response=err) - - fb_provider = MagicMock() - fb_provider.chat_with_retry = AsyncMock(return_value=ok) - factory = MagicMock(return_value=fb_provider) - - runner = AgentRunner(primary, provider_factory=factory) - result = await runner.run(_base_spec(fallback_models=["fb-model"])) - - assert result.final_content == "fallback-ok" - factory.assert_called_once_with("fb-model") - fb_provider.chat_with_retry.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_all_fallbacks_fail_returns_last_error(): - """Primary + all fallbacks fail -> return last error response.""" - err = _transient_error() - - primary = _make_provider(model_response=err) - fb1 = _make_provider(model_response=err) - fb2 = _make_provider(model_response=LLMResponse( - content="last-error", finish_reason="error", error_status_code=500, usage={}, - )) - - providers = {"fb-1": fb1, "fb-2": fb2} - factory = MagicMock(side_effect=lambda m: providers[m]) - - runner = AgentRunner(primary, provider_factory=factory) - result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"])) - - assert result.error is not None or result.final_content is not None - - -@pytest.mark.asyncio -async def test_empty_fallback_list_no_retry(): - """Empty fallback_models -> no fallback attempted.""" - err = LLMResponse(content=None, finish_reason="error", usage={}) - primary = _make_provider(model_response=err) - factory = MagicMock() - - runner = AgentRunner(primary, provider_factory=factory) - result = await runner.run(_base_spec(fallback_models=[])) - - factory.assert_not_called() - assert result.error is not None - - -@pytest.mark.asyncio -async def test_cross_provider_fallback(): - """Fallback uses a different provider instance (cross-provider).""" - err = _transient_error() - ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={}) - - primary = _make_provider(model_response=err) - anthropic_provider = MagicMock() - anthropic_provider.chat_with_retry = AsyncMock(return_value=ok) - - def cross_factory(model: str): - if model == "anthropic/claude-sonnet": - return anthropic_provider - raise ValueError(f"unexpected model: {model}") - - runner = AgentRunner(primary, provider_factory=cross_factory) - result = await runner.run(_base_spec(fallback_models=["anthropic/claude-sonnet"])) - - assert result.final_content == "cross-provider-ok" - anthropic_provider.chat_with_retry.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_fallback_skips_to_second_on_first_error(): - """First fallback also fails -> second fallback succeeds.""" - err = _transient_error() - ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={}) - - primary = _make_provider(model_response=err) - fb1 = _make_provider(model_response=err) - fb2 = MagicMock() - fb2.chat_with_retry = AsyncMock(return_value=ok) - - providers = {"fb-1": fb1, "fb-2": fb2} - factory = MagicMock(side_effect=lambda m: providers[m]) - - runner = AgentRunner(primary, provider_factory=factory) - result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"])) - - assert result.final_content == "second-fb-ok" - assert factory.call_count == 2 - - -@pytest.mark.asyncio -async def test_fallback_reuses_same_provider_without_factory(): - """No provider_factory -> fallback reuses primary provider with different model.""" - call_count = {"n": 0} - - async def chat_with_retry(*, messages, model, **kw): - call_count["n"] += 1 - if call_count["n"] == 1: - return _transient_error() - return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={}) - - primary = MagicMock() - primary.chat_with_retry = chat_with_retry - - runner = AgentRunner(primary, provider_factory=None) - result = await runner.run(_base_spec(fallback_models=["fallback-model"])) - - assert result.final_content == "ok-via-fallback-model" - - -@pytest.mark.asyncio -async def test_fallback_provider_cached(): - """Provider factory is called once per unique provider, not per attempt.""" - err = _transient_error() - ok = LLMResponse(content="cached-ok", tool_calls=[], usage={}) - - primary = _make_provider(model_response=err) - - fb_provider = MagicMock() - call_seq = [err, ok] - fb_provider.chat_with_retry = AsyncMock(side_effect=call_seq) - - factory = MagicMock(return_value=fb_provider) - - runner = AgentRunner(primary, provider_factory=factory) - result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"])) - - assert result.final_content == "cached-ok" - - -@pytest.mark.asyncio -async def test_non_transient_error_does_not_fallback(): - """Auth/config-style errors should surface instead of hiding bugs via fallback.""" - primary = _make_provider(model_response=LLMResponse( - content="401 unauthorized", - finish_reason="error", - error_status_code=401, - )) - fallback = _make_provider(model_response=LLMResponse(content="fallback-ok")) - factory = MagicMock(return_value=fallback) - - runner = AgentRunner(primary, provider_factory=factory) - result = await runner.run(_base_spec(fallback_models=["fb-model"])) - - factory.assert_not_called() - assert result.error is not None - - -@pytest.mark.asyncio -async def test_quota_error_does_not_fallback_by_default(): - """Quota/billing/payment 429s should not route by default.""" - primary = _make_provider(model_response=LLMResponse( - content="insufficient quota", - finish_reason="error", - error_status_code=429, - error_code="insufficient_quota", - )) - fallback = _make_provider(model_response=LLMResponse(content="fallback-ok")) - factory = MagicMock(return_value=fallback) - - runner = AgentRunner(primary, provider_factory=factory) - result = await runner.run(_base_spec(fallback_models=["fb-model"])) - - factory.assert_not_called() - assert result.error is not None - - -@pytest.mark.asyncio -async def test_streaming_fallback_discards_failed_primary_deltas(): - """Buffered streaming prevents primary partial output from leaking on fallback.""" - streamed: list[str] = [] - - class StreamingHook(AgentHook): - def wants_streaming(self) -> bool: - return True - - async def on_stream(self, context: AgentHookContext, delta: str) -> None: - streamed.append(delta) - - async def primary_stream(*, on_content_delta, **kwargs): - await on_content_delta("bad partial") - return _transient_error() - - async def fallback_stream(*, on_content_delta, **kwargs): - await on_content_delta("good") - await on_content_delta(" answer") - return LLMResponse(content="good answer", tool_calls=[], usage={}) - - primary = MagicMock() - primary.chat_stream_with_retry = primary_stream - fallback = MagicMock() - fallback.chat_stream_with_retry = fallback_stream - factory = MagicMock(return_value=fallback) - - runner = AgentRunner(primary, provider_factory=factory) - result = await runner.run(_base_spec( - fallback_models=["fb-model"], - hook=StreamingHook(), - )) - - assert result.final_content == "good answer" - assert streamed == ["good", " answer"] diff --git a/tests/agent/test_self_model_preset.py b/tests/agent/test_self_model_preset.py index 7383b40b6..602499667 100644 --- a/tests/agent/test_self_model_preset.py +++ b/tests/agent/test_self_model_preset.py @@ -1,6 +1,6 @@ # tests/agent/test_self_model_preset.py -import asyncio from pathlib import Path +from typing import Any from unittest.mock import MagicMock from nanobot.agent.loop import AgentLoop @@ -8,10 +8,23 @@ from nanobot.config.schema import ModelPresetConfig, MyToolConfig, ToolsConfig from nanobot.providers.base import GenerationSettings -def _make_loop(presets: dict | None = None) -> tuple[AgentLoop, "MyTool"]: +def _make_loop(presets: dict | None = None) -> tuple[AgentLoop, Any]: provider = MagicMock() provider.get_default_model.return_value = "test-model" provider.generation = GenerationSettings(temperature=0.1, max_tokens=8192) + + def _factory(name: str): + preset = (presets or {}).get(name) + if preset: + new_provider = MagicMock() + new_provider.generation = GenerationSettings( + temperature=preset.temperature, + max_tokens=preset.max_tokens, + reasoning_effort=preset.reasoning_effort, + ) + return new_provider + return provider + loop = AgentLoop( bus=MagicMock(), provider=provider, @@ -19,6 +32,7 @@ def _make_loop(presets: dict | None = None) -> tuple[AgentLoop, "MyTool"]: model="test-model", context_window_tokens=65536, model_presets=presets or {}, + provider_factory=_factory, tools_config=ToolsConfig(my=MyToolConfig(allow_set=True)), ) tool = loop.tools.get("my") @@ -36,7 +50,7 @@ async def test_set_model_preset_updates_all_fields() -> None: ), } loop, tool = _make_loop(presets) - result = await tool.execute(action="set", key="model_preset", value="gpt5") + await tool.execute(action="set", key="model_preset", value="gpt5") assert loop.model == "gpt-5" assert loop.context_window_tokens == 128000 @@ -73,12 +87,3 @@ async def test_check_model_presets_shows_available() -> None: assert "ds" in result -async def test_set_model_directly_clears_preset() -> None: - presets = {"gpt5": ModelPresetConfig(model="gpt-5", provider="openai")} - loop, tool = _make_loop(presets) - await tool.execute(action="set", key="model_preset", value="gpt5") - assert loop._active_preset == "gpt5" - - await tool.execute(action="set", key="model", value="other-model") - assert loop._active_preset is None - assert loop.model == "other-model" diff --git a/tests/agent/tools/test_self_tool.py b/tests/agent/tools/test_self_tool.py index 19b1639d0..89c463fc4 100644 --- a/tests/agent/tools/test_self_tool.py +++ b/tests/agent/tools/test_self_tool.py @@ -4,7 +4,7 @@ from __future__ import annotations import time from pathlib import Path -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest from pydantic import BaseModel @@ -35,6 +35,7 @@ def _make_mock_loop(**overrides): loop._concurrency_gate = None loop._unified_session = False loop._extra_hooks = [] + loop.model_preset = None # web_config mock — needed for check tests loop.web_config = MagicMock() @@ -76,7 +77,7 @@ class TestInspectSummary: tool = _make_tool() result = await tool.execute(action="check") assert "max_iterations: 40" in result - assert "context_window_tokens: 65536" in result + assert "model_preset" in result @pytest.mark.asyncio async def test_inspect_includes_runtime_vars(self): @@ -92,8 +93,7 @@ class TestInspectSummary: tool = _make_tool() result = await tool.execute(action="check") assert "max_iterations" in result - assert "context_window_tokens" in result - assert "model" in result + assert "model_preset" in result assert "workspace" in result assert "provider_retry_mode" in result assert "max_tool_result_chars" in result @@ -231,13 +231,13 @@ class TestModifyRestricted: @pytest.mark.asyncio async def test_modify_string_int_coerced(self): tool = _make_tool() - result = await tool.execute(action="set", key="max_iterations", value="80") + await tool.execute(action="set", key="max_iterations", value="80") assert tool._loop.max_iterations == 80 @pytest.mark.asyncio async def test_modify_context_window_valid(self): tool = _make_tool() - result = await tool.execute(action="set", key="context_window_tokens", value=131072) + await tool.execute(action="set", key="context_window_tokens", value=131072) assert tool._loop.context_window_tokens == 131072 @pytest.mark.asyncio @@ -337,13 +337,13 @@ class TestModifyFree: @pytest.mark.asyncio async def test_modify_allows_list(self): tool = _make_tool() - result = await tool.execute(action="set", key="items", value=[1, 2, 3]) + await tool.execute(action="set", key="items", value=[1, 2, 3]) assert tool._loop._runtime_vars["items"] == [1, 2, 3] @pytest.mark.asyncio async def test_modify_allows_dict(self): tool = _make_tool() - result = await tool.execute(action="set", key="data", value={"a": 1}) + await tool.execute(action="set", key="data", value={"a": 1}) assert tool._loop._runtime_vars["data"] == {"a": 1} @pytest.mark.asyncio @@ -392,6 +392,26 @@ class TestModifyFree: assert "Error" in result assert tool._loop.max_tool_result_chars == 16000 + @pytest.mark.asyncio + async def test_modify_model_clears_active_preset(self): + """Directly modifying model must clear _active_preset so state stays consistent.""" + tool = _make_tool() + tool._loop._active_preset = "gpt5" + result = await tool.execute(action="set", key="model", value="other-model") + assert "Set model" in result + assert tool._loop.model == "other-model" + assert tool._loop._active_preset is None + + @pytest.mark.asyncio + async def test_modify_context_window_tokens_clears_active_preset(self): + """Directly modifying context_window_tokens must clear _active_preset.""" + tool = _make_tool() + tool._loop._active_preset = "gpt5" + result = await tool.execute(action="set", key="context_window_tokens", value=32768) + assert "Set context_window_tokens" in result + assert tool._loop.context_window_tokens == 32768 + assert tool._loop._active_preset is None + # --------------------------------------------------------------------------- # set — previously BLOCKED/READONLY now open @@ -689,8 +709,8 @@ class TestSubagentHookStatus: @pytest.mark.asyncio async def test_after_iteration_updates_status(self): """after_iteration should copy iteration, tool_events, usage to status.""" - from nanobot.agent.subagent import SubagentStatus, _SubagentHook from nanobot.agent.hook import AgentHookContext + from nanobot.agent.subagent import SubagentStatus, _SubagentHook status = SubagentStatus( task_id="test", @@ -716,8 +736,8 @@ class TestSubagentHookStatus: @pytest.mark.asyncio async def test_after_iteration_with_error(self): """after_iteration should set status.error when context has an error.""" - from nanobot.agent.subagent import SubagentStatus, _SubagentHook from nanobot.agent.hook import AgentHookContext + from nanobot.agent.subagent import SubagentStatus, _SubagentHook status = SubagentStatus( task_id="test", @@ -739,8 +759,8 @@ class TestSubagentHookStatus: @pytest.mark.asyncio async def test_after_iteration_no_status_is_noop(self): """after_iteration with no status should be a no-op.""" - from nanobot.agent.subagent import _SubagentHook from nanobot.agent.hook import AgentHookContext + from nanobot.agent.subagent import _SubagentHook hook = _SubagentHook("test") context = AgentHookContext(iteration=1, messages=[]) @@ -757,7 +777,6 @@ class TestCheckpointCallback: async def test_checkpoint_updates_phase_and_iteration(self): """The _on_checkpoint callback should update status.phase and iteration.""" from nanobot.agent.subagent import SubagentStatus - import asyncio status = SubagentStatus( task_id="cp", diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index d217c5f03..478cb2d41 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -9,7 +9,7 @@ import pytest from typer.testing import CliRunner from nanobot.bus.events import OutboundMessage -from nanobot.cli.commands import _make_provider, app +from nanobot.cli.commands import app from nanobot.config.schema import Config from nanobot.cron.types import CronJob, CronPayload from nanobot.providers.factory import ProviderSnapshot @@ -488,8 +488,8 @@ def test_openai_compat_provider_passes_model_through(): def test_make_provider_uses_github_copilot_backend(): - from nanobot.cli.commands import _make_provider from nanobot.config.schema import Config + from nanobot.providers.factory import build_provider_for_preset config = Config.model_validate( { @@ -503,7 +503,7 @@ def test_make_provider_uses_github_copilot_backend(): ) with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): - provider = _make_provider(config) + provider = build_provider_for_preset(config, config.resolve_preset()) assert provider.__class__.__name__ == "GitHubCopilotProvider" @@ -562,6 +562,8 @@ def test_openai_codex_strip_prefix_supports_hyphen_and_underscore(): def test_make_provider_passes_extra_headers_to_custom_provider(): + from nanobot.providers.factory import build_provider_for_preset + config = Config.model_validate( { "agents": {"defaults": {"provider": "custom", "model": "gpt-4o-mini"}}, @@ -579,7 +581,7 @@ def test_make_provider_passes_extra_headers_to_custom_provider(): ) with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai: - _make_provider(config) + build_provider_for_preset(config, config.resolve_preset()) kwargs = mock_async_openai.call_args.kwargs assert kwargs["api_key"] == "test-key" @@ -597,11 +599,11 @@ def mock_agent_runtime(tmp_path): with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \ patch("nanobot.config.loader.resolve_config_env_vars", side_effect=lambda c: c), \ patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \ - patch("nanobot.cli.commands._make_provider", return_value=object()), \ + patch("nanobot.providers.factory.build_provider_for_preset", return_value=MagicMock(generation=MagicMock(max_tokens=8192))), \ patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \ patch("nanobot.bus.queue.MessageBus"), \ patch("nanobot.cron.service.CronService"), \ - patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls: + patch("nanobot.cli.commands.AgentLoop") as mock_agent_loop_cls: agent_loop = MagicMock() agent_loop.channels_config = None agent_loop.process_direct = AsyncMock( @@ -609,6 +611,7 @@ def mock_agent_runtime(tmp_path): ) agent_loop.close_mcp = AsyncMock(return_value=None) mock_agent_loop_cls.return_value = agent_loop + mock_agent_loop_cls.from_config.return_value = agent_loop yield { "config": config, @@ -639,7 +642,7 @@ def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_ assert mock_agent_runtime["sync_templates"].call_args.args == ( mock_agent_runtime["config"].workspace_path, ) - assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == ( + assert mock_agent_runtime["agent_loop_cls"].from_config.call_args.args[0].workspace_path == ( mock_agent_runtime["config"].workspace_path ) mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once() @@ -672,7 +675,7 @@ def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None: ) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) + monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192))) monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) monkeypatch.setattr("nanobot.cron.service.CronService", lambda _store: object()) @@ -680,13 +683,17 @@ def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None: def __init__(self, *args, **kwargs) -> None: pass + @classmethod + def from_config(cls, *args, **kwargs): + return cls(*args, **kwargs) + async def process_direct(self, *_args, **_kwargs): return OutboundMessage(channel="cli", chat_id="direct", content="ok") async def close_mcp(self) -> None: return None - monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop) monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None) result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)]) @@ -707,7 +714,7 @@ def test_agent_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Pa monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) + monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192))) monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) class _FakeCron: @@ -718,6 +725,10 @@ def test_agent_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Pa def __init__(self, *args, **kwargs) -> None: pass + @classmethod + def from_config(cls, *args, **kwargs): + return cls(*args, **kwargs) + async def process_direct(self, *_args, **_kwargs): return OutboundMessage(channel="cli", chat_id="direct", content="ok") @@ -725,7 +736,7 @@ def test_agent_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Pa return None monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron) - monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop) monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None) result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)]) @@ -753,7 +764,7 @@ def test_agent_workspace_override_does_not_migrate_legacy_cron( monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) + monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192))) monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir) @@ -765,6 +776,10 @@ def test_agent_workspace_override_does_not_migrate_legacy_cron( def __init__(self, *args, **kwargs) -> None: pass + @classmethod + def from_config(cls, *args, **kwargs): + return cls(*args, **kwargs) + async def process_direct(self, *_args, **_kwargs): return OutboundMessage(channel="cli", chat_id="direct", content="ok") @@ -772,7 +787,7 @@ def test_agent_workspace_override_does_not_migrate_legacy_cron( return None monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron) - monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop) monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None) result = runner.invoke( @@ -806,7 +821,7 @@ def test_agent_custom_config_workspace_does_not_migrate_legacy_cron( monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) + monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192))) monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir) @@ -818,6 +833,10 @@ def test_agent_custom_config_workspace_does_not_migrate_legacy_cron( def __init__(self, *args, **kwargs) -> None: pass + @classmethod + def from_config(cls, *args, **kwargs): + return cls(*args, **kwargs) + async def process_direct(self, *_args, **_kwargs): return OutboundMessage(channel="cli", chat_id="direct", content="ok") @@ -825,7 +844,7 @@ def test_agent_custom_config_workspace_does_not_migrate_legacy_cron( return None monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron) - monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop) monkeypatch.setattr( "nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None ) @@ -846,7 +865,7 @@ def test_agent_overrides_workspace_path(mock_agent_runtime): assert result.exit_code == 0 assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path) assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,) - assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path + assert mock_agent_runtime["agent_loop_cls"].from_config.call_args.args[0].workspace_path == workspace_path def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, tmp_path: Path): @@ -863,7 +882,7 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),) assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path) assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,) - assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path + assert mock_agent_runtime["agent_loop_cls"].from_config.call_args.args[0].workspace_path == workspace_path def test_agent_hints_about_deprecated_memory_window(mock_agent_runtime, tmp_path): @@ -928,8 +947,8 @@ def _patch_cli_command_runtime( sync_templates or (lambda _path: None), ) monkeypatch.setattr( - "nanobot.cli.commands._make_provider", - provider_factory, + "nanobot.providers.factory.build_provider_for_preset", + lambda *_a, **_k: provider_factory(Config()), ) monkeypatch.setattr( "nanobot.providers.factory.build_provider_snapshot", @@ -962,6 +981,10 @@ def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) - def __init__(self, **kwargs) -> None: seen["workspace"] = kwargs["workspace"] + @classmethod + def from_config(cls, config, bus=None, **kwargs): + return cls(workspace=config.workspace_path, **kwargs) + async def _connect_mcp(self) -> None: return None @@ -985,7 +1008,7 @@ def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) - message_bus=lambda: object(), session_manager=lambda _workspace: object(), ) - monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop) monkeypatch.setattr("nanobot.api.server.create_app", _fake_create_app) monkeypatch.setattr("aiohttp.web.run_app", _fake_run_app) @@ -1077,7 +1100,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context( monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider) + monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *_a, **_k: provider) monkeypatch.setattr( "nanobot.providers.factory.build_provider_snapshot", lambda _config: _test_provider_snapshot(provider, _config), @@ -1117,8 +1140,13 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context( class _FakeAgentLoop: def __init__(self, *args, **kwargs) -> None: self.model = "test-model" + self.provider = object() self.tools = {} + @classmethod + def from_config(cls, *args, **kwargs): + return cls(*args, **kwargs) + async def process_direct(self, *_args, **_kwargs): return OutboundMessage( channel="telegram", @@ -1152,7 +1180,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context( return True monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron) - monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop) monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _StopAfterCronSetup) monkeypatch.setattr( "nanobot.utils.evaluator.evaluate_response", @@ -1181,7 +1209,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context( assert response == "Time to stretch." assert seen["response"] == "Time to stretch." - assert seen["provider"] is provider + assert seen["provider"] is not None # provider resolved inside AgentLoop assert seen["model"] == "test-model" assert seen["task_context"] == ( "The scheduled time has arrived. Deliver this reminder to the user now, " @@ -1228,7 +1256,7 @@ def test_gateway_cron_job_suppresses_intermediate_progress( monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) + monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192))) monkeypatch.setattr( "nanobot.providers.factory.build_provider_snapshot", lambda _config: _test_provider_snapshot(object(), _config), @@ -1248,8 +1276,13 @@ def test_gateway_cron_job_suppresses_intermediate_progress( class _FakeAgentLoop: def __init__(self, *args, **kwargs) -> None: self.model = "test-model" + self.provider = object() self.tools = {} + @classmethod + def from_config(cls, *args, **kwargs): + return cls(*args, **kwargs) + async def process_direct(self, *_args, on_progress=None, **_kwargs): seen["on_progress"] = on_progress return OutboundMessage( @@ -1275,7 +1308,7 @@ def test_gateway_cron_job_suppresses_intermediate_progress( return False monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron) - monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop) monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _StopAfterCronSetup) monkeypatch.setattr( "nanobot.utils.evaluator.evaluate_response", @@ -1480,9 +1513,14 @@ def test_gateway_health_endpoint_binds_and_serves_expected_responses( class _FakeAgentLoop: def __init__(self, **_kwargs) -> None: self.model = "test-model" + self.provider = object() self.dream = _FakeDream() self.sessions = _FakeSessionManager() + @classmethod + def from_config(cls, *args, **kwargs): + return cls(**kwargs) + async def run(self) -> None: await asyncio.Event().wait() @@ -1571,7 +1609,7 @@ def test_gateway_health_endpoint_binds_and_serves_expected_responses( message_bus=lambda: object(), session_manager=lambda _workspace: object(), ) - monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop) monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager) monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCronService) monkeypatch.setattr("nanobot.heartbeat.service.HeartbeatService", _FakeHeartbeatService) diff --git a/tests/config/test_model_presets.py b/tests/config/test_model_presets.py index 0546a5834..a0b6f2c85 100644 --- a/tests/config/test_model_presets.py +++ b/tests/config/test_model_presets.py @@ -99,11 +99,23 @@ def test_preset_not_found_raises_error() -> None: }) +def test_fallback_presets_invalid_preset_raises_error() -> None: + import pytest + with pytest.raises(Exception, match="fallback_presets.*not found"): + Config.model_validate({ + "model_presets": { + "valid": {"model": "gpt-4"}, + }, + "agents": {"defaults": {"fallback_presets": ["invalid_preset"]}}, + }) + + def test_resolve_preset_without_preset_returns_defaults() -> None: - """Backward compat: no preset → resolve_preset returns individual field values.""" + """Backward compat: no explicit preset → resolve_preset returns the auto-created 'default' preset.""" cfg = Config.model_validate({ "agents": {"defaults": {"model": "deepseek-chat"}}, }) + assert cfg.agents.defaults.model_preset == "default" r = cfg.resolve_preset() assert r.model == "deepseek-chat" assert r.max_tokens == 8192 @@ -170,13 +182,14 @@ def test_preset_with_auto_provider_uses_keyword_matching() -> None: def test_backward_compat_no_preset() -> None: - """Existing configs without model_presets work exactly as before.""" + """Existing configs without model_presets are automatically promoted to the 'default' preset.""" cfg = Config.model_validate({ "providers": {"anthropic": {"api_key": "test-key"}}, "agents": {"defaults": {"model": "anthropic/claude-opus-4-5"}}, }) assert cfg.resolve_preset().model == "anthropic/claude-opus-4-5" - assert cfg.agents.defaults.model_preset is None + assert cfg.agents.defaults.model_preset == "default" + assert "default" in cfg.model_presets assert cfg.get_provider_name() == "anthropic" @@ -204,3 +217,48 @@ def test_resolve_preset_overrides_all_model_fields() -> None: def test_empty_model_presets_dict_is_harmless() -> None: cfg = Config.model_validate({"model_presets": {}}) assert cfg.resolve_preset().model == "anthropic/claude-opus-4-5" + + +def test_factory_uses_preset_provider_not_defaults() -> None: + """When creating a provider for a non-active preset, the preset's own provider must be used.""" + from nanobot.providers.factory import make_provider_factory + + cfg = Config.model_validate({ + "model_presets": { + "kimi": {"model": "kimi-k2.6", "provider": "moonshot"}, + "zhipu": {"model": "glm-5.1", "provider": "zhipu"}, + }, + "providers": { + "moonshot": {"api_key": "moonshot-key", "api_base": "https://api.moonshot.ai/v1"}, + "zhipu": {"api_key": "zhipu-key", "api_base": "https://open.bigmodel.cn/api/paas/v4"}, + }, + "agents": {"defaults": {"model_preset": "kimi"}}, + }) + + factory = make_provider_factory(cfg) + zhipu_provider = factory("zhipu") + + assert zhipu_provider.api_base == "https://open.bigmodel.cn/api/paas/v4" + assert getattr(zhipu_provider, "api_key", None) == "zhipu-key" + + # Also verify the active preset provider is still correct + moonshot_provider = factory("kimi") + assert moonshot_provider.api_base == "https://api.moonshot.ai/v1" + + +def test_factory_rejects_unknown_preset_name() -> None: + """Factory must raise ValueError when asked for a preset not in model_presets.""" + import pytest + + from nanobot.providers.factory import make_provider_factory + + cfg = Config.model_validate({ + "model_presets": { + "known": {"model": "gpt-4", "provider": "openai"}, + }, + "providers": {"openai": {"api_key": "test-key"}}, + }) + + factory = make_provider_factory(cfg) + with pytest.raises(ValueError, match="Preset 'unknown' not found"): + factory("unknown") diff --git a/tests/test_nanobot_facade.py b/tests/test_nanobot_facade.py index 009c1c20d..f1c539ae9 100644 --- a/tests/test_nanobot_facade.py +++ b/tests/test_nanobot_facade.py @@ -39,7 +39,7 @@ def test_from_config_default_path(): from nanobot.config.schema import Config with patch("nanobot.config.loader.load_config") as mock_load, \ - patch("nanobot.nanobot._make_provider") as mock_prov: + patch("nanobot.providers.factory.build_provider_for_preset") as mock_prov: mock_load.return_value = Config() mock_prov.return_value = MagicMock() mock_prov.return_value.get_default_model.return_value = "test" @@ -127,7 +127,7 @@ def test_workspace_override(tmp_path): def test_sdk_make_provider_uses_github_copilot_backend(): from nanobot.config.schema import Config - from nanobot.nanobot import _make_provider + from nanobot.providers.factory import make_provider config = Config.model_validate( { @@ -141,7 +141,7 @@ def test_sdk_make_provider_uses_github_copilot_backend(): ) with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): - provider = _make_provider(config) + provider = make_provider(config) assert provider.__class__.__name__ == "GitHubCopilotProvider" diff --git a/tests/test_preset_failover_smoke.py b/tests/test_preset_failover_smoke.py new file mode 100644 index 000000000..09bb264f8 --- /dev/null +++ b/tests/test_preset_failover_smoke.py @@ -0,0 +1,467 @@ +"""End-to-end smoke tests for model presets + failover. + +Uses a local aiohttp fake OpenAI server so requests are real HTTP, +not mocked at the provider level. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from nanobot.nanobot import Nanobot +from nanobot.providers.base import GenerationSettings, LLMProvider +from nanobot.providers.failover import ModelRouter +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + +try: + from aiohttp import web + from aiohttp.test_utils import TestServer + + HAS_AIOHTTP = True +except ImportError: + HAS_AIOHTTP = False + + +@pytest.fixture(autouse=True) +def _disable_proxy_for_localhost_tests(monkeypatch): + """Prevent httpx from routing localhost requests through a system proxy.""" + monkeypatch.delenv("ALL_PROXY", raising=False) + monkeypatch.delenv("HTTP_PROXY", raising=False) + monkeypatch.delenv("HTTPS_PROXY", raising=False) + monkeypatch.setenv("NO_PROXY", "127.0.0.1,localhost") + + +# --------------------------------------------------------------------------- +# Helpers (mock-level preset tests) +# --------------------------------------------------------------------------- + +def _write_config(tmp_path: Path, **overrides) -> Path: + data = { + "providers": { + "openrouter": {"apiKey": "sk-test-key"}, + "openai": {"apiKey": "sk-openai-test"}, + }, + "agents": {"defaults": {"model": "openai/gpt-4.1"}}, + "tools": {"my": {"allowSet": True}}, + } + data.update(overrides) + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(data)) + return config_path + + + +# --------------------------------------------------------------------------- +# 1. Model Preset Mock Tests +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_preset_loaded_at_startup(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + model_presets={ + "fast": { + "model": "gpt-4.1-mini", + "provider": "openai", + "max_tokens": 4096, + "context_window_tokens": 128000, + "temperature": 0.3, + } + }, + agents={"defaults": {"model_preset": "fast", "model": "ignored-model"}}, + ) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + loop = bot._loop + assert loop.model == "gpt-4.1-mini" + assert loop.context_window_tokens == 128000 + assert loop.provider.generation.temperature == 0.3 + assert loop.provider.generation.max_tokens == 4096 + assert loop.model_preset == "fast" + + +@pytest.mark.asyncio +async def test_preset_runtime_switch_updates_all_fields(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + model_presets={ + "cheap": { + "model": "gpt-4.1-mini", + "provider": "openai", + "max_tokens": 2048, + "context_window_tokens": 64000, + "temperature": 0.5, + }, + "power": { + "model": "gpt-4.1", + "provider": "openai", + "max_tokens": 8192, + "context_window_tokens": 256000, + "temperature": 0.1, + }, + }, + agents={"defaults": {"model_preset": "cheap"}}, + ) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + loop = bot._loop + assert loop.model == "gpt-4.1-mini" + + my_tool = loop.tools.get("my") + result = await my_tool.execute(action="set", key="model_preset", value="power") + assert "Error" not in result + + assert loop.model == "gpt-4.1" + assert loop.context_window_tokens == 256000 + assert loop.provider.generation.temperature == 0.1 + assert loop.provider.generation.max_tokens == 8192 + assert loop.model_preset == "power" + + +@pytest.mark.asyncio +async def test_preset_switch_unknown_returns_error(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + model_presets={"a": {"model": "model-a"}}, + agents={"defaults": {"model_preset": "a"}}, + ) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + loop = bot._loop + original_model = loop.model + + my_tool = loop.tools.get("my") + result = await my_tool.execute(action="set", key="model_preset", value="nonexistent") + assert "not found" in result.lower() + + assert loop.model == original_model + assert loop.model_preset == "a" + + +@pytest.mark.asyncio +async def test_preset_model_with_fallback_presets_in_config(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + model_presets={ + "prod": { + "model": "gpt-4.1", + "provider": "openai", + "max_tokens": 8192, + "temperature": 0.1, + }, + "fallback": { + "model": "gpt-4.1-mini", + "provider": "openai", + "max_tokens": 4096, + "temperature": 0.2, + }, + }, + agents={ + "defaults": { + "model_preset": "prod", + "fallback_presets": ["fallback"], + } + }, + ) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + loop = bot._loop + assert loop.model == "gpt-4.1" + assert isinstance(loop.provider, ModelRouter) + assert loop.provider.fallback_presets == ["fallback"] + + +@pytest.mark.asyncio +async def test_fallback_presets_wired_to_all_subsystems(tmp_path: Path) -> None: + """When fallback_presets is configured, every subsystem that calls the LLM + must use the same ModelRouter instance, not the raw primary provider.""" + config_path = _write_config( + tmp_path, + model_presets={ + "prod": { + "model": "gpt-4.1", + "provider": "openai", + "max_tokens": 8192, + "temperature": 0.1, + }, + "fallback": { + "model": "gpt-4.1-mini", + "provider": "openai", + "max_tokens": 4096, + "temperature": 0.2, + }, + }, + agents={ + "defaults": { + "model_preset": "prod", + "fallback_presets": ["fallback"], + } + }, + ) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + loop = bot._loop + router = loop.provider + assert isinstance(router, ModelRouter) + + # Every LLM-consuming subsystem must share the same router + assert loop.runner.provider is router, "AgentRunner must use ModelRouter" + assert loop.subagents.provider is router, "SubagentManager must use ModelRouter" + assert loop.consolidator.provider is router, "Consolidator must use ModelRouter" + assert loop.dream.provider is router, "Dream must use ModelRouter" + + +# --------------------------------------------------------------------------- +# 2. Real HTTP Smoke Tests (aiohttp fake OpenAI server) +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_preset_generation_params_reach_http_request() -> None: + """Provider.generation settings must appear in the actual HTTP request body.""" + requests_log: list[dict] = [] + + async def handler(request: web.Request) -> web.Response: + body = await request.json() + requests_log.append(body) + return web.json_response({ + "id": "chatcmpl-test", + "object": "chat.completion", + "model": body.get("model"), + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "pong"}, + "finish_reason": "stop", + }], + }) + + app = web.Application() + app.router.add_post("/chat/completions", handler) + server = TestServer(app) + await server.start_server() + try: + base_url = str(server.make_url("/")) + provider = OpenAICompatProvider( + api_key="test", + api_base=base_url, + default_model="test-model", + ) + provider.generation = GenerationSettings(temperature=0.42, max_tokens=1024) + + with patch.object(LLMProvider, "_CHAT_RETRY_DELAYS", (0,)): + response = await provider.chat_with_retry( + messages=[{"role": "user", "content": "ping"}], + ) + + assert response.finish_reason != "error" + assert len(requests_log) >= 1 + req = requests_log[0] + assert req["model"] == "test-model" + assert req["temperature"] == 0.42 + assert req["max_tokens"] == 1024 + finally: + await server.close() + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_failover_sends_second_request_to_fallback_model() -> None: + """Primary returns 503; after retry exhaustion ModelRouter hits fallback.""" + requests_log: list[dict] = [] + + async def handler(request: web.Request) -> web.Response: + body = await request.json() + requests_log.append(body) + model = body.get("model") + + if model == "primary-model": + return web.Response( + status=503, + body=json.dumps({"error": {"message": "overloaded", "type": "server_error"}}), + content_type="application/json", + ) + + return web.json_response({ + "id": "chatcmpl-test", + "object": "chat.completion", + "model": model, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "fallback-ok"}, + "finish_reason": "stop", + }], + }) + + app = web.Application() + app.router.add_post("/chat/completions", handler) + server = TestServer(app) + await server.start_server() + try: + base_url = str(server.make_url("/")) + primary = OpenAICompatProvider( + api_key="test", api_base=base_url, default_model="primary-model" + ) + fallback = OpenAICompatProvider( + api_key="test", api_base=base_url, default_model="fallback-model" + ) + + factory = MagicMock(return_value=fallback) + + router = ModelRouter( + primary_provider=primary, + primary_model="primary-model", + fallback_presets=["fallback-model"], + provider_factory=factory, + ) + + with patch.object(LLMProvider, "_CHAT_RETRY_DELAYS", (0,)): + response = await router.chat_with_retry( + messages=[{"role": "user", "content": "hi"}], + ) + + assert response.finish_reason != "error" + assert response.content == "fallback-ok" + + models_requested = [r["model"] for r in requests_log] + assert "primary-model" in models_requested + assert "fallback-model" in models_requested + factory.assert_called_once_with("fallback-model") + finally: + await server.close() + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_failover_on_quota_429() -> None: + """Quota 429 on one provider may still work on a different provider.""" + requests_log: list[dict] = [] + + async def handler(request: web.Request) -> web.Response: + body = await request.json() + requests_log.append(body) + return web.Response( + status=429, + body=json.dumps({ + "error": { + "message": "insufficient quota", + "type": "insufficient_quota", + "code": "insufficient_quota", + } + }), + content_type="application/json", + ) + + app = web.Application() + app.router.add_post("/chat/completions", handler) + server = TestServer(app) + await server.start_server() + try: + base_url = str(server.make_url("/")) + primary = OpenAICompatProvider( + api_key="test", api_base=base_url, default_model="primary-model" + ) + fallback = OpenAICompatProvider( + api_key="test", api_base=base_url, default_model="fallback-model" + ) + + factory = MagicMock(return_value=fallback) + + router = ModelRouter( + primary_provider=primary, + primary_model="primary-model", + fallback_presets=["fallback-model"], + provider_factory=factory, + ) + + with patch.object(LLMProvider, "_CHAT_RETRY_DELAYS", (0,)): + response = await router.chat_with_retry( + messages=[{"role": "user", "content": "hi"}], + ) + + # Quota 429 SHOULD trigger failover — another provider may still work. + factory.assert_called_once_with("fallback-model") + assert response.finish_reason == "error" + # Both primary and fallback should have been requested. + assert len(requests_log) == 2 + finally: + await server.close() + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_model_router_failover_integration() -> None: + """ModelRouter -> real HTTP failover chain (primary 503, fallback 200).""" + requests_log: list[dict] = [] + + async def handler(request: web.Request) -> web.Response: + body = await request.json() + requests_log.append(body) + model = body.get("model") + + if model == "primary-model": + return web.Response( + status=503, + body=json.dumps({"error": {"message": "overloaded", "type": "server_error"}}), + content_type="application/json", + ) + + return web.json_response({ + "id": "chatcmpl-test", + "object": "chat.completion", + "model": model, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "fallback-ok"}, + "finish_reason": "stop", + }], + }) + + app = web.Application() + app.router.add_post("/chat/completions", handler) + server = TestServer(app) + await server.start_server() + try: + base_url = str(server.make_url("/")) + primary = OpenAICompatProvider( + api_key="test", api_base=base_url, default_model="primary-model" + ) + fallback = OpenAICompatProvider( + api_key="test", api_base=base_url, default_model="fallback-model" + ) + + factory = MagicMock(return_value=fallback) + + router = ModelRouter( + primary_provider=primary, + primary_model="primary-model", + fallback_presets=["fallback-model"], + provider_factory=factory, + ) + + with patch.object(LLMProvider, "_CHAT_RETRY_DELAYS", (0,)): + response = await router.chat_with_retry( + messages=[{"role": "user", "content": "hello"}], + ) + + assert response.finish_reason != "error" + assert response.content == "fallback-ok" + models_requested = [r["model"] for r in requests_log] + assert "primary-model" in models_requested + assert "fallback-model" in models_requested + factory.assert_called_once_with("fallback-model") + finally: + await server.close()