mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
refactor: restrict fallback_models to preset-only and clean up provider factory
- Restrict fallback_models to only reference preset names in model_presets. - Add schema validation to reject unknown preset names in fallback_models. - Remove build_provider_for_model() since bare model fallback is no longer supported. - Simplify make_provider_factory() to only look up presets by name. - Update onboard UI to remove "Add custom model" option from fallback chain. - Update tests to use preset names instead of bare model strings in fallback chains. - Fix test imports referencing deleted _make_provider function.
This commit is contained in:
parent
ecbe56dd92
commit
0bc42e2ab2
@ -123,6 +123,7 @@
|
|||||||
- **Ultra-lightweight**: stable long-running agent behavior with a small, readable core.
|
- **Ultra-lightweight**: stable long-running agent behavior with a small, readable core.
|
||||||
- **Research-ready**: the codebase is intentionally simple enough to study, modify, and extend.
|
- **Research-ready**: the codebase is intentionally simple enough to study, modify, and extend.
|
||||||
- **Practical**: chat channels, API, memory, MCP, and deployment paths are already built in.
|
- **Practical**: chat channels, API, memory, MCP, and deployment paths are already built in.
|
||||||
|
- **Runtime model switching**: define [model presets](docs/configuration.md#model-presets) and switch between cheap/fast and powerful models mid-conversation — no restart required.
|
||||||
- **Hackable**: you can start fast, then go deeper through repo docs instead of a monolithic landing page.
|
- **Hackable**: you can start fast, then go deeper through repo docs instead of a monolithic landing page.
|
||||||
|
|
||||||
## 📦 Install
|
## 📦 Install
|
||||||
|
|||||||
@ -656,6 +656,146 @@ That's it! Environment variables, model routing, config matching, and `nanobot s
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## Agent Settings
|
||||||
|
|
||||||
|
### Model Presets
|
||||||
|
|
||||||
|
Model presets let you define **named bundles** of model + generation parameters and switch between them instantly — no restart required.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Config fields in `config.json` use **camelCase** (`modelPreset`, `contextWindowTokens`).
|
||||||
|
> The [`my` tool](./my-tool.md) uses **snake_case** (`model_preset`, `context_window_tokens`).
|
||||||
|
> Both refer to the same thing — just different naming conventions for config vs. runtime API.
|
||||||
|
|
||||||
|
**Why use presets?**
|
||||||
|
- Switch between a cheap/fast model and a powerful model mid-conversation.
|
||||||
|
- Share the same config across different tasks without manually editing `model`, `provider`, `temperature`, etc.
|
||||||
|
- Runtime switching via the [`my` tool](./my-tool.md).
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> The easiest way to set up presets and fallback models is through the interactive wizard:
|
||||||
|
> ```bash
|
||||||
|
> nanobot onboard --wizard
|
||||||
|
> ```
|
||||||
|
> Choose **"[M] Model Presets"** to create, edit, or delete presets interactively.
|
||||||
|
|
||||||
|
**Configuration example:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"modelPresets": {
|
||||||
|
"fast": {
|
||||||
|
"model": "gpt-4.1-mini",
|
||||||
|
"provider": "openai",
|
||||||
|
"maxTokens": 4096,
|
||||||
|
"contextWindowTokens": 128000,
|
||||||
|
"temperature": 0.3
|
||||||
|
},
|
||||||
|
"deep": {
|
||||||
|
"model": "claude-opus-4-7",
|
||||||
|
"provider": "anthropic",
|
||||||
|
"maxTokens": 8192,
|
||||||
|
"contextWindowTokens": 200000,
|
||||||
|
"temperature": 0.1,
|
||||||
|
"reasoningEffort": "high"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"modelPreset": "fast"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Preset fields:**
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `model` | string | *(required)* | Model identifier, e.g. `anthropic/claude-opus-4-7` or `gpt-4.1` |
|
||||||
|
| `provider` | string | `"auto"` | Provider name or `"auto"` to infer from the model string |
|
||||||
|
| `maxTokens` | integer | `8192` | Max completion tokens per turn |
|
||||||
|
| `contextWindowTokens` | integer | `65536` | Context window size for token budgeting |
|
||||||
|
| `temperature` | float | `0.1` | Sampling temperature |
|
||||||
|
| `reasoningEffort` | string or null | `null` | Thinking mode: `low`, `medium`, `high`, `adaptive` |
|
||||||
|
|
||||||
|
**How it works:**
|
||||||
|
- When `modelPreset` is set, the preset **completely overrides** all model-specific fields in `agents.defaults`.
|
||||||
|
- When `modelPreset` is omitted, nanobot automatically creates an implicit `"default"` preset from your existing `agents.defaults.model`, `provider`, `temperature`, etc. — **zero migration required** for existing configs.
|
||||||
|
|
||||||
|
**Runtime switching** (requires `tools.my.allowSet: true`):
|
||||||
|
|
||||||
|
```text
|
||||||
|
my(action="set", key="model_preset", value="deep")
|
||||||
|
```
|
||||||
|
|
||||||
|
This atomically swaps the model, provider, generation parameters, and context window for the next turn.
|
||||||
|
|
||||||
|
If the preset name does not exist, the agent receives an error such as `model_preset 'unknown' not found. Available: fast, deep`.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Directly modifying `model` or `contextWindowTokens` via `my(action="set", key="model", ...)` still works, but it automatically clears the active preset because the live state no longer matches the preset bundle. Use `model_preset` for atomic switches instead.
|
||||||
|
|
||||||
|
See [`my-tool.md`](./my-tool.md) for more runtime examples.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Fallback Models
|
||||||
|
|
||||||
|
When the primary model returns a transient error (rate limit, server overload, quota exhausted), nanobot can automatically fail over to a chain of backup models.
|
||||||
|
|
||||||
|
**Configuration example:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"modelPreset": "fast",
|
||||||
|
"fallbackModels": ["deep", "backup"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**How it works:**
|
||||||
|
1. nanobot tries the primary model first (the one from the active preset).
|
||||||
|
2. The provider retries transient errors internally (e.g. 3 attempts with exponential backoff for 503/429).
|
||||||
|
3. Only after the provider's own retries are exhausted and the final response still has `finish_reason == "error"` with a retryable error kind, nanobot moves to the next candidate in `fallbackModels`.
|
||||||
|
4. Each candidate must be a preset name defined in `modelPresets`. The preset's full config (model, provider, generation params) is used.
|
||||||
|
5. If all candidates are exhausted, the final error is returned to the user.
|
||||||
|
|
||||||
|
**Failover triggers on:**
|
||||||
|
- `server_error` (503, 502, 500)
|
||||||
|
- `rate_limit` (429)
|
||||||
|
- `insufficient_quota` / `quota_exhausted` (429)
|
||||||
|
|
||||||
|
**Failover does NOT trigger on:**
|
||||||
|
- Authentication errors (401) — rotating to another model with the same key won't help
|
||||||
|
- Invalid request errors (400) — the request itself is malformed
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Fallback models must reference preset names defined in `modelPresets`. Define a preset for each fallback model you want to use: `["cheap-preset", "backup", "emergency"]`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Other Agent Defaults
|
||||||
|
|
||||||
|
| Option | Type | Default | Description |
|
||||||
|
|--------|------|---------|-------------|
|
||||||
|
| `agents.defaults.model` | string | `"anthropic/claude-opus-4-5"` | Default model when no preset is active |
|
||||||
|
| `agents.defaults.provider` | string | `"auto"` | Default provider when no preset is active |
|
||||||
|
| `agents.defaults.maxTokens` | integer | `8192` | Max completion tokens when no preset is active |
|
||||||
|
| `agents.defaults.temperature` | float | `0.1` | Sampling temperature when no preset is active |
|
||||||
|
| `agents.defaults.reasoningEffort` | string or null | `null` | Thinking mode when no preset is active |
|
||||||
|
| `agents.defaults.maxToolIterations` | integer | `200` | Max tool calls per conversation turn |
|
||||||
|
| `agents.defaults.maxToolResultChars` | integer | `16000` | Max characters per tool result |
|
||||||
|
| `agents.defaults.providerRetryMode` | string | `"standard"` | `"standard"` or `"persistent"` — how aggressively to retry provider-level errors |
|
||||||
|
| `agents.defaults.timezone` | string | `"UTC"` | IANA timezone for runtime context |
|
||||||
|
| `agents.defaults.unifiedSession` | boolean | `false` | Share one session across all channels |
|
||||||
|
| `agents.defaults.sessionTtlMinutes` | integer | `0` | Auto-compact idle threshold (0 = disabled) |
|
||||||
|
| `agents.defaults.maxMessages` | integer | `120` | Max messages to replay from session history |
|
||||||
|
| `agents.defaults.consolidationRatio` | float | `0.5` | Target ratio retained after context compression |
|
||||||
|
|
||||||
## Channel Settings
|
## Channel Settings
|
||||||
|
|
||||||
Global settings that apply to all channels. Configure under the `channels` section in `~/.nanobot/config.json`:
|
Global settings that apply to all channels. Configure under the `channels` section in `~/.nanobot/config.json`:
|
||||||
|
|||||||
@ -12,6 +12,11 @@ My tool fills this gap. With it, the agent can:
|
|||||||
- **Adapt on the fly**: Complex task? Expand the context window. Simple chat? Switch to a faster model.
|
- **Adapt on the fly**: Complex task? Expand the context window. Simple chat? Switch to a faster model.
|
||||||
- **Remember across turns**: Store notes in your scratchpad that persist into the next conversation turn.
|
- **Remember across turns**: Store notes in your scratchpad that persist into the next conversation turn.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> This tool uses **snake_case** keys (`model_preset`, `context_window_tokens`).
|
||||||
|
> The matching config fields in `config.json` are **camelCase** (`modelPreset`, `contextWindowTokens`).
|
||||||
|
> See [`configuration.md`](./configuration.md#model-presets) for how to define presets in your config.
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
Enabled by default (read-only mode). The agent can check its state but not set it.
|
Enabled by default (read-only mode). The agent can check its state but not set it.
|
||||||
@ -39,8 +44,7 @@ Without parameters, returns a key config overview:
|
|||||||
```text
|
```text
|
||||||
my(action="check")
|
my(action="check")
|
||||||
# → max_iterations: 40
|
# → max_iterations: 40
|
||||||
# context_window_tokens: 65536
|
# model_preset: 'fast'
|
||||||
# model: 'anthropic/claude-sonnet-4-20250514'
|
|
||||||
# workspace: PosixPath('/tmp/workspace')
|
# workspace: PosixPath('/tmp/workspace')
|
||||||
# provider_retry_mode: 'standard'
|
# provider_retry_mode: 'standard'
|
||||||
# max_tool_result_chars: 16000
|
# max_tool_result_chars: 16000
|
||||||
@ -55,8 +59,13 @@ With a key parameter, drill into a specific config:
|
|||||||
my(action="check", key="_last_usage.prompt_tokens")
|
my(action="check", key="_last_usage.prompt_tokens")
|
||||||
# → How many prompt tokens I've used so far
|
# → How many prompt tokens I've used so far
|
||||||
|
|
||||||
my(action="check", key="model")
|
my(action="check", key="model_preset")
|
||||||
# → What model I'm currently running on
|
# → Current active preset name (e.g. 'fast')
|
||||||
|
|
||||||
|
my(action="check", key="model_presets")
|
||||||
|
# → Lists all preset names and their models, e.g.:
|
||||||
|
# fast → gpt-4.1-mini (openai)
|
||||||
|
# deep → claude-opus-4-7 (anthropic)
|
||||||
|
|
||||||
my(action="check", key="web_config.enable")
|
my(action="check", key="web_config.enable")
|
||||||
# → Whether web search is enabled
|
# → Whether web search is enabled
|
||||||
@ -66,7 +75,7 @@ my(action="check", key="web_config.enable")
|
|||||||
|
|
||||||
| Scenario | How |
|
| Scenario | How |
|
||||||
|----------|-----|
|
|----------|-----|
|
||||||
| "What model are you using?" | `check("model")` |
|
| "What model are you using?" | `check("model_preset")` |
|
||||||
| "How many more tool calls can you make?" | `check("max_iterations")` minus `check("_current_iteration")` |
|
| "How many more tool calls can you make?" | `check("max_iterations")` minus `check("_current_iteration")` |
|
||||||
| "How many tokens has this conversation used?" | `check("_last_usage")` — cumulative across all turns |
|
| "How many tokens has this conversation used?" | `check("_last_usage")` — cumulative across all turns |
|
||||||
| "Where is your working directory?" | `check("workspace")` |
|
| "Where is your working directory?" | `check("workspace")` |
|
||||||
@ -83,8 +92,11 @@ Changes take effect immediately, no restart required.
|
|||||||
my(action="set", key="max_iterations", value=80)
|
my(action="set", key="max_iterations", value=80)
|
||||||
# → Bump iteration limit from 40 to 80
|
# → Bump iteration limit from 40 to 80
|
||||||
|
|
||||||
my(action="set", key="model", value="fast-model")
|
my(action="set", key="model_preset", value="fast")
|
||||||
# → Switch to a faster model
|
# → Switch to the 'fast' preset (model, provider, temperature, etc. all at once)
|
||||||
|
#
|
||||||
|
# If the preset name does not exist:
|
||||||
|
# → Error: model_preset 'unknown' not found. Available: fast, deep
|
||||||
|
|
||||||
my(action="set", key="context_window_tokens", value=131072)
|
my(action="set", key="context_window_tokens", value=131072)
|
||||||
# → Expand context window for long documents
|
# → Expand context window for long documents
|
||||||
@ -101,15 +113,17 @@ my(action="set", key="task_complexity", value="high")
|
|||||||
|
|
||||||
### Protected parameters
|
### Protected parameters
|
||||||
|
|
||||||
These parameters have type and range validation — invalid values are rejected:
|
These parameters have validation — invalid values are rejected:
|
||||||
|
|
||||||
| Parameter | Type | Range | Purpose |
|
| Parameter | Type | Range / Constraint | Purpose |
|
||||||
|-----------|------|-------|---------|
|
|-----------|------|-------------------|---------|
|
||||||
| `max_iterations` | int | 1–100 | Max tool calls per conversation turn |
|
| `max_iterations` | int | 1–100 | Max tool calls per conversation turn |
|
||||||
| `context_window_tokens` | int | 4,096–1,000,000 | Context window size |
|
| `model_preset` | str | must exist in `model_presets` | Switch to a named preset bundle |
|
||||||
| `model` | str | non-empty | LLM model to use |
|
|
||||||
|
|
||||||
Other parameters (e.g. `workspace`, `provider_retry_mode`, `max_tool_result_chars`) can be set freely, as long as the value is JSON-safe.
|
Other parameters (e.g. `model`, `context_window_tokens`, `workspace`, `provider_retry_mode`, `max_tool_result_chars`) can be set freely, as long as the value is JSON-safe.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Setting `model` or `context_window_tokens` directly automatically clears the active `model_preset`, because the live state no longer matches the preset bundle. Use `model_preset` for atomic switches instead.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -125,8 +139,8 @@ Agent: This codebase is large, let me expand my context window to handle it.
|
|||||||
### "Simple question, don't waste compute"
|
### "Simple question, don't waste compute"
|
||||||
|
|
||||||
```text
|
```text
|
||||||
Agent: This is a straightforward question, let me switch to a faster model.
|
Agent: This is a straightforward question, let me switch to the fast preset.
|
||||||
→ my(action="set", key="model", value="fast-model")
|
→ my(action="set", key="model_preset", value="fast")
|
||||||
```
|
```
|
||||||
|
|
||||||
### "Remember user preferences across turns"
|
### "Remember user preferences across turns"
|
||||||
|
|||||||
@ -95,6 +95,8 @@ Configure these **two parts** in your config (other options have defaults).
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
*Want to switch models mid-conversation?* Define [`modelPresets`](./configuration.md#model-presets) and switch instantly with `my(action="set", key="model_preset", value="fast")`.
|
||||||
|
|
||||||
**3. Chat**
|
**3. Chat**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@ -197,6 +197,50 @@ class AgentLoop:
|
|||||||
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||||
_PENDING_USER_TURN_KEY = "pending_user_turn"
|
_PENDING_USER_TURN_KEY = "pending_user_turn"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: Any,
|
||||||
|
bus: MessageBus | None = None,
|
||||||
|
**extra: Any,
|
||||||
|
) -> AgentLoop:
|
||||||
|
"""Create an AgentLoop from config with the common parameter set."""
|
||||||
|
from nanobot.providers.factory import build_provider_for_preset, make_provider_factory
|
||||||
|
|
||||||
|
if bus is None:
|
||||||
|
bus = MessageBus()
|
||||||
|
defaults = config.agents.defaults
|
||||||
|
resolved_preset = config.resolve_preset()
|
||||||
|
provider = build_provider_for_preset(config, resolved_preset)
|
||||||
|
return cls(
|
||||||
|
bus=bus,
|
||||||
|
provider=provider,
|
||||||
|
workspace=config.workspace_path,
|
||||||
|
model=resolved_preset.model,
|
||||||
|
max_iterations=defaults.max_tool_iterations,
|
||||||
|
context_window_tokens=resolved_preset.context_window_tokens,
|
||||||
|
context_block_limit=defaults.context_block_limit,
|
||||||
|
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||||
|
provider_retry_mode=defaults.provider_retry_mode,
|
||||||
|
fallback_presets=defaults.fallback_presets,
|
||||||
|
provider_factory=make_provider_factory(config),
|
||||||
|
web_config=config.tools.web,
|
||||||
|
exec_config=config.tools.exec,
|
||||||
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
|
mcp_servers=config.tools.mcp_servers,
|
||||||
|
channels_config=config.channels,
|
||||||
|
timezone=defaults.timezone,
|
||||||
|
unified_session=defaults.unified_session,
|
||||||
|
disabled_skills=defaults.disabled_skills,
|
||||||
|
session_ttl_minutes=defaults.session_ttl_minutes,
|
||||||
|
consolidation_ratio=defaults.consolidation_ratio,
|
||||||
|
max_messages=defaults.max_messages,
|
||||||
|
tools_config=config.tools,
|
||||||
|
model_presets=config.model_presets,
|
||||||
|
model_preset=defaults.model_preset,
|
||||||
|
**extra,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
@ -209,8 +253,8 @@ class AgentLoop:
|
|||||||
max_tool_result_chars: int | None = None,
|
max_tool_result_chars: int | None = None,
|
||||||
provider_retry_mode: str = "standard",
|
provider_retry_mode: str = "standard",
|
||||||
tool_hint_max_length: int | None = None,
|
tool_hint_max_length: int | None = None,
|
||||||
fallback_models: list[str] | None = None,
|
fallback_presets: list[str] | None = None,
|
||||||
provider_factory: Any | None = None,
|
provider_factory: Callable[[str], LLMProvider] | None = None,
|
||||||
web_config: WebToolsConfig | None = None,
|
web_config: WebToolsConfig | None = None,
|
||||||
exec_config: ExecToolConfig | None = None,
|
exec_config: ExecToolConfig | None = None,
|
||||||
cron_service: CronService | None = None,
|
cron_service: CronService | None = None,
|
||||||
@ -239,7 +283,12 @@ class AgentLoop:
|
|||||||
defaults = AgentDefaults()
|
defaults = AgentDefaults()
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.channels_config = channels_config
|
self.channels_config = channels_config
|
||||||
self.provider = provider
|
self.provider_factory = provider_factory
|
||||||
|
self.fallback_presets = fallback_presets or []
|
||||||
|
wrapped_provider = self._wrap_with_failover(
|
||||||
|
provider, model or provider.get_default_model()
|
||||||
|
)
|
||||||
|
self.provider = wrapped_provider
|
||||||
self._provider_snapshot_loader = provider_snapshot_loader
|
self._provider_snapshot_loader = provider_snapshot_loader
|
||||||
self._provider_signature = provider_signature
|
self._provider_signature = provider_signature
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
@ -263,7 +312,6 @@ class AgentLoop:
|
|||||||
tool_hint_max_length if tool_hint_max_length is not None
|
tool_hint_max_length if tool_hint_max_length is not None
|
||||||
else defaults.tool_hint_max_length
|
else defaults.tool_hint_max_length
|
||||||
)
|
)
|
||||||
self.fallback_models = fallback_models or []
|
|
||||||
self.web_config = web_config or WebToolsConfig()
|
self.web_config = web_config or WebToolsConfig()
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.tools_config = _tc
|
self.tools_config = _tc
|
||||||
@ -278,15 +326,16 @@ class AgentLoop:
|
|||||||
self._start_time = time.time()
|
self._start_time = time.time()
|
||||||
self._last_usage: dict[str, int] = {}
|
self._last_usage: dict[str, int] = {}
|
||||||
self._extra_hooks: list[AgentHook] = hooks or []
|
self._extra_hooks: list[AgentHook] = hooks or []
|
||||||
|
|
||||||
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
|
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
|
||||||
self.sessions = session_manager or SessionManager(workspace)
|
self.sessions = session_manager or SessionManager(workspace)
|
||||||
self.tools = ToolRegistry()
|
self.tools = ToolRegistry()
|
||||||
# One file-read/write tracker per logical session. The tool registry is
|
# One file-read/write tracker per logical session. The tool registry is
|
||||||
# shared by this loop, so tools resolve the active state via contextvars.
|
# shared by this loop, so tools resolve the active state via contextvars.
|
||||||
self._file_state_store = FileStateStore()
|
self._file_state_store = FileStateStore()
|
||||||
self.runner = AgentRunner(provider, provider_factory=provider_factory)
|
self.runner = AgentRunner(wrapped_provider)
|
||||||
self.subagents = SubagentManager(
|
self.subagents = SubagentManager(
|
||||||
provider=provider,
|
provider=wrapped_provider,
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
bus=bus,
|
bus=bus,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@ -318,13 +367,13 @@ class AgentLoop:
|
|||||||
)
|
)
|
||||||
self.consolidator = Consolidator(
|
self.consolidator = Consolidator(
|
||||||
store=self.context.memory,
|
store=self.context.memory,
|
||||||
provider=provider,
|
provider=wrapped_provider,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
sessions=self.sessions,
|
sessions=self.sessions,
|
||||||
context_window_tokens=self.context_window_tokens,
|
context_window_tokens=self.context_window_tokens,
|
||||||
build_messages=self.context.build_messages,
|
build_messages=self.context.build_messages,
|
||||||
get_tool_definitions=self.tools.get_definitions,
|
get_tool_definitions=self.tools.get_definitions,
|
||||||
max_completion_tokens=provider.generation.max_tokens,
|
max_completion_tokens=wrapped_provider.generation.max_tokens,
|
||||||
consolidation_ratio=consolidation_ratio,
|
consolidation_ratio=consolidation_ratio,
|
||||||
)
|
)
|
||||||
self.auto_compact = AutoCompact(
|
self.auto_compact = AutoCompact(
|
||||||
@ -334,11 +383,13 @@ class AgentLoop:
|
|||||||
)
|
)
|
||||||
self.dream = Dream(
|
self.dream = Dream(
|
||||||
store=self.context.memory,
|
store=self.context.memory,
|
||||||
provider=provider,
|
provider=wrapped_provider,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
)
|
)
|
||||||
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
|
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
|
||||||
self._active_preset: str | None = model_preset if model_presets and model_preset in model_presets else None
|
self._active_preset: str | None = (
|
||||||
|
model_preset if model_preset in self.model_presets else None
|
||||||
|
)
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
if _tc.my.enable:
|
if _tc.my.enable:
|
||||||
self.tools.register(MyTool(loop=self, modify_allowed=_tc.my.allow_set))
|
self.tools.register(MyTool(loop=self, modify_allowed=_tc.my.allow_set))
|
||||||
@ -351,6 +402,38 @@ class AgentLoop:
|
|||||||
"""Keep subagent runtime limits aligned with mutable loop settings."""
|
"""Keep subagent runtime limits aligned with mutable loop settings."""
|
||||||
self.subagents.max_iterations = self.max_iterations
|
self.subagents.max_iterations = self.max_iterations
|
||||||
|
|
||||||
|
def _wrap_with_failover(self, provider: LLMProvider, model: str) -> LLMProvider:
|
||||||
|
"""Wrap provider with failover router when fallback_presets are configured."""
|
||||||
|
if not self.fallback_presets or not self.provider_factory:
|
||||||
|
return provider
|
||||||
|
from nanobot.providers.failover import ModelRouter
|
||||||
|
|
||||||
|
if isinstance(provider, ModelRouter):
|
||||||
|
return provider
|
||||||
|
|
||||||
|
return ModelRouter(
|
||||||
|
primary_provider=provider,
|
||||||
|
primary_model=model,
|
||||||
|
fallback_presets=self.fallback_presets,
|
||||||
|
provider_factory=self.provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_provider_state(
|
||||||
|
self,
|
||||||
|
provider: LLMProvider,
|
||||||
|
model: str,
|
||||||
|
context_window_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""Push provider/model/context_window to all LLM-consuming subsystems."""
|
||||||
|
self.provider = provider
|
||||||
|
# Bypass property setters so internal updates don't clear _active_preset.
|
||||||
|
object.__setattr__(self, "_model", model)
|
||||||
|
object.__setattr__(self, "_context_window_tokens", context_window_tokens)
|
||||||
|
self.runner.provider = provider
|
||||||
|
self.subagents.set_provider(provider, model)
|
||||||
|
self.consolidator.set_provider(provider, model, context_window_tokens)
|
||||||
|
self.dream.set_provider(provider, model)
|
||||||
|
|
||||||
def _apply_provider_snapshot(self, snapshot: ProviderSnapshot) -> None:
|
def _apply_provider_snapshot(self, snapshot: ProviderSnapshot) -> None:
|
||||||
"""Swap model/provider for future turns without disturbing an active one."""
|
"""Swap model/provider for future turns without disturbing an active one."""
|
||||||
provider = snapshot.provider
|
provider = snapshot.provider
|
||||||
@ -359,14 +442,13 @@ class AgentLoop:
|
|||||||
if self.provider is provider and self.model == model:
|
if self.provider is provider and self.model == model:
|
||||||
return
|
return
|
||||||
old_model = self.model
|
old_model = self.model
|
||||||
self.provider = provider
|
provider = self._wrap_with_failover(provider, model)
|
||||||
self.model = model
|
self._apply_provider_state(provider, model, context_window_tokens)
|
||||||
self.context_window_tokens = context_window_tokens
|
|
||||||
self.runner.provider = provider
|
|
||||||
self.subagents.set_provider(provider, model)
|
|
||||||
self.consolidator.set_provider(provider, model, context_window_tokens)
|
|
||||||
self.dream.set_provider(provider, model)
|
|
||||||
self._provider_signature = snapshot.signature
|
self._provider_signature = snapshot.signature
|
||||||
|
if self._active_preset:
|
||||||
|
preset = self.model_presets.get(self._active_preset)
|
||||||
|
if preset and preset.model != model:
|
||||||
|
self._active_preset = None
|
||||||
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
||||||
|
|
||||||
def _refresh_provider_snapshot(self) -> None:
|
def _refresh_provider_snapshot(self) -> None:
|
||||||
@ -381,6 +463,28 @@ class AgentLoop:
|
|||||||
return
|
return
|
||||||
self._apply_provider_snapshot(snapshot)
|
self._apply_provider_snapshot(snapshot)
|
||||||
|
|
||||||
|
# -- model / context_window_tokens properties with preset invalidation --
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> str:
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
@model.setter
|
||||||
|
def model(self, value: str) -> None:
|
||||||
|
self._model = value
|
||||||
|
if hasattr(self, "_active_preset"):
|
||||||
|
self._active_preset = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def context_window_tokens(self) -> int:
|
||||||
|
return self._context_window_tokens
|
||||||
|
|
||||||
|
@context_window_tokens.setter
|
||||||
|
def context_window_tokens(self, value: int) -> None:
|
||||||
|
self._context_window_tokens = value
|
||||||
|
if hasattr(self, "_active_preset"):
|
||||||
|
self._active_preset = None
|
||||||
|
|
||||||
# -- model_preset property --
|
# -- model_preset property --
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -388,22 +492,27 @@ class AgentLoop:
|
|||||||
return self._active_preset
|
return self._active_preset
|
||||||
|
|
||||||
@model_preset.setter
|
@model_preset.setter
|
||||||
def model_preset(self, name: str | None) -> None:
|
def model_preset(self, name: str) -> None:
|
||||||
"""Resolve a preset by name and apply all fields atomically."""
|
"""Resolve a preset by name and apply all fields."""
|
||||||
from nanobot.providers.base import GenerationSettings
|
|
||||||
|
|
||||||
if not isinstance(name, str) or not name.strip():
|
if not isinstance(name, str) or not name.strip():
|
||||||
raise ValueError("model_preset must be a non-empty string")
|
raise ValueError("model_preset must be a non-empty string")
|
||||||
if name not in self.model_presets:
|
if name not in self.model_presets:
|
||||||
raise KeyError(f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}")
|
raise KeyError(
|
||||||
|
f"model_preset {name!r} not found. Available: {', '.join(self.model_presets) or '(none)'}"
|
||||||
|
)
|
||||||
|
if self.provider_factory is None:
|
||||||
|
raise ValueError("provider_factory is not configured; cannot switch model preset")
|
||||||
|
|
||||||
p = self.model_presets[name]
|
p = self.model_presets[name]
|
||||||
self.model = p.model
|
new_provider = self._wrap_with_failover(self.provider_factory(name), p.model)
|
||||||
self.context_window_tokens = p.context_window_tokens
|
|
||||||
self.provider.generation = GenerationSettings(
|
# Preserve dream model_override if it differs from the current loop model.
|
||||||
temperature=p.temperature,
|
old_dream_model = self.dream.model
|
||||||
max_tokens=p.max_tokens,
|
dream_had_override = old_dream_model != self.model
|
||||||
reasoning_effort=p.reasoning_effort,
|
|
||||||
)
|
self._apply_provider_state(new_provider, p.model, p.context_window_tokens)
|
||||||
|
if dream_had_override:
|
||||||
|
self.dream.model = old_dream_model
|
||||||
self._active_preset = name
|
self._active_preset = name
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
@ -710,7 +819,6 @@ class AgentLoop:
|
|||||||
context_window_tokens=self.context_window_tokens,
|
context_window_tokens=self.context_window_tokens,
|
||||||
context_block_limit=self.context_block_limit,
|
context_block_limit=self.context_block_limit,
|
||||||
provider_retry_mode=self.provider_retry_mode,
|
provider_retry_mode=self.provider_retry_mode,
|
||||||
fallback_models=self.fallback_models,
|
|
||||||
progress_callback=on_progress,
|
progress_callback=on_progress,
|
||||||
stream_progress_deltas=on_stream is not None,
|
stream_progress_deltas=on_stream is not None,
|
||||||
retry_wait_callback=on_retry_wait,
|
retry_wait_callback=on_retry_wait,
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from nanobot.agent.hook import AgentHook, AgentHookContext
|
|||||||
from nanobot.agent.tools.ask import AskUserInterrupt
|
from nanobot.agent.tools.ask import AskUserInterrupt
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
from nanobot.providers.failover import ModelCandidate, ModelRouter
|
|
||||||
from nanobot.utils.helpers import (
|
from nanobot.utils.helpers import (
|
||||||
build_assistant_message,
|
build_assistant_message,
|
||||||
estimate_message_tokens,
|
estimate_message_tokens,
|
||||||
@ -76,7 +75,6 @@ class AgentRunSpec:
|
|||||||
context_window_tokens: int | None = None
|
context_window_tokens: int | None = None
|
||||||
context_block_limit: int | None = None
|
context_block_limit: int | None = None
|
||||||
provider_retry_mode: str = "standard"
|
provider_retry_mode: str = "standard"
|
||||||
fallback_models: list[str] = field(default_factory=list)
|
|
||||||
progress_callback: Any | None = None
|
progress_callback: Any | None = None
|
||||||
stream_progress_deltas: bool = True
|
stream_progress_deltas: bool = True
|
||||||
retry_wait_callback: Any | None = None
|
retry_wait_callback: Any | None = None
|
||||||
@ -99,21 +97,11 @@ class AgentRunResult:
|
|||||||
had_injections: bool = False
|
had_injections: bool = False
|
||||||
|
|
||||||
|
|
||||||
ProviderFactory = Any # Callable[[str], LLMProvider] — avoids circular import
|
|
||||||
|
|
||||||
|
|
||||||
class AgentRunner:
|
class AgentRunner:
|
||||||
"""Run a tool-capable LLM loop without product-layer concerns."""
|
"""Run a tool-capable LLM loop without product-layer concerns."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, provider: LLMProvider):
|
||||||
self,
|
|
||||||
provider: LLMProvider,
|
|
||||||
*,
|
|
||||||
provider_factory: ProviderFactory | None = None,
|
|
||||||
):
|
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self._provider_factory = provider_factory
|
|
||||||
self._fallback_providers: dict[str, LLMProvider] = {}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||||
@ -606,9 +594,12 @@ class AgentRunner:
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
hook: AgentHook,
|
hook: AgentHook,
|
||||||
context: AgentHookContext,
|
context: AgentHookContext,
|
||||||
) -> LLMResponse:
|
):
|
||||||
timeout_s: float | None = spec.llm_timeout_s
|
timeout_s: float | None = spec.llm_timeout_s
|
||||||
if timeout_s is None:
|
if timeout_s is None:
|
||||||
|
# Default to a finite timeout to avoid per-session lock starvation when an LLM
|
||||||
|
# request hangs indefinitely (e.g. gateway/network stall).
|
||||||
|
# Set NANOBOT_LLM_TIMEOUT_S=0 to disable.
|
||||||
raw = os.environ.get("NANOBOT_LLM_TIMEOUT_S", "300").strip()
|
raw = os.environ.get("NANOBOT_LLM_TIMEOUT_S", "300").strip()
|
||||||
try:
|
try:
|
||||||
timeout_s = float(raw)
|
timeout_s = float(raw)
|
||||||
@ -622,30 +613,12 @@ class AgentRunner:
|
|||||||
messages,
|
messages,
|
||||||
tools=spec.tools.get_definitions(),
|
tools=spec.tools.get_definitions(),
|
||||||
)
|
)
|
||||||
provider: LLMProvider = self.provider
|
|
||||||
request_timeout = timeout_s
|
|
||||||
if spec.fallback_models:
|
|
||||||
provider = self._build_model_router(spec, timeout_s)
|
|
||||||
# ModelRouter applies the same timeout per candidate, preserving
|
|
||||||
# fallback on primary timeouts instead of timing out the whole chain.
|
|
||||||
request_timeout = None
|
|
||||||
return await self._call_provider(provider, kwargs, hook, context, spec, request_timeout)
|
|
||||||
|
|
||||||
async def _call_provider(
|
|
||||||
self,
|
|
||||||
provider: LLMProvider,
|
|
||||||
kwargs: dict[str, Any],
|
|
||||||
hook: AgentHook,
|
|
||||||
context: AgentHookContext,
|
|
||||||
spec: AgentRunSpec,
|
|
||||||
timeout_s: float | None = None,
|
|
||||||
) -> LLMResponse:
|
|
||||||
wants_streaming = hook.wants_streaming()
|
wants_streaming = hook.wants_streaming()
|
||||||
wants_progress_streaming = (
|
wants_progress_streaming = (
|
||||||
not wants_streaming
|
not wants_streaming
|
||||||
and spec.stream_progress_deltas
|
and spec.stream_progress_deltas
|
||||||
and spec.progress_callback is not None
|
and spec.progress_callback is not None
|
||||||
and getattr(provider, "supports_progress_deltas", False) is True
|
and getattr(self.provider, "supports_progress_deltas", False) is True
|
||||||
)
|
)
|
||||||
|
|
||||||
if wants_streaming:
|
if wants_streaming:
|
||||||
@ -654,7 +627,7 @@ class AgentRunner:
|
|||||||
context.streamed_content = True
|
context.streamed_content = True
|
||||||
await hook.on_stream(context, delta)
|
await hook.on_stream(context, delta)
|
||||||
|
|
||||||
coro = provider.chat_stream_with_retry(
|
coro = self.provider.chat_stream_with_retry(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
on_content_delta=_stream,
|
on_content_delta=_stream,
|
||||||
)
|
)
|
||||||
@ -673,12 +646,12 @@ class AgentRunner:
|
|||||||
context.streamed_content = True
|
context.streamed_content = True
|
||||||
await spec.progress_callback(incremental)
|
await spec.progress_callback(incremental)
|
||||||
|
|
||||||
coro = provider.chat_stream_with_retry(
|
coro = self.provider.chat_stream_with_retry(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
on_content_delta=_stream_progress,
|
on_content_delta=_stream_progress,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
coro = provider.chat_with_retry(**kwargs)
|
coro = self.provider.chat_with_retry(**kwargs)
|
||||||
|
|
||||||
if timeout_s is None:
|
if timeout_s is None:
|
||||||
return await coro
|
return await coro
|
||||||
@ -691,41 +664,6 @@ class AgentRunner:
|
|||||||
error_kind="timeout",
|
error_kind="timeout",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _resolve_fallback_provider(self, model: str) -> tuple[LLMProvider, str]:
|
|
||||||
"""Return (provider, actual_model_name) for a fallback model.
|
|
||||||
|
|
||||||
When a provider_factory is available (and the model string may be a
|
|
||||||
preset name), the factory resolves the actual model; otherwise the
|
|
||||||
primary provider is reused with the raw model string.
|
|
||||||
"""
|
|
||||||
if model in self._fallback_providers:
|
|
||||||
p = self._fallback_providers[model]
|
|
||||||
return p, p.get_default_model()
|
|
||||||
if self._provider_factory:
|
|
||||||
provider = self._provider_factory(model)
|
|
||||||
self._fallback_providers[model] = provider
|
|
||||||
return provider, provider.get_default_model()
|
|
||||||
return self.provider, model
|
|
||||||
|
|
||||||
def _build_model_router(
|
|
||||||
self,
|
|
||||||
spec: AgentRunSpec,
|
|
||||||
timeout_s: float | None,
|
|
||||||
) -> ModelRouter:
|
|
||||||
candidates = [
|
|
||||||
ModelCandidate(
|
|
||||||
label=model,
|
|
||||||
resolver=lambda m=model: self._resolve_fallback_provider(m),
|
|
||||||
)
|
|
||||||
for model in spec.fallback_models
|
|
||||||
]
|
|
||||||
return ModelRouter(
|
|
||||||
primary_provider=self.provider,
|
|
||||||
primary_model=spec.model,
|
|
||||||
fallback_candidates=candidates,
|
|
||||||
per_candidate_timeout_s=timeout_s,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _request_finalization_retry(
|
async def _request_finalization_retry(
|
||||||
self,
|
self,
|
||||||
spec: AgentRunSpec,
|
spec: AgentRunSpec,
|
||||||
|
|||||||
@ -76,8 +76,6 @@ class MyTool(Tool):
|
|||||||
|
|
||||||
RESTRICTED: dict[str, dict[str, Any]] = {
|
RESTRICTED: dict[str, dict[str, Any]] = {
|
||||||
"max_iterations": {"type": int, "min": 1, "max": 100},
|
"max_iterations": {"type": int, "min": 1, "max": 100},
|
||||||
"context_window_tokens": {"type": int, "min": 4096, "max": 1_000_000},
|
|
||||||
"model": {"type": str, "min_len": 1},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_MAX_RUNTIME_KEYS = 64
|
_MAX_RUNTIME_KEYS = 64
|
||||||
@ -118,13 +116,14 @@ class MyTool(Tool):
|
|||||||
"Scratchpad keys persist across turns but not restarts.\n"
|
"Scratchpad keys persist across turns but not restarts.\n"
|
||||||
"Key values: _current_iteration (current progress), "
|
"Key values: _current_iteration (current progress), "
|
||||||
"max_iterations - _current_iteration = remaining iterations.\n"
|
"max_iterations - _current_iteration = remaining iterations.\n"
|
||||||
|
"Use 'model_preset' to switch the active model preset.\n"
|
||||||
"Note: web_config and exec_config are readable but read-only.\n"
|
"Note: web_config and exec_config are readable but read-only.\n"
|
||||||
"\n"
|
"\n"
|
||||||
"When to use:\n"
|
"When to use:\n"
|
||||||
"- User asks about your model, settings, or token usage → check that key.\n"
|
"- User asks about your model, settings, or token usage → check that key.\n"
|
||||||
"- A tool fails or behaves unexpectedly → check the related config to diagnose.\n"
|
"- A tool fails or behaves unexpectedly → check the related config to diagnose.\n"
|
||||||
"- User asks you to remember a preference for this session → set to store it in your scratchpad.\n"
|
"- User asks you to remember a preference for this session → set to store it in your scratchpad.\n"
|
||||||
"- About to start a large task → check context_window_tokens and max_iterations first."
|
"- About to start a large task → check max_iterations and model_preset first."
|
||||||
)
|
)
|
||||||
if not self._modify_allowed:
|
if not self._modify_allowed:
|
||||||
base += "\nREAD-ONLY MODE: set is disabled."
|
base += "\nREAD-ONLY MODE: set is disabled."
|
||||||
@ -132,7 +131,7 @@ class MyTool(Tool):
|
|||||||
base += (
|
base += (
|
||||||
"\nIMPORTANT: Before setting state, predict the potential impact. "
|
"\nIMPORTANT: Before setting state, predict the potential impact. "
|
||||||
"If the operation could cause crashes or instability "
|
"If the operation could cause crashes or instability "
|
||||||
"(e.g. changing model), warn the user first."
|
"(e.g. changing model_preset), warn the user first."
|
||||||
)
|
)
|
||||||
return base
|
return base
|
||||||
|
|
||||||
@ -148,7 +147,7 @@ class MyTool(Tool):
|
|||||||
},
|
},
|
||||||
"key": {
|
"key": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Dot-path for check/set. Examples: 'max_iterations', 'workspace', 'provider_retry_mode'. "
|
"description": "Dot-path for check/set. Examples: 'max_iterations', 'model_preset', 'provider_retry_mode'. "
|
||||||
"For check without key, shows all config values.",
|
"For check without key, shows all config values.",
|
||||||
},
|
},
|
||||||
"value": {"description": "New value (for set). Type must match target (int for max_iterations/context_window_tokens, str for model)."},
|
"value": {"description": "New value (for set). Type must match target (int for max_iterations/context_window_tokens, str for model)."},
|
||||||
@ -391,9 +390,6 @@ class MyTool(Tool):
|
|||||||
|
|
||||||
# --- existing restricted key logic ---
|
# --- existing restricted key logic ---
|
||||||
old = getattr(self._loop, key)
|
old = getattr(self._loop, key)
|
||||||
# When model is set directly, it no longer matches any preset
|
|
||||||
if key == "model":
|
|
||||||
self._loop._active_preset = None
|
|
||||||
if "min" in spec and value < spec["min"]:
|
if "min" in spec and value < spec["min"]:
|
||||||
return f"Error: '{key}' must be >= {spec['min']}"
|
return f"Error: '{key}' must be >= {spec['min']}"
|
||||||
if "max" in spec and value > spec["max"]:
|
if "max" in spec and value > spec["max"]:
|
||||||
@ -419,9 +415,12 @@ class MyTool(Tool):
|
|||||||
f"REJECTED type mismatch {key}: expects {old_t.__name__}, got {new_t.__name__}",
|
f"REJECTED type mismatch {key}: expects {old_t.__name__}, got {new_t.__name__}",
|
||||||
)
|
)
|
||||||
return f"Error: '{key}' expects {old_t.__name__}, got {new_t.__name__}"
|
return f"Error: '{key}' expects {old_t.__name__}, got {new_t.__name__}"
|
||||||
|
# When a model-specific field is set directly, it no longer matches any preset
|
||||||
|
if key in ("model", "context_window_tokens"):
|
||||||
|
self._loop._active_preset = None
|
||||||
try:
|
try:
|
||||||
setattr(self._loop, key, value)
|
setattr(self._loop, key, value)
|
||||||
except (ValueError, KeyError) as e:
|
except (AttributeError, TypeError, ValueError, KeyError) as e:
|
||||||
self._audit("modify", f"REJECTED {key}: {e}")
|
self._audit("modify", f"REJECTED {key}: {e}")
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
self._audit("modify", f"{key}: {old!r} -> {value!r}")
|
self._audit("modify", f"{key}: {old!r} -> {value!r}")
|
||||||
|
|||||||
@ -48,6 +48,7 @@ from rich.table import Table
|
|||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
from nanobot import __logo__, __version__
|
from nanobot import __logo__, __version__
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
|
||||||
|
|
||||||
class SafeFileHistory(FileHistory):
|
class SafeFileHistory(FileHistory):
|
||||||
@ -437,104 +438,6 @@ def _onboard_plugins(config_path: Path) -> None:
|
|||||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def _make_provider(config: Config):
|
|
||||||
"""Create the appropriate LLM provider from config.
|
|
||||||
|
|
||||||
Routing is driven by ``ProviderSpec.backend`` in the registry.
|
|
||||||
"""
|
|
||||||
from nanobot.providers.base import GenerationSettings
|
|
||||||
from nanobot.providers.factory import make_provider
|
|
||||||
from nanobot.providers.registry import find_by_name
|
|
||||||
|
|
||||||
resolved = config.resolve_preset()
|
|
||||||
model = resolved.model
|
|
||||||
provider_name = config.get_provider_name(model)
|
|
||||||
p = config.get_provider(model)
|
|
||||||
spec = find_by_name(provider_name) if provider_name else None
|
|
||||||
backend = spec.backend if spec else "openai_compat"
|
|
||||||
|
|
||||||
# --- validation ---
|
|
||||||
if backend == "azure_openai":
|
|
||||||
if not p or not p.api_key or not p.api_base:
|
|
||||||
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
|
||||||
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
|
||||||
console.print("Use the model field to specify the deployment name.")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
|
||||||
needs_key = not (p and p.api_key)
|
|
||||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
|
||||||
if needs_key and not exempt:
|
|
||||||
console.print("[red]Error: No API key configured.[/red]")
|
|
||||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# --- instantiation by backend ---
|
|
||||||
if backend == "openai_codex":
|
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
|
||||||
|
|
||||||
provider = OpenAICodexProvider(default_model=model)
|
|
||||||
elif backend == "azure_openai":
|
|
||||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
|
||||||
|
|
||||||
provider = AzureOpenAIProvider(
|
|
||||||
api_key=p.api_key,
|
|
||||||
api_base=p.api_base,
|
|
||||||
default_model=model,
|
|
||||||
)
|
|
||||||
elif backend == "github_copilot":
|
|
||||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
|
||||||
provider = GitHubCopilotProvider(default_model=model)
|
|
||||||
elif backend == "anthropic":
|
|
||||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
|
||||||
|
|
||||||
provider = AnthropicProvider(
|
|
||||||
api_key=p.api_key if p else None,
|
|
||||||
api_base=config.get_api_base(model),
|
|
||||||
default_model=model,
|
|
||||||
extra_headers=p.extra_headers if p else None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
|
||||||
|
|
||||||
provider = OpenAICompatProvider(
|
|
||||||
api_key=p.api_key if p else None,
|
|
||||||
api_base=config.get_api_base(model),
|
|
||||||
default_model=model,
|
|
||||||
extra_headers=p.extra_headers if p else None,
|
|
||||||
spec=spec,
|
|
||||||
extra_body=p.extra_body if p else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
provider.generation = GenerationSettings(
|
|
||||||
temperature=resolved.temperature,
|
|
||||||
max_tokens=resolved.max_tokens,
|
|
||||||
reasoning_effort=resolved.reasoning_effort,
|
|
||||||
)
|
|
||||||
return provider
|
|
||||||
|
|
||||||
|
|
||||||
def _make_cli_provider_factory(config: Config):
|
|
||||||
"""Build a cached factory for fallback model providers (CLI side).
|
|
||||||
|
|
||||||
Supports preset names: if a model string matches a preset, the preset's
|
|
||||||
full config is used for provider creation.
|
|
||||||
"""
|
|
||||||
from nanobot.nanobot import _make_provider_for_model
|
|
||||||
|
|
||||||
cache: dict[str, Any] = {}
|
|
||||||
presets = getattr(config, "model_presets", {}) or {}
|
|
||||||
|
|
||||||
def factory(model_or_preset: str):
|
|
||||||
preset = presets.get(model_or_preset)
|
|
||||||
actual_model = preset.model if preset else model_or_preset
|
|
||||||
key = actual_model
|
|
||||||
if key not in cache:
|
|
||||||
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
|
||||||
return cache[key]
|
|
||||||
|
|
||||||
return factory
|
|
||||||
|
|
||||||
|
|
||||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||||
"""Load config and optionally override the active workspace."""
|
"""Load config and optionally override the active workspace."""
|
||||||
from nanobot.config.loader import load_config, resolve_config_env_vars, set_config_path
|
from nanobot.config.loader import load_config, resolve_config_env_vars, set_config_path
|
||||||
@ -612,8 +515,6 @@ def serve(
|
|||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.api.server import create_app
|
from nanobot.api.server import create_app
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.session.manager import SessionManager
|
from nanobot.session.manager import SessionManager
|
||||||
@ -630,45 +531,19 @@ def serve(
|
|||||||
timeout = timeout if timeout is not None else api_cfg.timeout
|
timeout = timeout if timeout is not None else api_cfg.timeout
|
||||||
sync_workspace_templates(runtime_config.workspace_path)
|
sync_workspace_templates(runtime_config.workspace_path)
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = _make_provider(runtime_config)
|
|
||||||
defaults = runtime_config.agents.defaults
|
defaults = runtime_config.agents.defaults
|
||||||
pf = _make_cli_provider_factory(runtime_config) if defaults.fallback_models else None
|
|
||||||
session_manager = SessionManager(runtime_config.workspace_path)
|
session_manager = SessionManager(runtime_config.workspace_path)
|
||||||
_resolved = runtime_config.resolve_preset()
|
resolved_preset = runtime_config.resolve_preset()
|
||||||
agent_loop = AgentLoop(
|
agent_loop = AgentLoop.from_config(
|
||||||
bus=bus,
|
runtime_config, bus,
|
||||||
provider=provider,
|
|
||||||
workspace=runtime_config.workspace_path,
|
|
||||||
model=_resolved.model,
|
|
||||||
max_iterations=defaults.max_tool_iterations,
|
|
||||||
context_window_tokens=_resolved.context_window_tokens,
|
|
||||||
context_block_limit=defaults.context_block_limit,
|
|
||||||
max_tool_result_chars=defaults.max_tool_result_chars,
|
|
||||||
provider_retry_mode=defaults.provider_retry_mode,
|
|
||||||
fallback_models=defaults.fallback_models,
|
|
||||||
provider_factory=pf,
|
|
||||||
web_config=runtime_config.tools.web,
|
|
||||||
exec_config=runtime_config.tools.exec,
|
|
||||||
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
|
|
||||||
session_manager=session_manager,
|
session_manager=session_manager,
|
||||||
mcp_servers=runtime_config.tools.mcp_servers,
|
|
||||||
channels_config=runtime_config.channels,
|
|
||||||
timezone=runtime_config.agents.defaults.timezone,
|
|
||||||
unified_session=runtime_config.agents.defaults.unified_session,
|
|
||||||
disabled_skills=runtime_config.agents.defaults.disabled_skills,
|
|
||||||
session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes,
|
|
||||||
consolidation_ratio=runtime_config.agents.defaults.consolidation_ratio,
|
|
||||||
max_messages=runtime_config.agents.defaults.max_messages,
|
|
||||||
tools_config=runtime_config.tools,
|
|
||||||
image_generation_provider_configs={
|
image_generation_provider_configs={
|
||||||
"openrouter": runtime_config.providers.openrouter,
|
"openrouter": runtime_config.providers.openrouter,
|
||||||
"aihubmix": runtime_config.providers.aihubmix,
|
"aihubmix": runtime_config.providers.aihubmix,
|
||||||
},
|
},
|
||||||
model_presets=runtime_config.model_presets,
|
|
||||||
model_preset=runtime_config.agents.defaults.model_preset,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_name = _resolved.model
|
model_name = resolved_preset.model
|
||||||
preset_name = defaults.model_preset
|
preset_name = defaults.model_preset
|
||||||
preset_tag = f" (preset: {preset_name})" if preset_name else ""
|
preset_tag = f" (preset: {preset_name})" if preset_name else ""
|
||||||
console.print(f"{__logo__} Starting OpenAI-compatible API server")
|
console.print(f"{__logo__} Starting OpenAI-compatible API server")
|
||||||
@ -735,7 +610,6 @@ def _run_gateway(
|
|||||||
open_browser_url: str | None = None,
|
open_browser_url: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Shared gateway runtime; ``open_browser_url`` opens a tab once channels are up."""
|
"""Shared gateway runtime; ``open_browser_url`` opens a tab once channels are up."""
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.agent.tools.message import MessageTool
|
from nanobot.agent.tools.message import MessageTool
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@ -751,9 +625,6 @@ def _run_gateway(
|
|||||||
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
|
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
|
||||||
sync_workspace_templates(config.workspace_path)
|
sync_workspace_templates(config.workspace_path)
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = _make_provider(config)
|
|
||||||
gw_defaults = config.agents.defaults
|
|
||||||
gw_pf = _make_cli_provider_factory(config) if gw_defaults.fallback_models else None
|
|
||||||
try:
|
try:
|
||||||
provider_snapshot = build_provider_snapshot(config)
|
provider_snapshot = build_provider_snapshot(config)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
@ -770,41 +641,16 @@ def _run_gateway(
|
|||||||
cron = CronService(cron_store_path)
|
cron = CronService(cron_store_path)
|
||||||
|
|
||||||
# Create agent with cron service
|
# Create agent with cron service
|
||||||
_resolved = config.resolve_preset()
|
agent = AgentLoop.from_config(
|
||||||
agent = AgentLoop(
|
config, bus,
|
||||||
bus=bus,
|
|
||||||
provider=provider,
|
|
||||||
workspace=config.workspace_path,
|
|
||||||
model=_resolved.model,
|
|
||||||
max_iterations=gw_defaults.max_tool_iterations,
|
|
||||||
context_window_tokens=_resolved.context_window_tokens,
|
|
||||||
web_config=config.tools.web,
|
|
||||||
context_block_limit=gw_defaults.context_block_limit,
|
|
||||||
max_tool_result_chars=gw_defaults.max_tool_result_chars,
|
|
||||||
provider_retry_mode=gw_defaults.provider_retry_mode,
|
|
||||||
fallback_models=gw_defaults.fallback_models,
|
|
||||||
provider_factory=gw_pf,
|
|
||||||
exec_config=config.tools.exec,
|
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
|
||||||
session_manager=session_manager,
|
session_manager=session_manager,
|
||||||
mcp_servers=config.tools.mcp_servers,
|
|
||||||
channels_config=config.channels,
|
|
||||||
timezone=config.agents.defaults.timezone,
|
|
||||||
unified_session=config.agents.defaults.unified_session,
|
|
||||||
disabled_skills=config.agents.defaults.disabled_skills,
|
|
||||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
|
||||||
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
|
||||||
max_messages=config.agents.defaults.max_messages,
|
|
||||||
tools_config=config.tools,
|
|
||||||
image_generation_provider_configs={
|
image_generation_provider_configs={
|
||||||
"openrouter": config.providers.openrouter,
|
"openrouter": config.providers.openrouter,
|
||||||
"aihubmix": config.providers.aihubmix,
|
"aihubmix": config.providers.aihubmix,
|
||||||
},
|
},
|
||||||
provider_snapshot_loader=load_provider_snapshot,
|
provider_snapshot_loader=load_provider_snapshot,
|
||||||
provider_signature=provider_snapshot.signature,
|
provider_signature=provider_snapshot.signature,
|
||||||
model_presets=config.model_presets,
|
|
||||||
model_preset=config.agents.defaults.model_preset,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||||
@ -908,7 +754,7 @@ def _run_gateway(
|
|||||||
|
|
||||||
if job.payload.deliver and job.payload.to and response:
|
if job.payload.deliver and job.payload.to and response:
|
||||||
should_notify = await evaluate_response(
|
should_notify = await evaluate_response(
|
||||||
response, reminder_note, provider, agent.model,
|
response, reminder_note, agent.provider, agent.model,
|
||||||
)
|
)
|
||||||
if should_notify:
|
if should_notify:
|
||||||
await _deliver_to_channel(
|
await _deliver_to_channel(
|
||||||
@ -998,7 +844,7 @@ def _run_gateway(
|
|||||||
hb_cfg = config.gateway.heartbeat
|
hb_cfg = config.gateway.heartbeat
|
||||||
heartbeat = HeartbeatService(
|
heartbeat = HeartbeatService(
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
provider=provider,
|
provider=agent.provider,
|
||||||
model=agent.model,
|
model=agent.model,
|
||||||
on_execute=on_heartbeat_execute,
|
on_execute=on_heartbeat_execute,
|
||||||
on_notify=on_heartbeat_notify,
|
on_notify=on_heartbeat_notify,
|
||||||
@ -1151,7 +997,6 @@ def agent(
|
|||||||
"""Interact with the agent directly."""
|
"""Interact with the agent directly."""
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
|
|
||||||
@ -1159,10 +1004,6 @@ def agent(
|
|||||||
sync_workspace_templates(config.workspace_path)
|
sync_workspace_templates(config.workspace_path)
|
||||||
|
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = _make_provider(config)
|
|
||||||
chat_defaults = config.agents.defaults
|
|
||||||
chat_pf = _make_cli_provider_factory(config) if chat_defaults.fallback_models else None
|
|
||||||
|
|
||||||
# Preserve existing single-workspace installs, but keep custom workspaces clean.
|
# Preserve existing single-workspace installs, but keep custom workspaces clean.
|
||||||
if is_default_workspace(config.workspace_path):
|
if is_default_workspace(config.workspace_path):
|
||||||
_migrate_cron_store(config)
|
_migrate_cron_store(config)
|
||||||
@ -1176,34 +1017,10 @@ def agent(
|
|||||||
else:
|
else:
|
||||||
logger.disable("nanobot")
|
logger.disable("nanobot")
|
||||||
|
|
||||||
_resolved = config.resolve_preset()
|
resolved_preset = config.resolve_preset()
|
||||||
agent_loop = AgentLoop(
|
agent_loop = AgentLoop.from_config(
|
||||||
bus=bus,
|
config, bus,
|
||||||
provider=provider,
|
|
||||||
workspace=config.workspace_path,
|
|
||||||
model=_resolved.model,
|
|
||||||
max_iterations=chat_defaults.max_tool_iterations,
|
|
||||||
context_window_tokens=_resolved.context_window_tokens,
|
|
||||||
web_config=config.tools.web,
|
|
||||||
context_block_limit=chat_defaults.context_block_limit,
|
|
||||||
max_tool_result_chars=chat_defaults.max_tool_result_chars,
|
|
||||||
provider_retry_mode=chat_defaults.provider_retry_mode,
|
|
||||||
fallback_models=chat_defaults.fallback_models,
|
|
||||||
provider_factory=chat_pf,
|
|
||||||
exec_config=config.tools.exec,
|
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
|
||||||
mcp_servers=config.tools.mcp_servers,
|
|
||||||
channels_config=config.channels,
|
|
||||||
timezone=config.agents.defaults.timezone,
|
|
||||||
unified_session=config.agents.defaults.unified_session,
|
|
||||||
disabled_skills=config.agents.defaults.disabled_skills,
|
|
||||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
|
||||||
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
|
||||||
max_messages=config.agents.defaults.max_messages,
|
|
||||||
tools_config=config.tools,
|
|
||||||
model_presets=config.model_presets,
|
|
||||||
model_preset=config.agents.defaults.model_preset,
|
|
||||||
)
|
)
|
||||||
restart_notice = consume_restart_notice_from_env()
|
restart_notice = consume_restart_notice_from_env()
|
||||||
if restart_notice and should_show_cli_restart_notice(restart_notice, session_id):
|
if restart_notice and should_show_cli_restart_notice(restart_notice, session_id):
|
||||||
@ -1247,7 +1064,7 @@ def agent(
|
|||||||
# Interactive mode — route through bus like other channels
|
# Interactive mode — route through bus like other channels
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
_init_prompt_session()
|
_init_prompt_session()
|
||||||
console.print(f"{__logo__} Interactive mode [bold blue]({_resolved.model})[/bold blue] — type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit\n")
|
console.print(f"{__logo__} Interactive mode [bold blue]({resolved_preset.model})[/bold blue] — type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit\n")
|
||||||
|
|
||||||
if ":" in session_id:
|
if ":" in session_id:
|
||||||
cli_channel, cli_chat_id = session_id.split(":", 1)
|
cli_channel, cli_chat_id = session_id.split(":", 1)
|
||||||
@ -1605,10 +1422,10 @@ def status():
|
|||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
from nanobot.providers.registry import PROVIDERS
|
from nanobot.providers.registry import PROVIDERS
|
||||||
|
|
||||||
_resolved = config.resolve_preset()
|
resolved_preset = config.resolve_preset()
|
||||||
_preset = config.agents.defaults.model_preset
|
preset = config.agents.defaults.model_preset
|
||||||
_preset_tag = f" (preset: {_preset})" if _preset else ""
|
preset_tag = f" (preset: {preset})" if preset else ""
|
||||||
console.print(f"Model: {_resolved.model}{_preset_tag}")
|
console.print(f"Model: {resolved_preset.model}{preset_tag}")
|
||||||
|
|
||||||
# Check API keys from registry
|
# Check API keys from registry
|
||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from nanobot.cli.models import (
|
|||||||
get_model_suggestions,
|
get_model_suggestions,
|
||||||
)
|
)
|
||||||
from nanobot.config.loader import get_config_path, load_config
|
from nanobot.config.loader import get_config_path, load_config
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config, ModelPresetConfig
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
@ -49,6 +49,16 @@ _SELECT_FIELD_HINTS: dict[str, tuple[list[str], str]] = {
|
|||||||
|
|
||||||
_BACK_PRESSED = object() # Sentinel value for back navigation
|
_BACK_PRESSED = object() # Sentinel value for back navigation
|
||||||
|
|
||||||
|
# Cache of model-preset names populated at runtime so that field handlers can
|
||||||
|
# offer existing presets as choices (e.g. AgentDefaults.model_preset).
|
||||||
|
#
|
||||||
|
# Lifecycle: populated by _sync_preset_cache(config), which must be called
|
||||||
|
# after every config mutation that changes model_presets (add, delete, edit).
|
||||||
|
# Cleared between tests via _MODEL_PRESET_CACHE.clear(). In long-running
|
||||||
|
# processes (gateway) the cache is refreshed each time the preset management
|
||||||
|
# screen is entered, so staleness is bounded by user interaction.
|
||||||
|
_MODEL_PRESET_CACHE: set[str] = set()
|
||||||
|
|
||||||
|
|
||||||
def _get_questionary():
|
def _get_questionary():
|
||||||
"""Return questionary or raise a clear error when wizard deps are unavailable."""
|
"""Return questionary or raise a clear error when wizard deps are unavailable."""
|
||||||
@ -588,9 +598,100 @@ def _handle_context_window_field(
|
|||||||
setattr(working_model, field_name, new_value)
|
setattr(working_model, field_name, new_value)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_model_preset_field(
|
||||||
|
working_model: BaseModel, field_name: str, field_display: str, current_value: Any
|
||||||
|
) -> None:
|
||||||
|
"""Handle the 'model_preset' field with a list of existing presets."""
|
||||||
|
# model_preset lives on AgentDefaults, but the preset list is on Config.
|
||||||
|
# We can't easily access Config here, so we read from the global config
|
||||||
|
# via a module-level cache set by _configure_model_presets / run_onboard.
|
||||||
|
preset_names = sorted(_MODEL_PRESET_CACHE)
|
||||||
|
choices = ["(clear/unset)"] + preset_names
|
||||||
|
default_choice = str(current_value) if current_value else "(clear/unset)"
|
||||||
|
new_value = _select_with_back(field_display, choices, default=default_choice)
|
||||||
|
if new_value is _BACK_PRESSED:
|
||||||
|
return
|
||||||
|
if new_value == "(clear/unset)":
|
||||||
|
setattr(working_model, field_name, None)
|
||||||
|
elif new_value is not None:
|
||||||
|
setattr(working_model, field_name, new_value)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_provider_field(
|
||||||
|
working_model: BaseModel, field_name: str, field_display: str, current_value: Any
|
||||||
|
) -> None:
|
||||||
|
"""Handle the 'provider' field with a list of registered providers."""
|
||||||
|
provider_names = sorted(_get_provider_names().keys())
|
||||||
|
choices = ["auto"] + provider_names
|
||||||
|
default_choice = str(current_value) if current_value else "auto"
|
||||||
|
new_value = _select_with_back(field_display, choices, default=default_choice)
|
||||||
|
if new_value is _BACK_PRESSED:
|
||||||
|
return
|
||||||
|
if new_value is not None:
|
||||||
|
setattr(working_model, field_name, new_value)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_fallback_presets_field(
|
||||||
|
working_model: BaseModel, field_name: str, field_display: str, current_value: Any
|
||||||
|
) -> None:
|
||||||
|
"""Handle the 'fallback_presets' field with preset-aware multi-select."""
|
||||||
|
items: list[str] = list(current_value) if isinstance(current_value, list) else []
|
||||||
|
preset_names = sorted(_MODEL_PRESET_CACHE)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
console.clear()
|
||||||
|
console.print(f"[bold]{field_display}[/bold]")
|
||||||
|
if items:
|
||||||
|
for idx, item in enumerate(items, 1):
|
||||||
|
console.print(f" {idx}. {item}")
|
||||||
|
else:
|
||||||
|
console.print(" [dim](empty)[/dim]")
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
choices = ["[+] Add preset"]
|
||||||
|
if items:
|
||||||
|
choices.append("[-] Remove last")
|
||||||
|
choices.append("[X] Clear all")
|
||||||
|
choices.append("[Done]")
|
||||||
|
choices.append("<- Back")
|
||||||
|
|
||||||
|
answer = _get_questionary().select(
|
||||||
|
"Manage fallback chain:",
|
||||||
|
choices=choices,
|
||||||
|
qmark=">",
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if answer is None or answer == "<- Back":
|
||||||
|
return
|
||||||
|
if answer == "[Done]":
|
||||||
|
setattr(working_model, field_name, items)
|
||||||
|
return
|
||||||
|
if answer == "[+] Add preset":
|
||||||
|
if not preset_names:
|
||||||
|
console.print("[yellow]! No presets defined yet.[/yellow]")
|
||||||
|
_get_questionary().press_any_key_to_continue().ask()
|
||||||
|
continue
|
||||||
|
add_choices = [p for p in preset_names if p not in items]
|
||||||
|
if not add_choices:
|
||||||
|
console.print("[yellow]! All presets already added.[/yellow]")
|
||||||
|
_get_questionary().press_any_key_to_continue().ask()
|
||||||
|
continue
|
||||||
|
picked = _select_with_back("Select preset:", add_choices)
|
||||||
|
if picked is _BACK_PRESSED or picked is None:
|
||||||
|
continue
|
||||||
|
items.append(picked)
|
||||||
|
elif answer == "[-] Remove last" and items:
|
||||||
|
items.pop()
|
||||||
|
elif answer == "[X] Clear all" and items:
|
||||||
|
items.clear()
|
||||||
|
|
||||||
|
|
||||||
_FIELD_HANDLERS: dict[str, Any] = {
|
_FIELD_HANDLERS: dict[str, Any] = {
|
||||||
"model": _handle_model_field,
|
"model": _handle_model_field,
|
||||||
"context_window_tokens": _handle_context_window_field,
|
"context_window_tokens": _handle_context_window_field,
|
||||||
|
"model_preset": _handle_model_preset_field,
|
||||||
|
"provider": _handle_provider_field,
|
||||||
|
"fallback_presets": _handle_fallback_presets_field,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -757,6 +858,113 @@ def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None
|
|||||||
console.print("[dim](i) Could not auto-fill context window (model not in database)[/dim]")
|
console.print("[dim](i) Could not auto-fill context window (model not in database)[/dim]")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Model Preset Configuration ---
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_preset_cache(config: Config) -> None:
|
||||||
|
"""Synchronise the module-level preset name cache from config."""
|
||||||
|
_MODEL_PRESET_CACHE.clear()
|
||||||
|
_MODEL_PRESET_CACHE.update(config.model_presets.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_model_presets(config: Config) -> None:
|
||||||
|
"""Configure model presets (CRUD)."""
|
||||||
|
_sync_preset_cache(config)
|
||||||
|
|
||||||
|
def get_preset_choices() -> list[str]:
|
||||||
|
choices: list[str] = []
|
||||||
|
for name, preset in config.model_presets.items():
|
||||||
|
choices.append(f"{name} ({preset.model})")
|
||||||
|
choices.append("[+] Add new preset")
|
||||||
|
choices.append("<- Back")
|
||||||
|
return choices
|
||||||
|
|
||||||
|
last_preset_name: str | None = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
console.clear()
|
||||||
|
_show_section_header(
|
||||||
|
"Model Presets",
|
||||||
|
"Create, edit or delete named model presets for quick switching",
|
||||||
|
)
|
||||||
|
choices = get_preset_choices()
|
||||||
|
default_choice = None
|
||||||
|
if last_preset_name:
|
||||||
|
for c in choices:
|
||||||
|
if c.startswith(last_preset_name + " ("):
|
||||||
|
default_choice = c
|
||||||
|
break
|
||||||
|
answer = _select_with_back(
|
||||||
|
"Select preset:", choices, default=default_choice
|
||||||
|
)
|
||||||
|
|
||||||
|
if answer is _BACK_PRESSED or answer is None or answer == "<- Back":
|
||||||
|
break
|
||||||
|
|
||||||
|
assert isinstance(answer, str)
|
||||||
|
|
||||||
|
if answer == "[+] Add new preset":
|
||||||
|
name_input = _get_questionary().text(
|
||||||
|
"Preset name:",
|
||||||
|
validate=lambda t: True if t and t.strip() else "Name cannot be empty",
|
||||||
|
).ask()
|
||||||
|
if not name_input:
|
||||||
|
continue
|
||||||
|
name = name_input.strip()
|
||||||
|
if name in config.model_presets:
|
||||||
|
console.print(f"[yellow]! Preset '{name}' already exists[/yellow]")
|
||||||
|
_pause()
|
||||||
|
continue
|
||||||
|
new_preset = ModelPresetConfig(model="")
|
||||||
|
updated = _configure_pydantic_model(new_preset, f"New Preset: {name}")
|
||||||
|
if updated is not None:
|
||||||
|
config.model_presets[name] = updated
|
||||||
|
_sync_preset_cache(config)
|
||||||
|
last_preset_name = name
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Editing / deleting an existing preset
|
||||||
|
# Extract preset name from "name (model)" format
|
||||||
|
preset_name = answer.split(" (", 1)[0]
|
||||||
|
preset = config.model_presets.get(preset_name)
|
||||||
|
if preset is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
last_preset_name = preset_name
|
||||||
|
|
||||||
|
choices = ["Edit", "Cancel"]
|
||||||
|
if preset_name != "default":
|
||||||
|
choices.insert(1, "Delete")
|
||||||
|
action = _select_with_back(
|
||||||
|
f"Preset: {preset_name}",
|
||||||
|
choices,
|
||||||
|
default="Edit",
|
||||||
|
)
|
||||||
|
if action is _BACK_PRESSED or action == "Cancel" or action is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if action == "Delete":
|
||||||
|
confirm = _get_questionary().confirm(
|
||||||
|
f"Delete preset '{preset_name}'?",
|
||||||
|
default=False,
|
||||||
|
).ask()
|
||||||
|
if confirm:
|
||||||
|
del config.model_presets[preset_name]
|
||||||
|
_sync_preset_cache(config)
|
||||||
|
last_preset_name = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
if action == "Edit":
|
||||||
|
updated = _configure_pydantic_model(preset, f"Edit Preset: {preset_name}")
|
||||||
|
if updated is not None:
|
||||||
|
config.model_presets[preset_name] = updated
|
||||||
|
_sync_preset_cache(config)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
console.print("\n[dim]Returning to main menu...[/dim]")
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
# --- Provider Configuration ---
|
# --- Provider Configuration ---
|
||||||
|
|
||||||
|
|
||||||
@ -1043,6 +1251,12 @@ def _show_summary(config: Config) -> None:
|
|||||||
channel_rows.append((display, status))
|
channel_rows.append((display, status))
|
||||||
_print_summary_panel(channel_rows, "Chat Channels")
|
_print_summary_panel(channel_rows, "Chat Channels")
|
||||||
|
|
||||||
|
# Model Presets
|
||||||
|
preset_rows = []
|
||||||
|
for name, preset in config.model_presets.items():
|
||||||
|
preset_rows.append((name, f"{preset.model} (ctx={preset.context_window_tokens})"))
|
||||||
|
_print_summary_panel(preset_rows, "Model Presets")
|
||||||
|
|
||||||
# Settings sections
|
# Settings sections
|
||||||
for title, model in [
|
for title, model in [
|
||||||
("Agent Settings", config.agents.defaults),
|
("Agent Settings", config.agents.defaults),
|
||||||
@ -1112,6 +1326,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
|||||||
|
|
||||||
original_config = base_config.model_copy(deep=True)
|
original_config = base_config.model_copy(deep=True)
|
||||||
config = base_config.model_copy(deep=True)
|
config = base_config.model_copy(deep=True)
|
||||||
|
_sync_preset_cache(config)
|
||||||
|
|
||||||
last_main_choice: str | None = None
|
last_main_choice: str | None = None
|
||||||
while True:
|
while True:
|
||||||
@ -1123,6 +1338,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
|||||||
"What would you like to configure?",
|
"What would you like to configure?",
|
||||||
choices=[
|
choices=[
|
||||||
"[P] LLM Provider",
|
"[P] LLM Provider",
|
||||||
|
"[M] Model Presets",
|
||||||
"[C] Chat Channel",
|
"[C] Chat Channel",
|
||||||
"[H] Channel Common",
|
"[H] Channel Common",
|
||||||
"[A] Agent Settings",
|
"[A] Agent Settings",
|
||||||
@ -1149,6 +1365,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
|||||||
|
|
||||||
_menu_dispatch = {
|
_menu_dispatch = {
|
||||||
"[P] LLM Provider": lambda: _configure_providers(config),
|
"[P] LLM Provider": lambda: _configure_providers(config),
|
||||||
|
"[M] Model Presets": lambda: _configure_model_presets(config),
|
||||||
"[C] Chat Channel": lambda: _configure_channels(config),
|
"[C] Chat Channel": lambda: _configure_channels(config),
|
||||||
"[H] Channel Common": lambda: _configure_general_settings(config, "Channel Common"),
|
"[H] Channel Common": lambda: _configure_general_settings(config, "Channel Common"),
|
||||||
"[A] Agent Settings": lambda: _configure_general_settings(config, "Agent Settings"),
|
"[A] Agent Settings": lambda: _configure_general_settings(config, "Agent Settings"),
|
||||||
|
|||||||
@ -104,8 +104,9 @@ class AgentDefaults(Base):
|
|||||||
validation_alias=AliasChoices("toolHintMaxLength"),
|
validation_alias=AliasChoices("toolHintMaxLength"),
|
||||||
serialization_alias="toolHintMaxLength",
|
serialization_alias="toolHintMaxLength",
|
||||||
) # Max characters for tool hint display (e.g. "$ cd …/project && npm test")
|
) # Max characters for tool hint display (e.g. "$ cd …/project && npm test")
|
||||||
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
|
fallback_presets: list[str] = Field(
|
||||||
fallback_models: list[str] = Field(default_factory=list)
|
default_factory=list
|
||||||
|
) # Ordered fallback chain. Each item must be a preset name defined in model_presets.
|
||||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||||
unified_session: bool = False # Share one session across all channels (single-user multi-device)
|
unified_session: bool = False # Share one session across all channels (single-user multi-device)
|
||||||
disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"])
|
disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"])
|
||||||
@ -306,24 +307,52 @@ class Config(BaseSettings):
|
|||||||
model_presets: dict[str, ModelPresetConfig] = Field(default_factory=dict)
|
model_presets: dict[str, ModelPresetConfig] = Field(default_factory=dict)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_model_preset(self) -> "Config":
|
def _sync_and_validate_preset(self) -> "Config":
|
||||||
name = self.agents.defaults.model_preset
|
"""Expose agents.defaults model fields as the implicit 'default' preset
|
||||||
if name and name not in self.model_presets:
|
and validate the active preset reference.
|
||||||
raise ValueError(f"model_preset {name!r} not found in model_presets")
|
|
||||||
|
This guarantees that ``model_presets`` is never empty and that legacy
|
||||||
|
configs (which only set ``agents.defaults.model`` etc.) continue to work
|
||||||
|
without explicitly declaring a preset.
|
||||||
|
"""
|
||||||
|
self._refresh_default_preset()
|
||||||
|
defaults = self.agents.defaults
|
||||||
|
if defaults.model_preset is None:
|
||||||
|
defaults.model_preset = "default"
|
||||||
|
if defaults.model_preset not in self.model_presets:
|
||||||
|
raise ValueError(f"model_preset {defaults.model_preset!r} not found in model_presets")
|
||||||
|
for fb in defaults.fallback_presets:
|
||||||
|
if fb not in self.model_presets:
|
||||||
|
raise ValueError(f"fallback_presets entry {fb!r} not found in model_presets")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def resolve_preset(self) -> ModelPresetConfig:
|
def _refresh_default_preset(self) -> None:
|
||||||
"""Return effective model params: from active preset, or individual defaults."""
|
"""Rebuild the implicit 'default' preset from current agents.defaults.
|
||||||
name = self.agents.defaults.model_preset
|
|
||||||
if name:
|
Called inside ``_sync_and_validate_preset`` (model validator) and
|
||||||
return self.model_presets[name]
|
``resolve_preset()`` so that runtime mutations (e.g. tests directly
|
||||||
|
setting ``defaults.model``) are reflected.
|
||||||
|
"""
|
||||||
d = self.agents.defaults
|
d = self.agents.defaults
|
||||||
return ModelPresetConfig(
|
self.model_presets["default"] = ModelPresetConfig(
|
||||||
model=d.model, provider=d.provider, max_tokens=d.max_tokens,
|
model=d.model,
|
||||||
|
provider=d.provider,
|
||||||
|
max_tokens=d.max_tokens,
|
||||||
context_window_tokens=d.context_window_tokens,
|
context_window_tokens=d.context_window_tokens,
|
||||||
temperature=d.temperature, reasoning_effort=d.reasoning_effort,
|
temperature=d.temperature,
|
||||||
|
reasoning_effort=d.reasoning_effort,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def resolve_preset(self) -> ModelPresetConfig:
|
||||||
|
"""Return the active preset.
|
||||||
|
|
||||||
|
The implicit ``"default"`` preset is rebuilt from current defaults every
|
||||||
|
time so that runtime mutations (e.g. tests setting ``defaults.model``)
|
||||||
|
are always reflected.
|
||||||
|
"""
|
||||||
|
self._refresh_default_preset()
|
||||||
|
return self.model_presets[self.agents.defaults.model_preset]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def workspace_path(self) -> Path:
|
def workspace_path(self) -> Path:
|
||||||
"""Get expanded workspace path."""
|
"""Get expanded workspace path."""
|
||||||
@ -335,15 +364,16 @@ class Config(BaseSettings):
|
|||||||
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
||||||
from nanobot.providers.registry import PROVIDERS, find_by_name
|
from nanobot.providers.registry import PROVIDERS, find_by_name
|
||||||
|
|
||||||
forced = self.resolve_preset().provider
|
resolved = self.resolve_preset()
|
||||||
|
forced = resolved.provider
|
||||||
if forced != "auto":
|
if forced != "auto":
|
||||||
spec = find_by_name(forced)
|
spec = find_by_name(forced)
|
||||||
if spec:
|
if spec:
|
||||||
p = getattr(self.providers, spec.name, None)
|
provider_cfg = getattr(self.providers, spec.name, None)
|
||||||
return (p, spec.name) if p else (None, None)
|
return (provider_cfg, spec.name) if provider_cfg else (None, None)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
model_lower = (model or self.resolve_preset().model).lower()
|
model_lower = (model or resolved.model).lower()
|
||||||
model_normalized = model_lower.replace("-", "_")
|
model_normalized = model_lower.replace("-", "_")
|
||||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||||
normalized_prefix = model_prefix.replace("-", "_")
|
normalized_prefix = model_prefix.replace("-", "_")
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from typing import Any
|
|||||||
|
|
||||||
from nanobot.agent.hook import AgentHook, SDKCaptureHook
|
from nanobot.agent.hook import AgentHook, SDKCaptureHook
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@ -62,41 +61,12 @@ class Nanobot:
|
|||||||
Path(workspace).expanduser().resolve()
|
Path(workspace).expanduser().resolve()
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = _make_provider(config)
|
loop = AgentLoop.from_config(
|
||||||
bus = MessageBus()
|
config,
|
||||||
defaults = config.agents.defaults
|
|
||||||
_resolved = config.resolve_preset()
|
|
||||||
pf = _make_provider_factory(config) if defaults.fallback_models else None
|
|
||||||
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus,
|
|
||||||
provider=provider,
|
|
||||||
workspace=config.workspace_path,
|
|
||||||
model=_resolved.model,
|
|
||||||
max_iterations=defaults.max_tool_iterations,
|
|
||||||
context_window_tokens=_resolved.context_window_tokens,
|
|
||||||
context_block_limit=defaults.context_block_limit,
|
|
||||||
max_tool_result_chars=defaults.max_tool_result_chars,
|
|
||||||
provider_retry_mode=defaults.provider_retry_mode,
|
|
||||||
tool_hint_max_length=defaults.tool_hint_max_length,
|
|
||||||
fallback_models=defaults.fallback_models,
|
|
||||||
provider_factory=pf,
|
|
||||||
web_config=config.tools.web,
|
|
||||||
exec_config=config.tools.exec,
|
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
|
||||||
mcp_servers=config.tools.mcp_servers,
|
|
||||||
timezone=defaults.timezone,
|
|
||||||
unified_session=defaults.unified_session,
|
|
||||||
disabled_skills=defaults.disabled_skills,
|
|
||||||
session_ttl_minutes=defaults.session_ttl_minutes,
|
|
||||||
consolidation_ratio=defaults.consolidation_ratio,
|
|
||||||
tools_config=config.tools,
|
|
||||||
image_generation_provider_configs={
|
image_generation_provider_configs={
|
||||||
"openrouter": config.providers.openrouter,
|
"openrouter": config.providers.openrouter,
|
||||||
"aihubmix": config.providers.aihubmix,
|
"aihubmix": config.providers.aihubmix,
|
||||||
},
|
},
|
||||||
model_presets=config.model_presets,
|
|
||||||
model_preset=defaults.model_preset,
|
|
||||||
)
|
)
|
||||||
return cls(loop)
|
return cls(loop)
|
||||||
|
|
||||||
@ -134,99 +104,3 @@ class Nanobot:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_provider_for_model(
|
|
||||||
config: Any,
|
|
||||||
model: str,
|
|
||||||
*,
|
|
||||||
preset: Any | None = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Create an LLM provider instance for a specific model string.
|
|
||||||
|
|
||||||
When *preset* is given, its generation settings (temperature, max_tokens,
|
|
||||||
reasoning_effort) override the active preset defaults.
|
|
||||||
"""
|
|
||||||
from nanobot.providers.base import GenerationSettings
|
|
||||||
from nanobot.providers.factory import make_provider
|
|
||||||
from nanobot.providers.registry import find_by_name
|
|
||||||
|
|
||||||
gen_src = preset or config.resolve_preset()
|
|
||||||
provider_name = config.get_provider_name(model)
|
|
||||||
p = config.get_provider(model)
|
|
||||||
spec = find_by_name(provider_name) if provider_name else None
|
|
||||||
backend = spec.backend if spec else "openai_compat"
|
|
||||||
|
|
||||||
if backend == "azure_openai":
|
|
||||||
if not p or not p.api_key or not p.api_base:
|
|
||||||
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
|
||||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
|
||||||
needs_key = not (p and p.api_key)
|
|
||||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
|
||||||
if needs_key and not exempt:
|
|
||||||
raise ValueError(f"No API key configured for provider '{provider_name}'.")
|
|
||||||
|
|
||||||
if backend == "openai_codex":
|
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
|
||||||
|
|
||||||
provider = OpenAICodexProvider(default_model=model)
|
|
||||||
elif backend == "github_copilot":
|
|
||||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
|
||||||
|
|
||||||
provider = GitHubCopilotProvider(default_model=model)
|
|
||||||
elif backend == "azure_openai":
|
|
||||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
|
||||||
|
|
||||||
provider = AzureOpenAIProvider(
|
|
||||||
api_key=p.api_key, api_base=p.api_base, default_model=model
|
|
||||||
)
|
|
||||||
elif backend == "anthropic":
|
|
||||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
|
||||||
|
|
||||||
provider = AnthropicProvider(
|
|
||||||
api_key=p.api_key if p else None,
|
|
||||||
api_base=config.get_api_base(model),
|
|
||||||
default_model=model,
|
|
||||||
extra_headers=p.extra_headers if p else None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
|
||||||
|
|
||||||
provider = OpenAICompatProvider(
|
|
||||||
api_key=p.api_key if p else None,
|
|
||||||
api_base=config.get_api_base(model),
|
|
||||||
default_model=model,
|
|
||||||
extra_headers=p.extra_headers if p else None,
|
|
||||||
spec=spec,
|
|
||||||
extra_body=p.extra_body if p else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
provider.generation = GenerationSettings(
|
|
||||||
temperature=gen_src.temperature,
|
|
||||||
max_tokens=gen_src.max_tokens,
|
|
||||||
reasoning_effort=gen_src.reasoning_effort,
|
|
||||||
)
|
|
||||||
return provider
|
|
||||||
|
|
||||||
|
|
||||||
def _make_provider(config: Any) -> Any:
|
|
||||||
"""Create the LLM provider for the primary model from config."""
|
|
||||||
return _make_provider_for_model(config, config.resolve_preset().model)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_provider_factory(config: Any):
|
|
||||||
"""Build a cached factory that creates providers for arbitrary model strings.
|
|
||||||
|
|
||||||
If a model string matches a preset name in ``config.model_presets``, the
|
|
||||||
preset's full config (model, temperature, max_tokens, …) is used.
|
|
||||||
"""
|
|
||||||
cache: dict[str, Any] = {}
|
|
||||||
presets = getattr(config, "model_presets", {}) or {}
|
|
||||||
|
|
||||||
def factory(model_or_preset: str):
|
|
||||||
preset = presets.get(model_or_preset)
|
|
||||||
actual_model = preset.model if preset else model_or_preset
|
|
||||||
key = actual_model
|
|
||||||
if key not in cache:
|
|
||||||
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
|
||||||
return cache[key]
|
|
||||||
|
|
||||||
return factory
|
|
||||||
|
|||||||
@ -137,7 +137,9 @@ class LLMProvider(ABC):
|
|||||||
"insufficient_quota",
|
"insufficient_quota",
|
||||||
"insufficient quota",
|
"insufficient quota",
|
||||||
"quota exceeded",
|
"quota exceeded",
|
||||||
|
"quota_exceeded",
|
||||||
"quota exhausted",
|
"quota exhausted",
|
||||||
|
"quota_exhausted",
|
||||||
"billing hard limit",
|
"billing hard limit",
|
||||||
"billing_hard_limit_reached",
|
"billing_hard_limit_reached",
|
||||||
"billing not active",
|
"billing not active",
|
||||||
|
|||||||
@ -4,11 +4,16 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
from nanobot.providers.base import GenerationSettings, LLMProvider
|
from nanobot.providers.base import GenerationSettings, LLMProvider
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanobot.config.schema import ModelPresetConfig, ProviderConfig
|
||||||
|
from nanobot.providers.registry import ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ProviderSnapshot:
|
class ProviderSnapshot:
|
||||||
@ -18,22 +23,62 @@ class ProviderSnapshot:
|
|||||||
signature: tuple[object, ...]
|
signature: tuple[object, ...]
|
||||||
|
|
||||||
|
|
||||||
def make_provider(config: Config) -> LLMProvider:
|
@dataclass(frozen=True)
|
||||||
"""Create the LLM provider implied by config."""
|
class _ProviderInfo:
|
||||||
model = config.agents.defaults.model
|
"""Resolved metadata needed to build and validate an LLM provider."""
|
||||||
provider_name = config.get_provider_name(model)
|
|
||||||
p = config.get_provider(model)
|
name: str | None
|
||||||
spec = find_by_name(provider_name) if provider_name else None
|
cfg: ProviderConfig | None
|
||||||
|
spec: ProviderSpec | None
|
||||||
|
api_base: str | None
|
||||||
|
backend: str
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_provider_info(
|
||||||
|
config: Config,
|
||||||
|
model: str,
|
||||||
|
preset: ModelPresetConfig,
|
||||||
|
) -> _ProviderInfo:
|
||||||
|
"""Derive provider name, config, spec and api_base from preset or auto-detection."""
|
||||||
|
if preset.provider != "auto":
|
||||||
|
name = preset.provider
|
||||||
|
cfg = getattr(config.providers, name, None)
|
||||||
|
spec = find_by_name(name)
|
||||||
|
api_base = (
|
||||||
|
cfg.api_base
|
||||||
|
if cfg and cfg.api_base
|
||||||
|
else (spec.default_api_base if spec and spec.default_api_base else None)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
name = config.get_provider_name(model)
|
||||||
|
cfg = config.get_provider(model)
|
||||||
|
spec = find_by_name(name) if name else None
|
||||||
|
api_base = config.get_api_base(model)
|
||||||
|
|
||||||
backend = spec.backend if spec else "openai_compat"
|
backend = spec.backend if spec else "openai_compat"
|
||||||
|
return _ProviderInfo(name=name, cfg=cfg, spec=spec, api_base=api_base, backend=backend)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_provider(info: _ProviderInfo, model: str) -> None:
|
||||||
|
"""Ensure credentials / endpoints are present before instantiation."""
|
||||||
|
cfg = info.cfg
|
||||||
|
backend = info.backend
|
||||||
|
name = info.name
|
||||||
|
|
||||||
if backend == "azure_openai":
|
if backend == "azure_openai":
|
||||||
if not p or not p.api_key or not p.api_base:
|
if not cfg or not cfg.api_key or not cfg.api_base:
|
||||||
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
||||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||||
needs_key = not (p and p.api_key)
|
needs_key = not (cfg and cfg.api_key)
|
||||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
exempt = info.spec and (info.spec.is_oauth or info.spec.is_local or info.spec.is_direct)
|
||||||
if needs_key and not exempt:
|
if needs_key and not exempt:
|
||||||
raise ValueError(f"No API key configured for provider '{provider_name}'.")
|
raise ValueError(f"No API key configured for provider '{name}'.")
|
||||||
|
|
||||||
|
|
||||||
|
def _create_provider(model: str, info: _ProviderInfo) -> LLMProvider:
|
||||||
|
"""Instantiate the concrete provider class for *backend*."""
|
||||||
|
cfg = info.cfg
|
||||||
|
backend = info.backend
|
||||||
|
|
||||||
if backend == "openai_codex":
|
if backend == "openai_codex":
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
@ -43,8 +88,8 @@ def make_provider(config: Config) -> LLMProvider:
|
|||||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
|
|
||||||
provider = AzureOpenAIProvider(
|
provider = AzureOpenAIProvider(
|
||||||
api_key=p.api_key,
|
api_key=cfg.api_key if cfg else None,
|
||||||
api_base=p.api_base,
|
api_base=info.api_base,
|
||||||
default_model=model,
|
default_model=model,
|
||||||
)
|
)
|
||||||
elif backend == "github_copilot":
|
elif backend == "github_copilot":
|
||||||
@ -55,70 +100,103 @@ def make_provider(config: Config) -> LLMProvider:
|
|||||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||||
|
|
||||||
provider = AnthropicProvider(
|
provider = AnthropicProvider(
|
||||||
api_key=p.api_key if p else None,
|
api_key=cfg.api_key if cfg else None,
|
||||||
api_base=config.get_api_base(model),
|
api_base=info.api_base,
|
||||||
default_model=model,
|
default_model=model,
|
||||||
extra_headers=p.extra_headers if p else None,
|
extra_headers=cfg.extra_headers if cfg else None,
|
||||||
)
|
)
|
||||||
elif backend == "bedrock":
|
elif backend == "bedrock":
|
||||||
from nanobot.providers.bedrock_provider import BedrockProvider
|
from nanobot.providers.bedrock_provider import BedrockProvider
|
||||||
|
|
||||||
provider = BedrockProvider(
|
provider = BedrockProvider(
|
||||||
api_key=p.api_key if p else None,
|
api_key=cfg.api_key if cfg else None,
|
||||||
api_base=p.api_base if p else None,
|
api_base=info.api_base if cfg else None,
|
||||||
default_model=model,
|
default_model=model,
|
||||||
region=getattr(p, "region", None) if p else None,
|
region=getattr(cfg, "region", None) if cfg else None,
|
||||||
profile=getattr(p, "profile", None) if p else None,
|
profile=getattr(cfg, "profile", None) if cfg else None,
|
||||||
extra_body=p.extra_body if p else None,
|
extra_body=cfg.extra_body if cfg else None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
|
||||||
provider = OpenAICompatProvider(
|
provider = OpenAICompatProvider(
|
||||||
api_key=p.api_key if p else None,
|
api_key=cfg.api_key if cfg else None,
|
||||||
api_base=config.get_api_base(model),
|
api_base=info.api_base,
|
||||||
default_model=model,
|
default_model=model,
|
||||||
extra_headers=p.extra_headers if p else None,
|
extra_headers=cfg.extra_headers if cfg else None,
|
||||||
spec=spec,
|
spec=info.spec,
|
||||||
extra_body=p.extra_body if p else None,
|
extra_body=cfg.extra_body if cfg else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
defaults = config.agents.defaults
|
|
||||||
provider.generation = GenerationSettings(
|
|
||||||
temperature=defaults.temperature,
|
|
||||||
max_tokens=defaults.max_tokens,
|
|
||||||
reasoning_effort=defaults.reasoning_effort,
|
|
||||||
)
|
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_generation(provider: LLMProvider, preset: ModelPresetConfig) -> None:
|
||||||
|
provider.generation = GenerationSettings(
|
||||||
|
temperature=preset.temperature,
|
||||||
|
max_tokens=preset.max_tokens,
|
||||||
|
reasoning_effort=preset.reasoning_effort,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_provider_for_preset(config: Config, preset: ModelPresetConfig) -> LLMProvider:
|
||||||
|
"""Create an LLM provider from a full *preset* (model + provider + generation)."""
|
||||||
|
info = _resolve_provider_info(config, preset.model, preset)
|
||||||
|
_validate_provider(info, preset.model)
|
||||||
|
provider = _create_provider(preset.model, info)
|
||||||
|
_apply_generation(provider, preset)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def make_provider(config: Config) -> LLMProvider:
|
||||||
|
"""Create the LLM provider implied by config (legacy entrypoint)."""
|
||||||
|
resolved = config.resolve_preset()
|
||||||
|
return build_provider_for_preset(config, resolved)
|
||||||
|
|
||||||
|
|
||||||
|
def make_provider_factory(config: Config):
|
||||||
|
"""Build a cached factory that creates providers for preset names.
|
||||||
|
|
||||||
|
The factory looks up *preset_name* in ``config.model_presets`` and builds
|
||||||
|
the provider from the preset's full configuration.
|
||||||
|
"""
|
||||||
|
cache: dict[str, LLMProvider] = {}
|
||||||
|
presets = config.model_presets
|
||||||
|
|
||||||
|
def factory(preset_name: str) -> LLMProvider:
|
||||||
|
preset = presets.get(preset_name)
|
||||||
|
if preset is None:
|
||||||
|
raise ValueError(f"Preset {preset_name!r} not found in model_presets")
|
||||||
|
if preset_name not in cache:
|
||||||
|
cache[preset_name] = build_provider_for_preset(config, preset)
|
||||||
|
return cache[preset_name]
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
def provider_signature(config: Config) -> tuple[object, ...]:
|
def provider_signature(config: Config) -> tuple[object, ...]:
|
||||||
"""Return the config fields that affect the primary LLM provider."""
|
"""Return the config fields that affect the primary LLM provider."""
|
||||||
model = config.agents.defaults.model
|
resolved = config.resolve_preset()
|
||||||
defaults = config.agents.defaults
|
defaults = config.agents.defaults
|
||||||
p = config.get_provider(model)
|
|
||||||
return (
|
return (
|
||||||
model,
|
resolved.model,
|
||||||
defaults.provider,
|
resolved.provider,
|
||||||
config.get_provider_name(model),
|
config.get_provider_name(resolved.model),
|
||||||
config.get_api_key(model),
|
config.get_api_key(resolved.model),
|
||||||
config.get_api_base(model),
|
config.get_api_base(resolved.model),
|
||||||
p.extra_headers if p else None,
|
resolved.max_tokens,
|
||||||
p.extra_body if p else None,
|
resolved.temperature,
|
||||||
getattr(p, "region", None) if p else None,
|
resolved.reasoning_effort,
|
||||||
getattr(p, "profile", None) if p else None,
|
resolved.context_window_tokens,
|
||||||
defaults.max_tokens,
|
tuple(defaults.fallback_presets),
|
||||||
defaults.temperature,
|
|
||||||
defaults.reasoning_effort,
|
|
||||||
defaults.context_window_tokens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_provider_snapshot(config: Config) -> ProviderSnapshot:
|
def build_provider_snapshot(config: Config) -> ProviderSnapshot:
|
||||||
|
resolved = config.resolve_preset()
|
||||||
return ProviderSnapshot(
|
return ProviderSnapshot(
|
||||||
provider=make_provider(config),
|
provider=make_provider(config),
|
||||||
model=config.agents.defaults.model,
|
model=resolved.model,
|
||||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
context_window_tokens=resolved.context_window_tokens,
|
||||||
signature=provider_signature(config),
|
signature=provider_signature(config),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -12,68 +11,16 @@ from loguru import logger
|
|||||||
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
|
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ModelCandidate:
|
|
||||||
"""A lazily resolved model/provider candidate."""
|
|
||||||
|
|
||||||
label: str
|
|
||||||
resolver: Callable[[], tuple[LLMProvider, str]]
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRouter(LLMProvider):
|
class ModelRouter(LLMProvider):
|
||||||
"""Try fallback model candidates for eligible transient final errors."""
|
"""Try fallback model candidates for eligible transient final errors."""
|
||||||
|
|
||||||
supports_progress_deltas = False
|
|
||||||
|
|
||||||
_BLOCKED_STATUS_CODES = frozenset({400, 401, 403, 404, 422})
|
|
||||||
_QUOTA_MARKERS = (
|
|
||||||
"insufficient_quota",
|
|
||||||
"insufficient quota",
|
|
||||||
"quota exceeded",
|
|
||||||
"quota_exceeded",
|
|
||||||
"quota exhausted",
|
|
||||||
"quota_exhausted",
|
|
||||||
"billing hard limit",
|
|
||||||
"billing_hard_limit_reached",
|
|
||||||
"billing not active",
|
|
||||||
"insufficient balance",
|
|
||||||
"insufficient_balance",
|
|
||||||
"credit balance too low",
|
|
||||||
"payment required",
|
|
||||||
"out of credits",
|
|
||||||
"out of quota",
|
|
||||||
"exceeded your current quota",
|
|
||||||
)
|
|
||||||
_NON_FAILOVER_MARKERS = (
|
|
||||||
"context length",
|
|
||||||
"context_length",
|
|
||||||
"maximum context",
|
|
||||||
"max context",
|
|
||||||
"token budget",
|
|
||||||
"too many tokens",
|
|
||||||
"schema",
|
|
||||||
"invalid request",
|
|
||||||
"invalid_request",
|
|
||||||
"invalid parameter",
|
|
||||||
"invalid_parameter",
|
|
||||||
"unsupported",
|
|
||||||
"unauthorized",
|
|
||||||
"authentication",
|
|
||||||
"permission",
|
|
||||||
"forbidden",
|
|
||||||
"refusal",
|
|
||||||
"content policy",
|
|
||||||
"content_filter",
|
|
||||||
"policy violation",
|
|
||||||
"safety",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
primary_provider: LLMProvider,
|
primary_provider: LLMProvider,
|
||||||
primary_model: str,
|
primary_model: str,
|
||||||
fallback_candidates: list[ModelCandidate],
|
fallback_presets: list[str],
|
||||||
|
provider_factory: Callable[[str], LLMProvider] | None = None,
|
||||||
per_candidate_timeout_s: float | None = None,
|
per_candidate_timeout_s: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -82,7 +29,9 @@ class ModelRouter(LLMProvider):
|
|||||||
)
|
)
|
||||||
self.primary_provider = primary_provider
|
self.primary_provider = primary_provider
|
||||||
self.primary_model = primary_model
|
self.primary_model = primary_model
|
||||||
self.fallback_candidates = list(fallback_candidates)
|
self.fallback_presets = list(fallback_presets)
|
||||||
|
self._provider_factory = provider_factory
|
||||||
|
self._provider_cache: dict[str, LLMProvider] = {}
|
||||||
self.per_candidate_timeout_s = per_candidate_timeout_s
|
self.per_candidate_timeout_s = per_candidate_timeout_s
|
||||||
self.generation = getattr(primary_provider, "generation", GenerationSettings())
|
self.generation = getattr(primary_provider, "generation", GenerationSettings())
|
||||||
|
|
||||||
@ -90,41 +39,46 @@ class ModelRouter(LLMProvider):
|
|||||||
return self.primary_model
|
return self.primary_model
|
||||||
|
|
||||||
async def chat(self, **kwargs: Any) -> LLMResponse:
|
async def chat(self, **kwargs: Any) -> LLMResponse:
|
||||||
return await self.primary_provider.chat(**kwargs)
|
async def call(provider: LLMProvider, candidate_model: str, _unused_delta: Any) -> LLMResponse:
|
||||||
|
return await provider.chat(**{**kwargs, "model": candidate_model})
|
||||||
|
return await self._route(call)
|
||||||
|
|
||||||
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
|
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||||
return await self.primary_provider.chat_stream(**kwargs)
|
async def call(provider: LLMProvider, candidate_model: str, content_delta: Any) -> LLMResponse:
|
||||||
|
return await provider.chat_stream(
|
||||||
|
**{**kwargs, "model": candidate_model, "on_content_delta": content_delta}
|
||||||
|
)
|
||||||
|
return await self._route(call, on_content_delta=kwargs.get("on_content_delta"))
|
||||||
|
|
||||||
@classmethod
|
@property
|
||||||
def _is_quota_error(cls, response: LLMResponse) -> bool:
|
def supports_progress_deltas(self) -> bool: # type: ignore[override]
|
||||||
tokens = {
|
return getattr(self.primary_provider, "supports_progress_deltas", False)
|
||||||
cls._normalize_error_token(response.error_type),
|
|
||||||
cls._normalize_error_token(response.error_code),
|
|
||||||
}
|
|
||||||
if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in tokens if token):
|
|
||||||
return True
|
|
||||||
content = (response.content or "").lower()
|
|
||||||
return any(marker in content for marker in cls._QUOTA_MARKERS)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _is_blocked_error(cls, response: LLMResponse) -> bool:
|
|
||||||
status = response.error_status_code
|
|
||||||
if status is not None and int(status) in cls._BLOCKED_STATUS_CODES:
|
|
||||||
return True
|
|
||||||
if response.finish_reason in {"refusal", "content_filter"}:
|
|
||||||
return True
|
|
||||||
content = (response.content or "").lower()
|
|
||||||
return any(marker in content for marker in cls._NON_FAILOVER_MARKERS)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _should_failover(cls, response: LLMResponse) -> bool:
|
def _should_failover(cls, response: LLMResponse) -> bool:
|
||||||
if response.finish_reason != "error":
|
if response.finish_reason != "error":
|
||||||
return False
|
return False
|
||||||
if cls._is_blocked_error(response):
|
if response.error_should_retry is False:
|
||||||
return False
|
return False
|
||||||
if cls._is_quota_error(response):
|
if response.error_kind == "configuration":
|
||||||
return False
|
return False
|
||||||
return cls._is_transient_response(response)
|
return True
|
||||||
|
|
||||||
|
def _resolve(self, model: str) -> tuple[LLMProvider, str]:
|
||||||
|
"""Return (provider, actual_model_name) for a preset name.
|
||||||
|
|
||||||
|
Caches results so factory is only invoked once per unique name.
|
||||||
|
"""
|
||||||
|
if model in self._provider_cache:
|
||||||
|
cached_provider = self._provider_cache[model]
|
||||||
|
return cached_provider, cached_provider.get_default_model()
|
||||||
|
if self._provider_factory is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot resolve fallback model {model!r}: no provider_factory configured"
|
||||||
|
)
|
||||||
|
provider = self._provider_factory(model)
|
||||||
|
self._provider_cache[model] = provider
|
||||||
|
return provider, provider.get_default_model()
|
||||||
|
|
||||||
async def _with_timeout(self, coro: Awaitable[LLMResponse]) -> LLMResponse:
|
async def _with_timeout(self, coro: Awaitable[LLMResponse]) -> LLMResponse:
|
||||||
timeout_s = self.per_candidate_timeout_s
|
timeout_s = self.per_candidate_timeout_s
|
||||||
@ -140,137 +94,90 @@ class ModelRouter(LLMProvider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolver_error(candidate: ModelCandidate, exc: Exception) -> LLMResponse:
|
def _resolver_error(label: str, exc: Exception) -> LLMResponse:
|
||||||
logger.warning("Failed to resolve fallback model {}: {}", candidate.label, exc)
|
logger.warning("Failed to resolve fallback model {}: {}", label, exc)
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=f"Error configuring fallback model {candidate.label}: {exc}",
|
content=f"Error configuring fallback model {label}: {exc}",
|
||||||
finish_reason="error",
|
finish_reason="error",
|
||||||
error_kind="configuration",
|
error_kind="configuration",
|
||||||
error_should_retry=False,
|
error_should_retry=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _candidate_chain(self) -> list[ModelCandidate]:
|
|
||||||
return [
|
|
||||||
ModelCandidate(
|
|
||||||
label=self.primary_model,
|
|
||||||
resolver=lambda: (self.primary_provider, self.primary_model),
|
|
||||||
),
|
|
||||||
*self.fallback_candidates,
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _route(
|
async def _route(
|
||||||
self,
|
self,
|
||||||
call: Callable[[LLMProvider, str, Callable[[str], Awaitable[None]] | None], Awaitable[LLMResponse]],
|
call: Callable[[LLMProvider, str, Callable[[str], Awaitable[None]] | None], Awaitable[LLMResponse]],
|
||||||
*,
|
*,
|
||||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
last_response: LLMResponse | None = None
|
"""Try primary then each fallback candidate, lazily resolving providers."""
|
||||||
chain = self._candidate_chain()
|
|
||||||
for index, candidate in enumerate(chain):
|
async def _try_one(label: str, provider: LLMProvider, model: str) -> LLMResponse:
|
||||||
try:
|
try:
|
||||||
provider, model = candidate.resolver()
|
return await self._with_timeout(call(provider, model, on_content_delta))
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
response = self._resolver_error(candidate, exc)
|
return self._resolver_error(label, exc)
|
||||||
else:
|
|
||||||
response = await self._with_timeout(call(provider, model, on_content_delta))
|
|
||||||
|
|
||||||
|
# Primary
|
||||||
|
response = await _try_one("primary", self.primary_provider, self.primary_model)
|
||||||
|
if response.finish_reason != "error":
|
||||||
|
return response
|
||||||
|
if not self._should_failover(response):
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Fallbacks
|
||||||
|
for name in self.fallback_presets:
|
||||||
|
try:
|
||||||
|
provider, model = self._resolve(name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to resolve fallback model {}: {}", name, exc)
|
||||||
|
return self._resolver_error(name, exc)
|
||||||
|
|
||||||
|
response = await _try_one(name, provider, model)
|
||||||
if response.finish_reason != "error":
|
if response.finish_reason != "error":
|
||||||
if index > 0:
|
logger.info("LLM failover selected model={}", name)
|
||||||
logger.info("LLM failover selected model={}", candidate.label)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
last_response = response
|
|
||||||
if not self._should_failover(response):
|
if not self._should_failover(response):
|
||||||
return response
|
return response
|
||||||
if index + 1 >= len(chain):
|
|
||||||
logger.warning("LLM failover exhausted after model={}", candidate.label)
|
|
||||||
return response
|
|
||||||
logger.warning(
|
|
||||||
"LLM failover model={} next_model={} status={} kind={}",
|
|
||||||
candidate.label,
|
|
||||||
chain[index + 1].label,
|
|
||||||
response.error_status_code,
|
|
||||||
response.error_kind or response.error_type or response.error_code or "unknown",
|
|
||||||
)
|
|
||||||
|
|
||||||
return last_response or LLMResponse(
|
logger.warning("LLM failover exhausted after all candidates")
|
||||||
content="No available fallback model candidate.",
|
return response
|
||||||
finish_reason="error",
|
|
||||||
error_kind="configuration",
|
|
||||||
error_should_retry=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def chat_with_retry(
|
async def chat_with_retry(self, **kwargs: Any) -> LLMResponse:
|
||||||
self,
|
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
model: str | None = None,
|
|
||||||
max_tokens: object = LLMProvider._SENTINEL,
|
|
||||||
temperature: object = LLMProvider._SENTINEL,
|
|
||||||
reasoning_effort: object = LLMProvider._SENTINEL,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
retry_mode: str = "standard",
|
|
||||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
|
||||||
) -> LLMResponse:
|
|
||||||
async def call(
|
async def call(
|
||||||
provider: LLMProvider,
|
provider: LLMProvider, candidate_model: str, _unused_delta: Any
|
||||||
candidate_model: str,
|
|
||||||
_delta: Callable[[str], Awaitable[None]] | None,
|
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
return await provider.chat_with_retry(
|
return await provider.chat_with_retry(
|
||||||
messages=messages,
|
**{**kwargs, "model": candidate_model}
|
||||||
tools=tools,
|
|
||||||
model=candidate_model,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
reasoning_effort=reasoning_effort,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
retry_mode=retry_mode,
|
|
||||||
on_retry_wait=on_retry_wait,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._route(call)
|
return await self._route(call)
|
||||||
|
|
||||||
async def chat_stream_with_retry(
|
async def chat_stream_with_retry(self, **kwargs: Any) -> LLMResponse:
|
||||||
self,
|
on_content_delta = kwargs.pop("on_content_delta", None)
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
model: str | None = None,
|
|
||||||
max_tokens: object = LLMProvider._SENTINEL,
|
|
||||||
temperature: object = LLMProvider._SENTINEL,
|
|
||||||
reasoning_effort: object = LLMProvider._SENTINEL,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
|
||||||
retry_mode: str = "standard",
|
|
||||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
|
||||||
) -> LLMResponse:
|
|
||||||
async def call(
|
async def call(
|
||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
candidate_model: str,
|
candidate_model: str,
|
||||||
external_delta: Callable[[str], Awaitable[None]] | None,
|
content_delta: Callable[[str], Awaitable[None]] | None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
buffered: list[str] = []
|
buffered: list[str] = []
|
||||||
|
|
||||||
async def buffer_delta(delta: str) -> None:
|
async def buffer_delta(delta: str) -> None:
|
||||||
buffered.append(delta)
|
buffered.append(delta)
|
||||||
|
|
||||||
|
kwargs["on_content_delta"] = buffer_delta if content_delta else None
|
||||||
response = await provider.chat_stream_with_retry(
|
response = await provider.chat_stream_with_retry(
|
||||||
messages=messages,
|
**{**kwargs, "model": candidate_model}
|
||||||
tools=tools,
|
|
||||||
model=candidate_model,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
reasoning_effort=reasoning_effort,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
on_content_delta=buffer_delta if external_delta else None,
|
|
||||||
retry_mode=retry_mode,
|
|
||||||
on_retry_wait=on_retry_wait,
|
|
||||||
)
|
)
|
||||||
if response.finish_reason != "error" and external_delta:
|
if response.finish_reason != "error" and content_delta:
|
||||||
for delta in buffered:
|
try:
|
||||||
await external_delta(delta)
|
for delta in buffered:
|
||||||
|
await content_delta(delta)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failover delta callback failed for model={}", candidate_model)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
return await self._route(call, on_content_delta=on_content_delta)
|
return await self._route(call, on_content_delta=on_content_delta)
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from nanobot.cli import onboard as onboard_wizard
|
from nanobot.cli import onboard as onboard_wizard
|
||||||
@ -636,8 +635,8 @@ class TestValidateFieldConstraint:
|
|||||||
|
|
||||||
def test_real_send_max_retries_field(self):
|
def test_real_send_max_retries_field(self):
|
||||||
"""Validate against the actual ChannelsConfig.send_max_retries field."""
|
"""Validate against the actual ChannelsConfig.send_max_retries field."""
|
||||||
from nanobot.config.schema import ChannelsConfig
|
|
||||||
from nanobot.cli.onboard import _validate_field_constraint
|
from nanobot.cli.onboard import _validate_field_constraint
|
||||||
|
from nanobot.config.schema import ChannelsConfig
|
||||||
|
|
||||||
field_info = ChannelsConfig.model_fields["send_max_retries"]
|
field_info = ChannelsConfig.model_fields["send_max_retries"]
|
||||||
assert _validate_field_constraint(3, field_info) is None
|
assert _validate_field_constraint(3, field_info) is None
|
||||||
@ -829,12 +828,11 @@ class TestMainMenuUpdate:
|
|||||||
|
|
||||||
def test_main_menu_dispatch_includes_channel_common(self):
|
def test_main_menu_dispatch_includes_channel_common(self):
|
||||||
"""Main menu dispatch should route [H] to Channel Common."""
|
"""Main menu dispatch should route [H] to Channel Common."""
|
||||||
from nanobot.cli.onboard import run_onboard
|
|
||||||
|
|
||||||
# We verify by checking the dispatch table is set up correctly
|
# We verify by checking the dispatch table is set up correctly
|
||||||
# The menu items are defined inline in run_onboard, so we test
|
# The menu items are defined inline in run_onboard, so we test
|
||||||
# that _configure_general_settings handles the new sections.
|
# that _configure_general_settings handles the new sections.
|
||||||
from nanobot.cli.onboard import _SETTINGS_SECTIONS, _SETTINGS_GETTER, _SETTINGS_SETTER
|
from nanobot.cli.onboard import _SETTINGS_GETTER, _SETTINGS_SECTIONS, _SETTINGS_SETTER
|
||||||
|
|
||||||
assert "Channel Common" in _SETTINGS_SECTIONS
|
assert "Channel Common" in _SETTINGS_SECTIONS
|
||||||
assert "Channel Common" in _SETTINGS_GETTER
|
assert "Channel Common" in _SETTINGS_GETTER
|
||||||
@ -842,7 +840,7 @@ class TestMainMenuUpdate:
|
|||||||
|
|
||||||
def test_main_menu_dispatch_includes_api_server(self):
|
def test_main_menu_dispatch_includes_api_server(self):
|
||||||
"""Main menu dispatch should route [I] to API Server."""
|
"""Main menu dispatch should route [I] to API Server."""
|
||||||
from nanobot.cli.onboard import _SETTINGS_SECTIONS, _SETTINGS_GETTER, _SETTINGS_SETTER
|
from nanobot.cli.onboard import _SETTINGS_GETTER, _SETTINGS_SECTIONS, _SETTINGS_SETTER
|
||||||
|
|
||||||
assert "API Server" in _SETTINGS_SECTIONS
|
assert "API Server" in _SETTINGS_SECTIONS
|
||||||
assert "API Server" in _SETTINGS_GETTER
|
assert "API Server" in _SETTINGS_GETTER
|
||||||
@ -1074,3 +1072,346 @@ class TestConfigurePydanticModelEmptyString:
|
|||||||
result = _configure_pydantic_model(model, "Test")
|
result = _configure_pydantic_model(model, "Test")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.api_key == ""
|
assert result.api_key == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelPresetWizard:
|
||||||
|
"""Tests for model preset CRUD in the onboard wizard."""
|
||||||
|
|
||||||
|
def test_sync_preset_cache(self):
|
||||||
|
"""_sync_preset_cache should populate the module-level cache."""
|
||||||
|
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _sync_preset_cache
|
||||||
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
|
||||||
|
config = Config()
|
||||||
|
config.model_presets = {
|
||||||
|
"fast": ModelPresetConfig(model="gpt-4.1-mini"),
|
||||||
|
"power": ModelPresetConfig(model="gpt-4.1"),
|
||||||
|
}
|
||||||
|
_sync_preset_cache(config)
|
||||||
|
assert _MODEL_PRESET_CACHE == {"fast", "power"}
|
||||||
|
|
||||||
|
def test_model_preset_add(self, monkeypatch):
|
||||||
|
"""_configure_model_presets should add a new preset."""
|
||||||
|
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _configure_model_presets
|
||||||
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
|
||||||
|
config = Config()
|
||||||
|
_MODEL_PRESET_CACHE.clear()
|
||||||
|
|
||||||
|
responses = iter([
|
||||||
|
"[+] Add new preset",
|
||||||
|
"my-preset",
|
||||||
|
"<- Back",
|
||||||
|
])
|
||||||
|
|
||||||
|
class FakePrompt:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
def ask(self):
|
||||||
|
if isinstance(self.response, BaseException):
|
||||||
|
raise self.response
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
def fake_select(*_args, **_kwargs):
|
||||||
|
return FakePrompt(next(responses))
|
||||||
|
|
||||||
|
def fake_text(*_args, **_kwargs):
|
||||||
|
return FakePrompt(next(responses))
|
||||||
|
|
||||||
|
def fake_configure(*_model, **_kwargs):
|
||||||
|
return ModelPresetConfig(model="gpt-test", temperature=0.5)
|
||||||
|
|
||||||
|
# _select_with_back returns a string/sentinel directly (not a prompt object)
|
||||||
|
def fake_select_with_back(*_args, **_kwargs):
|
||||||
|
return next(responses)
|
||||||
|
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select_with_back)
|
||||||
|
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select, text=fake_text))
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_configure_pydantic_model", fake_configure)
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_show_section_header", lambda *a, **kw: None)
|
||||||
|
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None))
|
||||||
|
|
||||||
|
_configure_model_presets(config)
|
||||||
|
|
||||||
|
assert "my-preset" in config.model_presets
|
||||||
|
assert config.model_presets["my-preset"].model == "gpt-test"
|
||||||
|
assert config.model_presets["my-preset"].temperature == 0.5
|
||||||
|
|
||||||
|
def test_model_preset_delete(self, monkeypatch):
|
||||||
|
"""_configure_model_presets should delete an existing preset."""
|
||||||
|
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _configure_model_presets
|
||||||
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
|
||||||
|
config = Config()
|
||||||
|
config.model_presets = {"old": ModelPresetConfig(model="x")}
|
||||||
|
_MODEL_PRESET_CACHE.clear()
|
||||||
|
_MODEL_PRESET_CACHE.add("old")
|
||||||
|
|
||||||
|
responses = iter([
|
||||||
|
"old (x)",
|
||||||
|
"Delete",
|
||||||
|
True,
|
||||||
|
"<- Back",
|
||||||
|
])
|
||||||
|
|
||||||
|
class FakePrompt:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
def ask(self):
|
||||||
|
if isinstance(self.response, BaseException):
|
||||||
|
raise self.response
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
def fake_select(*_args, **_kwargs):
|
||||||
|
return FakePrompt(next(responses))
|
||||||
|
|
||||||
|
def fake_confirm(*_args, **_kwargs):
|
||||||
|
return FakePrompt(next(responses))
|
||||||
|
|
||||||
|
def fake_select_with_back(*_args, **_kwargs):
|
||||||
|
return next(responses)
|
||||||
|
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select_with_back)
|
||||||
|
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select, confirm=fake_confirm))
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_show_section_header", lambda *a, **kw: None)
|
||||||
|
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None))
|
||||||
|
|
||||||
|
_configure_model_presets(config)
|
||||||
|
|
||||||
|
assert "old" not in config.model_presets
|
||||||
|
assert "old" not in _MODEL_PRESET_CACHE
|
||||||
|
|
||||||
|
def test_model_preset_field_handler(self, monkeypatch):
|
||||||
|
"""_handle_model_preset_field should set a preset name from choices."""
|
||||||
|
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_model_preset_field
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
|
||||||
|
_MODEL_PRESET_CACHE.clear()
|
||||||
|
_MODEL_PRESET_CACHE.update({"fast", "power"})
|
||||||
|
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: "fast")
|
||||||
|
|
||||||
|
defaults = AgentDefaults()
|
||||||
|
_handle_model_preset_field(defaults, "model_preset", "Model Preset", None)
|
||||||
|
assert defaults.model_preset == "fast"
|
||||||
|
|
||||||
|
def test_model_preset_field_handler_clear(self, monkeypatch):
|
||||||
|
"""_handle_model_preset_field should clear preset when (clear/unset) chosen."""
|
||||||
|
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_model_preset_field
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
|
||||||
|
_MODEL_PRESET_CACHE.clear()
|
||||||
|
_MODEL_PRESET_CACHE.add("fast")
|
||||||
|
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: "(clear/unset)")
|
||||||
|
|
||||||
|
defaults = AgentDefaults(model_preset="fast")
|
||||||
|
_handle_model_preset_field(defaults, "model_preset", "Model Preset", "fast")
|
||||||
|
assert defaults.model_preset is None
|
||||||
|
|
||||||
|
def test_main_menu_dispatch_includes_model_presets(self):
|
||||||
|
"""run_onboard dispatch should route [M] to Model Presets."""
|
||||||
|
from nanobot.cli.onboard import _configure_model_presets
|
||||||
|
|
||||||
|
# The function should be importable and callable
|
||||||
|
assert callable(_configure_model_presets)
|
||||||
|
|
||||||
|
def test_run_onboard_model_presets_edit(self, monkeypatch):
|
||||||
|
"""run_onboard should handle [M] Model Presets correctly."""
|
||||||
|
initial_config = Config()
|
||||||
|
|
||||||
|
responses = iter([
|
||||||
|
"[M] Model Presets",
|
||||||
|
KeyboardInterrupt(),
|
||||||
|
"[S] Save and Exit",
|
||||||
|
])
|
||||||
|
|
||||||
|
class FakePrompt:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
def ask(self):
|
||||||
|
if isinstance(self.response, BaseException):
|
||||||
|
raise self.response
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
def fake_select(*_args, **_kwargs):
|
||||||
|
return FakePrompt(next(responses))
|
||||||
|
|
||||||
|
preset_mutated = {"n": 0}
|
||||||
|
|
||||||
|
def fake_configure_model_presets(config):
|
||||||
|
preset_mutated["n"] += 1
|
||||||
|
# Mutate config so unsaved changes are detected
|
||||||
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
config.model_presets["test"] = ModelPresetConfig(model="x")
|
||||||
|
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_show_main_menu_header", lambda: None)
|
||||||
|
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_configure_model_presets", fake_configure_model_presets)
|
||||||
|
|
||||||
|
result = run_onboard(initial_config=initial_config)
|
||||||
|
|
||||||
|
assert result.should_save is True
|
||||||
|
assert preset_mutated["n"] == 1
|
||||||
|
|
||||||
|
def test_summary_shows_model_presets(self, monkeypatch):
|
||||||
|
"""_show_summary should include model presets panel."""
|
||||||
|
from nanobot.cli.onboard import _show_summary
|
||||||
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
|
||||||
|
config = Config()
|
||||||
|
config.model_presets = {
|
||||||
|
"fast": ModelPresetConfig(model="gpt-4.1-mini"),
|
||||||
|
}
|
||||||
|
|
||||||
|
panels = []
|
||||||
|
|
||||||
|
def fake_print_summary(rows, title):
|
||||||
|
panels.append(title)
|
||||||
|
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_print_summary_panel", fake_print_summary)
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_get_provider_names", lambda: {})
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_get_channel_names", lambda: {})
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_pause", lambda: None)
|
||||||
|
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(print=lambda *a, **kw: None))
|
||||||
|
|
||||||
|
_show_summary(config)
|
||||||
|
|
||||||
|
assert "Model Presets" in panels
|
||||||
|
|
||||||
|
def test_provider_field_handler(self, monkeypatch):
|
||||||
|
"""_handle_provider_field should set a provider from the registry list."""
|
||||||
|
from nanobot.cli.onboard import _handle_provider_field
|
||||||
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
onboard_wizard, "_get_provider_names", lambda: {"moonshot": "Moonshot", "openai": "OpenAI"}
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: "moonshot")
|
||||||
|
|
||||||
|
preset = ModelPresetConfig(model="x")
|
||||||
|
_handle_provider_field(preset, "provider", "Provider", "auto")
|
||||||
|
assert preset.provider == "moonshot"
|
||||||
|
|
||||||
|
def test_provider_field_handler_back_pressed(self, monkeypatch):
|
||||||
|
"""_handle_provider_field should not modify value when back is pressed."""
|
||||||
|
from nanobot.cli.onboard import _BACK_PRESSED, _handle_provider_field
|
||||||
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
onboard_wizard, "_get_provider_names", lambda: {"moonshot": "Moonshot"}
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: _BACK_PRESSED)
|
||||||
|
|
||||||
|
preset = ModelPresetConfig(model="x", provider="auto")
|
||||||
|
_handle_provider_field(preset, "provider", "Provider", "auto")
|
||||||
|
assert preset.provider == "auto"
|
||||||
|
|
||||||
|
def test_fallback_presets_add_preset_and_done(self, monkeypatch):
|
||||||
|
"""_handle_fallback_presets_field should add a preset and save on Done."""
|
||||||
|
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_fallback_presets_field
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
|
||||||
|
_MODEL_PRESET_CACHE.clear()
|
||||||
|
_MODEL_PRESET_CACHE.update({"fast", "power"})
|
||||||
|
|
||||||
|
responses = iter(["[+] Add preset", "[Done]"])
|
||||||
|
|
||||||
|
class FakePrompt:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
def ask(self):
|
||||||
|
if isinstance(self.response, BaseException):
|
||||||
|
raise self.response
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
def fake_select(*_args, **_kwargs):
|
||||||
|
return FakePrompt(next(responses))
|
||||||
|
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_select_with_back", lambda *a, **kw: "fast")
|
||||||
|
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
|
||||||
|
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None, print=lambda *a, **kw: None))
|
||||||
|
|
||||||
|
defaults = AgentDefaults()
|
||||||
|
_handle_fallback_presets_field(defaults, "fallback_presets", "Fallback Presets", [])
|
||||||
|
assert defaults.fallback_presets == ["fast"]
|
||||||
|
|
||||||
|
def test_fallback_presets_back_preserves_existing(self, monkeypatch):
|
||||||
|
"""_handle_fallback_presets_field should not modify value on Back."""
|
||||||
|
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_fallback_presets_field
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
|
||||||
|
_MODEL_PRESET_CACHE.clear()
|
||||||
|
_MODEL_PRESET_CACHE.add("fast")
|
||||||
|
|
||||||
|
class FakePrompt:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
def ask(self):
|
||||||
|
if isinstance(self.response, BaseException):
|
||||||
|
raise self.response
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
def fake_select(*_args, **_kwargs):
|
||||||
|
return FakePrompt("<- Back")
|
||||||
|
|
||||||
|
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
|
||||||
|
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None, print=lambda *a, **kw: None))
|
||||||
|
|
||||||
|
defaults = AgentDefaults(fallback_presets=["existing"])
|
||||||
|
_handle_fallback_presets_field(defaults, "fallback_presets", "Fallback Presets", ["existing"])
|
||||||
|
assert defaults.fallback_presets == ["existing"]
|
||||||
|
|
||||||
|
def test_fallback_presets_remove_last(self, monkeypatch):
|
||||||
|
"""_handle_fallback_presets_field should remove last item."""
|
||||||
|
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_fallback_presets_field
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
|
||||||
|
_MODEL_PRESET_CACHE.clear()
|
||||||
|
|
||||||
|
responses = iter(["[-] Remove last", "[Done]"])
|
||||||
|
|
||||||
|
class FakePrompt:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
def ask(self):
|
||||||
|
if isinstance(self.response, BaseException):
|
||||||
|
raise self.response
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
def fake_select(*_args, **_kwargs):
|
||||||
|
return FakePrompt(next(responses))
|
||||||
|
|
||||||
|
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
|
||||||
|
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None, print=lambda *a, **kw: None))
|
||||||
|
|
||||||
|
defaults = AgentDefaults(fallback_presets=["a", "b"])
|
||||||
|
_handle_fallback_presets_field(defaults, "fallback_presets", "Fallback Presets", ["a", "b"])
|
||||||
|
assert defaults.fallback_presets == ["a"]
|
||||||
|
|
||||||
|
def test_fallback_presets_no_presets_shows_warning(self, monkeypatch):
|
||||||
|
"""_handle_fallback_presets_field should warn when no presets exist."""
|
||||||
|
from nanobot.cli.onboard import _MODEL_PRESET_CACHE, _handle_fallback_presets_field
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
|
||||||
|
_MODEL_PRESET_CACHE.clear()
|
||||||
|
|
||||||
|
responses = iter(["[+] Add preset", "[Done]"])
|
||||||
|
|
||||||
|
class FakePrompt:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
def ask(self):
|
||||||
|
if isinstance(self.response, BaseException):
|
||||||
|
raise self.response
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
def fake_select(*_args, **_kwargs):
|
||||||
|
return FakePrompt(next(responses))
|
||||||
|
|
||||||
|
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select, press_any_key_to_continue=lambda: FakePrompt(None)))
|
||||||
|
monkeypatch.setattr(onboard_wizard, "console", SimpleNamespace(clear=lambda: None, print=lambda *a, **kw: None))
|
||||||
|
|
||||||
|
defaults = AgentDefaults()
|
||||||
|
_handle_fallback_presets_field(defaults, "fallback_presets", "Fallback Presets", [])
|
||||||
|
assert defaults.fallback_presets == []
|
||||||
|
|||||||
@ -1,269 +0,0 @@
|
|||||||
"""Tests for the provider fallback models feature in AgentRunner."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
|
||||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
|
|
||||||
def _make_tools():
|
|
||||||
tools = MagicMock()
|
|
||||||
tools.get_definitions.return_value = []
|
|
||||||
tools.execute = AsyncMock(return_value="ok")
|
|
||||||
return tools
|
|
||||||
|
|
||||||
|
|
||||||
def _make_provider(*, model_response: LLMResponse | None = None):
|
|
||||||
p = MagicMock()
|
|
||||||
if model_response is not None:
|
|
||||||
p.chat_with_retry = AsyncMock(return_value=model_response)
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
def _transient_error(content: str = "server unavailable") -> LLMResponse:
|
|
||||||
return LLMResponse(content=content, finish_reason="error", error_status_code=503)
|
|
||||||
|
|
||||||
|
|
||||||
def _base_spec(**overrides) -> AgentRunSpec:
|
|
||||||
defaults = dict(
|
|
||||||
initial_messages=[
|
|
||||||
{"role": "system", "content": "sys"},
|
|
||||||
{"role": "user", "content": "hello"},
|
|
||||||
],
|
|
||||||
tools=_make_tools(),
|
|
||||||
model="primary-model",
|
|
||||||
max_iterations=1,
|
|
||||||
max_tool_result_chars=8000,
|
|
||||||
)
|
|
||||||
defaults.update(overrides)
|
|
||||||
return AgentRunSpec(**defaults)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_fallback_when_primary_succeeds():
|
|
||||||
"""Primary succeeds -> fallback list never consulted."""
|
|
||||||
ok = LLMResponse(content="done", tool_calls=[], usage={})
|
|
||||||
provider = _make_provider(model_response=ok)
|
|
||||||
factory = MagicMock()
|
|
||||||
|
|
||||||
runner = AgentRunner(provider, provider_factory=factory)
|
|
||||||
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
|
|
||||||
|
|
||||||
assert result.final_content == "done"
|
|
||||||
factory.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_fallback_triggered_on_primary_error():
|
|
||||||
"""Primary fails -> first fallback succeeds."""
|
|
||||||
err = _transient_error()
|
|
||||||
ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={})
|
|
||||||
|
|
||||||
primary = _make_provider(model_response=err)
|
|
||||||
|
|
||||||
fb_provider = MagicMock()
|
|
||||||
fb_provider.chat_with_retry = AsyncMock(return_value=ok)
|
|
||||||
factory = MagicMock(return_value=fb_provider)
|
|
||||||
|
|
||||||
runner = AgentRunner(primary, provider_factory=factory)
|
|
||||||
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
|
|
||||||
|
|
||||||
assert result.final_content == "fallback-ok"
|
|
||||||
factory.assert_called_once_with("fb-model")
|
|
||||||
fb_provider.chat_with_retry.assert_awaited_once()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_all_fallbacks_fail_returns_last_error():
|
|
||||||
"""Primary + all fallbacks fail -> return last error response."""
|
|
||||||
err = _transient_error()
|
|
||||||
|
|
||||||
primary = _make_provider(model_response=err)
|
|
||||||
fb1 = _make_provider(model_response=err)
|
|
||||||
fb2 = _make_provider(model_response=LLMResponse(
|
|
||||||
content="last-error", finish_reason="error", error_status_code=500, usage={},
|
|
||||||
))
|
|
||||||
|
|
||||||
providers = {"fb-1": fb1, "fb-2": fb2}
|
|
||||||
factory = MagicMock(side_effect=lambda m: providers[m])
|
|
||||||
|
|
||||||
runner = AgentRunner(primary, provider_factory=factory)
|
|
||||||
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
|
|
||||||
|
|
||||||
assert result.error is not None or result.final_content is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_empty_fallback_list_no_retry():
|
|
||||||
"""Empty fallback_models -> no fallback attempted."""
|
|
||||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
|
||||||
primary = _make_provider(model_response=err)
|
|
||||||
factory = MagicMock()
|
|
||||||
|
|
||||||
runner = AgentRunner(primary, provider_factory=factory)
|
|
||||||
result = await runner.run(_base_spec(fallback_models=[]))
|
|
||||||
|
|
||||||
factory.assert_not_called()
|
|
||||||
assert result.error is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_cross_provider_fallback():
|
|
||||||
"""Fallback uses a different provider instance (cross-provider)."""
|
|
||||||
err = _transient_error()
|
|
||||||
ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={})
|
|
||||||
|
|
||||||
primary = _make_provider(model_response=err)
|
|
||||||
anthropic_provider = MagicMock()
|
|
||||||
anthropic_provider.chat_with_retry = AsyncMock(return_value=ok)
|
|
||||||
|
|
||||||
def cross_factory(model: str):
|
|
||||||
if model == "anthropic/claude-sonnet":
|
|
||||||
return anthropic_provider
|
|
||||||
raise ValueError(f"unexpected model: {model}")
|
|
||||||
|
|
||||||
runner = AgentRunner(primary, provider_factory=cross_factory)
|
|
||||||
result = await runner.run(_base_spec(fallback_models=["anthropic/claude-sonnet"]))
|
|
||||||
|
|
||||||
assert result.final_content == "cross-provider-ok"
|
|
||||||
anthropic_provider.chat_with_retry.assert_awaited_once()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_fallback_skips_to_second_on_first_error():
|
|
||||||
"""First fallback also fails -> second fallback succeeds."""
|
|
||||||
err = _transient_error()
|
|
||||||
ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={})
|
|
||||||
|
|
||||||
primary = _make_provider(model_response=err)
|
|
||||||
fb1 = _make_provider(model_response=err)
|
|
||||||
fb2 = MagicMock()
|
|
||||||
fb2.chat_with_retry = AsyncMock(return_value=ok)
|
|
||||||
|
|
||||||
providers = {"fb-1": fb1, "fb-2": fb2}
|
|
||||||
factory = MagicMock(side_effect=lambda m: providers[m])
|
|
||||||
|
|
||||||
runner = AgentRunner(primary, provider_factory=factory)
|
|
||||||
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
|
|
||||||
|
|
||||||
assert result.final_content == "second-fb-ok"
|
|
||||||
assert factory.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_fallback_reuses_same_provider_without_factory():
|
|
||||||
"""No provider_factory -> fallback reuses primary provider with different model."""
|
|
||||||
call_count = {"n": 0}
|
|
||||||
|
|
||||||
async def chat_with_retry(*, messages, model, **kw):
|
|
||||||
call_count["n"] += 1
|
|
||||||
if call_count["n"] == 1:
|
|
||||||
return _transient_error()
|
|
||||||
return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={})
|
|
||||||
|
|
||||||
primary = MagicMock()
|
|
||||||
primary.chat_with_retry = chat_with_retry
|
|
||||||
|
|
||||||
runner = AgentRunner(primary, provider_factory=None)
|
|
||||||
result = await runner.run(_base_spec(fallback_models=["fallback-model"]))
|
|
||||||
|
|
||||||
assert result.final_content == "ok-via-fallback-model"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_fallback_provider_cached():
|
|
||||||
"""Provider factory is called once per unique provider, not per attempt."""
|
|
||||||
err = _transient_error()
|
|
||||||
ok = LLMResponse(content="cached-ok", tool_calls=[], usage={})
|
|
||||||
|
|
||||||
primary = _make_provider(model_response=err)
|
|
||||||
|
|
||||||
fb_provider = MagicMock()
|
|
||||||
call_seq = [err, ok]
|
|
||||||
fb_provider.chat_with_retry = AsyncMock(side_effect=call_seq)
|
|
||||||
|
|
||||||
factory = MagicMock(return_value=fb_provider)
|
|
||||||
|
|
||||||
runner = AgentRunner(primary, provider_factory=factory)
|
|
||||||
result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"]))
|
|
||||||
|
|
||||||
assert result.final_content == "cached-ok"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_non_transient_error_does_not_fallback():
|
|
||||||
"""Auth/config-style errors should surface instead of hiding bugs via fallback."""
|
|
||||||
primary = _make_provider(model_response=LLMResponse(
|
|
||||||
content="401 unauthorized",
|
|
||||||
finish_reason="error",
|
|
||||||
error_status_code=401,
|
|
||||||
))
|
|
||||||
fallback = _make_provider(model_response=LLMResponse(content="fallback-ok"))
|
|
||||||
factory = MagicMock(return_value=fallback)
|
|
||||||
|
|
||||||
runner = AgentRunner(primary, provider_factory=factory)
|
|
||||||
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
|
|
||||||
|
|
||||||
factory.assert_not_called()
|
|
||||||
assert result.error is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_quota_error_does_not_fallback_by_default():
|
|
||||||
"""Quota/billing/payment 429s should not route by default."""
|
|
||||||
primary = _make_provider(model_response=LLMResponse(
|
|
||||||
content="insufficient quota",
|
|
||||||
finish_reason="error",
|
|
||||||
error_status_code=429,
|
|
||||||
error_code="insufficient_quota",
|
|
||||||
))
|
|
||||||
fallback = _make_provider(model_response=LLMResponse(content="fallback-ok"))
|
|
||||||
factory = MagicMock(return_value=fallback)
|
|
||||||
|
|
||||||
runner = AgentRunner(primary, provider_factory=factory)
|
|
||||||
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
|
|
||||||
|
|
||||||
factory.assert_not_called()
|
|
||||||
assert result.error is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_streaming_fallback_discards_failed_primary_deltas():
|
|
||||||
"""Buffered streaming prevents primary partial output from leaking on fallback."""
|
|
||||||
streamed: list[str] = []
|
|
||||||
|
|
||||||
class StreamingHook(AgentHook):
|
|
||||||
def wants_streaming(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
|
||||||
streamed.append(delta)
|
|
||||||
|
|
||||||
async def primary_stream(*, on_content_delta, **kwargs):
|
|
||||||
await on_content_delta("bad partial")
|
|
||||||
return _transient_error()
|
|
||||||
|
|
||||||
async def fallback_stream(*, on_content_delta, **kwargs):
|
|
||||||
await on_content_delta("good")
|
|
||||||
await on_content_delta(" answer")
|
|
||||||
return LLMResponse(content="good answer", tool_calls=[], usage={})
|
|
||||||
|
|
||||||
primary = MagicMock()
|
|
||||||
primary.chat_stream_with_retry = primary_stream
|
|
||||||
fallback = MagicMock()
|
|
||||||
fallback.chat_stream_with_retry = fallback_stream
|
|
||||||
factory = MagicMock(return_value=fallback)
|
|
||||||
|
|
||||||
runner = AgentRunner(primary, provider_factory=factory)
|
|
||||||
result = await runner.run(_base_spec(
|
|
||||||
fallback_models=["fb-model"],
|
|
||||||
hook=StreamingHook(),
|
|
||||||
))
|
|
||||||
|
|
||||||
assert result.final_content == "good answer"
|
|
||||||
assert streamed == ["good", " answer"]
|
|
||||||
@ -1,6 +1,6 @@
|
|||||||
# tests/agent/test_self_model_preset.py
|
# tests/agent/test_self_model_preset.py
|
||||||
import asyncio
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
@ -8,10 +8,23 @@ from nanobot.config.schema import ModelPresetConfig, MyToolConfig, ToolsConfig
|
|||||||
from nanobot.providers.base import GenerationSettings
|
from nanobot.providers.base import GenerationSettings
|
||||||
|
|
||||||
|
|
||||||
def _make_loop(presets: dict | None = None) -> tuple[AgentLoop, "MyTool"]:
|
def _make_loop(presets: dict | None = None) -> tuple[AgentLoop, Any]:
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
provider.generation = GenerationSettings(temperature=0.1, max_tokens=8192)
|
provider.generation = GenerationSettings(temperature=0.1, max_tokens=8192)
|
||||||
|
|
||||||
|
def _factory(name: str):
|
||||||
|
preset = (presets or {}).get(name)
|
||||||
|
if preset:
|
||||||
|
new_provider = MagicMock()
|
||||||
|
new_provider.generation = GenerationSettings(
|
||||||
|
temperature=preset.temperature,
|
||||||
|
max_tokens=preset.max_tokens,
|
||||||
|
reasoning_effort=preset.reasoning_effort,
|
||||||
|
)
|
||||||
|
return new_provider
|
||||||
|
return provider
|
||||||
|
|
||||||
loop = AgentLoop(
|
loop = AgentLoop(
|
||||||
bus=MagicMock(),
|
bus=MagicMock(),
|
||||||
provider=provider,
|
provider=provider,
|
||||||
@ -19,6 +32,7 @@ def _make_loop(presets: dict | None = None) -> tuple[AgentLoop, "MyTool"]:
|
|||||||
model="test-model",
|
model="test-model",
|
||||||
context_window_tokens=65536,
|
context_window_tokens=65536,
|
||||||
model_presets=presets or {},
|
model_presets=presets or {},
|
||||||
|
provider_factory=_factory,
|
||||||
tools_config=ToolsConfig(my=MyToolConfig(allow_set=True)),
|
tools_config=ToolsConfig(my=MyToolConfig(allow_set=True)),
|
||||||
)
|
)
|
||||||
tool = loop.tools.get("my")
|
tool = loop.tools.get("my")
|
||||||
@ -36,7 +50,7 @@ async def test_set_model_preset_updates_all_fields() -> None:
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
loop, tool = _make_loop(presets)
|
loop, tool = _make_loop(presets)
|
||||||
result = await tool.execute(action="set", key="model_preset", value="gpt5")
|
await tool.execute(action="set", key="model_preset", value="gpt5")
|
||||||
|
|
||||||
assert loop.model == "gpt-5"
|
assert loop.model == "gpt-5"
|
||||||
assert loop.context_window_tokens == 128000
|
assert loop.context_window_tokens == 128000
|
||||||
@ -73,12 +87,3 @@ async def test_check_model_presets_shows_available() -> None:
|
|||||||
assert "ds" in result
|
assert "ds" in result
|
||||||
|
|
||||||
|
|
||||||
async def test_set_model_directly_clears_preset() -> None:
|
|
||||||
presets = {"gpt5": ModelPresetConfig(model="gpt-5", provider="openai")}
|
|
||||||
loop, tool = _make_loop(presets)
|
|
||||||
await tool.execute(action="set", key="model_preset", value="gpt5")
|
|
||||||
assert loop._active_preset == "gpt5"
|
|
||||||
|
|
||||||
await tool.execute(action="set", key="model", value="other-model")
|
|
||||||
assert loop._active_preset is None
|
|
||||||
assert loop.model == "other-model"
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -35,6 +35,7 @@ def _make_mock_loop(**overrides):
|
|||||||
loop._concurrency_gate = None
|
loop._concurrency_gate = None
|
||||||
loop._unified_session = False
|
loop._unified_session = False
|
||||||
loop._extra_hooks = []
|
loop._extra_hooks = []
|
||||||
|
loop.model_preset = None
|
||||||
|
|
||||||
# web_config mock — needed for check tests
|
# web_config mock — needed for check tests
|
||||||
loop.web_config = MagicMock()
|
loop.web_config = MagicMock()
|
||||||
@ -76,7 +77,7 @@ class TestInspectSummary:
|
|||||||
tool = _make_tool()
|
tool = _make_tool()
|
||||||
result = await tool.execute(action="check")
|
result = await tool.execute(action="check")
|
||||||
assert "max_iterations: 40" in result
|
assert "max_iterations: 40" in result
|
||||||
assert "context_window_tokens: 65536" in result
|
assert "model_preset" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_inspect_includes_runtime_vars(self):
|
async def test_inspect_includes_runtime_vars(self):
|
||||||
@ -92,8 +93,7 @@ class TestInspectSummary:
|
|||||||
tool = _make_tool()
|
tool = _make_tool()
|
||||||
result = await tool.execute(action="check")
|
result = await tool.execute(action="check")
|
||||||
assert "max_iterations" in result
|
assert "max_iterations" in result
|
||||||
assert "context_window_tokens" in result
|
assert "model_preset" in result
|
||||||
assert "model" in result
|
|
||||||
assert "workspace" in result
|
assert "workspace" in result
|
||||||
assert "provider_retry_mode" in result
|
assert "provider_retry_mode" in result
|
||||||
assert "max_tool_result_chars" in result
|
assert "max_tool_result_chars" in result
|
||||||
@ -231,13 +231,13 @@ class TestModifyRestricted:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_modify_string_int_coerced(self):
|
async def test_modify_string_int_coerced(self):
|
||||||
tool = _make_tool()
|
tool = _make_tool()
|
||||||
result = await tool.execute(action="set", key="max_iterations", value="80")
|
await tool.execute(action="set", key="max_iterations", value="80")
|
||||||
assert tool._loop.max_iterations == 80
|
assert tool._loop.max_iterations == 80
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_modify_context_window_valid(self):
|
async def test_modify_context_window_valid(self):
|
||||||
tool = _make_tool()
|
tool = _make_tool()
|
||||||
result = await tool.execute(action="set", key="context_window_tokens", value=131072)
|
await tool.execute(action="set", key="context_window_tokens", value=131072)
|
||||||
assert tool._loop.context_window_tokens == 131072
|
assert tool._loop.context_window_tokens == 131072
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -337,13 +337,13 @@ class TestModifyFree:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_modify_allows_list(self):
|
async def test_modify_allows_list(self):
|
||||||
tool = _make_tool()
|
tool = _make_tool()
|
||||||
result = await tool.execute(action="set", key="items", value=[1, 2, 3])
|
await tool.execute(action="set", key="items", value=[1, 2, 3])
|
||||||
assert tool._loop._runtime_vars["items"] == [1, 2, 3]
|
assert tool._loop._runtime_vars["items"] == [1, 2, 3]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_modify_allows_dict(self):
|
async def test_modify_allows_dict(self):
|
||||||
tool = _make_tool()
|
tool = _make_tool()
|
||||||
result = await tool.execute(action="set", key="data", value={"a": 1})
|
await tool.execute(action="set", key="data", value={"a": 1})
|
||||||
assert tool._loop._runtime_vars["data"] == {"a": 1}
|
assert tool._loop._runtime_vars["data"] == {"a": 1}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -392,6 +392,26 @@ class TestModifyFree:
|
|||||||
assert "Error" in result
|
assert "Error" in result
|
||||||
assert tool._loop.max_tool_result_chars == 16000
|
assert tool._loop.max_tool_result_chars == 16000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_modify_model_clears_active_preset(self):
|
||||||
|
"""Directly modifying model must clear _active_preset so state stays consistent."""
|
||||||
|
tool = _make_tool()
|
||||||
|
tool._loop._active_preset = "gpt5"
|
||||||
|
result = await tool.execute(action="set", key="model", value="other-model")
|
||||||
|
assert "Set model" in result
|
||||||
|
assert tool._loop.model == "other-model"
|
||||||
|
assert tool._loop._active_preset is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_modify_context_window_tokens_clears_active_preset(self):
|
||||||
|
"""Directly modifying context_window_tokens must clear _active_preset."""
|
||||||
|
tool = _make_tool()
|
||||||
|
tool._loop._active_preset = "gpt5"
|
||||||
|
result = await tool.execute(action="set", key="context_window_tokens", value=32768)
|
||||||
|
assert "Set context_window_tokens" in result
|
||||||
|
assert tool._loop.context_window_tokens == 32768
|
||||||
|
assert tool._loop._active_preset is None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# set — previously BLOCKED/READONLY now open
|
# set — previously BLOCKED/READONLY now open
|
||||||
@ -689,8 +709,8 @@ class TestSubagentHookStatus:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_after_iteration_updates_status(self):
|
async def test_after_iteration_updates_status(self):
|
||||||
"""after_iteration should copy iteration, tool_events, usage to status."""
|
"""after_iteration should copy iteration, tool_events, usage to status."""
|
||||||
from nanobot.agent.subagent import SubagentStatus, _SubagentHook
|
|
||||||
from nanobot.agent.hook import AgentHookContext
|
from nanobot.agent.hook import AgentHookContext
|
||||||
|
from nanobot.agent.subagent import SubagentStatus, _SubagentHook
|
||||||
|
|
||||||
status = SubagentStatus(
|
status = SubagentStatus(
|
||||||
task_id="test",
|
task_id="test",
|
||||||
@ -716,8 +736,8 @@ class TestSubagentHookStatus:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_after_iteration_with_error(self):
|
async def test_after_iteration_with_error(self):
|
||||||
"""after_iteration should set status.error when context has an error."""
|
"""after_iteration should set status.error when context has an error."""
|
||||||
from nanobot.agent.subagent import SubagentStatus, _SubagentHook
|
|
||||||
from nanobot.agent.hook import AgentHookContext
|
from nanobot.agent.hook import AgentHookContext
|
||||||
|
from nanobot.agent.subagent import SubagentStatus, _SubagentHook
|
||||||
|
|
||||||
status = SubagentStatus(
|
status = SubagentStatus(
|
||||||
task_id="test",
|
task_id="test",
|
||||||
@ -739,8 +759,8 @@ class TestSubagentHookStatus:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_after_iteration_no_status_is_noop(self):
|
async def test_after_iteration_no_status_is_noop(self):
|
||||||
"""after_iteration with no status should be a no-op."""
|
"""after_iteration with no status should be a no-op."""
|
||||||
from nanobot.agent.subagent import _SubagentHook
|
|
||||||
from nanobot.agent.hook import AgentHookContext
|
from nanobot.agent.hook import AgentHookContext
|
||||||
|
from nanobot.agent.subagent import _SubagentHook
|
||||||
|
|
||||||
hook = _SubagentHook("test")
|
hook = _SubagentHook("test")
|
||||||
context = AgentHookContext(iteration=1, messages=[])
|
context = AgentHookContext(iteration=1, messages=[])
|
||||||
@ -757,7 +777,6 @@ class TestCheckpointCallback:
|
|||||||
async def test_checkpoint_updates_phase_and_iteration(self):
|
async def test_checkpoint_updates_phase_and_iteration(self):
|
||||||
"""The _on_checkpoint callback should update status.phase and iteration."""
|
"""The _on_checkpoint callback should update status.phase and iteration."""
|
||||||
from nanobot.agent.subagent import SubagentStatus
|
from nanobot.agent.subagent import SubagentStatus
|
||||||
import asyncio
|
|
||||||
|
|
||||||
status = SubagentStatus(
|
status = SubagentStatus(
|
||||||
task_id="cp",
|
task_id="cp",
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import pytest
|
|||||||
from typer.testing import CliRunner
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.cli.commands import _make_provider, app
|
from nanobot.cli.commands import app
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
from nanobot.cron.types import CronJob, CronPayload
|
from nanobot.cron.types import CronJob, CronPayload
|
||||||
from nanobot.providers.factory import ProviderSnapshot
|
from nanobot.providers.factory import ProviderSnapshot
|
||||||
@ -488,8 +488,8 @@ def test_openai_compat_provider_passes_model_through():
|
|||||||
|
|
||||||
|
|
||||||
def test_make_provider_uses_github_copilot_backend():
|
def test_make_provider_uses_github_copilot_backend():
|
||||||
from nanobot.cli.commands import _make_provider
|
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
|
from nanobot.providers.factory import build_provider_for_preset
|
||||||
|
|
||||||
config = Config.model_validate(
|
config = Config.model_validate(
|
||||||
{
|
{
|
||||||
@ -503,7 +503,7 @@ def test_make_provider_uses_github_copilot_backend():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
provider = _make_provider(config)
|
provider = build_provider_for_preset(config, config.resolve_preset())
|
||||||
|
|
||||||
assert provider.__class__.__name__ == "GitHubCopilotProvider"
|
assert provider.__class__.__name__ == "GitHubCopilotProvider"
|
||||||
|
|
||||||
@ -562,6 +562,8 @@ def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
|
|||||||
|
|
||||||
|
|
||||||
def test_make_provider_passes_extra_headers_to_custom_provider():
|
def test_make_provider_passes_extra_headers_to_custom_provider():
|
||||||
|
from nanobot.providers.factory import build_provider_for_preset
|
||||||
|
|
||||||
config = Config.model_validate(
|
config = Config.model_validate(
|
||||||
{
|
{
|
||||||
"agents": {"defaults": {"provider": "custom", "model": "gpt-4o-mini"}},
|
"agents": {"defaults": {"provider": "custom", "model": "gpt-4o-mini"}},
|
||||||
@ -579,7 +581,7 @@ def test_make_provider_passes_extra_headers_to_custom_provider():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai:
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai:
|
||||||
_make_provider(config)
|
build_provider_for_preset(config, config.resolve_preset())
|
||||||
|
|
||||||
kwargs = mock_async_openai.call_args.kwargs
|
kwargs = mock_async_openai.call_args.kwargs
|
||||||
assert kwargs["api_key"] == "test-key"
|
assert kwargs["api_key"] == "test-key"
|
||||||
@ -597,11 +599,11 @@ def mock_agent_runtime(tmp_path):
|
|||||||
with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
|
with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
|
||||||
patch("nanobot.config.loader.resolve_config_env_vars", side_effect=lambda c: c), \
|
patch("nanobot.config.loader.resolve_config_env_vars", side_effect=lambda c: c), \
|
||||||
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
|
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
|
||||||
patch("nanobot.cli.commands._make_provider", return_value=object()), \
|
patch("nanobot.providers.factory.build_provider_for_preset", return_value=MagicMock(generation=MagicMock(max_tokens=8192))), \
|
||||||
patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
|
patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
|
||||||
patch("nanobot.bus.queue.MessageBus"), \
|
patch("nanobot.bus.queue.MessageBus"), \
|
||||||
patch("nanobot.cron.service.CronService"), \
|
patch("nanobot.cron.service.CronService"), \
|
||||||
patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls:
|
patch("nanobot.cli.commands.AgentLoop") as mock_agent_loop_cls:
|
||||||
agent_loop = MagicMock()
|
agent_loop = MagicMock()
|
||||||
agent_loop.channels_config = None
|
agent_loop.channels_config = None
|
||||||
agent_loop.process_direct = AsyncMock(
|
agent_loop.process_direct = AsyncMock(
|
||||||
@ -609,6 +611,7 @@ def mock_agent_runtime(tmp_path):
|
|||||||
)
|
)
|
||||||
agent_loop.close_mcp = AsyncMock(return_value=None)
|
agent_loop.close_mcp = AsyncMock(return_value=None)
|
||||||
mock_agent_loop_cls.return_value = agent_loop
|
mock_agent_loop_cls.return_value = agent_loop
|
||||||
|
mock_agent_loop_cls.from_config.return_value = agent_loop
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"config": config,
|
"config": config,
|
||||||
@ -639,7 +642,7 @@ def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_
|
|||||||
assert mock_agent_runtime["sync_templates"].call_args.args == (
|
assert mock_agent_runtime["sync_templates"].call_args.args == (
|
||||||
mock_agent_runtime["config"].workspace_path,
|
mock_agent_runtime["config"].workspace_path,
|
||||||
)
|
)
|
||||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == (
|
assert mock_agent_runtime["agent_loop_cls"].from_config.call_args.args[0].workspace_path == (
|
||||||
mock_agent_runtime["config"].workspace_path
|
mock_agent_runtime["config"].workspace_path
|
||||||
)
|
)
|
||||||
mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
|
mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
|
||||||
@ -672,7 +675,7 @@ def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
|
|||||||
)
|
)
|
||||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192)))
|
||||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||||
monkeypatch.setattr("nanobot.cron.service.CronService", lambda _store: object())
|
monkeypatch.setattr("nanobot.cron.service.CronService", lambda _store: object())
|
||||||
|
|
||||||
@ -680,13 +683,17 @@ def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
|
|||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
return cls(*args, **kwargs)
|
||||||
|
|
||||||
async def process_direct(self, *_args, **_kwargs):
|
async def process_direct(self, *_args, **_kwargs):
|
||||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||||
|
|
||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||||
|
|
||||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||||
@ -707,7 +714,7 @@ def test_agent_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Pa
|
|||||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192)))
|
||||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||||
|
|
||||||
class _FakeCron:
|
class _FakeCron:
|
||||||
@ -718,6 +725,10 @@ def test_agent_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Pa
|
|||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
return cls(*args, **kwargs)
|
||||||
|
|
||||||
async def process_direct(self, *_args, **_kwargs):
|
async def process_direct(self, *_args, **_kwargs):
|
||||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||||
|
|
||||||
@ -725,7 +736,7 @@ def test_agent_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Pa
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||||
|
|
||||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||||
@ -753,7 +764,7 @@ def test_agent_workspace_override_does_not_migrate_legacy_cron(
|
|||||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192)))
|
||||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||||
|
|
||||||
@ -765,6 +776,10 @@ def test_agent_workspace_override_does_not_migrate_legacy_cron(
|
|||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
return cls(*args, **kwargs)
|
||||||
|
|
||||||
async def process_direct(self, *_args, **_kwargs):
|
async def process_direct(self, *_args, **_kwargs):
|
||||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||||
|
|
||||||
@ -772,7 +787,7 @@ def test_agent_workspace_override_does_not_migrate_legacy_cron(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||||
|
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
@ -806,7 +821,7 @@ def test_agent_custom_config_workspace_does_not_migrate_legacy_cron(
|
|||||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192)))
|
||||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||||
|
|
||||||
@ -818,6 +833,10 @@ def test_agent_custom_config_workspace_does_not_migrate_legacy_cron(
|
|||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
return cls(*args, **kwargs)
|
||||||
|
|
||||||
async def process_direct(self, *_args, **_kwargs):
|
async def process_direct(self, *_args, **_kwargs):
|
||||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||||
|
|
||||||
@ -825,7 +844,7 @@ def test_agent_custom_config_workspace_does_not_migrate_legacy_cron(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None
|
"nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None
|
||||||
)
|
)
|
||||||
@ -846,7 +865,7 @@ def test_agent_overrides_workspace_path(mock_agent_runtime):
|
|||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
|
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
|
||||||
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
|
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
|
||||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
assert mock_agent_runtime["agent_loop_cls"].from_config.call_args.args[0].workspace_path == workspace_path
|
||||||
|
|
||||||
|
|
||||||
def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, tmp_path: Path):
|
def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, tmp_path: Path):
|
||||||
@ -863,7 +882,7 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
|
|||||||
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
|
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
|
||||||
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
|
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
|
||||||
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
|
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
|
||||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
assert mock_agent_runtime["agent_loop_cls"].from_config.call_args.args[0].workspace_path == workspace_path
|
||||||
|
|
||||||
|
|
||||||
def test_agent_hints_about_deprecated_memory_window(mock_agent_runtime, tmp_path):
|
def test_agent_hints_about_deprecated_memory_window(mock_agent_runtime, tmp_path):
|
||||||
@ -928,8 +947,8 @@ def _patch_cli_command_runtime(
|
|||||||
sync_templates or (lambda _path: None),
|
sync_templates or (lambda _path: None),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.cli.commands._make_provider",
|
"nanobot.providers.factory.build_provider_for_preset",
|
||||||
provider_factory,
|
lambda *_a, **_k: provider_factory(Config()),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.providers.factory.build_provider_snapshot",
|
"nanobot.providers.factory.build_provider_snapshot",
|
||||||
@ -962,6 +981,10 @@ def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) -
|
|||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
seen["workspace"] = kwargs["workspace"]
|
seen["workspace"] = kwargs["workspace"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config, bus=None, **kwargs):
|
||||||
|
return cls(workspace=config.workspace_path, **kwargs)
|
||||||
|
|
||||||
async def _connect_mcp(self) -> None:
|
async def _connect_mcp(self) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -985,7 +1008,7 @@ def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) -
|
|||||||
message_bus=lambda: object(),
|
message_bus=lambda: object(),
|
||||||
session_manager=lambda _workspace: object(),
|
session_manager=lambda _workspace: object(),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||||
monkeypatch.setattr("nanobot.api.server.create_app", _fake_create_app)
|
monkeypatch.setattr("nanobot.api.server.create_app", _fake_create_app)
|
||||||
monkeypatch.setattr("aiohttp.web.run_app", _fake_run_app)
|
monkeypatch.setattr("aiohttp.web.run_app", _fake_run_app)
|
||||||
|
|
||||||
@ -1077,7 +1100,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
|||||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider)
|
monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *_a, **_k: provider)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.providers.factory.build_provider_snapshot",
|
"nanobot.providers.factory.build_provider_snapshot",
|
||||||
lambda _config: _test_provider_snapshot(provider, _config),
|
lambda _config: _test_provider_snapshot(provider, _config),
|
||||||
@ -1117,8 +1140,13 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
|||||||
class _FakeAgentLoop:
|
class _FakeAgentLoop:
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
self.model = "test-model"
|
self.model = "test-model"
|
||||||
|
self.provider = object()
|
||||||
self.tools = {}
|
self.tools = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
return cls(*args, **kwargs)
|
||||||
|
|
||||||
async def process_direct(self, *_args, **_kwargs):
|
async def process_direct(self, *_args, **_kwargs):
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel="telegram",
|
channel="telegram",
|
||||||
@ -1152,7 +1180,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||||
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _StopAfterCronSetup)
|
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _StopAfterCronSetup)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.utils.evaluator.evaluate_response",
|
"nanobot.utils.evaluator.evaluate_response",
|
||||||
@ -1181,7 +1209,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
|||||||
|
|
||||||
assert response == "Time to stretch."
|
assert response == "Time to stretch."
|
||||||
assert seen["response"] == "Time to stretch."
|
assert seen["response"] == "Time to stretch."
|
||||||
assert seen["provider"] is provider
|
assert seen["provider"] is not None # provider resolved inside AgentLoop
|
||||||
assert seen["model"] == "test-model"
|
assert seen["model"] == "test-model"
|
||||||
assert seen["task_context"] == (
|
assert seen["task_context"] == (
|
||||||
"The scheduled time has arrived. Deliver this reminder to the user now, "
|
"The scheduled time has arrived. Deliver this reminder to the user now, "
|
||||||
@ -1228,7 +1256,7 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
|
|||||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
monkeypatch.setattr("nanobot.providers.factory.build_provider_for_preset", lambda *a, **k: MagicMock(generation=MagicMock(max_tokens=8192)))
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.providers.factory.build_provider_snapshot",
|
"nanobot.providers.factory.build_provider_snapshot",
|
||||||
lambda _config: _test_provider_snapshot(object(), _config),
|
lambda _config: _test_provider_snapshot(object(), _config),
|
||||||
@ -1248,8 +1276,13 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
|
|||||||
class _FakeAgentLoop:
|
class _FakeAgentLoop:
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
self.model = "test-model"
|
self.model = "test-model"
|
||||||
|
self.provider = object()
|
||||||
self.tools = {}
|
self.tools = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
return cls(*args, **kwargs)
|
||||||
|
|
||||||
async def process_direct(self, *_args, on_progress=None, **_kwargs):
|
async def process_direct(self, *_args, on_progress=None, **_kwargs):
|
||||||
seen["on_progress"] = on_progress
|
seen["on_progress"] = on_progress
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
@ -1275,7 +1308,7 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||||
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _StopAfterCronSetup)
|
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _StopAfterCronSetup)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.utils.evaluator.evaluate_response",
|
"nanobot.utils.evaluator.evaluate_response",
|
||||||
@ -1480,9 +1513,14 @@ def test_gateway_health_endpoint_binds_and_serves_expected_responses(
|
|||||||
class _FakeAgentLoop:
|
class _FakeAgentLoop:
|
||||||
def __init__(self, **_kwargs) -> None:
|
def __init__(self, **_kwargs) -> None:
|
||||||
self.model = "test-model"
|
self.model = "test-model"
|
||||||
|
self.provider = object()
|
||||||
self.dream = _FakeDream()
|
self.dream = _FakeDream()
|
||||||
self.sessions = _FakeSessionManager()
|
self.sessions = _FakeSessionManager()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
await asyncio.Event().wait()
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
@ -1571,7 +1609,7 @@ def test_gateway_health_endpoint_binds_and_serves_expected_responses(
|
|||||||
message_bus=lambda: object(),
|
message_bus=lambda: object(),
|
||||||
session_manager=lambda _workspace: object(),
|
session_manager=lambda _workspace: object(),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||||
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager)
|
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager)
|
||||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCronService)
|
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCronService)
|
||||||
monkeypatch.setattr("nanobot.heartbeat.service.HeartbeatService", _FakeHeartbeatService)
|
monkeypatch.setattr("nanobot.heartbeat.service.HeartbeatService", _FakeHeartbeatService)
|
||||||
|
|||||||
@ -99,11 +99,23 @@ def test_preset_not_found_raises_error() -> None:
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_presets_invalid_preset_raises_error() -> None:
|
||||||
|
import pytest
|
||||||
|
with pytest.raises(Exception, match="fallback_presets.*not found"):
|
||||||
|
Config.model_validate({
|
||||||
|
"model_presets": {
|
||||||
|
"valid": {"model": "gpt-4"},
|
||||||
|
},
|
||||||
|
"agents": {"defaults": {"fallback_presets": ["invalid_preset"]}},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_preset_without_preset_returns_defaults() -> None:
|
def test_resolve_preset_without_preset_returns_defaults() -> None:
|
||||||
"""Backward compat: no preset → resolve_preset returns individual field values."""
|
"""Backward compat: no explicit preset → resolve_preset returns the auto-created 'default' preset."""
|
||||||
cfg = Config.model_validate({
|
cfg = Config.model_validate({
|
||||||
"agents": {"defaults": {"model": "deepseek-chat"}},
|
"agents": {"defaults": {"model": "deepseek-chat"}},
|
||||||
})
|
})
|
||||||
|
assert cfg.agents.defaults.model_preset == "default"
|
||||||
r = cfg.resolve_preset()
|
r = cfg.resolve_preset()
|
||||||
assert r.model == "deepseek-chat"
|
assert r.model == "deepseek-chat"
|
||||||
assert r.max_tokens == 8192
|
assert r.max_tokens == 8192
|
||||||
@ -170,13 +182,14 @@ def test_preset_with_auto_provider_uses_keyword_matching() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_backward_compat_no_preset() -> None:
|
def test_backward_compat_no_preset() -> None:
|
||||||
"""Existing configs without model_presets work exactly as before."""
|
"""Existing configs without model_presets are automatically promoted to the 'default' preset."""
|
||||||
cfg = Config.model_validate({
|
cfg = Config.model_validate({
|
||||||
"providers": {"anthropic": {"api_key": "test-key"}},
|
"providers": {"anthropic": {"api_key": "test-key"}},
|
||||||
"agents": {"defaults": {"model": "anthropic/claude-opus-4-5"}},
|
"agents": {"defaults": {"model": "anthropic/claude-opus-4-5"}},
|
||||||
})
|
})
|
||||||
assert cfg.resolve_preset().model == "anthropic/claude-opus-4-5"
|
assert cfg.resolve_preset().model == "anthropic/claude-opus-4-5"
|
||||||
assert cfg.agents.defaults.model_preset is None
|
assert cfg.agents.defaults.model_preset == "default"
|
||||||
|
assert "default" in cfg.model_presets
|
||||||
assert cfg.get_provider_name() == "anthropic"
|
assert cfg.get_provider_name() == "anthropic"
|
||||||
|
|
||||||
|
|
||||||
@ -204,3 +217,48 @@ def test_resolve_preset_overrides_all_model_fields() -> None:
|
|||||||
def test_empty_model_presets_dict_is_harmless() -> None:
|
def test_empty_model_presets_dict_is_harmless() -> None:
|
||||||
cfg = Config.model_validate({"model_presets": {}})
|
cfg = Config.model_validate({"model_presets": {}})
|
||||||
assert cfg.resolve_preset().model == "anthropic/claude-opus-4-5"
|
assert cfg.resolve_preset().model == "anthropic/claude-opus-4-5"
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_uses_preset_provider_not_defaults() -> None:
|
||||||
|
"""When creating a provider for a non-active preset, the preset's own provider must be used."""
|
||||||
|
from nanobot.providers.factory import make_provider_factory
|
||||||
|
|
||||||
|
cfg = Config.model_validate({
|
||||||
|
"model_presets": {
|
||||||
|
"kimi": {"model": "kimi-k2.6", "provider": "moonshot"},
|
||||||
|
"zhipu": {"model": "glm-5.1", "provider": "zhipu"},
|
||||||
|
},
|
||||||
|
"providers": {
|
||||||
|
"moonshot": {"api_key": "moonshot-key", "api_base": "https://api.moonshot.ai/v1"},
|
||||||
|
"zhipu": {"api_key": "zhipu-key", "api_base": "https://open.bigmodel.cn/api/paas/v4"},
|
||||||
|
},
|
||||||
|
"agents": {"defaults": {"model_preset": "kimi"}},
|
||||||
|
})
|
||||||
|
|
||||||
|
factory = make_provider_factory(cfg)
|
||||||
|
zhipu_provider = factory("zhipu")
|
||||||
|
|
||||||
|
assert zhipu_provider.api_base == "https://open.bigmodel.cn/api/paas/v4"
|
||||||
|
assert getattr(zhipu_provider, "api_key", None) == "zhipu-key"
|
||||||
|
|
||||||
|
# Also verify the active preset provider is still correct
|
||||||
|
moonshot_provider = factory("kimi")
|
||||||
|
assert moonshot_provider.api_base == "https://api.moonshot.ai/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_rejects_unknown_preset_name() -> None:
|
||||||
|
"""Factory must raise ValueError when asked for a preset not in model_presets."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.providers.factory import make_provider_factory
|
||||||
|
|
||||||
|
cfg = Config.model_validate({
|
||||||
|
"model_presets": {
|
||||||
|
"known": {"model": "gpt-4", "provider": "openai"},
|
||||||
|
},
|
||||||
|
"providers": {"openai": {"api_key": "test-key"}},
|
||||||
|
})
|
||||||
|
|
||||||
|
factory = make_provider_factory(cfg)
|
||||||
|
with pytest.raises(ValueError, match="Preset 'unknown' not found"):
|
||||||
|
factory("unknown")
|
||||||
|
|||||||
@ -39,7 +39,7 @@ def test_from_config_default_path():
|
|||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
|
|
||||||
with patch("nanobot.config.loader.load_config") as mock_load, \
|
with patch("nanobot.config.loader.load_config") as mock_load, \
|
||||||
patch("nanobot.nanobot._make_provider") as mock_prov:
|
patch("nanobot.providers.factory.build_provider_for_preset") as mock_prov:
|
||||||
mock_load.return_value = Config()
|
mock_load.return_value = Config()
|
||||||
mock_prov.return_value = MagicMock()
|
mock_prov.return_value = MagicMock()
|
||||||
mock_prov.return_value.get_default_model.return_value = "test"
|
mock_prov.return_value.get_default_model.return_value = "test"
|
||||||
@ -127,7 +127,7 @@ def test_workspace_override(tmp_path):
|
|||||||
|
|
||||||
def test_sdk_make_provider_uses_github_copilot_backend():
|
def test_sdk_make_provider_uses_github_copilot_backend():
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
from nanobot.nanobot import _make_provider
|
from nanobot.providers.factory import make_provider
|
||||||
|
|
||||||
config = Config.model_validate(
|
config = Config.model_validate(
|
||||||
{
|
{
|
||||||
@ -141,7 +141,7 @@ def test_sdk_make_provider_uses_github_copilot_backend():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
provider = _make_provider(config)
|
provider = make_provider(config)
|
||||||
|
|
||||||
assert provider.__class__.__name__ == "GitHubCopilotProvider"
|
assert provider.__class__.__name__ == "GitHubCopilotProvider"
|
||||||
|
|
||||||
|
|||||||
467
tests/test_preset_failover_smoke.py
Normal file
467
tests/test_preset_failover_smoke.py
Normal file
@ -0,0 +1,467 @@
|
|||||||
|
"""End-to-end smoke tests for model presets + failover.
|
||||||
|
|
||||||
|
Uses a local aiohttp fake OpenAI server so requests are real HTTP,
|
||||||
|
not mocked at the provider level.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.nanobot import Nanobot
|
||||||
|
from nanobot.providers.base import GenerationSettings, LLMProvider
|
||||||
|
from nanobot.providers.failover import ModelRouter
|
||||||
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
|
||||||
|
try:
|
||||||
|
from aiohttp import web
|
||||||
|
from aiohttp.test_utils import TestServer
|
||||||
|
|
||||||
|
HAS_AIOHTTP = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_AIOHTTP = False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _disable_proxy_for_localhost_tests(monkeypatch):
|
||||||
|
"""Prevent httpx from routing localhost requests through a system proxy."""
|
||||||
|
monkeypatch.delenv("ALL_PROXY", raising=False)
|
||||||
|
monkeypatch.delenv("HTTP_PROXY", raising=False)
|
||||||
|
monkeypatch.delenv("HTTPS_PROXY", raising=False)
|
||||||
|
monkeypatch.setenv("NO_PROXY", "127.0.0.1,localhost")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers (mock-level preset tests)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _write_config(tmp_path: Path, **overrides) -> Path:
|
||||||
|
data = {
|
||||||
|
"providers": {
|
||||||
|
"openrouter": {"apiKey": "sk-test-key"},
|
||||||
|
"openai": {"apiKey": "sk-openai-test"},
|
||||||
|
},
|
||||||
|
"agents": {"defaults": {"model": "openai/gpt-4.1"}},
|
||||||
|
"tools": {"my": {"allowSet": True}},
|
||||||
|
}
|
||||||
|
data.update(overrides)
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(json.dumps(data))
|
||||||
|
return config_path
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 1. Model Preset Mock Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preset_loaded_at_startup(tmp_path: Path) -> None:
|
||||||
|
config_path = _write_config(
|
||||||
|
tmp_path,
|
||||||
|
model_presets={
|
||||||
|
"fast": {
|
||||||
|
"model": "gpt-4.1-mini",
|
||||||
|
"provider": "openai",
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"context_window_tokens": 128000,
|
||||||
|
"temperature": 0.3,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
agents={"defaults": {"model_preset": "fast", "model": "ignored-model"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||||
|
|
||||||
|
loop = bot._loop
|
||||||
|
assert loop.model == "gpt-4.1-mini"
|
||||||
|
assert loop.context_window_tokens == 128000
|
||||||
|
assert loop.provider.generation.temperature == 0.3
|
||||||
|
assert loop.provider.generation.max_tokens == 4096
|
||||||
|
assert loop.model_preset == "fast"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preset_runtime_switch_updates_all_fields(tmp_path: Path) -> None:
|
||||||
|
config_path = _write_config(
|
||||||
|
tmp_path,
|
||||||
|
model_presets={
|
||||||
|
"cheap": {
|
||||||
|
"model": "gpt-4.1-mini",
|
||||||
|
"provider": "openai",
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"context_window_tokens": 64000,
|
||||||
|
"temperature": 0.5,
|
||||||
|
},
|
||||||
|
"power": {
|
||||||
|
"model": "gpt-4.1",
|
||||||
|
"provider": "openai",
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"context_window_tokens": 256000,
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
agents={"defaults": {"model_preset": "cheap"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||||
|
|
||||||
|
loop = bot._loop
|
||||||
|
assert loop.model == "gpt-4.1-mini"
|
||||||
|
|
||||||
|
my_tool = loop.tools.get("my")
|
||||||
|
result = await my_tool.execute(action="set", key="model_preset", value="power")
|
||||||
|
assert "Error" not in result
|
||||||
|
|
||||||
|
assert loop.model == "gpt-4.1"
|
||||||
|
assert loop.context_window_tokens == 256000
|
||||||
|
assert loop.provider.generation.temperature == 0.1
|
||||||
|
assert loop.provider.generation.max_tokens == 8192
|
||||||
|
assert loop.model_preset == "power"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preset_switch_unknown_returns_error(tmp_path: Path) -> None:
|
||||||
|
config_path = _write_config(
|
||||||
|
tmp_path,
|
||||||
|
model_presets={"a": {"model": "model-a"}},
|
||||||
|
agents={"defaults": {"model_preset": "a"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||||
|
|
||||||
|
loop = bot._loop
|
||||||
|
original_model = loop.model
|
||||||
|
|
||||||
|
my_tool = loop.tools.get("my")
|
||||||
|
result = await my_tool.execute(action="set", key="model_preset", value="nonexistent")
|
||||||
|
assert "not found" in result.lower()
|
||||||
|
|
||||||
|
assert loop.model == original_model
|
||||||
|
assert loop.model_preset == "a"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preset_model_with_fallback_presets_in_config(tmp_path: Path) -> None:
|
||||||
|
config_path = _write_config(
|
||||||
|
tmp_path,
|
||||||
|
model_presets={
|
||||||
|
"prod": {
|
||||||
|
"model": "gpt-4.1",
|
||||||
|
"provider": "openai",
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
"fallback": {
|
||||||
|
"model": "gpt-4.1-mini",
|
||||||
|
"provider": "openai",
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"temperature": 0.2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
agents={
|
||||||
|
"defaults": {
|
||||||
|
"model_preset": "prod",
|
||||||
|
"fallback_presets": ["fallback"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||||
|
|
||||||
|
loop = bot._loop
|
||||||
|
assert loop.model == "gpt-4.1"
|
||||||
|
assert isinstance(loop.provider, ModelRouter)
|
||||||
|
assert loop.provider.fallback_presets == ["fallback"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_presets_wired_to_all_subsystems(tmp_path: Path) -> None:
|
||||||
|
"""When fallback_presets is configured, every subsystem that calls the LLM
|
||||||
|
must use the same ModelRouter instance, not the raw primary provider."""
|
||||||
|
config_path = _write_config(
|
||||||
|
tmp_path,
|
||||||
|
model_presets={
|
||||||
|
"prod": {
|
||||||
|
"model": "gpt-4.1",
|
||||||
|
"provider": "openai",
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
"fallback": {
|
||||||
|
"model": "gpt-4.1-mini",
|
||||||
|
"provider": "openai",
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"temperature": 0.2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
agents={
|
||||||
|
"defaults": {
|
||||||
|
"model_preset": "prod",
|
||||||
|
"fallback_presets": ["fallback"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||||
|
|
||||||
|
loop = bot._loop
|
||||||
|
router = loop.provider
|
||||||
|
assert isinstance(router, ModelRouter)
|
||||||
|
|
||||||
|
# Every LLM-consuming subsystem must share the same router
|
||||||
|
assert loop.runner.provider is router, "AgentRunner must use ModelRouter"
|
||||||
|
assert loop.subagents.provider is router, "SubagentManager must use ModelRouter"
|
||||||
|
assert loop.consolidator.provider is router, "Consolidator must use ModelRouter"
|
||||||
|
assert loop.dream.provider is router, "Dream must use ModelRouter"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 2. Real HTTP Smoke Tests (aiohttp fake OpenAI server)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preset_generation_params_reach_http_request() -> None:
|
||||||
|
"""Provider.generation settings must appear in the actual HTTP request body."""
|
||||||
|
requests_log: list[dict] = []
|
||||||
|
|
||||||
|
async def handler(request: web.Request) -> web.Response:
|
||||||
|
body = await request.json()
|
||||||
|
requests_log.append(body)
|
||||||
|
return web.json_response({
|
||||||
|
"id": "chatcmpl-test",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"model": body.get("model"),
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": "pong"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_post("/chat/completions", handler)
|
||||||
|
server = TestServer(app)
|
||||||
|
await server.start_server()
|
||||||
|
try:
|
||||||
|
base_url = str(server.make_url("/"))
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="test",
|
||||||
|
api_base=base_url,
|
||||||
|
default_model="test-model",
|
||||||
|
)
|
||||||
|
provider.generation = GenerationSettings(temperature=0.42, max_tokens=1024)
|
||||||
|
|
||||||
|
with patch.object(LLMProvider, "_CHAT_RETRY_DELAYS", (0,)):
|
||||||
|
response = await provider.chat_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "ping"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.finish_reason != "error"
|
||||||
|
assert len(requests_log) >= 1
|
||||||
|
req = requests_log[0]
|
||||||
|
assert req["model"] == "test-model"
|
||||||
|
assert req["temperature"] == 0.42
|
||||||
|
assert req["max_tokens"] == 1024
|
||||||
|
finally:
|
||||||
|
await server.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failover_sends_second_request_to_fallback_model() -> None:
|
||||||
|
"""Primary returns 503; after retry exhaustion ModelRouter hits fallback."""
|
||||||
|
requests_log: list[dict] = []
|
||||||
|
|
||||||
|
async def handler(request: web.Request) -> web.Response:
|
||||||
|
body = await request.json()
|
||||||
|
requests_log.append(body)
|
||||||
|
model = body.get("model")
|
||||||
|
|
||||||
|
if model == "primary-model":
|
||||||
|
return web.Response(
|
||||||
|
status=503,
|
||||||
|
body=json.dumps({"error": {"message": "overloaded", "type": "server_error"}}),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
"id": "chatcmpl-test",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"model": model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": "fallback-ok"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_post("/chat/completions", handler)
|
||||||
|
server = TestServer(app)
|
||||||
|
await server.start_server()
|
||||||
|
try:
|
||||||
|
base_url = str(server.make_url("/"))
|
||||||
|
primary = OpenAICompatProvider(
|
||||||
|
api_key="test", api_base=base_url, default_model="primary-model"
|
||||||
|
)
|
||||||
|
fallback = OpenAICompatProvider(
|
||||||
|
api_key="test", api_base=base_url, default_model="fallback-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
|
||||||
|
router = ModelRouter(
|
||||||
|
primary_provider=primary,
|
||||||
|
primary_model="primary-model",
|
||||||
|
fallback_presets=["fallback-model"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(LLMProvider, "_CHAT_RETRY_DELAYS", (0,)):
|
||||||
|
response = await router.chat_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.finish_reason != "error"
|
||||||
|
assert response.content == "fallback-ok"
|
||||||
|
|
||||||
|
models_requested = [r["model"] for r in requests_log]
|
||||||
|
assert "primary-model" in models_requested
|
||||||
|
assert "fallback-model" in models_requested
|
||||||
|
factory.assert_called_once_with("fallback-model")
|
||||||
|
finally:
|
||||||
|
await server.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failover_on_quota_429() -> None:
|
||||||
|
"""Quota 429 on one provider may still work on a different provider."""
|
||||||
|
requests_log: list[dict] = []
|
||||||
|
|
||||||
|
async def handler(request: web.Request) -> web.Response:
|
||||||
|
body = await request.json()
|
||||||
|
requests_log.append(body)
|
||||||
|
return web.Response(
|
||||||
|
status=429,
|
||||||
|
body=json.dumps({
|
||||||
|
"error": {
|
||||||
|
"message": "insufficient quota",
|
||||||
|
"type": "insufficient_quota",
|
||||||
|
"code": "insufficient_quota",
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_post("/chat/completions", handler)
|
||||||
|
server = TestServer(app)
|
||||||
|
await server.start_server()
|
||||||
|
try:
|
||||||
|
base_url = str(server.make_url("/"))
|
||||||
|
primary = OpenAICompatProvider(
|
||||||
|
api_key="test", api_base=base_url, default_model="primary-model"
|
||||||
|
)
|
||||||
|
fallback = OpenAICompatProvider(
|
||||||
|
api_key="test", api_base=base_url, default_model="fallback-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
|
||||||
|
router = ModelRouter(
|
||||||
|
primary_provider=primary,
|
||||||
|
primary_model="primary-model",
|
||||||
|
fallback_presets=["fallback-model"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(LLMProvider, "_CHAT_RETRY_DELAYS", (0,)):
|
||||||
|
response = await router.chat_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quota 429 SHOULD trigger failover — another provider may still work.
|
||||||
|
factory.assert_called_once_with("fallback-model")
|
||||||
|
assert response.finish_reason == "error"
|
||||||
|
# Both primary and fallback should have been requested.
|
||||||
|
assert len(requests_log) == 2
|
||||||
|
finally:
|
||||||
|
await server.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_router_failover_integration() -> None:
|
||||||
|
"""ModelRouter -> real HTTP failover chain (primary 503, fallback 200)."""
|
||||||
|
requests_log: list[dict] = []
|
||||||
|
|
||||||
|
async def handler(request: web.Request) -> web.Response:
|
||||||
|
body = await request.json()
|
||||||
|
requests_log.append(body)
|
||||||
|
model = body.get("model")
|
||||||
|
|
||||||
|
if model == "primary-model":
|
||||||
|
return web.Response(
|
||||||
|
status=503,
|
||||||
|
body=json.dumps({"error": {"message": "overloaded", "type": "server_error"}}),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
"id": "chatcmpl-test",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"model": model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": "fallback-ok"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_post("/chat/completions", handler)
|
||||||
|
server = TestServer(app)
|
||||||
|
await server.start_server()
|
||||||
|
try:
|
||||||
|
base_url = str(server.make_url("/"))
|
||||||
|
primary = OpenAICompatProvider(
|
||||||
|
api_key="test", api_base=base_url, default_model="primary-model"
|
||||||
|
)
|
||||||
|
fallback = OpenAICompatProvider(
|
||||||
|
api_key="test", api_base=base_url, default_model="fallback-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
|
||||||
|
router = ModelRouter(
|
||||||
|
primary_provider=primary,
|
||||||
|
primary_model="primary-model",
|
||||||
|
fallback_presets=["fallback-model"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(LLMProvider, "_CHAT_RETRY_DELAYS", (0,)):
|
||||||
|
response = await router.chat_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.finish_reason != "error"
|
||||||
|
assert response.content == "fallback-ok"
|
||||||
|
models_requested = [r["model"] for r in requests_log]
|
||||||
|
assert "primary-model" in models_requested
|
||||||
|
assert "fallback-model" in models_requested
|
||||||
|
factory.assert_called_once_with("fallback-model")
|
||||||
|
finally:
|
||||||
|
await server.close()
|
||||||
Loading…
x
Reference in New Issue
Block a user