mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
refactor: restrict fallback_models to preset-only and clean up provider factory
- Restrict fallback_models to only reference preset names in model_presets. - Add schema validation to reject unknown preset names in fallback_models. - Remove build_provider_for_model() since bare model fallback is no longer supported. - Simplify make_provider_factory() to only look up presets by name. - Update onboard UI to remove "Add custom model" option from fallback chain. - Update tests to use preset names instead of bare model strings in fallback chains. - Fix test imports referencing deleted _make_provider function.
This commit is contained in:
parent
ecbe56dd92
commit
0bc42e2ab2
@ -123,6 +123,7 @@
|
||||
- **Ultra-lightweight**: stable long-running agent behavior with a small, readable core.
|
||||
- **Research-ready**: the codebase is intentionally simple enough to study, modify, and extend.
|
||||
- **Practical**: chat channels, API, memory, MCP, and deployment paths are already built in.
|
||||
- **Runtime model switching**: define [model presets](docs/configuration.md#model-presets) and switch between cheap/fast and powerful models mid-conversation — no restart required.
|
||||
- **Hackable**: you can start fast, then go deeper through repo docs instead of a monolithic landing page.
|
||||
|
||||
## 📦 Install
|
||||
|
||||
@ -656,6 +656,146 @@ That's it! Environment variables, model routing, config matching, and `nanobot s
|
||||
|
||||
</details>
|
||||
|
||||
## Agent Settings
|
||||
|
||||
### Model Presets
|
||||
|
||||
Model presets let you define **named bundles** of model + generation parameters and switch between them instantly — no restart required.
|
||||
|
||||
> [!NOTE]
|
||||
> Config fields in `config.json` use **camelCase** (`modelPreset`, `contextWindowTokens`).
|
||||
> The [`my` tool](./my-tool.md) uses **snake_case** (`model_preset`, `context_window_tokens`).
|
||||
> Both refer to the same thing — just different naming conventions for config vs. runtime API.
|
||||
|
||||
**Why use presets?**
|
||||
- Switch between a cheap/fast model and a powerful model mid-conversation.
|
||||
- Share the same config across different tasks without manually editing `model`, `provider`, `temperature`, etc.
|
||||
- Runtime switching via the [`my` tool](./my-tool.md).
|
||||
|
||||
> [!TIP]
|
||||
> The easiest way to set up presets and fallback models is through the interactive wizard:
|
||||
> ```bash
|
||||
> nanobot onboard --wizard
|
||||
> ```
|
||||
> Choose **"[M] Model Presets"** to create, edit, or delete presets interactively.
|
||||
|
||||
**Configuration example:**
|
||||
|
||||
```json
|
||||
{
|
||||
"modelPresets": {
|
||||
"fast": {
|
||||
"model": "gpt-4.1-mini",
|
||||
"provider": "openai",
|
||||
"maxTokens": 4096,
|
||||
"contextWindowTokens": 128000,
|
||||
"temperature": 0.3
|
||||
},
|
||||
"deep": {
|
||||
"model": "claude-opus-4-7",
|
||||
"provider": "anthropic",
|
||||
"maxTokens": 8192,
|
||||
"contextWindowTokens": 200000,
|
||||
"temperature": 0.1,
|
||||
"reasoningEffort": "high"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"modelPreset": "fast"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Preset fields:**
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `model` | string | *(required)* | Model identifier, e.g. `anthropic/claude-opus-4-7` or `gpt-4.1` |
|
||||
| `provider` | string | `"auto"` | Provider name or `"auto"` to infer from the model string |
|
||||
| `maxTokens` | integer | `8192` | Max completion tokens per turn |
|
||||
| `contextWindowTokens` | integer | `65536` | Context window size for token budgeting |
|
||||
| `temperature` | float | `0.1` | Sampling temperature |
|
||||
| `reasoningEffort` | string or null | `null` | Thinking mode: `low`, `medium`, `high`, `adaptive` |
|
||||
|
||||
**How it works:**
|
||||
- When `modelPreset` is set, the preset **completely overrides** all model-specific fields in `agents.defaults`.
|
||||
- When `modelPreset` is omitted, nanobot automatically creates an implicit `"default"` preset from your existing `agents.defaults.model`, `provider`, `temperature`, etc. — **zero migration required** for existing configs.
|
||||
|
||||
**Runtime switching** (requires `tools.my.allowSet: true`):
|
||||
|
||||
```text
|
||||
my(action="set", key="model_preset", value="deep")
|
||||
```
|
||||
|
||||
This atomically swaps the model, provider, generation parameters, and context window for the next turn.
|
||||
|
||||
If the preset name does not exist, the agent receives an error such as `model_preset 'unknown' not found. Available: fast, deep`.
|
||||
|
||||
> [!NOTE]
|
||||
> Directly modifying `model` or `contextWindowTokens` via `my(action="set", key="model", ...)` still works, but it automatically clears the active preset because the live state no longer matches the preset bundle. Use `model_preset` for atomic switches instead.
|
||||
|
||||
See [`my-tool.md`](./my-tool.md) for more runtime examples.
|
||||
|
||||
---
|
||||
|
||||
### Fallback Models
|
||||
|
||||
When the primary model returns a transient error (rate limit, server overload, quota exhausted), nanobot can automatically fail over to a chain of backup models.
|
||||
|
||||
**Configuration example:**
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"modelPreset": "fast",
|
||||
"fallbackModels": ["deep", "backup"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. nanobot tries the primary model first (the one from the active preset).
|
||||
2. The provider retries transient errors internally (e.g. 3 attempts with exponential backoff for 503/429).
|
||||
3. Only after the provider's own retries are exhausted and the final response still has `finish_reason == "error"` with a retryable error kind, nanobot moves to the next candidate in `fallbackModels`.
|
||||
4. Each candidate must be a preset name defined in `modelPresets`. The preset's full config (model, provider, generation params) is used.
|
||||
5. If all candidates are exhausted, the final error is returned to the user.
|
||||
|
||||
**Failover triggers on:**
|
||||
- `server_error` (503, 502, 500)
|
||||
- `rate_limit` (429)
|
||||
- `insufficient_quota` / `quota_exhausted` (429)
|
||||
|
||||
**Failover does NOT trigger on:**
|
||||
- Authentication errors (401) — rotating to another model with the same key won't help
|
||||
- Invalid request errors (400) — the request itself is malformed
|
||||
|
||||
> [!TIP]
|
||||
> Fallback models must reference preset names defined in `modelPresets`. Define a preset for each fallback model you want to use: `["cheap-preset", "backup", "emergency"]`.
|
||||
|
||||
---
|
||||
|
||||
### Other Agent Defaults
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `agents.defaults.model` | string | `"anthropic/claude-opus-4-5"` | Default model when no preset is active |
|
||||
| `agents.defaults.provider` | string | `"auto"` | Default provider when no preset is active |
|
||||
| `agents.defaults.maxTokens` | integer | `8192` | Max completion tokens when no preset is active |
|
||||
| `agents.defaults.temperature` | float | `0.1` | Sampling temperature when no preset is active |
|
||||
| `agents.defaults.reasoningEffort` | string or null | `null` | Thinking mode when no preset is active |
|
||||
| `agents.defaults.maxToolIterations` | integer | `200` | Max tool calls per conversation turn |
|
||||
| `agents.defaults.maxToolResultChars` | integer | `16000` | Max characters per tool result |
|
||||
| `agents.defaults.providerRetryMode` | string | `"standard"` | `"standard"` or `"persistent"` — how aggressively to retry provider-level errors |
|
||||
| `agents.defaults.timezone` | string | `"UTC"` | IANA timezone for runtime context |
|
||||
| `agents.defaults.unifiedSession` | boolean | `false` | Share one session across all channels |
|
||||
| `agents.defaults.sessionTtlMinutes` | integer | `0` | Auto-compact idle threshold (0 = disabled) |
|
||||
| `agents.defaults.maxMessages` | integer | `120` | Max messages to replay from session history |
|
||||
| `agents.defaults.consolidationRatio` | float | `0.5` | Target ratio retained after context compression |
|
||||
|
||||
## Channel Settings
|
||||
|
||||
Global settings that apply to all channels. Configure under the `channels` section in `~/.nanobot/config.json`:
|
||||
|
||||
@ -12,6 +12,11 @@ My tool fills this gap. With it, the agent can:
|
||||
- **Adapt on the fly**: Complex task? Expand the context window. Simple chat? Switch to a faster model.
|
||||
- **Remember across turns**: Store notes in your scratchpad that persist into the next conversation turn.
|
||||
|
||||
> [!NOTE]
|
||||
> This tool uses **snake_case** keys (`model_preset`, `context_window_tokens`).
|
||||
> The matching config fields in `config.json` are **camelCase** (`modelPreset`, `contextWindowTokens`).
|
||||
> See [`configuration.md`](./configuration.md#model-presets) for how to define presets in your config.
|
||||
|
||||
## Configuration
|
||||
|
||||
Enabled by default (read-only mode). The agent can check its state but not set it.
|
||||
@ -39,8 +44,7 @@ Without parameters, returns a key config overview:
|
||||
```text
|
||||
my(action="check")
|
||||
# → max_iterations: 40
|
||||
# context_window_tokens: 65536
|
||||
# model: 'anthropic/claude-sonnet-4-20250514'
|
||||
# model_preset: 'fast'
|
||||
# workspace: PosixPath('/tmp/workspace')
|
||||
# provider_retry_mode: 'standard'
|
||||
# max_tool_result_chars: 16000
|
||||
@ -55,8 +59,13 @@ With a key parameter, drill into a specific config:
|
||||
my(action="check", key="_last_usage.prompt_tokens")
|
||||
# → How many prompt tokens I've used so far
|
||||
|
||||
my(action="check", key="model")
|
||||
# → What model I'm currently running on
|
||||
my(action="check", key="model_preset")
|
||||
# → Current active preset name (e.g. 'fast')
|
||||
|
||||
my(action="check", key="model_presets")
|
||||
# → Lists all preset names and their models, e.g.:
|
||||
# fast → gpt-4.1-mini (openai)
|
||||
# deep → claude-opus-4-7 (anthropic)
|
||||
|
||||
my(action="check", key="web_config.enable")
|
||||
# → Whether web search is enabled
|
||||
@ -66,7 +75,7 @@ my(action="check", key="web_config.enable")
|
||||
|
||||
| Scenario | How |
|
||||
|----------|-----|
|
||||
| "What model are you using?" | `check("model")` |
|
||||
| "What model are you using?" | `check("model_preset")` |
|
||||
| "How many more tool calls can you make?" | `check("max_iterations")` minus `check("_current_iteration")` |
|
||||
| "How many tokens has this conversation used?" | `check("_last_usage")` — cumulative across all turns |
|
||||
| "Where is your working directory?" | `check("workspace")` |
|
||||
@ -83,8 +92,11 @@ Changes take effect immediately, no restart required.
|
||||
my(action="set", key="max_iterations", value=80)
|
||||
# → Bump iteration limit from 40 to 80
|
||||
|
||||
my(action="set", key="model", value="fast-model")
|
||||
# → Switch to a faster model
|
||||
my(action="set", key="model_preset", value="fast")
|
||||
# → Switch to the 'fast' preset (model, provider, temperature, etc. all at once)
|
||||
#
|
||||
# If the preset name does not exist:
|
||||
# → Error: model_preset 'unknown' not found. Available: fast, deep
|
||||
|
||||
my(action="set", key="context_window_tokens", value=131072)
|
||||
# → Expand context window for long documents
|
||||
@ -101,15 +113,17 @@ my(action="set", key="task_complexity", value="high")
|
||||
|
||||
### Protected parameters
|
||||
|
||||
These parameters have type and range validation — invalid values are rejected:
|
||||
These parameters have validation — invalid values are rejected:
|
||||
|
||||
| Parameter | Type | Range | Purpose |
|
||||
|-----------|------|-------|---------|
|
||||
| Parameter | Type | Range / Constraint | Purpose |
|
||||
|-----------|------|-------------------|---------|
|
||||
| `max_iterations` | int | 1–100 | Max tool calls per conversation turn |
|
||||
| `context_window_tokens` | int | 4,096–1,000,000 | Context window size |
|
||||
| `model` | str | non-empty | LLM model to use |
|
||||
| `model_preset` | str | must exist in `model_presets` | Switch to a named preset bundle |
|
||||
|
||||
Other parameters (e.g. `workspace`, `provider_retry_mode`, `max_tool_result_chars`) can be set freely, as long as the value is JSON-safe.
|
||||
Other parameters (e.g. `model`, `context_window_tokens`, `workspace`, `provider_retry_mode`, `max_tool_result_chars`) can be set freely, as long as the value is JSON-safe.
|
||||
|
||||
> [!NOTE]
|
||||
> Setting `model` or `context_window_tokens` directly automatically clears the active `model_preset`, because the live state no longer matches the preset bundle. Use `model_preset` for atomic switches instead.
|
||||
|
||||
---
|
||||
|
||||
@ -125,8 +139,8 @@ Agent: This codebase is large, let me expand my context window to handle it.
|
||||
### "Simple question, don't waste compute"
|
||||
|
||||
```text
|
||||
Agent: This is a straightforward question, let me switch to a faster model.
|
||||
→ my(action="set", key="model", value="fast-model")
|
||||
Agent: This is a straightforward question, let me switch to the fast preset.
|
||||
→ my(action="set", key="model_preset", value="fast")
|
||||
```
|
||||
|
||||
### "Remember user preferences across turns"
|
||||
|
||||
@ -95,6 +95,8 @@ Configure these **two parts** in your config (other options have defaults).
|
||||
}
|
||||
```
|
||||
|
||||
*Want to switch models mid-conversation?* Define [`modelPresets`](./configuration.md#model-presets) and switch instantly with `my(action="set", key="model_preset", value="fast")`.
|
||||
|
||||
**3. Chat**
|
||||
|
||||
```bash
|
||||
|
||||
@ -197,6 +197,50 @@ class AgentLoop:
|
||||
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||
_PENDING_USER_TURN_KEY = "pending_user_turn"
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: Any,
|
||||
bus: MessageBus | None = None,
|
||||
**extra: Any,
|
||||
) -> AgentLoop:
|
||||
"""Create an AgentLoop from config with the common parameter set."""
|
||||
from nanobot.providers.factory import build_provider_for_preset, make_provider_factory
|
||||
|
||||
if bus is None:
|
||||
bus = MessageBus()
|
||||
defaults = config.agents.defaults
|
||||
resolved_preset = config.resolve_preset()
|
||||
provider = build_provider_for_preset(config, resolved_preset)
|
||||
return cls(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=resolved_preset.model,
|
||||
max_iterations=defaults.max_tool_iterations,
|
||||
context_window_tokens=resolved_preset.context_window_tokens,
|
||||
context_block_limit=defaults.context_block_limit,
|
||||
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||
provider_retry_mode=defaults.provider_retry_mode,
|
||||
fallback_presets=defaults.fallback_presets,
|
||||
provider_factory=make_provider_factory(config),
|
||||
web_config=config.tools.web,
|
||||
exec_config=config.tools.exec,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
channels_config=config.channels,
|
||||
timezone=defaults.timezone,
|
||||
unified_session=defaults.unified_session,
|
||||
disabled_skills=defaults.disabled_skills,
|
||||
session_ttl_minutes=defaults.session_ttl_minutes,
|
||||
consolidation_ratio=defaults.consolidation_ratio,
|
||||
max_messages=defaults.max_messages,
|
||||
tools_config=config.tools,
|
||||
model_presets=config.model_presets,
|
||||
model_preset=defaults.model_preset,
|
||||
**extra,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bus: MessageBus,
|
||||
@ -209,8 +253,8 @@ class AgentLoop:
|
||||
max_tool_result_chars: int | None = None,
|
||||
provider_retry_mode: str = "standard",
|
||||
tool_hint_max_length: int | None = None,
|
||||
fallback_models: list[str] | None = None,
|
||||
provider_factory: Any | None = None,
|
||||
fallback_presets: list[str] | None = None,
|
||||
provider_factory: Callable[[str], LLMProvider] | None = None,
|
||||
web_config: WebToolsConfig | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
cron_service: CronService | None = None,
|
||||
@ -239,7 +283,12 @@ class AgentLoop:
|
||||
defaults = AgentDefaults()
|
||||
self.bus = bus
|
||||
self.channels_config = channels_config
|
||||
self.provider = provider
|
||||
self.provider_factory = provider_factory
|
||||
self.fallback_presets = fallback_presets or []
|
||||
wrapped_provider = self._wrap_with_failover(
|
||||
provider, model or provider.get_default_model()
|
||||
)
|
||||
self.provider = wrapped_provider
|
||||
self._provider_snapshot_loader = provider_snapshot_loader
|
||||
self._provider_signature = provider_signature
|
||||
self.workspace = workspace
|
||||
@ -263,7 +312,6 @@ class AgentLoop:
|
||||
tool_hint_max_length if tool_hint_max_length is not None
|
||||
else defaults.tool_hint_max_length
|
||||
)
|
||||
self.fallback_models = fallback_models or []
|
||||
self.web_config = web_config or WebToolsConfig()
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.tools_config = _tc
|
||||
@ -278,15 +326,16 @@ class AgentLoop:
|
||||
self._start_time = time.time()
|
||||
self._last_usage: dict[str, int] = {}
|
||||
self._extra_hooks: list[AgentHook] = hooks or []
|
||||
|
||||
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
self.tools = ToolRegistry()
|
||||
# One file-read/write tracker per logical session. The tool registry is
|
||||
# shared by this loop, so tools resolve the active state via contextvars.
|
||||
self._file_state_store = FileStateStore()
|
||||
self.runner = AgentRunner(provider, provider_factory=provider_factory)
|
||||
self.runner = AgentRunner(wrapped_provider)
|
||||
self.subagents = SubagentManager(
|
||||
provider=provider,
|
||||
provider=wrapped_provider,
|
||||
workspace=workspace,
|
||||
bus=bus,
|
||||
model=self.model,
|
||||
@ -318,13 +367,13 @@ class AgentLoop:
|
||||
)
|
||||
self.consolidator = Consolidator(
|
||||
store=self.context.memory,
|
||||
provider=provider,
|
||||
provider=wrapped_provider,
|
||||
model=self.model,
|
||||
sessions=self.sessions,
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
build_messages=self.context.build_messages,
|
||||
get_tool_definitions=self.tools.get_definitions,
|
||||
max_completion_tokens=provider.generation.max_tokens,
|
||||
max_completion_tokens=wrapped_provider.generation.max_tokens,
|
||||
consolidation_ratio=consolidation_ratio,
|
||||
)
|
||||
self.auto_compact = AutoCompact(
|
||||
@ -334,11 +383,13 @@ class AgentLoop:
|
||||
)
|
||||
self.dream = Dream(
|
||||
store=self.context.memory,
|
||||
provider=provider,
|
||||
provider=wrapped_provider,
|
||||
model=self.model,
|
||||
)
|
||||
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
|
||||
self._active_preset: str | None = model_preset if model_presets and model_preset in model_presets else None
|
||||
self._active_preset: str | None = (
|
||||
model_preset if model_preset in self.model_presets else None
|
||||
)
|
||||
self._register_default_tools()
|
||||
if _tc.my.enable:
|
||||
self.tools.register(MyTool(loop=self, modify_allowed=_tc.my.allow_set))
|
||||
@ -351,6 +402,38 @@ class AgentLoop:
|
||||
"""Keep subagent runtime limits aligned with mutable loop settings."""
|
||||
self.subagents.max_iterations = self.max_iterations
|
||||
|
||||
def _wrap_with_failover(self, provider: LLMProvider, model: str) -> LLMProvider:
|
||||
"""Wrap provider with failover router when fallback_presets are configured."""
|
||||
if not self.fallback_presets or not self.provider_factory:
|
||||
return provider
|
||||
from nanobot.providers.failover import ModelRouter
|
||||
|
||||
if isinstance(provider, ModelRouter):
|
||||
return provider
|
||||
|
||||
return ModelRouter(
|
||||
primary_provider=provider,
|
||||
primary_model=model,
|
||||
fallback_presets=self.fallback_presets,
|
||||
provider_factory=self.provider_factory,
|
||||
)
|
||||
|
||||
def _apply_provider_state(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
context_window_tokens: int,
|
||||
) -> None:
|
||||
"""Push provider/model/context_window to all LLM-consuming subsystems."""
|
||||
self.provider = provider
|
||||
# Bypass property setters so internal updates don't clear _active_preset.
|
||||
object.__setattr__(self, "_model", model)
|
||||
object.__setattr__(self, "_context_window_tokens", context_window_tokens)
|
||||
self.runner.provider = provider
|
||||
self.subagents.set_provider(provider, model)
|
||||
self.consolidator.set_provider(provider, model, context_window_tokens)
|
||||
self.dream.set_provider(provider, model)
|
||||
|
||||
def _apply_provider_snapshot(self, snapshot: ProviderSnapshot) -> None:
|
||||
"""Swap model/provider for future turns without disturbing an active one."""
|
||||
provider = snapshot.provider
|
||||
@ -359,14 +442,13 @@ class AgentLoop:
|
||||
if self.provider is provider and self.model == model:
|
||||
return
|
||||
old_model = self.model
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self.runner.provider = provider
|
||||
self.subagents.set_provider(provider, model)
|
||||
self.consolidator.set_provider(provider, model, context_window_tokens)
|
||||
self.dream.set_provider(provider, model)
|
||||
provider = self._wrap_with_failover(provider, model)
|
||||
self._apply_provider_state(provider, model, context_window_tokens)
|
||||
self._provider_signature = snapshot.signature
|
||||
if self._active_preset:
|
||||
preset = self.model_presets.get(self._active_preset)
|
||||
if preset and preset.model != model:
|
||||
self._active_preset = None
|
||||
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
||||
|
||||
def _refresh_provider_snapshot(self) -> None:
|
||||
@ -381,6 +463,28 @@ class AgentLoop:
|
||||
return
|
||||
self._apply_provider_snapshot(snapshot)
|
||||
|
||||
# -- model / context_window_tokens properties with preset invalidation --
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._model
|
||||
|
||||
@model.setter
|
||||
def model(self, value: str) -> None:
|
||||
self._model = value
|
||||
if hasattr(self, "_active_preset"):
|
||||
self._active_preset = None
|
||||
|
||||
@property
|
||||
def context_window_tokens(self) -> int:
|
||||
return self._context_window_tokens
|
||||
|
||||
@context_window_tokens.setter
|
||||
def context_window_tokens(self, value: int) -> None:
|
||||
self._context_window_tokens = value
|
||||
if hasattr(self, "_active_preset"):
|
||||
self._active_preset = None
|
||||
|
||||
# -- model_preset property --
|
||||
|
||||
@property
|
||||
@ -388,22 +492,27 @@ class AgentLoop:
|
||||
return self._active_preset
|
||||
|
||||
@model_preset.setter
|
||||
def model_preset(self, name: str | None) -> None:
|
||||
"""Resolve a preset by name and apply all fields atomically."""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
|
||||
def model_preset(self, name: str) -> None:
|
||||
"""Resolve a preset by name and apply all fields."""
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ValueError("model_preset must be a non-empty string")
|
||||
if name not in self.model_presets:
|
||||
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
|
||||
raise KeyError(
|
||||
f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}"
|
||||
)
|
||||
if self.provider_factory is None:
|
||||
raise ValueError("provider_factory is not configured; cannot switch model preset")
|
||||
|
||||
p = self.model_presets[name]
|
||||
self.model = p.model
|
||||
self.context_window_tokens = p.context_window_tokens
|
||||
self.provider.generation = GenerationSettings(
|
||||
temperature=p.temperature,
|
||||
max_tokens=p.max_tokens,
|
||||
reasoning_effort=p.reasoning_effort,
|
||||
)
|
||||
new_provider = self._wrap_with_failover(self.provider_factory(name), p.model)
|
||||
|
||||
# Preserve dream model_override if it differs from the current loop model.
|
||||
old_dream_model = self.dream.model
|
||||
dream_had_override = old_dream_model != self.model
|
||||
|
||||
self._apply_provider_state(new_provider, p.model, p.context_window_tokens)
|
||||
if dream_had_override:
|
||||
self.dream.model = old_dream_model
|
||||
self._active_preset = name
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
@ -710,7 +819,6 @@ class AgentLoop:
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
context_block_limit=self.context_block_limit,
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
fallback_models=self.fallback_models,
|
||||
progress_callback=on_progress,
|
||||
stream_progress_deltas=on_stream is not None,
|
||||
retry_wait_callback=on_retry_wait,
|
||||
|
||||
@ -16,7 +16,6 @@ from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.tools.ask import AskUserInterrupt
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.failover import ModelCandidate, ModelRouter
|
||||
from nanobot.utils.helpers import (
|
||||
build_assistant_message,
|
||||
estimate_message_tokens,
|
||||
@ -76,7 +75,6 @@ class AgentRunSpec:
|
||||
context_window_tokens: int | None = None
|
||||
context_block_limit: int | None = None
|
||||
provider_retry_mode: str = "standard"
|
||||
fallback_models: list[str] = field(default_factory=list)
|
||||
progress_callback: Any | None = None
|
||||
stream_progress_deltas: bool = True
|
||||
retry_wait_callback: Any | None = None
|
||||
@ -99,21 +97,11 @@ class AgentRunResult:
|
||||
had_injections: bool = False
|
||||
|
||||
|
||||
ProviderFactory = Any # Callable[[str], LLMProvider] — avoids circular import
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
"""Run a tool-capable LLM loop without product-layer concerns."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
*,
|
||||
provider_factory: ProviderFactory | None = None,
|
||||
):
|
||||
def __init__(self, provider: LLMProvider):
|
||||
self.provider = provider
|
||||
self._provider_factory = provider_factory
|
||||
self._fallback_providers: dict[str, LLMProvider] = {}
|
||||
|
||||
@staticmethod
|
||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||
@ -606,9 +594,12 @@ class AgentRunner:
|
||||
messages: list[dict[str, Any]],
|
||||
hook: AgentHook,
|
||||
context: AgentHookContext,
|
||||
) -> LLMResponse:
|
||||
):
|
||||
timeout_s: float | None = spec.llm_timeout_s
|
||||
if timeout_s is None:
|
||||
# Default to a finite timeout to avoid per-session lock starvation when an LLM
|
||||
# request hangs indefinitely (e.g. gateway/network stall).
|
||||
# Set NANOBOT_LLM_TIMEOUT_S=0 to disable.
|
||||
raw = os.environ.get("NANOBOT_LLM_TIMEOUT_S", "300").strip()
|
||||
try:
|
||||
timeout_s = float(raw)
|
||||
@ -622,30 +613,12 @@ class AgentRunner:
|
||||
messages,
|
||||
tools=spec.tools.get_definitions(),
|
||||
)
|
||||
provider: LLMProvider = self.provider
|
||||
request_timeout = timeout_s
|
||||
if spec.fallback_models:
|
||||
provider = self._build_model_router(spec, timeout_s)
|
||||
# ModelRouter applies the same timeout per candidate, preserving
|
||||
# fallback on primary timeouts instead of timing out the whole chain.
|
||||
request_timeout = None
|
||||
return await self._call_provider(provider, kwargs, hook, context, spec, request_timeout)
|
||||
|
||||
async def _call_provider(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
kwargs: dict[str, Any],
|
||||
hook: AgentHook,
|
||||
context: AgentHookContext,
|
||||
spec: AgentRunSpec,
|
||||
timeout_s: float | None = None,
|
||||
) -> LLMResponse:
|
||||
wants_streaming = hook.wants_streaming()
|
||||
wants_progress_streaming = (
|
||||
not wants_streaming
|
||||
and spec.stream_progress_deltas
|
||||
and spec.progress_callback is not None
|
||||
and getattr(provider, "supports_progress_deltas", False) is True
|
||||
and getattr(self.provider, "supports_progress_deltas", False) is True
|
||||
)
|
||||
|
||||
if wants_streaming:
|
||||
@ -654,7 +627,7 @@ class AgentRunner:
|
||||
context.streamed_content = True
|
||||
await hook.on_stream(context, delta)
|
||||
|
||||
coro = provider.chat_stream_with_retry(
|
||||
coro = self.provider.chat_stream_with_retry(
|
||||
**kwargs,
|
||||
on_content_delta=_stream,
|
||||
)
|
||||
@ -673,12 +646,12 @@ class AgentRunner:
|
||||
context.streamed_content = True
|
||||
await spec.progress_callback(incremental)
|
||||
|
||||
coro = provider.chat_stream_with_retry(
|
||||
coro = self.provider.chat_stream_with_retry(
|
||||
**kwargs,
|
||||
on_content_delta=_stream_progress,
|
||||
)
|
||||
else:
|
||||
coro = provider.chat_with_retry(**kwargs)
|
||||
coro = self.provider.chat_with_retry(**kwargs)
|
||||
|
||||
if timeout_s is None:
|
||||
return await coro
|
||||
@ -691,41 +664,6 @@ class AgentRunner:
|
||||
error_kind="timeout",
|
||||
)
|
||||
|
||||
def _resolve_fallback_provider(self, model: str) -> tuple[LLMProvider, str]:
|
||||
"""Return (provider, actual_model_name) for a fallback model.
|
||||
|
||||
When a provider_factory is available (and the model string may be a
|
||||
preset name), the factory resolves the actual model; otherwise the
|
||||
primary provider is reused with the raw model string.
|
||||
"""
|
||||
if model in self._fallback_providers:
|
||||
p = self._fallback_providers[model]
|
||||
return p, p.get_default_model()
|
||||
if self._provider_factory:
|
||||
provider = self._provider_factory(model)
|
||||
self._fallback_providers[model] = provider
|
||||
return provider, provider.get_default_model()
|
||||
return self.provider, model
|
||||
|
||||
def _build_model_router(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
timeout_s: float | None,
|
||||
) -> ModelRouter:
|
||||
candidates = [
|
||||
ModelCandidate(
|
||||
label=model,
|
||||
resolver=lambda m=model: self._resolve_fallback_provider(m),
|
||||
)
|
||||
for model in spec.fallback_models
|
||||
]
|
||||
return ModelRouter(
|
||||
primary_provider=self.provider,
|
||||
primary_model=spec.model,
|
||||
fallback_candidates=candidates,
|
||||
per_candidate_timeout_s=timeout_s,
|
||||
)
|
||||
|
||||
async def _request_finalization_retry(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
|
||||
@ -76,8 +76,6 @@ class MyTool(Tool):
|
||||
|
||||
RESTRICTED: dict[str, dict[str, Any]] = {
|
||||
"max_iterations": {"type": int, "min": 1, "max": 100},
|
||||
"context_window_tokens": {"type": int, "min": 4096, "max": 1_000_000},
|
||||
"model": {"type": str, "min_len": 1},
|
||||
}
|
||||
|
||||
_MAX_RUNTIME_KEYS = 64
|
||||
@ -118,13 +116,14 @@ class MyTool(Tool):
|
||||
"Scratchpad keys persist across turns but not restarts.\n"
|
||||
"Key values: _current_iteration (current progress), "
|
||||
"max_iterations - _current_iteration = remaining iterations.\n"
|
||||
"Use 'model_preset' to switch the active model preset.\n"
|
||||
"Note: web_config and exec_config are readable but read-only.\n"
|
||||
"\n"
|
||||
"When to use:\n"
|
||||
"- User asks about your model, settings, or token usage → check that key.\n"
|
||||
"- A tool fails or behaves unexpectedly → check the related config to diagnose.\n"
|
||||
"- User asks you to remember a preference for this session → set to store it in your scratchpad.\n"
|
||||
"- About to start a large task → check context_window_tokens and max_iterations first."
|
||||
"- About to start a large task → check max_iterations and model_preset first."
|
||||
)
|
||||
if not self._modify_allowed:
|
||||
base += "\nREAD-ONLY MODE: set is disabled."
|
||||
@ -132,7 +131,7 @@ class MyTool(Tool):
|
||||
base += (
|
||||
"\nIMPORTANT: Before setting state, predict the potential impact. "
|
||||
"If the operation could cause crashes or instability "
|
||||
"(e.g. changing model), warn the user first."
|
||||
"(e.g. changing model_preset), warn the user first."
|
||||
)
|
||||
return base
|
||||
|
||||
@ -148,7 +147,7 @@ class MyTool(Tool):
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Dot-path for check/set. Examples: 'max_iterations', 'workspace', 'provider_retry_mode'. "
|
||||
"description": "Dot-path for check/set. Examples: 'max_iterations', 'model_preset', 'provider_retry_mode'. "
|
||||
"For check without key, shows all config values.",
|
||||
},
|
||||
"value": {"description": "New value (for set). Type must match target (int for max_iterations/context_window_tokens, str for model)."},
|
||||
@ -391,9 +390,6 @@ class MyTool(Tool):
|
||||
|
||||
# --- existing restricted key logic ---
|
||||
old = getattr(self._loop, key)
|
||||
# When model is set directly, it no longer matches any preset
|
||||
if key == "model":
|
||||
self._loop._active_preset = None
|
||||
if "min" in spec and value < spec["min"]:
|
||||
return f"Error: '{key}' must be >= {spec['min']}"
|
||||
if "max" in spec and value > spec["max"]:
|
||||
@ -419,9 +415,12 @@ class MyTool(Tool):
|
||||
f"REJECTED type mismatch {key}: expects {old_t.__name__}, got {new_t.__name__}",
|
||||
)
|
||||
return f"Error: '{key}' expects {old_t.__name__}, got {new_t.__name__}"
|
||||
# When a model-specific field is set directly, it no longer matches any preset
|
||||
if key in ("model", "context_window_tokens"):
|
||||
self._loop._active_preset = None
|
||||
try:
|
||||
setattr(self._loop, key, value)
|
||||
except (ValueError, KeyError) as e:
|
||||
except (AttributeError, TypeError, ValueError, KeyError) as e:
|
||||
self._audit("modify", f"REJECTED {key}: {e}")
|
||||
return f"Error: {e}"
|
||||
self._audit("modify", f"{key}: {old!r} -> {value!r}")
|
||||
|
||||
@ -48,6 +48,7 @@ from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from nanobot import __logo__, __version__
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
|
||||
|
||||
class SafeFileHistory(FileHistory):
|
||||
@ -437,104 +438,6 @@ def _onboard_plugins(config_path: Path) -> None:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def _make_provider(config: Config):
|
||||
"""Create the appropriate LLM provider from config.
|
||||
|
||||
Routing is driven by ``ProviderSpec.backend`` in the registry.
|
||||
"""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.factory import make_provider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
resolved = config.resolve_preset()
|
||||
model = resolved.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
|
||||
# --- validation ---
|
||||
if backend == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
||||
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
||||
console.print("Use the model field to specify the deployment name.")
|
||||
raise typer.Exit(1)
|
||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||
needs_key = not (p and p.api_key)
|
||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||
if needs_key and not exempt:
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# --- instantiation by backend ---
|
||||
if backend == "openai_codex":
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
elif backend == "azure_openai":
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key,
|
||||
api_base=p.api_base,
|
||||
default_model=model,
|
||||
)
|
||||
elif backend == "github_copilot":
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
provider = GitHubCopilotProvider(default_model=model)
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
spec=spec,
|
||||
extra_body=p.extra_body if p else None,
|
||||
)
|
||||
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=resolved.temperature,
|
||||
max_tokens=resolved.max_tokens,
|
||||
reasoning_effort=resolved.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
def _make_cli_provider_factory(config: Config):
|
||||
"""Build a cached factory for fallback model providers (CLI side).
|
||||
|
||||
Supports preset names: if a model string matches a preset, the preset's
|
||||
full config is used for provider creation.
|
||||
"""
|
||||
from nanobot.nanobot import _make_provider_for_model
|
||||
|
||||
cache: dict[str, Any] = {}
|
||||
presets = getattr(config, "model_presets", {}) or {}
|
||||
|
||||
def factory(model_or_preset: str):
|
||||
preset = presets.get(model_or_preset)
|
||||
actual_model = preset.model if preset else model_or_preset
|
||||
key = actual_model
|
||||
if key not in cache:
|
||||
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
||||
return cache[key]
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||
"""Load config and optionally override the active workspace."""
|
||||
from nanobot.config.loader import load_config, resolve_config_env_vars, set_config_path
|
||||
@ -612,8 +515,6 @@ def serve(
|
||||
raise typer.Exit(1)
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.api.server import create_app
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.session.manager import SessionManager
|
||||
@ -630,45 +531,19 @@ def serve(
|
||||
timeout = timeout if timeout is not None else api_cfg.timeout
|
||||
sync_workspace_templates(runtime_config.workspace_path)
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(runtime_config)
|
||||
defaults = runtime_config.agents.defaults
|
||||
pf = _make_cli_provider_factory(runtime_config) if defaults.fallback_models else None
|
||||
session_manager = SessionManager(runtime_config.workspace_path)
|
||||
_resolved = runtime_config.resolve_preset()
|
||||
agent_loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=runtime_config.workspace_path,
|
||||
model=_resolved.model,
|
||||
max_iterations=defaults.max_tool_iterations,
|
||||
context_window_tokens=_resolved.context_window_tokens,
|
||||
context_block_limit=defaults.context_block_limit,
|
||||
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||
provider_retry_mode=defaults.provider_retry_mode,
|
||||
fallback_models=defaults.fallback_models,
|
||||
provider_factory=pf,
|
||||
web_config=runtime_config.tools.web,
|
||||
exec_config=runtime_config.tools.exec,
|
||||
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
|
||||
resolved_preset = runtime_config.resolve_preset()
|
||||
agent_loop = AgentLoop.from_config(
|
||||
runtime_config, bus,
|
||||
session_manager=session_manager,
|
||||
mcp_servers=runtime_config.tools.mcp_servers,
|
||||
channels_config=runtime_config.channels,
|
||||
timezone=runtime_config.agents.defaults.timezone,
|
||||
unified_session=runtime_config.agents.defaults.unified_session,
|
||||
disabled_skills=runtime_config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes,
|
||||
consolidation_ratio=runtime_config.agents.defaults.consolidation_ratio,
|
||||
max_messages=runtime_config.agents.defaults.max_messages,
|
||||
tools_config=runtime_config.tools,
|
||||
image_generation_provider_configs={
|
||||
"openrouter": runtime_config.providers.openrouter,
|
||||
"aihubmix": runtime_config.providers.aihubmix,
|
||||
},
|
||||
model_presets=runtime_config.model_presets,
|
||||
model_preset=runtime_config.agents.defaults.model_preset,
|
||||
)
|
||||
|
||||
model_name = _resolved.model
|
||||
model_name = resolved_preset.model
|
||||
preset_name = defaults.model_preset
|
||||
preset_tag = f" (preset: {preset_name})" if preset_name else ""
|
||||
console.print(f"{__logo__} Starting OpenAI-compatible API server")
|
||||
@ -735,7 +610,6 @@ def _run_gateway(
|
||||
open_browser_url: str | None = None,
|
||||
) -> None:
|
||||
"""Shared gateway runtime; ``open_browser_url`` opens a tab once channels are up."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.bus.queue import MessageBus
|
||||
@ -751,9 +625,6 @@ def _run_gateway(
|
||||
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
|
||||
sync_workspace_templates(config.workspace_path)
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(config)
|
||||
gw_defaults = config.agents.defaults
|
||||
gw_pf = _make_cli_provider_factory(config) if gw_defaults.fallback_models else None
|
||||
try:
|
||||
provider_snapshot = build_provider_snapshot(config)
|
||||
except ValueError as exc:
|
||||
@ -770,41 +641,16 @@ def _run_gateway(
|
||||
cron = CronService(cron_store_path)
|
||||
|
||||
# Create agent with cron service
|
||||
_resolved = config.resolve_preset()
|
||||
agent = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=_resolved.model,
|
||||
max_iterations=gw_defaults.max_tool_iterations,
|
||||
context_window_tokens=_resolved.context_window_tokens,
|
||||
web_config=config.tools.web,
|
||||
context_block_limit=gw_defaults.context_block_limit,
|
||||
max_tool_result_chars=gw_defaults.max_tool_result_chars,
|
||||
provider_retry_mode=gw_defaults.provider_retry_mode,
|
||||
fallback_models=gw_defaults.fallback_models,
|
||||
provider_factory=gw_pf,
|
||||
exec_config=config.tools.exec,
|
||||
agent = AgentLoop.from_config(
|
||||
config, bus,
|
||||
cron_service=cron,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
session_manager=session_manager,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
channels_config=config.channels,
|
||||
timezone=config.agents.defaults.timezone,
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
disabled_skills=config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
||||
max_messages=config.agents.defaults.max_messages,
|
||||
tools_config=config.tools,
|
||||
image_generation_provider_configs={
|
||||
"openrouter": config.providers.openrouter,
|
||||
"aihubmix": config.providers.aihubmix,
|
||||
},
|
||||
provider_snapshot_loader=load_provider_snapshot,
|
||||
provider_signature=provider_snapshot.signature,
|
||||
model_presets=config.model_presets,
|
||||
model_preset=config.agents.defaults.model_preset,
|
||||
)
|
||||
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
@ -908,7 +754,7 @@ def _run_gateway(
|
||||
|
||||
if job.payload.deliver and job.payload.to and response:
|
||||
should_notify = await evaluate_response(
|
||||
response, reminder_note, provider, agent.model,
|
||||
response, reminder_note, agent.provider, agent.model,
|
||||
)
|
||||
if should_notify:
|
||||
await _deliver_to_channel(
|
||||
@ -998,7 +844,7 @@ def _run_gateway(
|
||||
hb_cfg = config.gateway.heartbeat
|
||||
heartbeat = HeartbeatService(
|
||||
workspace=config.workspace_path,
|
||||
provider=provider,
|
||||
provider=agent.provider,
|
||||
model=agent.model,
|
||||
on_execute=on_heartbeat_execute,
|
||||
on_notify=on_heartbeat_notify,
|
||||
@ -1151,7 +997,6 @@ def agent(
|
||||
"""Interact with the agent directly."""
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
@ -1159,10 +1004,6 @@ def agent(
|
||||
sync_workspace_templates(config.workspace_path)
|
||||
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(config)
|
||||
chat_defaults = config.agents.defaults
|
||||
chat_pf = _make_cli_provider_factory(config) if chat_defaults.fallback_models else None
|
||||
|
||||
# Preserve existing single-workspace installs, but keep custom workspaces clean.
|
||||
if is_default_workspace(config.workspace_path):
|
||||
_migrate_cron_store(config)
|
||||
@ -1176,34 +1017,10 @@ def agent(
|
||||
else:
|
||||
logger.disable("nanobot")
|
||||
|
||||
_resolved = config.resolve_preset()
|
||||
agent_loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=_resolved.model,
|
||||
max_iterations=chat_defaults.max_tool_iterations,
|
||||
context_window_tokens=_resolved.context_window_tokens,
|
||||
web_config=config.tools.web,
|
||||
context_block_limit=chat_defaults.context_block_limit,
|
||||
max_tool_result_chars=chat_defaults.max_tool_result_chars,
|
||||
provider_retry_mode=chat_defaults.provider_retry_mode,
|
||||
fallback_models=chat_defaults.fallback_models,
|
||||
provider_factory=chat_pf,
|
||||
exec_config=config.tools.exec,
|
||||
resolved_preset = config.resolve_preset()
|
||||
agent_loop = AgentLoop.from_config(
|
||||
config, bus,
|
||||
cron_service=cron,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
channels_config=config.channels,
|
||||
timezone=config.agents.defaults.timezone,
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
disabled_skills=config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
||||
max_messages=config.agents.defaults.max_messages,
|
||||
tools_config=config.tools,
|
||||
model_presets=config.model_presets,
|
||||
model_preset=config.agents.defaults.model_preset,
|
||||
)
|
||||
restart_notice = consume_restart_notice_from_env()
|
||||
if restart_notice and should_show_cli_restart_notice(restart_notice, session_id):
|
||||
@ -1247,7 +1064,7 @@ def agent(
|
||||
# Interactive mode — route through bus like other channels
|
||||
from nanobot.bus.events import InboundMessage
|
||||
_init_prompt_session()
|
||||
console.print(f"{__logo__} Interactive mode [bold blue]({_resolved.model})[/bold blue] — type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit\n")
|
||||
console.print(f"{__logo__} Interactive mode [bold blue]({resolved_preset.model})[/bold blue] — type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit\n")
|
||||
|
||||
if ":" in session_id:
|
||||
cli_channel, cli_chat_id = session_id.split(":", 1)
|
||||
@ -1605,10 +1422,10 @@ def status():
|
||||
if config_path.exists():
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
|
||||
_resolved = config.resolve_preset()
|
||||
_preset = config.agents.defaults.model_preset
|
||||
_preset_tag = f" (preset: {_preset})" if _preset else ""
|
||||
console.print(f"Model: {_resolved.model}{_preset_tag}")
|
||||
resolved_preset = config.resolve_preset()
|
||||
preset = config.agents.defaults.model_preset
|
||||
preset_tag = f" (preset: {preset})" if preset else ""
|
||||
console.print(f"Model: {resolved_preset.model}{preset_tag}")
|
||||
|
||||
# Check API keys from registry
|
||||
for spec in PROVIDERS:
|
||||
|
||||
@ -22,7 +22,7 @@ from nanobot.cli.models import (
|
||||
get_model_suggestions,
|
||||
)
|
||||
from nanobot.config.loader import get_config_path, load_config
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.config.schema import Config, ModelPresetConfig
|
||||
|
||||
console = Console()
|
||||
|
||||
@ -49,6 +49,16 @@ _SELECT_FIELD_HINTS: dict[str, tuple[list[str], str]] = {
|
||||
|
||||
_BACK_PRESSED = object() # Sentinel value for back navigation
|
||||
|
||||
# Cache of model-preset names populated at runtime so that field handlers can
|
||||
# offer existing presets as choices (e.g. AgentDefaults.model_preset).
|
||||
#
|
||||
# Lifecycle: populated by _sync_preset_cache(config), which must be called
|
||||
# after every config mutation that changes model_presets (add, delete, edit).
|
||||
# Cleared between tests via _MODEL_PRESET_CACHE.clear(). In long-running
|
||||
# processes (gateway) the cache is refreshed each time the preset management
|
||||
# screen is entered, so staleness is bounded by user interaction.
|
||||
_MODEL_PRESET_CACHE: set[str] = set()
|
||||
|
||||
|
||||
def _get_questionary():
|
||||
"""Return questionary or raise a clear error when wizard deps are unavailable."""
|
||||
@ -588,9 +598,100 @@ def _handle_context_window_field(
|
||||
setattr(working_model, field_name, new_value)
|
||||
|
||||
|
||||
def _handle_model_preset_field(
|
||||
working_model: BaseModel, field_name: str, field_display: str, current_value: Any
|
||||
) -> None:
|
||||
"""Handle the 'model_preset' field with a list of existing presets."""
|
||||
# model_preset lives on AgentDefaults, but the preset list is on Config.
|
||||
# We can't easily access Config here, so we read from the global config
|
||||
# via a module-level cache set by _configure_model_presets / run_onboard.
|
||||
preset_names = sorted(_MODEL_PRESET_CACHE)
|
||||
choices = ["(clear/unset)"] + preset_names
|
||||
default_choice = str(current_value) if current_value else "(clear/unset)"
|
||||
new_value = _select_with_back(field_display, choices, default=default_choice)
|
||||
if new_value is _BACK_PRESSED:
|
||||
return
|
||||
if new_value == "(clear/unset)":
|
||||
setattr(working_model, field_name, None)
|
||||
elif new_value is not None:
|
||||
setattr(working_model, field_name, new_value)
|
||||
|
||||
|
||||
def _handle_provider_field(
|
||||
working_model: BaseModel, field_name: str, field_display: str, current_value: Any
|
||||
) -> None:
|
||||
"""Handle the 'provider' field with a list of registered providers."""
|
||||
provider_names = sorted(_get_provider_names().keys())
|
||||
choices = ["auto"] + provider_names
|
||||
default_choice = str(current_value) if current_value else "auto"
|
||||
new_value = _select_with_back(field_display, choices, default=default_choice)
|
||||
if new_value is _BACK_PRESSED:
|
||||
return
|
||||
if new_value is not None:
|
||||
setattr(working_model, field_name, new_value)
|
||||
|
||||
|
||||
def _handle_fallback_presets_field(
|
||||
working_model: BaseModel, field_name: str, field_display: str, current_value: Any
|
||||
) -> None:
|
||||
"""Handle the 'fallback_presets' field with preset-aware multi-select."""
|
||||
items: list[str] = list(current_value) if isinstance(current_value, list) else []
|
||||
preset_names = sorted(_MODEL_PRESET_CACHE)
|
||||
|
||||
while True:
|
||||
console.clear()
|
||||
console.print(f"[bold]{field_display}[/bold]")
|
||||
if items:
|
||||
for idx, item in enumerate(items, 1):
|
||||
console.print(f" {idx}. {item}")
|
||||
else:
|
||||
console.print(" [dim](empty)[/dim]")
|
||||
console.print()
|
||||
|
||||
choices = ["[+] Add preset"]
|
||||
if items:
|
||||
choices.append("[-] Remove last")
|
||||
choices.append("[X] Clear all")
|
||||
choices.append("[Done]")
|
||||
choices.append("<- Back")
|
||||
|
||||
answer = _get_questionary().select(
|
||||
"Manage fallback chain:",
|
||||
choices=choices,
|
||||
qmark=">",
|
||||
).ask()
|
||||
|
||||
if answer is None or answer == "<- Back":
|
||||
return
|
||||
if answer == "[Done]":
|
||||
setattr(working_model, field_name, items)
|
||||
return
|
||||
if answer == "[+] Add preset":
|
||||
if not preset_names:
|
||||
console.print("[yellow]! No presets defined yet.[/yellow]")
|
||||
_get_questionary().press_any_key_to_continue().ask()
|
||||
continue
|
||||
add_choices = [p for p in preset_names if p not in items]
|
||||
if not add_choices:
|
||||
console.print("[yellow]! All presets already added.[/yellow]")
|
||||
_get_questionary().press_any_key_to_continue().ask()
|
||||
continue
|
||||
picked = _select_with_back("Select preset:", add_choices)
|
||||
if picked is _BACK_PRESSED or picked is None:
|
||||
continue
|
||||
items.append(picked)
|
||||
elif answer == "[-] Remove last" and items:
|
||||
items.pop()
|
||||
elif answer == "[X] Clear all" and items:
|
||||
items.clear()
|
||||
|
||||
|
||||
_FIELD_HANDLERS: dict[str, Any] = {
|
||||
"model": _handle_model_field,
|
||||
"context_window_tokens": _handle_context_window_field,
|
||||
"model_preset": _handle_model_preset_field,
|
||||
"provider": _handle_provider_field,
|
||||
"fallback_presets": _handle_fallback_presets_field,
|
||||
}
|
||||
|
||||
|
||||
@ -757,6 +858,113 @@ def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None
|
||||
console.print("[dim](i) Could not auto-fill context window (model not in database)[/dim]")
|
||||
|
||||
|
||||
# --- Model Preset Configuration ---
|
||||
|
||||
|
||||
def _sync_preset_cache(config: Config) -> None:
|
||||
"""Synchronise the module-level preset name cache from config."""
|
||||
_MODEL_PRESET_CACHE.clear()
|
||||
_MODEL_PRESET_CACHE.update(config.model_presets.keys())
|
||||
|
||||
|
||||
def _configure_model_presets(config: Config) -> None:
|
||||
"""Configure model presets (CRUD)."""
|
||||
_sync_preset_cache(config)
|
||||
|
||||
def get_preset_choices() -> list[str]:
|
||||
choices: list[str] = []
|
||||
for name, preset in config.model_presets.items():
|
||||
choices.append(f"{name} ({preset.model})")
|
||||
choices.append("[+] Add new preset")
|
||||
choices.append("<- Back")
|
||||
return choices
|
||||
|
||||
last_preset_name: str | None = None
|
||||
while True:
|
||||
try:
|
||||
console.clear()
|
||||
_show_section_header(
|
||||
"Model Presets",
|
||||
"Create, edit or delete named model presets for quick switching",
|
||||
)
|
||||
choices = get_preset_choices()
|
||||
default_choice = None
|
||||
if last_preset_name:
|
||||
for c in choices:
|
||||
if c.startswith(last_preset_name + " ("):
|
||||
default_choice = c
|
||||
break
|
||||
answer = _select_with_back(
|
||||
"Select preset:", choices, default=default_choice
|
||||
)
|
||||
|
||||
if answer is _BACK_PRESSED or answer is None or answer == "<- Back":
|
||||
break
|
||||
|
||||
assert isinstance(answer, str)
|
||||
|
||||
if answer == "[+] Add new preset":
|
||||
name_input = _get_questionary().text(
|
||||
"Preset name:",
|
||||
validate=lambda t: True if t and t.strip() else "Name cannot be empty",
|
||||
).ask()
|
||||
if not name_input:
|
||||
continue
|
||||
name = name_input.strip()
|
||||
if name in config.model_presets:
|
||||
console.print(f"[yellow]! Preset '{name}' already exists[/yellow]")
|
||||
_pause()
|
||||
continue
|
||||
new_preset = ModelPresetConfig(model="")
|
||||
updated = _configure_pydantic_model(new_preset, f"New Preset: {name}")
|
||||
if updated is not None:
|
||||
config.model_presets[name] = updated
|
||||
_sync_preset_cache(config)
|
||||
last_preset_name = name
|
||||
continue
|
||||
|
||||
# Editing / deleting an existing preset
|
||||
# Extract preset name from "name (model)" format
|
||||
preset_name = answer.split(" (", 1)[0]
|
||||
preset = config.model_presets.get(preset_name)
|
||||
if preset is None:
|
||||
continue
|
||||
|
||||
last_preset_name = preset_name
|
||||
|
||||
choices = ["Edit", "Cancel"]
|
||||
if preset_name != "default":
|
||||
choices.insert(1, "Delete")
|
||||
action = _select_with_back(
|
||||
f"Preset: {preset_name}",
|
||||
choices,
|
||||
default="Edit",
|
||||
)
|
||||
if action is _BACK_PRESSED or action == "Cancel" or action is None:
|
||||
continue
|
||||
|
||||
if action == "Delete":
|
||||
confirm = _get_questionary().confirm(
|
||||
f"Delete preset '{preset_name}'?",
|
||||
default=False,
|
||||
).ask()
|
||||
if confirm:
|
||||
del config.model_presets[preset_name]
|
||||
_sync_preset_cache(config)
|
||||
last_preset_name = None
|
||||
continue
|
||||
|
||||
if action == "Edit":
|
||||
updated = _configure_pydantic_model(preset, f"Edit Preset: {preset_name}")
|
||||
if updated is not None:
|
||||
config.model_presets[preset_name] = updated
|
||||
_sync_preset_cache(config)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[dim]Returning to main menu...[/dim]")
|
||||
break
|
||||
|
||||
|
||||
# --- Provider Configuration ---
|
||||
|
||||
|
||||
@ -1043,6 +1251,12 @@ def _show_summary(config: Config) -> None:
|
||||
channel_rows.append((display, status))
|
||||
_print_summary_panel(channel_rows, "Chat Channels")
|
||||
|
||||
# Model Presets
|
||||
preset_rows = []
|
||||
for name, preset in config.model_presets.items():
|
||||
preset_rows.append((name, f"{preset.model} (ctx={preset.context_window_tokens})"))
|
||||
_print_summary_panel(preset_rows, "Model Presets")
|
||||
|
||||
# Settings sections
|
||||
for title, model in [
|
||||
("Agent Settings", config.agents.defaults),
|
||||
@ -1112,6 +1326,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
||||
|
||||
original_config = base_config.model_copy(deep=True)
|
||||
config = base_config.model_copy(deep=True)
|
||||
_sync_preset_cache(config)
|
||||
|
||||
last_main_choice: str | None = None
|
||||
while True:
|
||||
@ -1123,6 +1338,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
||||
"What would you like to configure?",
|
||||
choices=[
|
||||
"[P] LLM Provider",
|
||||
"[M] Model Presets",
|
||||
"[C] Chat Channel",
|
||||
"[H] Channel Common",
|
||||
"[A] Agent Settings",
|
||||
@ -1149,6 +1365,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
||||
|
||||
_menu_dispatch = {
|
||||
"[P] LLM Provider": lambda: _configure_providers(config),
|
||||
"[M] Model Presets": lambda: _configure_model_presets(config),
|
||||
"[C] Chat Channel": lambda: _configure_channels(config),
|
||||
"[H] Channel Common": lambda: _configure_general_settings(config, "Channel Common"),
|
||||
"[A] Agent Settings": lambda: _configure_general_settings(config, "Agent Settings"),
|
||||
|
||||
@ -104,8 +104,9 @@ class AgentDefaults(Base):
|
||||
validation_alias=AliasChoices("toolHintMaxLength"),
|
||||
serialization_alias="toolHintMaxLength",
|
||||
) # Max characters for tool hint display (e.g. "$ cd …/project && npm test")
|
||||
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
|
||||
fallback_models: list[str] = Field(default_factory=list)
|
||||
fallback_presets: list[str] = Field(
|
||||
default_factory=list
|
||||
) # Ordered fallback chain. Each item must be a preset name defined in model_presets.
|
||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||
unified_session: bool = False # Share one session across all channels (single-user multi-device)
|
||||
disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"])
|
||||
@ -306,24 +307,52 @@ class Config(BaseSettings):
|
||||
model_presets: dict[str, ModelPresetConfig] = Field(default_factory=dict)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model_preset(self) -> "Config":
|
||||
name = self.agents.defaults.model_preset
|
||||
if name and name not in self.model_presets:
|
||||
raise ValueError(f"model_preset {name!r} not found in model_presets")
|
||||
def _sync_and_validate_preset(self) -> "Config":
|
||||
"""Expose agents.defaults model fields as the implicit 'default' preset
|
||||
and validate the active preset reference.
|
||||
|
||||
This guarantees that ``model_presets`` is never empty and that legacy
|
||||
configs (which only set ``agents.defaults.model`` etc.) continue to work
|
||||
without explicitly declaring a preset.
|
||||
"""
|
||||
self._refresh_default_preset()
|
||||
defaults = self.agents.defaults
|
||||
if defaults.model_preset is None:
|
||||
defaults.model_preset = "default"
|
||||
if defaults.model_preset not in self.model_presets:
|
||||
raise ValueError(f"model_preset {defaults.model_preset!r} not found in model_presets")
|
||||
for fb in defaults.fallback_presets:
|
||||
if fb not in self.model_presets:
|
||||
raise ValueError(f"fallback_presets entry {fb!r} not found in model_presets")
|
||||
return self
|
||||
|
||||
def resolve_preset(self) -> ModelPresetConfig:
|
||||
"""Return effective model params: from active preset, or individual defaults."""
|
||||
name = self.agents.defaults.model_preset
|
||||
if name:
|
||||
return self.model_presets[name]
|
||||
def _refresh_default_preset(self) -> None:
|
||||
"""Rebuild the implicit 'default' preset from current agents.defaults.
|
||||
|
||||
Called inside ``_sync_and_validate_preset`` (model validator) and
|
||||
``resolve_preset()`` so that runtime mutations (e.g. tests directly
|
||||
setting ``defaults.model``) are reflected.
|
||||
"""
|
||||
d = self.agents.defaults
|
||||
return ModelPresetConfig(
|
||||
model=d.model, provider=d.provider, max_tokens=d.max_tokens,
|
||||
self.model_presets["default"] = ModelPresetConfig(
|
||||
model=d.model,
|
||||
provider=d.provider,
|
||||
max_tokens=d.max_tokens,
|
||||
context_window_tokens=d.context_window_tokens,
|
||||
temperature=d.temperature, reasoning_effort=d.reasoning_effort,
|
||||
temperature=d.temperature,
|
||||
reasoning_effort=d.reasoning_effort,
|
||||
)
|
||||
|
||||
def resolve_preset(self) -> ModelPresetConfig:
|
||||
"""Return the active preset.
|
||||
|
||||
The implicit ``"default"`` preset is rebuilt from current defaults every
|
||||
time so that runtime mutations (e.g. tests setting ``defaults.model``)
|
||||
are always reflected.
|
||||
"""
|
||||
self._refresh_default_preset()
|
||||
return self.model_presets[self.agents.defaults.model_preset]
|
||||
|
||||
@property
|
||||
def workspace_path(self) -> Path:
|
||||
"""Get expanded workspace path."""
|
||||
@ -335,15 +364,16 @@ class Config(BaseSettings):
|
||||
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
||||
from nanobot.providers.registry import PROVIDERS, find_by_name
|
||||
|
||||
forced = self.resolve_preset().provider
|
||||
resolved = self.resolve_preset()
|
||||
forced = resolved.provider
|
||||
if forced != "auto":
|
||||
spec = find_by_name(forced)
|
||||
if spec:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
return (p, spec.name) if p else (None, None)
|
||||
provider_cfg = getattr(self.providers, spec.name, None)
|
||||
return (provider_cfg, spec.name) if provider_cfg else (None, None)
|
||||
return None, None
|
||||
|
||||
model_lower = (model or self.resolve_preset().model).lower()
|
||||
model_lower = (model or resolved.model).lower()
|
||||
model_normalized = model_lower.replace("-", "_")
|
||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||
normalized_prefix = model_prefix.replace("-", "_")
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Any
|
||||
|
||||
from nanobot.agent.hook import AgentHook, SDKCaptureHook
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -62,41 +61,12 @@ class Nanobot:
|
||||
Path(workspace).expanduser().resolve()
|
||||
)
|
||||
|
||||
provider = _make_provider(config)
|
||||
bus = MessageBus()
|
||||
defaults = config.agents.defaults
|
||||
_resolved = config.resolve_preset()
|
||||
pf = _make_provider_factory(config) if defaults.fallback_models else None
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=_resolved.model,
|
||||
max_iterations=defaults.max_tool_iterations,
|
||||
context_window_tokens=_resolved.context_window_tokens,
|
||||
context_block_limit=defaults.context_block_limit,
|
||||
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||
provider_retry_mode=defaults.provider_retry_mode,
|
||||
tool_hint_max_length=defaults.tool_hint_max_length,
|
||||
fallback_models=defaults.fallback_models,
|
||||
provider_factory=pf,
|
||||
web_config=config.tools.web,
|
||||
exec_config=config.tools.exec,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
timezone=defaults.timezone,
|
||||
unified_session=defaults.unified_session,
|
||||
disabled_skills=defaults.disabled_skills,
|
||||
session_ttl_minutes=defaults.session_ttl_minutes,
|
||||
consolidation_ratio=defaults.consolidation_ratio,
|
||||
tools_config=config.tools,
|
||||
loop = AgentLoop.from_config(
|
||||
config,
|
||||
image_generation_provider_configs={
|
||||
"openrouter": config.providers.openrouter,
|
||||
"aihubmix": config.providers.aihubmix,
|
||||
},
|
||||
model_presets=config.model_presets,
|
||||
model_preset=defaults.model_preset,
|
||||
)
|
||||
return cls(loop)
|
||||
|
||||
@ -134,99 +104,3 @@ class Nanobot:
|
||||
)
|
||||
|
||||
|
||||
def _make_provider_for_model(
|
||||
config: Any,
|
||||
model: str,
|
||||
*,
|
||||
preset: Any | None = None,
|
||||
) -> Any:
|
||||
"""Create an LLM provider instance for a specific model string.
|
||||
|
||||
When *preset* is given, its generation settings (temperature, max_tokens,
|
||||
reasoning_effort) override the active preset defaults.
|
||||
"""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.factory import make_provider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
gen_src = preset or config.resolve_preset()
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
|
||||
if backend == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||
needs_key = not (p and p.api_key)
|
||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||
if needs_key and not exempt:
|
||||
raise ValueError(f"No API key configured for provider '{provider_name}'.")
|
||||
|
||||
if backend == "openai_codex":
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
elif backend == "github_copilot":
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
|
||||
provider = GitHubCopilotProvider(default_model=model)
|
||||
elif backend == "azure_openai":
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key, api_base=p.api_base, default_model=model
|
||||
)
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
spec=spec,
|
||||
extra_body=p.extra_body if p else None,
|
||||
)
|
||||
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=gen_src.temperature,
|
||||
max_tokens=gen_src.max_tokens,
|
||||
reasoning_effort=gen_src.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
def _make_provider(config: Any) -> Any:
|
||||
"""Create the LLM provider for the primary model from config."""
|
||||
return _make_provider_for_model(config, config.resolve_preset().model)
|
||||
|
||||
|
||||
def _make_provider_factory(config: Any):
|
||||
"""Build a cached factory that creates providers for arbitrary model strings.
|
||||
|
||||
If a model string matches a preset name in ``config.model_presets``, the
|
||||
preset's full config (model, temperature, max_tokens, …) is used.
|
||||
"""
|
||||
cache: dict[str, Any] = {}
|
||||
presets = getattr(config, "model_presets", {}) or {}
|
||||
|
||||
def factory(model_or_preset: str):
|
||||
preset = presets.get(model_or_preset)
|
||||
actual_model = preset.model if preset else model_or_preset
|
||||
key = actual_model
|
||||
if key not in cache:
|
||||
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
||||
return cache[key]
|
||||
|
||||
return factory
|
||||
|
||||
@ -137,7 +137,9 @@ class LLMProvider(ABC):
|
||||
"insufficient_quota",
|
||||
"insufficient quota",
|
||||
"quota exceeded",
|
||||
"quota_exceeded",
|
||||
"quota exhausted",
|
||||
"quota_exhausted",
|
||||
"billing hard limit",
|
||||
"billing_hard_limit_reached",
|
||||
"billing not active",
|
||||
|
||||
@ -4,11 +4,16 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.providers.base import GenerationSettings, LLMProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.config.schema import ModelPresetConfig, ProviderConfig
|
||||
from nanobot.providers.registry import ProviderSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderSnapshot:
|
||||
@ -18,22 +23,62 @@ class ProviderSnapshot:
|
||||
signature: tuple[object, ...]
|
||||
|
||||
|
||||
def make_provider(config: Config) -> LLMProvider:
|
||||
"""Create the LLM provider implied by config."""
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
@dataclass(frozen=True)
|
||||
class _ProviderInfo:
|
||||
"""Resolved metadata needed to build and validate an LLM provider."""
|
||||
|
||||
name: str | None
|
||||
cfg: ProviderConfig | None
|
||||
spec: ProviderSpec | None
|
||||
api_base: str | None
|
||||
backend: str
|
||||
|
||||
|
||||
def _resolve_provider_info(
|
||||
config: Config,
|
||||
model: str,
|
||||
preset: ModelPresetConfig,
|
||||
) -> _ProviderInfo:
|
||||
"""Derive provider name, config, spec and api_base from preset or auto-detection."""
|
||||
if preset.provider != "auto":
|
||||
name = preset.provider
|
||||
cfg = getattr(config.providers, name, None)
|
||||
spec = find_by_name(name)
|
||||
api_base = (
|
||||
cfg.api_base
|
||||
if cfg and cfg.api_base
|
||||
else (spec.default_api_base if spec and spec.default_api_base else None)
|
||||
)
|
||||
else:
|
||||
name = config.get_provider_name(model)
|
||||
cfg = config.get_provider(model)
|
||||
spec = find_by_name(name) if name else None
|
||||
api_base = config.get_api_base(model)
|
||||
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
return _ProviderInfo(name=name, cfg=cfg, spec=spec, api_base=api_base, backend=backend)
|
||||
|
||||
|
||||
def _validate_provider(info: _ProviderInfo, model: str) -> None:
|
||||
"""Ensure credentials / endpoints are present before instantiation."""
|
||||
cfg = info.cfg
|
||||
backend = info.backend
|
||||
name = info.name
|
||||
|
||||
if backend == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
if not cfg or not cfg.api_key or not cfg.api_base:
|
||||
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||
needs_key = not (p and p.api_key)
|
||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||
needs_key = not (cfg and cfg.api_key)
|
||||
exempt = info.spec and (info.spec.is_oauth or info.spec.is_local or info.spec.is_direct)
|
||||
if needs_key and not exempt:
|
||||
raise ValueError(f"No API key configured for provider '{provider_name}'.")
|
||||
raise ValueError(f"No API key configured for provider '{name}'.")
|
||||
|
||||
|
||||
def _create_provider(model: str, info: _ProviderInfo) -> LLMProvider:
|
||||
"""Instantiate the concrete provider class for *backend*."""
|
||||
cfg = info.cfg
|
||||
backend = info.backend
|
||||
|
||||
if backend == "openai_codex":
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
@ -43,8 +88,8 @@ def make_provider(config: Config) -> LLMProvider:
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key,
|
||||
api_base=p.api_base,
|
||||
api_key=cfg.api_key if cfg else None,
|
||||
api_base=info.api_base,
|
||||
default_model=model,
|
||||
)
|
||||
elif backend == "github_copilot":
|
||||
@ -55,70 +100,103 @@ def make_provider(config: Config) -> LLMProvider:
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
api_key=cfg.api_key if cfg else None,
|
||||
api_base=info.api_base,
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
extra_headers=cfg.extra_headers if cfg else None,
|
||||
)
|
||||
elif backend == "bedrock":
|
||||
from nanobot.providers.bedrock_provider import BedrockProvider
|
||||
|
||||
provider = BedrockProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=p.api_base if p else None,
|
||||
api_key=cfg.api_key if cfg else None,
|
||||
api_base=info.api_base if cfg else None,
|
||||
default_model=model,
|
||||
region=getattr(p, "region", None) if p else None,
|
||||
profile=getattr(p, "profile", None) if p else None,
|
||||
extra_body=p.extra_body if p else None,
|
||||
region=getattr(cfg, "region", None) if cfg else None,
|
||||
profile=getattr(cfg, "profile", None) if cfg else None,
|
||||
extra_body=cfg.extra_body if cfg else None,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
api_key=cfg.api_key if cfg else None,
|
||||
api_base=info.api_base,
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
spec=spec,
|
||||
extra_body=p.extra_body if p else None,
|
||||
extra_headers=cfg.extra_headers if cfg else None,
|
||||
spec=info.spec,
|
||||
extra_body=cfg.extra_body if cfg else None,
|
||||
)
|
||||
|
||||
defaults = config.agents.defaults
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=defaults.temperature,
|
||||
max_tokens=defaults.max_tokens,
|
||||
reasoning_effort=defaults.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
def _apply_generation(provider: LLMProvider, preset: ModelPresetConfig) -> None:
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=preset.temperature,
|
||||
max_tokens=preset.max_tokens,
|
||||
reasoning_effort=preset.reasoning_effort,
|
||||
)
|
||||
|
||||
|
||||
def build_provider_for_preset(config: Config, preset: ModelPresetConfig) -> LLMProvider:
|
||||
"""Create an LLM provider from a full *preset* (model + provider + generation)."""
|
||||
info = _resolve_provider_info(config, preset.model, preset)
|
||||
_validate_provider(info, preset.model)
|
||||
provider = _create_provider(preset.model, info)
|
||||
_apply_generation(provider, preset)
|
||||
return provider
|
||||
|
||||
|
||||
def make_provider(config: Config) -> LLMProvider:
|
||||
"""Create the LLM provider implied by config (legacy entrypoint)."""
|
||||
resolved = config.resolve_preset()
|
||||
return build_provider_for_preset(config, resolved)
|
||||
|
||||
|
||||
def make_provider_factory(config: Config):
|
||||
"""Build a cached factory that creates providers for preset names.
|
||||
|
||||
The factory looks up *preset_name* in ``config.model_presets`` and builds
|
||||
the provider from the preset's full configuration.
|
||||
"""
|
||||
cache: dict[str, LLMProvider] = {}
|
||||
presets = config.model_presets
|
||||
|
||||
def factory(preset_name: str) -> LLMProvider:
|
||||
preset = presets.get(preset_name)
|
||||
if preset is None:
|
||||
raise ValueError(f"Preset {preset_name!r} not found in model_presets")
|
||||
if preset_name not in cache:
|
||||
cache[preset_name] = build_provider_for_preset(config, preset)
|
||||
return cache[preset_name]
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
def provider_signature(config: Config) -> tuple[object, ...]:
|
||||
"""Return the config fields that affect the primary LLM provider."""
|
||||
model = config.agents.defaults.model
|
||||
resolved = config.resolve_preset()
|
||||
defaults = config.agents.defaults
|
||||
p = config.get_provider(model)
|
||||
return (
|
||||
model,
|
||||
defaults.provider,
|
||||
config.get_provider_name(model),
|
||||
config.get_api_key(model),
|
||||
config.get_api_base(model),
|
||||
p.extra_headers if p else None,
|
||||
p.extra_body if p else None,
|
||||
getattr(p, "region", None) if p else None,
|
||||
getattr(p, "profile", None) if p else None,
|
||||
defaults.max_tokens,
|
||||
defaults.temperature,
|
||||
defaults.reasoning_effort,
|
||||
defaults.context_window_tokens,
|
||||
resolved.model,
|
||||
resolved.provider,
|
||||
config.get_provider_name(resolved.model),
|
||||
config.get_api_key(resolved.model),
|
||||
config.get_api_base(resolved.model),
|
||||
resolved.max_tokens,
|
||||
resolved.temperature,
|
||||
resolved.reasoning_effort,
|
||||
resolved.context_window_tokens,
|
||||
tuple(defaults.fallback_presets),
|
||||
)
|
||||
|
||||
|
||||
def build_provider_snapshot(config: Config) -> ProviderSnapshot:
|
||||
resolved = config.resolve_preset()
|
||||
return ProviderSnapshot(
|
||||
provider=make_provider(config),
|
||||
model=config.agents.defaults.model,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
model=resolved.model,
|
||||
context_window_tokens=resolved.context_window_tokens,
|
||||
signature=provider_signature(config),
|
||||
)
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
@ -12,68 +11,16 @@ from loguru import logger
|
||||
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelCandidate:
|
||||
"""A lazily resolved model/provider candidate."""
|
||||
|
||||
label: str
|
||||
resolver: Callable[[], tuple[LLMProvider, str]]
|
||||
|
||||
|
||||
class ModelRouter(LLMProvider):
|
||||
"""Try fallback model candidates for eligible transient final errors."""
|
||||
|
||||
supports_progress_deltas = False
|
||||
|
||||
_BLOCKED_STATUS_CODES = frozenset({400, 401, 403, 404, 422})
|
||||
_QUOTA_MARKERS = (
|
||||
"insufficient_quota",
|
||||
"insufficient quota",
|
||||
"quota exceeded",
|
||||
"quota_exceeded",
|
||||
"quota exhausted",
|
||||
"quota_exhausted",
|
||||
"billing hard limit",
|
||||
"billing_hard_limit_reached",
|
||||
"billing not active",
|
||||
"insufficient balance",
|
||||
"insufficient_balance",
|
||||
"credit balance too low",
|
||||
"payment required",
|
||||
"out of credits",
|
||||
"out of quota",
|
||||
"exceeded your current quota",
|
||||
)
|
||||
_NON_FAILOVER_MARKERS = (
|
||||
"context length",
|
||||
"context_length",
|
||||
"maximum context",
|
||||
"max context",
|
||||
"token budget",
|
||||
"too many tokens",
|
||||
"schema",
|
||||
"invalid request",
|
||||
"invalid_request",
|
||||
"invalid parameter",
|
||||
"invalid_parameter",
|
||||
"unsupported",
|
||||
"unauthorized",
|
||||
"authentication",
|
||||
"permission",
|
||||
"forbidden",
|
||||
"refusal",
|
||||
"content policy",
|
||||
"content_filter",
|
||||
"policy violation",
|
||||
"safety",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
primary_provider: LLMProvider,
|
||||
primary_model: str,
|
||||
fallback_candidates: list[ModelCandidate],
|
||||
fallback_presets: list[str],
|
||||
provider_factory: Callable[[str], LLMProvider] | None = None,
|
||||
per_candidate_timeout_s: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@ -82,7 +29,9 @@ class ModelRouter(LLMProvider):
|
||||
)
|
||||
self.primary_provider = primary_provider
|
||||
self.primary_model = primary_model
|
||||
self.fallback_candidates = list(fallback_candidates)
|
||||
self.fallback_presets = list(fallback_presets)
|
||||
self._provider_factory = provider_factory
|
||||
self._provider_cache: dict[str, LLMProvider] = {}
|
||||
self.per_candidate_timeout_s = per_candidate_timeout_s
|
||||
self.generation = getattr(primary_provider, "generation", GenerationSettings())
|
||||
|
||||
@ -90,41 +39,46 @@ class ModelRouter(LLMProvider):
|
||||
return self.primary_model
|
||||
|
||||
async def chat(self, **kwargs: Any) -> LLMResponse:
|
||||
return await self.primary_provider.chat(**kwargs)
|
||||
async def call(provider: LLMProvider, candidate_model: str, _unused_delta: Any) -> LLMResponse:
|
||||
return await provider.chat(**{**kwargs, "model": candidate_model})
|
||||
return await self._route(call)
|
||||
|
||||
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||
return await self.primary_provider.chat_stream(**kwargs)
|
||||
async def call(provider: LLMProvider, candidate_model: str, content_delta: Any) -> LLMResponse:
|
||||
return await provider.chat_stream(
|
||||
**{**kwargs, "model": candidate_model, "on_content_delta": content_delta}
|
||||
)
|
||||
return await self._route(call, on_content_delta=kwargs.get("on_content_delta"))
|
||||
|
||||
@classmethod
|
||||
def _is_quota_error(cls, response: LLMResponse) -> bool:
|
||||
tokens = {
|
||||
cls._normalize_error_token(response.error_type),
|
||||
cls._normalize_error_token(response.error_code),
|
||||
}
|
||||
if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in tokens if token):
|
||||
return True
|
||||
content = (response.content or "").lower()
|
||||
return any(marker in content for marker in cls._QUOTA_MARKERS)
|
||||
|
||||
@classmethod
|
||||
def _is_blocked_error(cls, response: LLMResponse) -> bool:
|
||||
status = response.error_status_code
|
||||
if status is not None and int(status) in cls._BLOCKED_STATUS_CODES:
|
||||
return True
|
||||
if response.finish_reason in {"refusal", "content_filter"}:
|
||||
return True
|
||||
content = (response.content or "").lower()
|
||||
return any(marker in content for marker in cls._NON_FAILOVER_MARKERS)
|
||||
@property
|
||||
def supports_progress_deltas(self) -> bool: # type: ignore[override]
|
||||
return getattr(self.primary_provider, "supports_progress_deltas", False)
|
||||
|
||||
@classmethod
|
||||
def _should_failover(cls, response: LLMResponse) -> bool:
|
||||
if response.finish_reason != "error":
|
||||
return False
|
||||
if cls._is_blocked_error(response):
|
||||
if response.error_should_retry is False:
|
||||
return False
|
||||
if cls._is_quota_error(response):
|
||||
if response.error_kind == "configuration":
|
||||
return False
|
||||
return cls._is_transient_response(response)
|
||||
return True
|
||||
|
||||
def _resolve(self, model: str) -> tuple[LLMProvider, str]:
|
||||
"""Return (provider, actual_model_name) for a preset name.
|
||||
|
||||
Caches results so factory is only invoked once per unique name.
|
||||
"""
|
||||
if model in self._provider_cache:
|
||||
cached_provider = self._provider_cache[model]
|
||||
return cached_provider, cached_provider.get_default_model()
|
||||
if self._provider_factory is None:
|
||||
raise ValueError(
|
||||
f"Cannot resolve fallback model {model!r}: no provider_factory configured"
|
||||
)
|
||||
provider = self._provider_factory(model)
|
||||
self._provider_cache[model] = provider
|
||||
return provider, provider.get_default_model()
|
||||
|
||||
async def _with_timeout(self, coro: Awaitable[LLMResponse]) -> LLMResponse:
|
||||
timeout_s = self.per_candidate_timeout_s
|
||||
@ -140,137 +94,90 @@ class ModelRouter(LLMProvider):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolver_error(candidate: ModelCandidate, exc: Exception) -> LLMResponse:
|
||||
logger.warning("Failed to resolve fallback model {}: {}", candidate.label, exc)
|
||||
def _resolver_error(label: str, exc: Exception) -> LLMResponse:
|
||||
logger.warning("Failed to resolve fallback model {}: {}", label, exc)
|
||||
return LLMResponse(
|
||||
content=f"Error configuring fallback model {candidate.label}: {exc}",
|
||||
content=f"Error configuring fallback model {label}: {exc}",
|
||||
finish_reason="error",
|
||||
error_kind="configuration",
|
||||
error_should_retry=False,
|
||||
)
|
||||
|
||||
def _candidate_chain(self) -> list[ModelCandidate]:
|
||||
return [
|
||||
ModelCandidate(
|
||||
label=self.primary_model,
|
||||
resolver=lambda: (self.primary_provider, self.primary_model),
|
||||
),
|
||||
*self.fallback_candidates,
|
||||
]
|
||||
|
||||
async def _route(
|
||||
self,
|
||||
call: Callable[[LLMProvider, str, Callable[[str], Awaitable[None]] | None], Awaitable[LLMResponse]],
|
||||
*,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
last_response: LLMResponse | None = None
|
||||
chain = self._candidate_chain()
|
||||
for index, candidate in enumerate(chain):
|
||||
"""Try primary then each fallback candidate, lazily resolving providers."""
|
||||
|
||||
async def _try_one(label: str, provider: LLMProvider, model: str) -> LLMResponse:
|
||||
try:
|
||||
provider, model = candidate.resolver()
|
||||
return await self._with_timeout(call(provider, model, on_content_delta))
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
response = self._resolver_error(candidate, exc)
|
||||
else:
|
||||
response = await self._with_timeout(call(provider, model, on_content_delta))
|
||||
return self._resolver_error(label, exc)
|
||||
|
||||
# Primary
|
||||
response = await _try_one("primary", self.primary_provider, self.primary_model)
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
if not self._should_failover(response):
|
||||
return response
|
||||
|
||||
# Fallbacks
|
||||
for name in self.fallback_presets:
|
||||
try:
|
||||
provider, model = self._resolve(name)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to resolve fallback model {}: {}", name, exc)
|
||||
return self._resolver_error(name, exc)
|
||||
|
||||
response = await _try_one(name, provider, model)
|
||||
if response.finish_reason != "error":
|
||||
if index > 0:
|
||||
logger.info("LLM failover selected model={}", candidate.label)
|
||||
logger.info("LLM failover selected model={}", name)
|
||||
return response
|
||||
|
||||
last_response = response
|
||||
if not self._should_failover(response):
|
||||
return response
|
||||
if index + 1 >= len(chain):
|
||||
logger.warning("LLM failover exhausted after model={}", candidate.label)
|
||||
return response
|
||||
logger.warning(
|
||||
"LLM failover model={} next_model={} status={} kind={}",
|
||||
candidate.label,
|
||||
chain[index + 1].label,
|
||||
response.error_status_code,
|
||||
response.error_kind or response.error_type or response.error_code or "unknown",
|
||||
)
|
||||
|
||||
return last_response or LLMResponse(
|
||||
content="No available fallback model candidate.",
|
||||
finish_reason="error",
|
||||
error_kind="configuration",
|
||||
error_should_retry=False,
|
||||
)
|
||||
logger.warning("LLM failover exhausted after all candidates")
|
||||
return response
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: object = LLMProvider._SENTINEL,
|
||||
temperature: object = LLMProvider._SENTINEL,
|
||||
reasoning_effort: object = LLMProvider._SENTINEL,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
retry_mode: str = "standard",
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
async def chat_with_retry(self, **kwargs: Any) -> LLMResponse:
|
||||
async def call(
|
||||
provider: LLMProvider,
|
||||
candidate_model: str,
|
||||
_delta: Callable[[str], Awaitable[None]] | None,
|
||||
provider: LLMProvider, candidate_model: str, _unused_delta: Any
|
||||
) -> LLMResponse:
|
||||
return await provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=candidate_model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
retry_mode=retry_mode,
|
||||
on_retry_wait=on_retry_wait,
|
||||
**{**kwargs, "model": candidate_model}
|
||||
)
|
||||
|
||||
return await self._route(call)
|
||||
|
||||
async def chat_stream_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: object = LLMProvider._SENTINEL,
|
||||
temperature: object = LLMProvider._SENTINEL,
|
||||
reasoning_effort: object = LLMProvider._SENTINEL,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
retry_mode: str = "standard",
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
async def chat_stream_with_retry(self, **kwargs: Any) -> LLMResponse:
|
||||
on_content_delta = kwargs.pop("on_content_delta", None)
|
||||
|
||||
async def call(
|
||||
provider: LLMProvider,
|
||||
candidate_model: str,
|
||||
external_delta: Callable[[str], Awaitable[None]] | None,
|
||||
content_delta: Callable[[str], Awaitable[None]] | None,
|
||||
) -> LLMResponse:
|
||||
buffered: list[str] = []
|
||||
|
||||
async def buffer_delta(delta: str) -> None:
|
||||
buffered.append(delta)
|
||||
|
||||
kwargs["on_content_delta"] = buffer_delta if content_delta else None
|
||||
response = await provider.chat_stream_with_retry(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=candidate_model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
on_content_delta=buffer_delta if external_delta else None,
|
||||
retry_mode=retry_mode,
|
||||
on_retry_wait=on_retry_wait,
|
||||
**{**kwargs, "model": candidate_model}
|
||||
)
|
||||
if response.finish_reason != "error" and external_delta:
|
||||
for delta in buffered:
|
||||
await external_delta(delta)
|
||||
if response.finish_reason != "error" and content_delta:
|
||||
try:
|
||||
for delta in buffered:
|
||||
await content_delta(delta)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failover delta callback failed for model={}", candidate_model)
|
||||
return response
|
||||
|
||||
return await self._route(call, on_content_delta=on_content_delta)
|
||||
|
||||
@ -8,7 +8,6 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from nanobot.cli import onboard as onboard_wizard
|
||||
@ -636,8 +635,8 @@ class TestValidateFieldConstraint:
|
||||
|
||||
def test_real_send_max_retries_field(self):
|
||||
"""Validate against the actual ChannelsConfig.send_max_retries field."""
|
||||
from nanobot.config.schema import ChannelsConfig
|
||||
from nanobot.cli.onboard import _validate_field_constraint
|
||||
from nanobot.config.schema import ChannelsConfig
|
||||
|
||||
field_info = ChannelsConfig.model_fields["send_max_retries"]
|
||||
assert _validate_field_constraint(3, field_info) is None
|
||||
@ -829,12 +828,11 @@ class TestMainMenuUpdate:
|
||||
|
||||
def test_main_menu_dispatch_includes_channel_common(self):
|
||||
"""Main menu dispatch should route [H] to Channel Common."""
|
||||
from nanobot.cli.onboard import run_onboard
|
||||
|
||||
# We verify by checking the dispatch table is set up correctly
|
||||
# The menu items are defined inline in run_onboard, so we test
|
||||
# that _configure_general_settings handles the new sections.
|
||||
from nanobot.cli.onboard import _SETTINGS_SECTIONS, _SETTINGS_GETTER, _SETTINGS_SETTER
|
||||
from nanobot.cli.onboard import _SETTINGS_GETTER, _SETTINGS_SECTIONS, _SETTINGS_SETTER
|
||||
|
||||
assert "Channel Common" in _SETTINGS_SECTIONS
|
||||
assert "Channel Common" in _SETTINGS_GETTER
|
||||
@ -842,7 +840,7 @@ class TestMainMenuUpdate:
|
||||
|
||||
def test_main_menu_dispatch_includes_api_server(self):
|
||||
"""Main menu dispatch should route [I] to API Server."""
|
||||
from nanobot.cli.onboard import _SETTINGS_SECTIONS, _SETTINGS_GETTER, _SETTINGS_SETTER
|
||||
from nanobot.cli.onboard import _SETTINGS_GETTER, _SETTINGS_SECTIONS, _SETTINGS_SETTER
|
||||
|
||||
assert "API Server" in _SETTINGS_SECTIONS
|
||||
assert "API Server" in _SETTINGS_GETTER
|
||||
@ -1074,3 +1072,346 @@ class TestConfigurePydanticModelEmptyString:
|
||||
result = _configure_pydantic_model(model, "Test")
|
||||
assert result is not None
|
||||
assert result.api_key == ""
|
||||
|
||||
|
||||
class TestModelPresetWizard:
|
||||
"""Tests for model preset CRUD in the onboard wizard."""
|
||||
|
||||
def test_sync_preset_cache(self):
|
||||
"""_sync_preset_cache should populate the module-level cache."""
|
||||
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _sync_preset_cache
|
||||
from nanobot.config.schema import ModelPresetConfig
|
||||
|
||||
config = Config()
|
||||
config.model_presets = {
|
||||
"fast": ModelPresetConfig(model="gpt-4.1-mini"),
|
||||
"power": ModelPresetConfig(model="gpt-4.1"),
|
||||
}
|
||||
_sync_preset_cache(config)
|
||||
assert _MODEL_PRESET_CACHE == {"fast", "power"}
|
||||
|
||||
def test_model_preset_add(self, monkeypatch):
|
||||
"""_configure_model_presets should add a new preset."""
|
||||
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _configure_model_presets
|
||||
from nanobot.config.schema import ModelPresetConfig
|
||||
|
||||
config = Config()
|
||||
_MODEL_PRESET_CACHE.clear()
|
||||
|
||||
responses = iter([
|
||||
"[+] Add new preset",
|
||||
"my-preset",
|
||||
"<- Back",
|
||||
])
|
||||
|
||||
class FakePrompt:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
def ask(self):
|
||||
if isinstance(self.response, BaseException):
|
||||
raise self.response
|
||||
return self.response
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
return FakePrompt(next(responses))
|
||||
|
||||
def fake_text(*_args, **_kwargs):
|
||||
return FakePrompt(next(responses))
|
||||
|
||||
def fake_configure(*_model, **_kwargs):
|
||||
return ModelPresetConfig(model="gpt-test", temperature=0.5)
|
||||
|
||||
# _select_with_back returns a string/sentinel directly (not a prompt object)
|
||||
def fake_select_with_back(*_args, **_kwargs):
|
||||
return next(responses)
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select_with_back)
|
||||
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select, text=fake_text))
|
||||
monkeypatch.setattr(onboard_wizard, "_configure_pydantic_model", fake_configure)
|
||||
monkeypatch.setattr(onboard_wizard, "_show_section_header", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None))
|
||||
|
||||
_configure_model_presets(config)
|
||||
|
||||
assert "my-preset" in config.model_presets
|
||||
assert config.model_presets["my-preset"].model == "gpt-test"
|
||||
assert config.model_presets["my-preset"].temperature == 0.5
|
||||
|
||||
def test_model_preset_delete(self, monkeypatch):
|
||||
"""_configure_model_presets should delete an existing preset."""
|
||||
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _configure_model_presets
|
||||
from nanobot.config.schema import ModelPresetConfig
|
||||
|
||||
config = Config()
|
||||
config.model_presets = {"old": ModelPresetConfig(model="x")}
|
||||
_MODEL_PRESET_CACHE.clear()
|
||||
_MODEL_PRESET_CACHE.add("old")
|
||||
|
||||
responses = iter([
|
||||
"old (x)",
|
||||
"Delete",
|
||||
True,
|
||||
"<- Back",
|
||||
])
|
||||
|
||||
class FakePrompt:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
def ask(self):
|
||||
if isinstance(self.response, BaseException):
|
||||
raise self.response
|
||||
return self.response
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
return FakePrompt(next(responses))
|
||||
|
||||
def fake_confirm(*_args, **_kwargs):
|
||||
return FakePrompt(next(responses))
|
||||
|
||||
def fake_select_with_back(*_args, **_kwargs):
|
||||
return next(responses)
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select_with_back)
|
||||
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select, confirm=fake_confirm))
|
||||
monkeypatch.setattr(onboard_wizard, "_show_section_header", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None))
|
||||
|
||||
_configure_model_presets(config)
|
||||
|
||||
assert "old" not in config.model_presets
|
||||
assert "old" not in _MODEL_PRESET_CACHE
|
||||
|
||||
def test_model_preset_field_handler(self, monkeypatch):
|
||||
"""_handle_model_preset_field should set a preset name from choices."""
|
||||
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_model_preset_field
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
_MODEL_PRESET_CACHE.clear()
|
||||
_MODEL_PRESET_CACHE.update({"fast", "power"})
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: "fast")
|
||||
|
||||
defaults = AgentDefaults()
|
||||
_handle_model_preset_field(defaults, "model_preset", "Model Preset", None)
|
||||
assert defaults.model_preset == "fast"
|
||||
|
||||
def test_model_preset_field_handler_clear(self, monkeypatch):
|
||||
"""_handle_model_preset_field should clear preset when (clear/unset) chosen."""
|
||||
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_model_preset_field
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
_MODEL_PRESET_CACHE.clear()
|
||||
_MODEL_PRESET_CACHE.add("fast")
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: "(clear/unset)")
|
||||
|
||||
defaults = AgentDefaults(model_preset="fast")
|
||||
_handle_model_preset_field(defaults, "model_preset", "Model Preset", "fast")
|
||||
assert defaults.model_preset is None
|
||||
|
||||
def test_main_menu_dispatch_includes_model_presets(self):
|
||||
"""run_onboard dispatch should route [M] to Model Presets."""
|
||||
from nanobot.cli.onboard import _configure_model_presets
|
||||
|
||||
# The function should be importable and callable
|
||||
assert callable(_configure_model_presets)
|
||||
|
||||
def test_run_onboard_model_presets_edit(self, monkeypatch):
|
||||
"""run_onboard should handle [M] Model Presets correctly."""
|
||||
initial_config = Config()
|
||||
|
||||
responses = iter([
|
||||
"[M] Model Presets",
|
||||
KeyboardInterrupt(),
|
||||
"[S] Save and Exit",
|
||||
])
|
||||
|
||||
class FakePrompt:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
def ask(self):
|
||||
if isinstance(self.response, BaseException):
|
||||
raise self.response
|
||||
return self.response
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
return FakePrompt(next(responses))
|
||||
|
||||
preset_mutated = {"n": 0}
|
||||
|
||||
def fake_configure_model_presets(config):
|
||||
preset_mutated["n"] += 1
|
||||
# Mutate config so unsaved changes are detected
|
||||
from nanobot.config.schema import ModelPresetConfig
|
||||
config.model_presets["test"] = ModelPresetConfig(model="x")
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_show_main_menu_header", lambda: None)
|
||||
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
|
||||
monkeypatch.setattr(onboard_wizard, "_configure_model_presets", fake_configure_model_presets)
|
||||
|
||||
result = run_onboard(initial_config=initial_config)
|
||||
|
||||
assert result.should_save is True
|
||||
assert preset_mutated["n"] == 1
|
||||
|
||||
def test_summary_shows_model_presets(self, monkeypatch):
|
||||
"""_show_summary should include model presets panel."""
|
||||
from nanobot.cli.onboard import _show_summary
|
||||
from nanobot.config.schema import ModelPresetConfig
|
||||
|
||||
config = Config()
|
||||
config.model_presets = {
|
||||
"fast": ModelPresetConfig(model="gpt-4.1-mini"),
|
||||
}
|
||||
|
||||
panels = []
|
||||
|
||||
def fake_print_summary(rows, title):
|
||||
panels.append(title)
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_print_summary_panel", fake_print_summary)
|
||||
monkeypatch.setattr(onboard_wizard, "_get_provider_names", lambda: {})
|
||||
monkeypatch.setattr(onboard_wizard, "_get_channel_names", lambda: {})
|
||||
monkeypatch.setattr(onboard_wizard, "_pause", lambda: None)
|
||||
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(print=lambda *a, **kw: None))
|
||||
|
||||
_show_summary(config)
|
||||
|
||||
assert "Model Presets" in panels
|
||||
|
||||
def test_provider_field_handler(self, monkeypatch):
|
||||
"""_handle_provider_field should set a provider from the registry list."""
|
||||
from nanobot.cli.onboard import _handle_provider_field
|
||||
from nanobot.config.schema import ModelPresetConfig
|
||||
|
||||
monkeypatch.setattr(
|
||||
onboard_wizard, "_get_provider_names", lambda: {"moonshot": "Moonshot", "openai": "OpenAI"}
|
||||
)
|
||||
monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: "moonshot")
|
||||
|
||||
preset = ModelPresetConfig(model="x")
|
||||
_handle_provider_field(preset, "provider", "Provider", "auto")
|
||||
assert preset.provider == "moonshot"
|
||||
|
||||
def test_provider_field_handler_back_pressed(self, monkeypatch):
|
||||
"""_handle_provider_field should not modify value when back is pressed."""
|
||||
from nanobot.cli.onboard import _BACK_PRESSED, _handle_provider_field
|
||||
from nanobot.config.schema import ModelPresetConfig
|
||||
|
||||
monkeypatch.setattr(
|
||||
onboard_wizard, "_get_provider_names", lambda: {"moonshot": "Moonshot"}
|
||||
)
|
||||
monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: _BACK_PRESSED)
|
||||
|
||||
preset = ModelPresetConfig(model="x", provider="auto")
|
||||
_handle_provider_field(preset, "provider", "Provider", "auto")
|
||||
assert preset.provider == "auto"
|
||||
|
||||
def test_fallback_presets_add_preset_and_done(self, monkeypatch):
|
||||
"""_handle_fallback_presets_field should add a preset and save on Done."""
|
||||
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_fallback_presets_field
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
_MODEL_PRESET_CACHE.clear()
|
||||
_MODEL_PRESET_CACHE.update({"fast", "power"})
|
||||
|
||||
responses = iter(["[+] Add preset", "[Done]"])
|
||||
|
||||
class FakePrompt:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
def ask(self):
|
||||
if isinstance(self.response, BaseException):
|
||||
raise self.response
|
||||
return self.response
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
return FakePrompt(next(responses))
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: "fast")
|
||||
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
|
||||
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None, print=lambda *a, **kw: None))
|
||||
|
||||
defaults = AgentDefaults()
|
||||
_handle_fallback_presets_field(defaults, "fallback_presets", "Fallback Presets", [])
|
||||
assert defaults.fallback_presets == ["fast"]
|
||||
|
||||
def test_fallback_presets_back_preserves_existing(self, monkeypatch):
|
||||
"""_handle_fallback_presets_field should not modify value on Back."""
|
||||
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_fallback_presets_field
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
_MODEL_PRESET_CACHE.clear()
|
||||
_MODEL_PRESET_CACHE.add("fast")
|
||||
|
||||
class FakePrompt:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
def ask(self):
|
||||
if isinstance(self.response, BaseException):
|
||||
raise self.response
|
||||
return self.response
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
return FakePrompt("<- Back")
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
|
||||
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None, print=lambda *a, **kw: None))
|
||||
|
||||
defaults = AgentDefaults(fallback_presets=["existing"])
|
||||
_handle_fallback_presets_field(defaults, "fallback_presets", "Fallback Presets", ["existing"])
|
||||
assert defaults.fallback_presets == ["existing"]
|
||||
|
||||
def test_fallback_presets_remove_last(self, monkeypatch):
|
||||
"""_handle_fallback_presets_field should remove last item."""
|
||||
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_fallback_presets_field
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
_MODEL_PRESET_CACHE.clear()
|
||||
|
||||
responses = iter(["[-] Remove last", "[Done]"])
|
||||
|
||||
class FakePrompt:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
def ask(self):
|
||||
if isinstance(self.response, BaseException):
|
||||
raise self.response
|
||||
return self.response
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
return FakePrompt(next(responses))
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
|
||||
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None, print=lambda *a, **kw: None))
|
||||
|
||||
defaults = AgentDefaults(fallback_presets=["a", "b"])
|
||||
_handle_fallback_presets_field(defaults, "fallback_presets", "Fallback Presets", ["a", "b"])
|
||||
assert defaults.fallback_presets == ["a"]
|
||||
|
||||
def test_fallback_presets_no_presets_shows_warning(self, monkeypatch):
|
||||
"""_handle_fallback_presets_field should warn when no presets exist."""
|
||||
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_fallback_presets_field
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
_MODEL_PRESET_CACHE.clear()
|
||||
|
||||
responses = iter(["[+] Add preset", "[Done]"])
|
||||
|
||||
class FakePrompt:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
def ask(self):
|
||||
if isinstance(self.response, BaseException):
|
||||
raise self.response
|
||||
return self.response
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
return FakePrompt(next(responses))
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select, press_any_key_to_continue=lambda: FakePrompt(None)))
|
||||
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None, print=lambda *a, **kw: None))
|
||||
|
||||
defaults = AgentDefaults()
|
||||
_handle_fallback_presets_field(defaults, "fallback_presets", "Fallback Presets", [])
|
||||
assert defaults.fallback_presets == []
|
||||
|
||||
@ -1,269 +0,0 @@
|
||||
"""Tests for the provider fallback models feature in AgentRunner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_tools():
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="ok")
|
||||
return tools
|
||||
|
||||
|
||||
def _make_provider(*, model_response: LLMResponse | None = None):
|
||||
p = MagicMock()
|
||||
if model_response is not None:
|
||||
p.chat_with_retry = AsyncMock(return_value=model_response)
|
||||
return p
|
||||
|
||||
|
||||
def _transient_error(content: str = "server unavailable") -> LLMResponse:
|
||||
return LLMResponse(content=content, finish_reason="error", error_status_code=503)
|
||||
|
||||
|
||||
def _base_spec(**overrides) -> AgentRunSpec:
|
||||
defaults = dict(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "hello"},
|
||||
],
|
||||
tools=_make_tools(),
|
||||
model="primary-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=8000,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return AgentRunSpec(**defaults)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_fallback_when_primary_succeeds():
|
||||
"""Primary succeeds -> fallback list never consulted."""
|
||||
ok = LLMResponse(content="done", tool_calls=[], usage={})
|
||||
provider = _make_provider(model_response=ok)
|
||||
factory = MagicMock()
|
||||
|
||||
runner = AgentRunner(provider, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
|
||||
|
||||
assert result.final_content == "done"
|
||||
factory.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_triggered_on_primary_error():
|
||||
"""Primary fails -> first fallback succeeds."""
|
||||
err = _transient_error()
|
||||
ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
|
||||
fb_provider = MagicMock()
|
||||
fb_provider.chat_with_retry = AsyncMock(return_value=ok)
|
||||
factory = MagicMock(return_value=fb_provider)
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
|
||||
|
||||
assert result.final_content == "fallback-ok"
|
||||
factory.assert_called_once_with("fb-model")
|
||||
fb_provider.chat_with_retry.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_fallbacks_fail_returns_last_error():
|
||||
"""Primary + all fallbacks fail -> return last error response."""
|
||||
err = _transient_error()
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
fb1 = _make_provider(model_response=err)
|
||||
fb2 = _make_provider(model_response=LLMResponse(
|
||||
content="last-error", finish_reason="error", error_status_code=500, usage={},
|
||||
))
|
||||
|
||||
providers = {"fb-1": fb1, "fb-2": fb2}
|
||||
factory = MagicMock(side_effect=lambda m: providers[m])
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
|
||||
|
||||
assert result.error is not None or result.final_content is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_fallback_list_no_retry():
|
||||
"""Empty fallback_models -> no fallback attempted."""
|
||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||
primary = _make_provider(model_response=err)
|
||||
factory = MagicMock()
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=[]))
|
||||
|
||||
factory.assert_not_called()
|
||||
assert result.error is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cross_provider_fallback():
|
||||
"""Fallback uses a different provider instance (cross-provider)."""
|
||||
err = _transient_error()
|
||||
ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
anthropic_provider = MagicMock()
|
||||
anthropic_provider.chat_with_retry = AsyncMock(return_value=ok)
|
||||
|
||||
def cross_factory(model: str):
|
||||
if model == "anthropic/claude-sonnet":
|
||||
return anthropic_provider
|
||||
raise ValueError(f"unexpected model: {model}")
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=cross_factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["anthropic/claude-sonnet"]))
|
||||
|
||||
assert result.final_content == "cross-provider-ok"
|
||||
anthropic_provider.chat_with_retry.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_skips_to_second_on_first_error():
|
||||
"""First fallback also fails -> second fallback succeeds."""
|
||||
err = _transient_error()
|
||||
ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
fb1 = _make_provider(model_response=err)
|
||||
fb2 = MagicMock()
|
||||
fb2.chat_with_retry = AsyncMock(return_value=ok)
|
||||
|
||||
providers = {"fb-1": fb1, "fb-2": fb2}
|
||||
factory = MagicMock(side_effect=lambda m: providers[m])
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
|
||||
|
||||
assert result.final_content == "second-fb-ok"
|
||||
assert factory.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_reuses_same_provider_without_factory():
|
||||
"""No provider_factory -> fallback reuses primary provider with different model."""
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, model, **kw):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return _transient_error()
|
||||
return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={})
|
||||
|
||||
primary = MagicMock()
|
||||
primary.chat_with_retry = chat_with_retry
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=None)
|
||||
result = await runner.run(_base_spec(fallback_models=["fallback-model"]))
|
||||
|
||||
assert result.final_content == "ok-via-fallback-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_provider_cached():
|
||||
"""Provider factory is called once per unique provider, not per attempt."""
|
||||
err = _transient_error()
|
||||
ok = LLMResponse(content="cached-ok", tool_calls=[], usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
|
||||
fb_provider = MagicMock()
|
||||
call_seq = [err, ok]
|
||||
fb_provider.chat_with_retry = AsyncMock(side_effect=call_seq)
|
||||
|
||||
factory = MagicMock(return_value=fb_provider)
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"]))
|
||||
|
||||
assert result.final_content == "cached-ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_transient_error_does_not_fallback():
|
||||
"""Auth/config-style errors should surface instead of hiding bugs via fallback."""
|
||||
primary = _make_provider(model_response=LLMResponse(
|
||||
content="401 unauthorized",
|
||||
finish_reason="error",
|
||||
error_status_code=401,
|
||||
))
|
||||
fallback = _make_provider(model_response=LLMResponse(content="fallback-ok"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
|
||||
|
||||
factory.assert_not_called()
|
||||
assert result.error is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_error_does_not_fallback_by_default():
|
||||
"""Quota/billing/payment 429s should not route by default."""
|
||||
primary = _make_provider(model_response=LLMResponse(
|
||||
content="insufficient quota",
|
||||
finish_reason="error",
|
||||
error_status_code=429,
|
||||
error_code="insufficient_quota",
|
||||
))
|
||||
fallback = _make_provider(model_response=LLMResponse(content="fallback-ok"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
|
||||
|
||||
factory.assert_not_called()
|
||||
assert result.error is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_fallback_discards_failed_primary_deltas():
|
||||
"""Buffered streaming prevents primary partial output from leaking on fallback."""
|
||||
streamed: list[str] = []
|
||||
|
||||
class StreamingHook(AgentHook):
|
||||
def wants_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
streamed.append(delta)
|
||||
|
||||
async def primary_stream(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("bad partial")
|
||||
return _transient_error()
|
||||
|
||||
async def fallback_stream(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("good")
|
||||
await on_content_delta(" answer")
|
||||
return LLMResponse(content="good answer", tool_calls=[], usage={})
|
||||
|
||||
primary = MagicMock()
|
||||
primary.chat_stream_with_retry = primary_stream
|
||||
fallback = MagicMock()
|
||||
fallback.chat_stream_with_retry = fallback_stream
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(
|
||||
fallback_models=["fb-model"],
|
||||
hook=StreamingHook(),
|
||||
))
|
||||
|
||||
assert result.final_content == "good answer"
|
||||
assert streamed == ["good", " answer"]
|
||||
@ -1,6 +1,6 @@
|
||||
# tests/agent/test_self_model_preset.py
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
@ -8,10 +8,23 @@ from nanobot.config.schema import ModelPresetConfig, MyToolConfig, ToolsConfig
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
|
||||
|
||||
def _make_loop(presets: dict | None = None) -> tuple[AgentLoop, "MyTool"]:
|
||||
def _make_loop(presets: dict | None = None) -> tuple[AgentLoop, Any]:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = GenerationSettings(temperature=0.1, max_tokens=8192)
|
||||
|
||||
def _factory(name: str):
|
||||
preset = (presets or {}).get(name)
|
||||
if preset:
|
||||
new_provider = MagicMock()
|
||||
new_provider.generation = GenerationSettings(
|
||||
temperature=preset.temperature,
|
||||
max_tokens=preset.max_tokens,
|
||||
reasoning_effort=preset.reasoning_effort,
|
||||
)
|
||||
return new_provider
|
||||
return provider
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MagicMock(),
|
||||
provider=provider,
|
||||
@ -19,6 +32,7 @@ def _make_loop(presets: dict | None = None) -> tuple[AgentLoop, "MyTool"]:
|
||||
model="test-model",
|
||||
context_window_tokens=65536,
|
||||
model_presets=presets or {},
|
||||
provider_factory=_factory,
|
||||
tools_config=ToolsConfig(my=MyToolConfig(allow_set=True)),
|
||||
)
|
||||
tool = loop.tools.get("my")
|
||||
@ -36,7 +50,7 @@ async def test_set_model_preset_updates_all_fields() -> None:
|
||||
),
|
||||
}
|
||||
loop, tool = _make_loop(presets)
|
||||
result = await tool.execute(action="set", key="model_preset", value="gpt5")
|
||||
await tool.execute(action="set", key="model_preset", value="gpt5")
|
||||
|
||||
assert loop.model == "gpt-5"
|
||||
assert loop.context_window_tokens == 128000
|
||||
@ -73,12 +87,3 @@ async def test_check_model_presets_shows_available() -> None:
|
||||
assert "ds" in result
|
||||
|
||||
|
||||
async def test_set_model_directly_clears_preset() -> None:
|
||||
presets = {"gpt5": ModelPresetConfig(model="gpt-5", provider="openai")}
|
||||
loop, tool = _make_loop(presets)
|
||||
await tool.execute(action="set", key="model_preset", value="gpt5")
|
||||
assert loop._active_preset == "gpt5"
|
||||
|
||||
await tool.execute(action="set", key="model", value="other-model")
|
||||
assert loop._active_preset is None
|
||||
assert loop.model == "other-model"
|
||||
|
||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
@ -35,6 +35,7 @@ def _make_mock_loop(**overrides):
|
||||
loop._concurrency_gate = None
|
||||
loop._unified_session = False
|
||||
loop._extra_hooks = []
|
||||
loop.model_preset = None
|
||||
|
||||
# web_config mock — needed for check tests
|
||||
loop.web_config = MagicMock()
|
||||
@ -76,7 +77,7 @@ class TestInspectSummary:
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="check")
|
||||
assert "max_iterations: 40" in result
|
||||
assert "context_window_tokens: 65536" in result
|
||||
assert "model_preset" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_includes_runtime_vars(self):
|
||||
@ -92,8 +93,7 @@ class TestInspectSummary:
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="check")
|
||||
assert "max_iterations" in result
|
||||
assert "context_window_tokens" in result
|
||||
assert "model" in result
|
||||
assert "model_preset" in result
|
||||
assert "workspace" in result
|
||||
assert "provider_retry_mode" in result
|
||||
assert "max_tool_result_chars" in result
|
||||
@ -231,13 +231,13 @@ class TestModifyRestricted:
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_string_int_coerced(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="set", key="max_iterations", value="80")
|
||||
await tool.execute(action="set", key="max_iterations", value="80")
|
||||
assert tool._loop.max_iterations == 80
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_context_window_valid(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="set", key="context_window_tokens", value=131072)
|
||||
await tool.execute(action="set", key="context_window_tokens", value=131072)
|
||||
assert tool._loop.context_window_tokens == 131072
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -337,13 +337,13 @@ class TestModifyFree:
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_allows_list(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="set", key="items", value=[1, 2, 3])
|
||||
await tool.execute(action="set", key="items", value=[1, 2, 3])
|
||||
assert tool._loop._runtime_vars["items"] == [1, 2, 3]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_allows_dict(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="set", key="data", value={"a": 1})
|
||||
await tool.execute(action="set", key="data", value={"a": 1})
|
||||
assert tool._loop._runtime_vars["data"] == {"a": 1}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -392,6 +392,26 @@ class TestModifyFree:
|
||||
assert "Error" in result
|
||||
assert tool._loop.max_tool_result_chars == 16000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_model_clears_active_preset(self):
|
||||
"""Directly modifying model must clear _active_preset so state stays consistent."""
|
||||
tool = _make_tool()
|
||||
tool._loop._active_preset = "gpt5"
|
||||
result = await tool.execute(action="set", key="model", value="other-model")
|
||||
assert "Set model" in result
|
||||
assert tool._loop.model == "other-model"
|
||||
assert tool._loop._active_preset is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_context_window_tokens_clears_active_preset(self):
|
||||
"""Directly modifying context_window_tokens must clear _active_preset."""
|
||||
tool = _make_tool()
|
||||
tool._loop._active_preset = "gpt5"
|
||||
result = await tool.execute(action="set", key="context_window_tokens", value=32768)
|
||||
assert "Set context_window_tokens" in result
|
||||
assert tool._loop.context_window_tokens == 32768
|
||||
assert tool._loop._active_preset is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set — previously BLOCKED/READONLY now open
|
||||
@ -689,8 +709,8 @@ class TestSubagentHookStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_iteration_updates_status(self):
|
||||
"""after_iteration should copy iteration, tool_events, usage to status."""
|
||||
from nanobot.agent.subagent import SubagentStatus, _SubagentHook
|
||||
from nanobot.agent.hook import AgentHookContext
|
||||
from nanobot.agent.subagent import SubagentStatus, _SubagentHook
|
||||
|
||||
status = SubagentStatus(
|
||||
task_id="test",
|
||||
@ -716,8 +736,8 @@ class TestSubagentHookStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_iteration_with_error(self):
|
||||
"""after_iteration should set status.error when context has an error."""
|
||||
from nanobot.agent.subagent import SubagentStatus, _SubagentHook
|
||||
from nanobot.agent.hook import AgentHookContext
|
||||
from nanobot.agent.subagent import SubagentStatus, _SubagentHook
|
||||
|
||||
status = SubagentStatus(
|
||||
task_id="test",
|
||||
@ -739,8 +759,8 @@ class TestSubagentHookStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_iteration_no_status_is_noop(self):
|
||||
"""after_iteration with no status should be a no-op."""
|
||||
from nanobot.agent.subagent import _SubagentHook
|
||||
from nanobot.agent.hook import AgentHookContext
|
||||
from nanobot.agent.subagent import _SubagentHook
|
||||
|
||||
hook = _SubagentHook("test")
|
||||
context = AgentHookContext(iteration=1, messages=[])
|
||||
@ -757,7 +777,6 @@ class TestCheckpointCallback:
|
||||
async def test_checkpoint_updates_phase_and_iteration(self):
|
||||
"""The _on_checkpoint callback should update status.phase and iteration."""
|
||||
from nanobot.agent.subagent import SubagentStatus
|
||||
import asyncio
|
||||
|
||||
status = SubagentStatus(
|
||||
task_id="cp",
|
||||
|
||||
@ -9,7 +9,7 @@ import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.cli.commands import _make_provider, app
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.cron.types import CronJob, CronPayload
|
||||
from nanobot.providers.factory import ProviderSnapshot
|
||||
@ -488,8 +488,8 @@ def test_openai_compat_provider_passes_model_through():
|
||||
|
||||
|
||||
def test_make_provider_uses_github_copilot_backend():
|
||||
from nanobot.cli.commands import _make_provider
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.providers.factory import build_provider_for_preset
|
||||
|
||||
config = Config.model_validate(
|
||||
{
|
||||
@ -503,7 +503,7 @@ def test_make_provider_uses_github_copilot_backend():
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = _make_provider(config)
|
||||
provider = build_provider_for_preset(config, config.resolve_preset())
|
||||
|
||||
assert provider.__class__.__name__ == "GitHubCopilotProvider"
|
||||
|
||||
@ -562,6 +562,8 @@ def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
|
||||
|
||||
|
||||
def test_make_provider_passes_extra_headers_to_custom_provider():
|
||||
from nanobot.providers.factory import build_provider_for_preset
|
||||
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {"defaults": {"provider": "custom", "model": "gpt-4o-mini"}},
|
||||
@ -579,7 +581,7 @@ def test_make_provider_passes_extra_headers_to_custom_provider():
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai:
|
||||
_make_provider(config)
|
||||
build_provider_for_preset(config, config.resolve_preset())
|
||||
|
||||
kwargs = mock_async_openai.call_args.kwargs
|
||||
assert kwargs["api_key"] == "test-key"
|
||||
@ -597,11 +599,11 @@ def mock_agent_runtime(tmp_path):
|
||||
with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
|
||||
patch("nanobot.config.loader.resolve_config_env_vars", side_effect=lambda c: c), \
|
||||
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
|
||||
patch("nanobot.cli.commands._make_provider", return_value=object()), \
|
||||
patch("nanobot.providers.factory.build_provider_for_preset", return_value=MagicMock(generation=MagicMock(max_tokens=8192))), \
|
||||
patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
|
||||
patch("nanobot.bus.queue.MessageBus"), \
|
||||
patch("nanobot.cron.service.CronService"), \
|
||||
patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls:
|
||||
patch("nanobot.cli.commands.AgentLoop") as mock_agent_loop_cls:
|
||||
agent_loop = MagicMock()
|
||||
agent_loop.channels_config = None
|
||||
agent_loop.process_direct = AsyncMock(
|
||||
@ -609,6 +611,7 @@ def mock_agent_runtime(tmp_path):
|
||||
)
|
||||
agent_loop.close_mcp = AsyncMock(return_value=None)
|
||||
mock_agent_loop_cls.return_value = agent_loop
|
||||
mock_agent_loop_cls.from_config.return_value = agent_loop
|
||||
|
||||
yield {
|
||||
"config": config,
|
||||
@ -639,7 +642,7 @@ def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_
|
||||
assert mock_agent_runtime["sync_templates"].call_args.args == (
|
||||
mock_agent_runtime["config"].workspace_path,
|
||||
)
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == (
|
||||
assert mock_agent_runtime["agent_loop_cls"].from_config.call_args.args[0].workspace_path == (
|
||||
mock_agent_runtime["config"].workspace_path
|
||||
)
|
||||
mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
|
||||
@ -672,7 +675,7 @@ def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
|
||||
)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192)))
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", lambda _store: object())
|
||||
|
||||
@ -680,13 +683,17 @@ def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||
@ -707,7 +714,7 @@ def test_agent_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Pa
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192)))
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
|
||||
class _FakeCron:
|
||||
@ -718,6 +725,10 @@ def test_agent_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Pa
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||
|
||||
@ -725,7 +736,7 @@ def test_agent_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Pa
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||
@ -753,7 +764,7 @@ def test_agent_workspace_override_does_not_migrate_legacy_cron(
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192)))
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||
|
||||
@ -765,6 +776,10 @@ def test_agent_workspace_override_does_not_migrate_legacy_cron(
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||
|
||||
@ -772,7 +787,7 @@ def test_agent_workspace_override_does_not_migrate_legacy_cron(
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||
|
||||
result = runner.invoke(
|
||||
@ -806,7 +821,7 @@ def test_agent_custom_config_workspace_does_not_migrate_legacy_cron(
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192)))
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||
|
||||
@ -818,6 +833,10 @@ def test_agent_custom_config_workspace_does_not_migrate_legacy_cron(
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||
|
||||
@ -825,7 +844,7 @@ def test_agent_custom_config_workspace_does_not_migrate_legacy_cron(
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None
|
||||
)
|
||||
@ -846,7 +865,7 @@ def test_agent_overrides_workspace_path(mock_agent_runtime):
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
|
||||
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||
assert mock_agent_runtime["agent_loop_cls"].from_config.call_args.args[0].workspace_path == workspace_path
|
||||
|
||||
|
||||
def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, tmp_path: Path):
|
||||
@ -863,7 +882,7 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
|
||||
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
|
||||
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
|
||||
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||
assert mock_agent_runtime["agent_loop_cls"].from_config.call_args.args[0].workspace_path == workspace_path
|
||||
|
||||
|
||||
def test_agent_hints_about_deprecated_memory_window(mock_agent_runtime, tmp_path):
|
||||
@ -928,8 +947,8 @@ def _patch_cli_command_runtime(
|
||||
sync_templates or (lambda _path: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
provider_factory,
|
||||
"nanobot.providers.factory.build_provider_for_preset",
|
||||
lambda *_a, **_k: provider_factory(Config()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.providers.factory.build_provider_snapshot",
|
||||
@ -962,6 +981,10 @@ def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) -
|
||||
def __init__(self, **kwargs) -> None:
|
||||
seen["workspace"] = kwargs["workspace"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, bus=None, **kwargs):
|
||||
return cls(workspace=config.workspace_path, **kwargs)
|
||||
|
||||
async def _connect_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
@ -985,7 +1008,7 @@ def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) -
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.api.server.create_app", _fake_create_app)
|
||||
monkeypatch.setattr("aiohttp.web.run_app", _fake_run_app)
|
||||
|
||||
@ -1077,7 +1100,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider)
|
||||
monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *_a, **_k: provider)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.providers.factory.build_provider_snapshot",
|
||||
lambda _config: _test_provider_snapshot(provider, _config),
|
||||
@ -1117,8 +1140,13 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self.model = "test-model"
|
||||
self.provider = object()
|
||||
self.tools = {}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return OutboundMessage(
|
||||
channel="telegram",
|
||||
@ -1152,7 +1180,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
return True
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _StopAfterCronSetup)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.evaluator.evaluate_response",
|
||||
@ -1181,7 +1209,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
|
||||
assert response == "Time to stretch."
|
||||
assert seen["response"] == "Time to stretch."
|
||||
assert seen["provider"] is provider
|
||||
assert seen["provider"] is not None # provider resolved inside AgentLoop
|
||||
assert seen["model"] == "test-model"
|
||||
assert seen["task_context"] == (
|
||||
"The scheduled time has arrived. Deliver this reminder to the user now, "
|
||||
@ -1228,7 +1256,7 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192)))
|
||||
monkeypatch.setattr(
|
||||
"nanobot.providers.factory.build_provider_snapshot",
|
||||
lambda _config: _test_provider_snapshot(object(), _config),
|
||||
@ -1248,8 +1276,13 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self.model = "test-model"
|
||||
self.provider = object()
|
||||
self.tools = {}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
async def process_direct(self, *_args, on_progress=None, **_kwargs):
|
||||
seen["on_progress"] = on_progress
|
||||
return OutboundMessage(
|
||||
@ -1275,7 +1308,7 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
|
||||
return False
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _StopAfterCronSetup)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.evaluator.evaluate_response",
|
||||
@ -1480,9 +1513,14 @@ def test_gateway_health_endpoint_binds_and_serves_expected_responses(
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, **_kwargs) -> None:
|
||||
self.model = "test-model"
|
||||
self.provider = object()
|
||||
self.dream = _FakeDream()
|
||||
self.sessions = _FakeSessionManager()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
return cls(**kwargs)
|
||||
|
||||
async def run(self) -> None:
|
||||
await asyncio.Event().wait()
|
||||
|
||||
@ -1571,7 +1609,7 @@ def test_gateway_health_endpoint_binds_and_serves_expected_responses(
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager)
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCronService)
|
||||
monkeypatch.setattr("nanobot.heartbeat.service.HeartbeatService", _FakeHeartbeatService)
|
||||
|
||||
@ -99,11 +99,23 @@ def test_preset_not_found_raises_error() -> None:
|
||||
})
|
||||
|
||||
|
||||
def test_fallback_presets_invalid_preset_raises_error() -> None:
|
||||
import pytest
|
||||
with pytest.raises(Exception, match="fallback_presets.*not found"):
|
||||
Config.model_validate({
|
||||
"model_presets": {
|
||||
"valid": {"model": "gpt-4"},
|
||||
},
|
||||
"agents": {"defaults": {"fallback_presets": ["invalid_preset"]}},
|
||||
})
|
||||
|
||||
|
||||
def test_resolve_preset_without_preset_returns_defaults() -> None:
|
||||
"""Backward compat: no preset → resolve_preset returns individual field values."""
|
||||
"""Backward compat: no explicit preset → resolve_preset returns the auto-created 'default' preset."""
|
||||
cfg = Config.model_validate({
|
||||
"agents": {"defaults": {"model": "deepseek-chat"}},
|
||||
})
|
||||
assert cfg.agents.defaults.model_preset == "default"
|
||||
r = cfg.resolve_preset()
|
||||
assert r.model == "deepseek-chat"
|
||||
assert r.max_tokens == 8192
|
||||
@ -170,13 +182,14 @@ def test_preset_with_auto_provider_uses_keyword_matching() -> None:
|
||||
|
||||
|
||||
def test_backward_compat_no_preset() -> None:
|
||||
"""Existing configs without model_presets work exactly as before."""
|
||||
"""Existing configs without model_presets are automatically promoted to the 'default' preset."""
|
||||
cfg = Config.model_validate({
|
||||
"providers": {"anthropic": {"api_key": "test-key"}},
|
||||
"agents": {"defaults": {"model": "anthropic/claude-opus-4-5"}},
|
||||
})
|
||||
assert cfg.resolve_preset().model == "anthropic/claude-opus-4-5"
|
||||
assert cfg.agents.defaults.model_preset is None
|
||||
assert cfg.agents.defaults.model_preset == "default"
|
||||
assert "default" in cfg.model_presets
|
||||
assert cfg.get_provider_name() == "anthropic"
|
||||
|
||||
|
||||
@ -204,3 +217,48 @@ def test_resolve_preset_overrides_all_model_fields() -> None:
|
||||
def test_empty_model_presets_dict_is_harmless() -> None:
|
||||
cfg = Config.model_validate({"model_presets": {}})
|
||||
assert cfg.resolve_preset().model == "anthropic/claude-opus-4-5"
|
||||
|
||||
|
||||
def test_factory_uses_preset_provider_not_defaults() -> None:
|
||||
"""When creating a provider for a non-active preset, the preset's own provider must be used."""
|
||||
from nanobot.providers.factory import make_provider_factory
|
||||
|
||||
cfg = Config.model_validate({
|
||||
"model_presets": {
|
||||
"kimi": {"model": "kimi-k2.6", "provider": "moonshot"},
|
||||
"zhipu": {"model": "glm-5.1", "provider": "zhipu"},
|
||||
},
|
||||
"providers": {
|
||||
"moonshot": {"api_key": "moonshot-key", "api_base": "https://api.moonshot.ai/v1"},
|
||||
"zhipu": {"api_key": "zhipu-key", "api_base": "https://open.bigmodel.cn/api/paas/v4"},
|
||||
},
|
||||
"agents": {"defaults": {"model_preset": "kimi"}},
|
||||
})
|
||||
|
||||
factory = make_provider_factory(cfg)
|
||||
zhipu_provider = factory("zhipu")
|
||||
|
||||
assert zhipu_provider.api_base == "https://open.bigmodel.cn/api/paas/v4"
|
||||
assert getattr(zhipu_provider, "api_key", None) == "zhipu-key"
|
||||
|
||||
# Also verify the active preset provider is still correct
|
||||
moonshot_provider = factory("kimi")
|
||||
assert moonshot_provider.api_base == "https://api.moonshot.ai/v1"
|
||||
|
||||
|
||||
def test_factory_rejects_unknown_preset_name() -> None:
|
||||
"""Factory must raise ValueError when asked for a preset not in model_presets."""
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.factory import make_provider_factory
|
||||
|
||||
cfg = Config.model_validate({
|
||||
"model_presets": {
|
||||
"known": {"model": "gpt-4", "provider": "openai"},
|
||||
},
|
||||
"providers": {"openai": {"api_key": "test-key"}},
|
||||
})
|
||||
|
||||
factory = make_provider_factory(cfg)
|
||||
with pytest.raises(ValueError, match="Preset 'unknown' not found"):
|
||||
factory("unknown")
|
||||
|
||||
@ -39,7 +39,7 @@ def test_from_config_default_path():
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
with patch("nanobot.config.loader.load_config") as mock_load, \
|
||||
patch("nanobot.nanobot._make_provider") as mock_prov:
|
||||
patch("nanobot.providers.factory.build_provider_for_preset") as mock_prov:
|
||||
mock_load.return_value = Config()
|
||||
mock_prov.return_value = MagicMock()
|
||||
mock_prov.return_value.get_default_model.return_value = "test"
|
||||
@ -127,7 +127,7 @@ def test_workspace_override(tmp_path):
|
||||
|
||||
def test_sdk_make_provider_uses_github_copilot_backend():
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.nanobot import _make_provider
|
||||
from nanobot.providers.factory import make_provider
|
||||
|
||||
config = Config.model_validate(
|
||||
{
|
||||
@ -141,7 +141,7 @@ def test_sdk_make_provider_uses_github_copilot_backend():
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = _make_provider(config)
|
||||
provider = make_provider(config)
|
||||
|
||||
assert provider.__class__.__name__ == "GitHubCopilotProvider"
|
||||
|
||||
|
||||
467
tests/test_preset_failover_smoke.py
Normal file
467
tests/test_preset_failover_smoke.py
Normal file
@ -0,0 +1,467 @@
|
||||
"""End-to-end smoke tests for model presets + failover.
|
||||
|
||||
Uses a local aiohttp fake OpenAI server so requests are real HTTP,
|
||||
not mocked at the provider level.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.nanobot import Nanobot
|
||||
from nanobot.providers.base import GenerationSettings, LLMProvider
|
||||
from nanobot.providers.failover import ModelRouter
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
try:
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestServer
|
||||
|
||||
HAS_AIOHTTP = True
|
||||
except ImportError:
|
||||
HAS_AIOHTTP = False
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _disable_proxy_for_localhost_tests(monkeypatch):
|
||||
"""Prevent httpx from routing localhost requests through a system proxy."""
|
||||
monkeypatch.delenv("ALL_PROXY", raising=False)
|
||||
monkeypatch.delenv("HTTP_PROXY", raising=False)
|
||||
monkeypatch.delenv("HTTPS_PROXY", raising=False)
|
||||
monkeypatch.setenv("NO_PROXY", "127.0.0.1,localhost")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers (mock-level preset tests)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _write_config(tmp_path: Path, **overrides) -> Path:
|
||||
data = {
|
||||
"providers": {
|
||||
"openrouter": {"apiKey": "sk-test-key"},
|
||||
"openai": {"apiKey": "sk-openai-test"},
|
||||
},
|
||||
"agents": {"defaults": {"model": "openai/gpt-4.1"}},
|
||||
"tools": {"my": {"allowSet": True}},
|
||||
}
|
||||
data.update(overrides)
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(json.dumps(data))
|
||||
return config_path
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Model Preset Mock Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preset_loaded_at_startup(tmp_path: Path) -> None:
|
||||
config_path = _write_config(
|
||||
tmp_path,
|
||||
model_presets={
|
||||
"fast": {
|
||||
"model": "gpt-4.1-mini",
|
||||
"provider": "openai",
|
||||
"max_tokens": 4096,
|
||||
"context_window_tokens": 128000,
|
||||
"temperature": 0.3,
|
||||
}
|
||||
},
|
||||
agents={"defaults": {"model_preset": "fast", "model": "ignored-model"}},
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
loop = bot._loop
|
||||
assert loop.model == "gpt-4.1-mini"
|
||||
assert loop.context_window_tokens == 128000
|
||||
assert loop.provider.generation.temperature == 0.3
|
||||
assert loop.provider.generation.max_tokens == 4096
|
||||
assert loop.model_preset == "fast"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preset_runtime_switch_updates_all_fields(tmp_path: Path) -> None:
|
||||
config_path = _write_config(
|
||||
tmp_path,
|
||||
model_presets={
|
||||
"cheap": {
|
||||
"model": "gpt-4.1-mini",
|
||||
"provider": "openai",
|
||||
"max_tokens": 2048,
|
||||
"context_window_tokens": 64000,
|
||||
"temperature": 0.5,
|
||||
},
|
||||
"power": {
|
||||
"model": "gpt-4.1",
|
||||
"provider": "openai",
|
||||
"max_tokens": 8192,
|
||||
"context_window_tokens": 256000,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
},
|
||||
agents={"defaults": {"model_preset": "cheap"}},
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
loop = bot._loop
|
||||
assert loop.model == "gpt-4.1-mini"
|
||||
|
||||
my_tool = loop.tools.get("my")
|
||||
result = await my_tool.execute(action="set", key="model_preset", value="power")
|
||||
assert "Error" not in result
|
||||
|
||||
assert loop.model == "gpt-4.1"
|
||||
assert loop.context_window_tokens == 256000
|
||||
assert loop.provider.generation.temperature == 0.1
|
||||
assert loop.provider.generation.max_tokens == 8192
|
||||
assert loop.model_preset == "power"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preset_switch_unknown_returns_error(tmp_path: Path) -> None:
|
||||
config_path = _write_config(
|
||||
tmp_path,
|
||||
model_presets={"a": {"model": "model-a"}},
|
||||
agents={"defaults": {"model_preset": "a"}},
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
loop = bot._loop
|
||||
original_model = loop.model
|
||||
|
||||
my_tool = loop.tools.get("my")
|
||||
result = await my_tool.execute(action="set", key="model_preset", value="nonexistent")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
assert loop.model == original_model
|
||||
assert loop.model_preset == "a"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preset_model_with_fallback_presets_in_config(tmp_path: Path) -> None:
|
||||
config_path = _write_config(
|
||||
tmp_path,
|
||||
model_presets={
|
||||
"prod": {
|
||||
"model": "gpt-4.1",
|
||||
"provider": "openai",
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
"fallback": {
|
||||
"model": "gpt-4.1-mini",
|
||||
"provider": "openai",
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
},
|
||||
agents={
|
||||
"defaults": {
|
||||
"model_preset": "prod",
|
||||
"fallback_presets": ["fallback"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
loop = bot._loop
|
||||
assert loop.model == "gpt-4.1"
|
||||
assert isinstance(loop.provider, ModelRouter)
|
||||
assert loop.provider.fallback_presets == ["fallback"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_presets_wired_to_all_subsystems(tmp_path: Path) -> None:
|
||||
"""When fallback_presets is configured, every subsystem that calls the LLM
|
||||
must use the same ModelRouter instance, not the raw primary provider."""
|
||||
config_path = _write_config(
|
||||
tmp_path,
|
||||
model_presets={
|
||||
"prod": {
|
||||
"model": "gpt-4.1",
|
||||
"provider": "openai",
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.1,
|
||||
},
|
||||
"fallback": {
|
||||
"model": "gpt-4.1-mini",
|
||||
"provider": "openai",
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
},
|
||||
agents={
|
||||
"defaults": {
|
||||
"model_preset": "prod",
|
||||
"fallback_presets": ["fallback"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
loop = bot._loop
|
||||
router = loop.provider
|
||||
assert isinstance(router, ModelRouter)
|
||||
|
||||
# Every LLM-consuming subsystem must share the same router
|
||||
assert loop.runner.provider is router, "AgentRunner must use ModelRouter"
|
||||
assert loop.subagents.provider is router, "SubagentManager must use ModelRouter"
|
||||
assert loop.consolidator.provider is router, "Consolidator must use ModelRouter"
|
||||
assert loop.dream.provider is router, "Dream must use ModelRouter"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Real HTTP Smoke Tests (aiohttp fake OpenAI server)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_preset_generation_params_reach_http_request() -> None:
|
||||
"""Provider.generation settings must appear in the actual HTTP request body."""
|
||||
requests_log: list[dict] = []
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
body = await request.json()
|
||||
requests_log.append(body)
|
||||
return web.json_response({
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"model": body.get("model"),
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "pong"},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
})
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/chat/completions", handler)
|
||||
server = TestServer(app)
|
||||
await server.start_server()
|
||||
try:
|
||||
base_url = str(server.make_url("/"))
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test",
|
||||
api_base=base_url,
|
||||
default_model="test-model",
|
||||
)
|
||||
provider.generation = GenerationSettings(temperature=0.42, max_tokens=1024)
|
||||
|
||||
with patch.object(LLMProvider, "_CHAT_RETRY_DELAYS", (0,)):
|
||||
response = await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
)
|
||||
|
||||
assert response.finish_reason != "error"
|
||||
assert len(requests_log) >= 1
|
||||
req = requests_log[0]
|
||||
assert req["model"] == "test-model"
|
||||
assert req["temperature"] == 0.42
|
||||
assert req["max_tokens"] == 1024
|
||||
finally:
|
||||
await server.close()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_failover_sends_second_request_to_fallback_model() -> None:
|
||||
"""Primary returns 503; after retry exhaustion ModelRouter hits fallback."""
|
||||
requests_log: list[dict] = []
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
body = await request.json()
|
||||
requests_log.append(body)
|
||||
model = body.get("model")
|
||||
|
||||
if model == "primary-model":
|
||||
return web.Response(
|
||||
status=503,
|
||||
body=json.dumps({"error": {"message": "overloaded", "type": "server_error"}}),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "fallback-ok"},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
})
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/chat/completions", handler)
|
||||
server = TestServer(app)
|
||||
await server.start_server()
|
||||
try:
|
||||
base_url = str(server.make_url("/"))
|
||||
primary = OpenAICompatProvider(
|
||||
api_key="test", api_base=base_url, default_model="primary-model"
|
||||
)
|
||||
fallback = OpenAICompatProvider(
|
||||
api_key="test", api_base=base_url, default_model="fallback-model"
|
||||
)
|
||||
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
router = ModelRouter(
|
||||
primary_provider=primary,
|
||||
primary_model="primary-model",
|
||||
fallback_presets=["fallback-model"],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
with patch.object(LLMProvider, "_CHAT_RETRY_DELAYS", (0,)):
|
||||
response = await router.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
assert response.finish_reason != "error"
|
||||
assert response.content == "fallback-ok"
|
||||
|
||||
models_requested = [r["model"] for r in requests_log]
|
||||
assert "primary-model" in models_requested
|
||||
assert "fallback-model" in models_requested
|
||||
factory.assert_called_once_with("fallback-model")
|
||||
finally:
|
||||
await server.close()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_failover_on_quota_429() -> None:
|
||||
"""Quota 429 on one provider may still work on a different provider."""
|
||||
requests_log: list[dict] = []
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
body = await request.json()
|
||||
requests_log.append(body)
|
||||
return web.Response(
|
||||
status=429,
|
||||
body=json.dumps({
|
||||
"error": {
|
||||
"message": "insufficient quota",
|
||||
"type": "insufficient_quota",
|
||||
"code": "insufficient_quota",
|
||||
}
|
||||
}),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/chat/completions", handler)
|
||||
server = TestServer(app)
|
||||
await server.start_server()
|
||||
try:
|
||||
base_url = str(server.make_url("/"))
|
||||
primary = OpenAICompatProvider(
|
||||
api_key="test", api_base=base_url, default_model="primary-model"
|
||||
)
|
||||
fallback = OpenAICompatProvider(
|
||||
api_key="test", api_base=base_url, default_model="fallback-model"
|
||||
)
|
||||
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
router = ModelRouter(
|
||||
primary_provider=primary,
|
||||
primary_model="primary-model",
|
||||
fallback_presets=["fallback-model"],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
with patch.object(LLMProvider, "_CHAT_RETRY_DELAYS", (0,)):
|
||||
response = await router.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
# Quota 429 SHOULD trigger failover — another provider may still work.
|
||||
factory.assert_called_once_with("fallback-model")
|
||||
assert response.finish_reason == "error"
|
||||
# Both primary and fallback should have been requested.
|
||||
assert len(requests_log) == 2
|
||||
finally:
|
||||
await server.close()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_router_failover_integration() -> None:
|
||||
"""ModelRouter -> real HTTP failover chain (primary 503, fallback 200)."""
|
||||
requests_log: list[dict] = []
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
body = await request.json()
|
||||
requests_log.append(body)
|
||||
model = body.get("model")
|
||||
|
||||
if model == "primary-model":
|
||||
return web.Response(
|
||||
status=503,
|
||||
body=json.dumps({"error": {"message": "overloaded", "type": "server_error"}}),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "fallback-ok"},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
})
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/chat/completions", handler)
|
||||
server = TestServer(app)
|
||||
await server.start_server()
|
||||
try:
|
||||
base_url = str(server.make_url("/"))
|
||||
primary = OpenAICompatProvider(
|
||||
api_key="test", api_base=base_url, default_model="primary-model"
|
||||
)
|
||||
fallback = OpenAICompatProvider(
|
||||
api_key="test", api_base=base_url, default_model="fallback-model"
|
||||
)
|
||||
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
router = ModelRouter(
|
||||
primary_provider=primary,
|
||||
primary_model="primary-model",
|
||||
fallback_presets=["fallback-model"],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
with patch.object(LLMProvider, "_CHAT_RETRY_DELAYS", (0,)):
|
||||
response = await router.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
assert response.finish_reason != "error"
|
||||
assert response.content == "fallback-ok"
|
||||
models_requested = [r["model"] for r in requests_log]
|
||||
assert "primary-model" in models_requested
|
||||
assert "fallback-model" in models_requested
|
||||
factory.assert_called_once_with("fallback-model")
|
||||
finally:
|
||||
await server.close()
|
||||
Loading…
x
Reference in New Issue
Block a user