mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
Merge PR #3756: feat(runner): model failover with fallback_models
feat(runner): model failover with fallback_models
This commit is contained in:
commit
921fe259f4
@ -672,7 +672,8 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age
|
||||
"maxTokens": 8192,
|
||||
"contextWindowTokens": 128000,
|
||||
"temperature": 0.1,
|
||||
"modelPreset": null
|
||||
"modelPreset": "fast",
|
||||
"fallbackModels": ["deep"]
|
||||
}
|
||||
},
|
||||
"modelPresets": {
|
||||
@ -708,6 +709,40 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age
|
||||
|
||||
`default` is reserved and always means the implicit preset built from `agents.defaults.*`; do not define `modelPresets.default`. Use `/model default` to switch back to `agents.defaults.*`.
|
||||
|
||||
### Model Fallbacks
|
||||
|
||||
`agents.defaults.fallbackModels` defines an ordered failover chain for the active model configuration. The primary model is still selected by `agents.defaults.modelPreset` (or the implicit default config when no preset is active).
|
||||
|
||||
Each fallback candidate can be either:
|
||||
|
||||
- A preset name from `modelPresets`, such as `"deep"`. The preset's full model, provider, generation, and context-window config is used.
|
||||
- An inline fallback object with at least `provider` and `model`. Optional `maxTokens`, `contextWindowTokens`, and `temperature` fields inherit from the active primary config when omitted. `reasoningEffort` does not inherit; omit it to leave reasoning off for that fallback, or set it explicitly for models that support reasoning.
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"modelPreset": "fast",
|
||||
"fallbackModels": [
|
||||
"deep",
|
||||
{
|
||||
"provider": "deepseek",
|
||||
"model": "deepseek-v4-pro",
|
||||
"maxTokens": 4096,
|
||||
"contextWindowTokens": 262144
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
String entries are preset names, not raw model names. If you want to use a model that is not already a preset, use the inline object form.
|
||||
|
||||
Failover only runs when the primary provider returns a retryable model/provider error before any answer text has been streamed. Typical fallback cases include timeouts, connection errors, 5xx server errors, 429 rate limits, overloads, and quota/balance exhaustion. It does not run for malformed requests, authentication/permission errors, content filtering/refusals, or context-length/message-format errors.
|
||||
|
||||
If fallback candidates use smaller `contextWindowTokens` values, nanobot builds context using the smallest window in the active chain so every candidate can receive the same prompt.
|
||||
|
||||
Set `agents.defaults.modelPreset` to start with a named preset:
|
||||
|
||||
```json
|
||||
|
||||
@ -74,6 +74,20 @@ class DreamConfig(Base):
|
||||
return f"every {hours}h"
|
||||
|
||||
|
||||
class InlineFallbackConfig(Base):
|
||||
"""One inline fallback model configuration."""
|
||||
|
||||
model: str
|
||||
provider: str
|
||||
max_tokens: int | None = None
|
||||
context_window_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
reasoning_effort: str | None = None
|
||||
|
||||
|
||||
FallbackCandidate = str | InlineFallbackConfig
|
||||
|
||||
|
||||
class ModelPresetConfig(Base):
|
||||
"""A named set of model + generation parameters for quick switching."""
|
||||
|
||||
@ -106,6 +120,7 @@ class AgentDefaults(Base):
|
||||
context_window_tokens: int = 65_536
|
||||
context_block_limit: int | None = None
|
||||
temperature: float = 0.1
|
||||
fallback_models: list[FallbackCandidate] = Field(default_factory=list)
|
||||
max_tool_iterations: int = 200
|
||||
max_concurrent_subagents: int = Field(default=1, ge=1)
|
||||
max_tool_result_chars: int = 16_000
|
||||
@ -287,6 +302,9 @@ class Config(BaseSettings):
|
||||
name = self.agents.defaults.model_preset
|
||||
if name and name != "default" and name not in self.model_presets:
|
||||
raise ValueError(f"model_preset {name!r} not found in model_presets")
|
||||
for fallback in self.agents.defaults.fallback_models:
|
||||
if isinstance(fallback, str) and fallback not in self.model_presets:
|
||||
raise ValueError(f"fallback_models entry {fallback!r} not found in model_presets")
|
||||
return self
|
||||
|
||||
def resolve_default_preset(self) -> ModelPresetConfig:
|
||||
|
||||
@ -5,8 +5,9 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.config.schema import Config, ModelPresetConfig
|
||||
from nanobot.config.schema import Config, InlineFallbackConfig, ModelPresetConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.providers.fallback_provider import FallbackProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
|
||||
@ -27,15 +28,16 @@ def _resolve_model_preset(
|
||||
return preset if preset is not None else config.resolve_preset(preset_name)
|
||||
|
||||
|
||||
def make_provider(
|
||||
def _make_provider_core(
|
||||
config: Config,
|
||||
*,
|
||||
preset_name: str | None = None,
|
||||
preset: ModelPresetConfig | None = None,
|
||||
model: str | None = None,
|
||||
) -> LLMProvider:
|
||||
"""Create the LLM provider implied by config."""
|
||||
"""Create a plain LLM provider without failover wrapping."""
|
||||
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
||||
model = resolved.model
|
||||
model = model or resolved.model
|
||||
provider_name = config.get_provider_name(model, preset=resolved)
|
||||
p = config.get_provider(model, preset=resolved)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
@ -102,15 +104,93 @@ def make_provider(
|
||||
return provider
|
||||
|
||||
|
||||
def _inline_fallback_preset(
|
||||
primary: ModelPresetConfig,
|
||||
fallback: InlineFallbackConfig,
|
||||
) -> ModelPresetConfig:
|
||||
return ModelPresetConfig(
|
||||
model=fallback.model,
|
||||
provider=fallback.provider,
|
||||
max_tokens=fallback.max_tokens if fallback.max_tokens is not None else primary.max_tokens,
|
||||
context_window_tokens=(
|
||||
fallback.context_window_tokens
|
||||
if fallback.context_window_tokens is not None
|
||||
else primary.context_window_tokens
|
||||
),
|
||||
temperature=(
|
||||
fallback.temperature if fallback.temperature is not None else primary.temperature
|
||||
),
|
||||
reasoning_effort=fallback.reasoning_effort,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_fallback_presets(config: Config, primary: ModelPresetConfig) -> list[ModelPresetConfig]:
|
||||
presets: list[ModelPresetConfig] = []
|
||||
for fallback in config.agents.defaults.fallback_models:
|
||||
if isinstance(fallback, str):
|
||||
presets.append(config.model_presets[fallback])
|
||||
else:
|
||||
presets.append(_inline_fallback_preset(primary, fallback))
|
||||
return presets
|
||||
|
||||
|
||||
def make_provider(
|
||||
config: Config,
|
||||
*,
|
||||
preset_name: str | None = None,
|
||||
preset: ModelPresetConfig | None = None,
|
||||
model: str | None = None,
|
||||
) -> LLMProvider:
|
||||
"""Create the LLM provider implied by config.
|
||||
|
||||
When *model* is given, it overrides the resolved/preset model — used by
|
||||
the failover path to create providers for fallback models.
|
||||
"""
|
||||
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
||||
provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model)
|
||||
fallback_presets = _resolve_fallback_presets(config, resolved)
|
||||
|
||||
if fallback_presets:
|
||||
provider = FallbackProvider(
|
||||
primary=provider,
|
||||
fallback_presets=fallback_presets,
|
||||
provider_factory=lambda fb: _make_provider_core(
|
||||
config, preset_name=preset_name, preset=fb
|
||||
),
|
||||
)
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
def provider_signature(
|
||||
config: Config,
|
||||
*,
|
||||
preset_name: str | None = None,
|
||||
preset: ModelPresetConfig | None = None,
|
||||
) -> tuple[object, ...]:
|
||||
"""Return the config fields that affect the primary LLM provider."""
|
||||
"""Return the config fields that affect the active provider chain."""
|
||||
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
||||
p = config.get_provider(resolved.model, preset=resolved)
|
||||
fallback_presets = _resolve_fallback_presets(config, resolved)
|
||||
|
||||
def _fallback_signature(fallback: ModelPresetConfig) -> tuple[object, ...]:
|
||||
fp = config.get_provider(fallback.model, preset=fallback)
|
||||
return (
|
||||
fallback.model,
|
||||
fallback.provider,
|
||||
config.get_provider_name(fallback.model, preset=fallback),
|
||||
config.get_api_key(fallback.model, preset=fallback),
|
||||
config.get_api_base(fallback.model, preset=fallback),
|
||||
fp.extra_headers if fp else None,
|
||||
fp.extra_body if fp else None,
|
||||
getattr(fp, "region", None) if fp else None,
|
||||
getattr(fp, "profile", None) if fp else None,
|
||||
fallback.max_tokens,
|
||||
fallback.temperature,
|
||||
fallback.reasoning_effort,
|
||||
fallback.context_window_tokens,
|
||||
)
|
||||
|
||||
return (
|
||||
resolved.model,
|
||||
resolved.provider,
|
||||
@ -125,6 +205,7 @@ def provider_signature(
|
||||
resolved.temperature,
|
||||
resolved.reasoning_effort,
|
||||
resolved.context_window_tokens,
|
||||
tuple(_fallback_signature(fallback) for fallback in fallback_presets),
|
||||
)
|
||||
|
||||
|
||||
@ -135,10 +216,14 @@ def build_provider_snapshot(
|
||||
preset: ModelPresetConfig | None = None,
|
||||
) -> ProviderSnapshot:
|
||||
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
||||
fallback_windows = [
|
||||
fallback.context_window_tokens
|
||||
for fallback in _resolve_fallback_presets(config, resolved)
|
||||
]
|
||||
return ProviderSnapshot(
|
||||
provider=make_provider(config, preset=resolved),
|
||||
model=resolved.model,
|
||||
context_window_tokens=resolved.context_window_tokens,
|
||||
context_window_tokens=min([resolved.context_window_tokens, *fallback_windows]),
|
||||
signature=provider_signature(config, preset=resolved),
|
||||
)
|
||||
|
||||
|
||||
273
nanobot/providers/fallback_provider.py
Normal file
273
nanobot/providers/fallback_provider.py
Normal file
@ -0,0 +1,273 @@
|
||||
"""Provider wrapper that transparently fails over to fallback models on error."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
|
||||
# Circuit breaker tuned to match OpenAICompatProvider's Responses API breaker.
|
||||
_PRIMARY_FAILURE_THRESHOLD = 3
|
||||
_PRIMARY_COOLDOWN_S = 60
|
||||
_MISSING = object()
|
||||
_FALLBACK_ERROR_KINDS = frozenset({
|
||||
"timeout",
|
||||
"connection",
|
||||
"server_error",
|
||||
"rate_limit",
|
||||
"overloaded",
|
||||
})
|
||||
_NON_FALLBACK_ERROR_KINDS = frozenset({
|
||||
"authentication",
|
||||
"auth",
|
||||
"permission",
|
||||
"content_filter",
|
||||
"refusal",
|
||||
"context_length",
|
||||
"invalid_request",
|
||||
})
|
||||
_FALLBACK_ERROR_TOKENS = (
|
||||
"rate_limit",
|
||||
"rate limit",
|
||||
"too_many_requests",
|
||||
"too many requests",
|
||||
"overloaded",
|
||||
"server_error",
|
||||
"server error",
|
||||
"temporarily unavailable",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"connection",
|
||||
"insufficient_quota",
|
||||
"insufficient quota",
|
||||
"quota_exceeded",
|
||||
"quota exceeded",
|
||||
"quota_exhausted",
|
||||
"quota exhausted",
|
||||
"billing_hard_limit",
|
||||
"insufficient_balance",
|
||||
"balance",
|
||||
"out of credits",
|
||||
)
|
||||
|
||||
|
||||
class FallbackProvider(LLMProvider):
|
||||
"""Wrap a primary provider and transparently failover to fallback models.
|
||||
|
||||
When the primary model returns an error and no content has been streamed yet,
|
||||
the wrapper tries each fallback model in order. Each fallback model may
|
||||
reside on a different provider — a factory callable creates the underlying
|
||||
provider on-the-fly.
|
||||
|
||||
Key design:
|
||||
- Failover is request-scoped (the wrapper itself is stateless between turns).
|
||||
- Skipped when content was already streamed to avoid duplicate output.
|
||||
- Recursive failover is prevented by the factory returning plain providers.
|
||||
- Primary provider is circuit-broken after repeated failures to avoid
|
||||
wasting requests on a known-bad endpoint.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
primary: LLMProvider,
|
||||
fallback_presets: list[Any],
|
||||
provider_factory: Callable[[Any], LLMProvider],
|
||||
):
|
||||
self._primary = primary
|
||||
self._fallback_presets = list(fallback_presets)
|
||||
self._provider_factory = provider_factory
|
||||
self._has_fallbacks = bool(fallback_presets)
|
||||
self._primary_failures = 0
|
||||
self._primary_tripped_at: float | None = None
|
||||
|
||||
@property
|
||||
def generation(self):
|
||||
return self._primary.generation
|
||||
|
||||
@generation.setter
|
||||
def generation(self, value):
|
||||
self._primary.generation = value
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self._primary.get_default_model()
|
||||
|
||||
@property
|
||||
def supports_progress_deltas(self) -> bool:
|
||||
return bool(getattr(self._primary, "supports_progress_deltas", False))
|
||||
|
||||
def _primary_available(self) -> bool:
|
||||
"""Return True if the primary provider is not currently tripped."""
|
||||
if self._primary_tripped_at is None:
|
||||
return True
|
||||
if time.monotonic() - self._primary_tripped_at >= _PRIMARY_COOLDOWN_S:
|
||||
# Half-open: allow one probe attempt.
|
||||
return True
|
||||
return False
|
||||
|
||||
async def chat(self, **kwargs: Any) -> LLMResponse:
|
||||
if not self._has_fallbacks:
|
||||
return await self._primary.chat(**kwargs)
|
||||
return await self._try_with_fallback(
|
||||
lambda p, kw: p.chat(**kw), kwargs, has_streamed=None
|
||||
)
|
||||
|
||||
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||
if not self._has_fallbacks:
|
||||
return await self._primary.chat_stream(**kwargs)
|
||||
|
||||
has_streamed: list[bool] = [False]
|
||||
original_delta = kwargs.get("on_content_delta")
|
||||
|
||||
async def _tracking_delta(text: str) -> None:
|
||||
if text:
|
||||
has_streamed[0] = True
|
||||
if original_delta:
|
||||
await original_delta(text)
|
||||
|
||||
kwargs["on_content_delta"] = _tracking_delta
|
||||
return await self._try_with_fallback(
|
||||
lambda p, kw: p.chat_stream(**kw), kwargs, has_streamed=has_streamed
|
||||
)
|
||||
|
||||
async def _try_with_fallback(
|
||||
self,
|
||||
call: Callable[[LLMProvider, dict[str, Any]], Awaitable[LLMResponse]],
|
||||
kwargs: dict[str, Any],
|
||||
has_streamed: list[bool] | None,
|
||||
) -> LLMResponse:
|
||||
primary_model = kwargs.get("model") or self._primary.get_default_model()
|
||||
|
||||
if self._primary_available():
|
||||
response = await call(self._primary, kwargs)
|
||||
if response.finish_reason != "error":
|
||||
self._primary_failures = 0
|
||||
self._primary_tripped_at = None
|
||||
return response
|
||||
|
||||
if has_streamed is not None and has_streamed[0]:
|
||||
logger.warning(
|
||||
"Primary model error but content already streamed; skipping failover"
|
||||
)
|
||||
return response
|
||||
|
||||
if not self._should_fallback(response):
|
||||
logger.warning(
|
||||
"Primary model '{}' returned non-fallbackable error: {}",
|
||||
primary_model,
|
||||
(response.content or "")[:120],
|
||||
)
|
||||
return response
|
||||
|
||||
self._primary_failures += 1
|
||||
if self._primary_failures >= _PRIMARY_FAILURE_THRESHOLD:
|
||||
self._primary_tripped_at = time.monotonic()
|
||||
logger.warning(
|
||||
"Primary model '{}' circuit open after {} consecutive failures",
|
||||
primary_model, self._primary_failures,
|
||||
)
|
||||
else:
|
||||
logger.debug("Primary model '{}' circuit open; skipping", primary_model)
|
||||
|
||||
last_response: LLMResponse | None = None
|
||||
primary_skipped = not self._primary_available()
|
||||
for idx, fallback in enumerate(self._fallback_presets):
|
||||
fallback_model = fallback.model
|
||||
if has_streamed is not None and has_streamed[0]:
|
||||
break
|
||||
if idx == 0 and primary_skipped:
|
||||
logger.info(
|
||||
"Primary model '{}' circuit open, trying fallback '{}'",
|
||||
primary_model, fallback_model,
|
||||
)
|
||||
elif idx == 0:
|
||||
logger.info(
|
||||
"Primary model '{}' failed, trying fallback '{}'",
|
||||
primary_model, fallback_model,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Fallback '{}' also failed, trying next fallback '{}'",
|
||||
self._fallback_presets[idx - 1].model, fallback_model,
|
||||
)
|
||||
try:
|
||||
fallback_provider = self._provider_factory(fallback)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to create provider for fallback '{}': {}", fallback_model, exc
|
||||
)
|
||||
continue
|
||||
|
||||
original_values = {
|
||||
name: kwargs.get(name, _MISSING)
|
||||
for name in ("model", "max_tokens", "temperature", "reasoning_effort")
|
||||
}
|
||||
kwargs["model"] = fallback_model
|
||||
kwargs["max_tokens"] = fallback.max_tokens
|
||||
kwargs["temperature"] = fallback.temperature
|
||||
if fallback.reasoning_effort is None:
|
||||
kwargs.pop("reasoning_effort", None)
|
||||
else:
|
||||
kwargs["reasoning_effort"] = fallback.reasoning_effort
|
||||
try:
|
||||
fallback_response = await call(fallback_provider, kwargs)
|
||||
finally:
|
||||
for name, value in original_values.items():
|
||||
if value is _MISSING:
|
||||
kwargs.pop(name, None)
|
||||
else:
|
||||
kwargs[name] = value
|
||||
|
||||
if fallback_response.finish_reason != "error":
|
||||
logger.info(
|
||||
"Fallback '{}' succeeded after primary '{}' failed",
|
||||
fallback_model, primary_model,
|
||||
)
|
||||
return fallback_response
|
||||
|
||||
last_response = fallback_response
|
||||
logger.warning(
|
||||
"Fallback '{}' also failed: {}",
|
||||
fallback_model,
|
||||
(fallback_response.content or "")[:120],
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"All {} fallback model(s) failed",
|
||||
len(self._fallback_presets),
|
||||
)
|
||||
# Return the last error response we saw (primary or last fallback).
|
||||
if last_response is not None:
|
||||
return last_response
|
||||
# Primary was tripped and we have no fallbacks — synthesize an error.
|
||||
return LLMResponse(
|
||||
content=f"Primary model '{primary_model}' circuit open and no fallbacks available",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_fallback(response: LLMResponse) -> bool:
|
||||
if response.error_should_retry is False:
|
||||
return False
|
||||
status = response.error_status_code
|
||||
kind = (response.error_kind or "").lower()
|
||||
error_type = (response.error_type or "").lower()
|
||||
code = (response.error_code or "").lower()
|
||||
text = (response.content or "").lower()
|
||||
|
||||
if status in {400, 401, 403, 404, 422}:
|
||||
return False
|
||||
if kind in _NON_FALLBACK_ERROR_KINDS:
|
||||
return False
|
||||
if any(token in value for value in (kind, error_type, code) for token in _NON_FALLBACK_ERROR_KINDS):
|
||||
return False
|
||||
if response.error_should_retry is True:
|
||||
return True
|
||||
if status is not None and (status in {408, 409, 429} or 500 <= status <= 599):
|
||||
return True
|
||||
if kind in _FALLBACK_ERROR_KINDS:
|
||||
return True
|
||||
return any(token in value for value in (kind, error_type, code, text) for token in _FALLBACK_ERROR_TOKENS)
|
||||
613
tests/agent/test_runner_fallback.py
Normal file
613
tests/agent/test_runner_fallback.py
Normal file
@ -0,0 +1,613 @@
|
||||
"""Tests for FallbackProvider model failover."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import ModelPresetConfig
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
from nanobot.providers.fallback_provider import FallbackProvider
|
||||
|
||||
|
||||
def _make_response(
|
||||
content: str = "ok",
|
||||
finish_reason: str = "stop",
|
||||
*,
|
||||
error_kind: str | None = None,
|
||||
error_status_code: int | None = None,
|
||||
error_type: str | None = None,
|
||||
error_code: str | None = None,
|
||||
error_should_retry: bool | None = None,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
finish_reason=finish_reason,
|
||||
error_kind=error_kind,
|
||||
error_status_code=error_status_code,
|
||||
error_type=error_type,
|
||||
error_code=error_code,
|
||||
error_should_retry=error_should_retry,
|
||||
)
|
||||
|
||||
|
||||
def _error_response(content: str = "api error") -> LLMResponse:
|
||||
return _make_response(content, finish_reason="error", error_kind="server_error")
|
||||
|
||||
|
||||
def _fallback(
|
||||
model: str,
|
||||
provider: str = "custom",
|
||||
*,
|
||||
max_tokens: int = 8192,
|
||||
context_window_tokens: int = 65_536,
|
||||
temperature: float = 0.1,
|
||||
reasoning_effort: str | None = None,
|
||||
) -> ModelPresetConfig:
|
||||
return ModelPresetConfig(
|
||||
model=model,
|
||||
provider=provider,
|
||||
max_tokens=max_tokens,
|
||||
context_window_tokens=context_window_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
|
||||
class _FakeProvider(LLMProvider):
|
||||
"""Fake provider for testing."""
|
||||
|
||||
def __init__(self, name: str = "fake", response: LLMResponse | None = None):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self._response = response or _make_response()
|
||||
self.chat_calls: list[dict[str, Any]] = []
|
||||
self.chat_stream_calls: list[dict[str, Any]] = []
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return f"{self.name}/model"
|
||||
|
||||
async def chat(self, **kwargs: Any) -> LLMResponse:
|
||||
self.chat_calls.append(dict(kwargs))
|
||||
return self._response
|
||||
|
||||
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||
self.chat_stream_calls.append(dict(kwargs))
|
||||
on_delta = kwargs.get("on_content_delta")
|
||||
if on_delta and self._response.content:
|
||||
await on_delta(self._response.content)
|
||||
return self._response
|
||||
|
||||
|
||||
# -- config-level tests --
|
||||
|
||||
|
||||
def test_fallback_models_default_empty() -> None:
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
defaults = AgentDefaults()
|
||||
|
||||
assert defaults.fallback_models == []
|
||||
|
||||
|
||||
def test_fallback_models_accept_preset_refs_and_inline_configs() -> None:
|
||||
from nanobot.config.schema import Config, InlineFallbackConfig
|
||||
|
||||
config = Config.model_validate({
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"fallbackModels": [
|
||||
"deep",
|
||||
{
|
||||
"provider": "openai",
|
||||
"model": "gpt-4.1",
|
||||
"maxTokens": 4096,
|
||||
},
|
||||
]
|
||||
}
|
||||
},
|
||||
"modelPresets": {
|
||||
"deep": {"provider": "anthropic", "model": "claude-opus-4-7"}
|
||||
},
|
||||
})
|
||||
|
||||
assert config.agents.defaults.fallback_models[0] == "deep"
|
||||
assert config.agents.defaults.fallback_models[1] == InlineFallbackConfig(
|
||||
provider="openai",
|
||||
model="gpt-4.1",
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
def test_fallback_model_preset_ref_must_exist() -> None:
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
with pytest.raises(ValueError, match="fallback_models.*not found"):
|
||||
Config.model_validate({
|
||||
"agents": {"defaults": {"fallbackModels": ["missing"]}},
|
||||
"modelPresets": {},
|
||||
})
|
||||
|
||||
|
||||
def test_provider_signature_tracks_fallback_presets_and_provider_config() -> None:
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.providers.factory import provider_signature
|
||||
|
||||
base = {
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"modelPreset": "fast",
|
||||
"fallbackModels": ["deep"],
|
||||
}
|
||||
},
|
||||
"modelPresets": {
|
||||
"fast": {"model": "openai/gpt-4.1", "provider": "openai"},
|
||||
"deep": {"model": "anthropic/claude-sonnet-4-6", "provider": "anthropic"},
|
||||
},
|
||||
"providers": {
|
||||
"openai": {"apiKey": "primary-key"},
|
||||
"anthropic": {"apiKey": "fallback-key"},
|
||||
},
|
||||
}
|
||||
changed_fallback = {
|
||||
**base,
|
||||
"agents": {"defaults": {"modelPreset": "fast", "fallbackModels": ["backup"]}},
|
||||
"modelPresets": {
|
||||
**base["modelPresets"],
|
||||
"backup": {"model": "deepseek/deepseek-chat", "provider": "deepseek"},
|
||||
},
|
||||
"providers": {
|
||||
**base["providers"],
|
||||
"deepseek": {"apiKey": "deepseek-key"},
|
||||
},
|
||||
}
|
||||
changed_key = {
|
||||
**base,
|
||||
"providers": {
|
||||
"openai": {"apiKey": "primary-key"},
|
||||
"anthropic": {"apiKey": "new-fallback-key"},
|
||||
},
|
||||
}
|
||||
|
||||
signature = provider_signature(Config.model_validate(base))
|
||||
|
||||
assert signature != provider_signature(Config.model_validate(changed_fallback))
|
||||
assert signature != provider_signature(Config.model_validate(changed_key))
|
||||
|
||||
|
||||
def test_provider_snapshot_uses_smallest_fallback_context_window() -> None:
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.providers.factory import build_provider_snapshot
|
||||
|
||||
config = Config.model_validate({
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"modelPreset": "fast",
|
||||
"fallbackModels": ["deep"],
|
||||
}
|
||||
},
|
||||
"modelPresets": {
|
||||
"fast": {
|
||||
"model": "openai/gpt-4.1",
|
||||
"provider": "openai",
|
||||
"contextWindowTokens": 128000,
|
||||
},
|
||||
"deep": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"provider": "deepseek",
|
||||
"contextWindowTokens": 64000,
|
||||
},
|
||||
},
|
||||
"providers": {
|
||||
"openai": {"apiKey": "primary-key"},
|
||||
"deepseek": {"apiKey": "fallback-key"},
|
||||
},
|
||||
})
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
snapshot = build_provider_snapshot(config)
|
||||
|
||||
assert snapshot.context_window_tokens == 64000
|
||||
|
||||
|
||||
def test_inline_fallback_reasoning_effort_does_not_inherit_primary() -> None:
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.providers.factory import provider_signature
|
||||
|
||||
config = Config.model_validate({
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"modelPreset": "fast",
|
||||
"fallbackModels": [
|
||||
{"provider": "openai", "model": "gpt-4.1"}
|
||||
],
|
||||
}
|
||||
},
|
||||
"modelPresets": {
|
||||
"fast": {
|
||||
"model": "anthropic/claude-opus-4-5",
|
||||
"provider": "anthropic",
|
||||
"reasoningEffort": "high",
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"anthropic": {"apiKey": "primary-key"},
|
||||
"openai": {"apiKey": "fallback-key"},
|
||||
},
|
||||
})
|
||||
|
||||
signature = provider_signature(config)
|
||||
fallback_signatures = signature[-1]
|
||||
|
||||
assert fallback_signatures[0][11] is None
|
||||
|
||||
|
||||
# -- FallbackProvider tests --
|
||||
|
||||
|
||||
class TestNoFallbackWhenPrimarySucceeds:
|
||||
@pytest.mark.asyncio
|
||||
async def test(self) -> None:
|
||||
primary = _FakeProvider("primary", _make_response("primary ok"))
|
||||
factory = MagicMock()
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
assert result.content == "primary ok"
|
||||
assert result.finish_reason == "stop"
|
||||
factory.assert_not_called()
|
||||
|
||||
|
||||
class TestFallbackOnPrimaryError:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_fallback_succeeds(self) -> None:
|
||||
primary = _FakeProvider("primary", _error_response())
|
||||
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model")
|
||||
assert result.content == "fallback ok"
|
||||
assert result.finish_reason == "stop"
|
||||
factory.assert_called_once_with(_fallback("fallback-a"))
|
||||
assert primary.chat_calls[0]["model"] == "primary-model"
|
||||
assert fallback.chat_calls[0]["model"] == "fallback-a"
|
||||
|
||||
|
||||
class TestNoFallbackWhenContentStreamed:
|
||||
@pytest.mark.asyncio
|
||||
async def test(self) -> None:
|
||||
primary = _FakeProvider("primary", _error_response())
|
||||
factory = MagicMock()
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
async def _delta(text: str) -> None:
|
||||
pass
|
||||
|
||||
result = await fb.chat_stream(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
on_content_delta=_delta,
|
||||
)
|
||||
# Primary returns error but content was "streamed" (FakeProvider calls delta)
|
||||
# so failover should be skipped
|
||||
assert result.finish_reason == "error"
|
||||
factory.assert_not_called()
|
||||
|
||||
|
||||
class TestFailoverOnTransientError:
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit(self) -> None:
|
||||
primary = _FakeProvider("primary", _error_response("rate limit exceeded"))
|
||||
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
assert result.content == "fallback ok"
|
||||
assert result.finish_reason == "stop"
|
||||
factory.assert_called_once_with(_fallback("fallback-a"))
|
||||
|
||||
|
||||
class TestNoFallbackOnNonRetryableError:
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_request(self) -> None:
|
||||
primary = _FakeProvider(
|
||||
"primary",
|
||||
_make_response(
|
||||
"invalid request",
|
||||
finish_reason="error",
|
||||
error_status_code=400,
|
||||
error_kind="invalid_request",
|
||||
),
|
||||
)
|
||||
factory = MagicMock()
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
|
||||
assert result.finish_reason == "error"
|
||||
factory.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_error(self) -> None:
|
||||
primary = _FakeProvider(
|
||||
"primary",
|
||||
_make_response(
|
||||
"unauthorized",
|
||||
finish_reason="error",
|
||||
error_status_code=401,
|
||||
error_kind="authentication",
|
||||
),
|
||||
)
|
||||
factory = MagicMock()
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
|
||||
assert result.finish_reason == "error"
|
||||
factory.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout(self) -> None:
|
||||
primary = _FakeProvider(
|
||||
"primary",
|
||||
_make_response("timed out", finish_reason="error", error_kind="timeout"),
|
||||
)
|
||||
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
assert result.content == "fallback ok"
|
||||
assert result.finish_reason == "stop"
|
||||
factory.assert_called_once_with(_fallback("fallback-a"))
|
||||
|
||||
|
||||
class TestFallbackTriesModelsInOrder:
|
||||
@pytest.mark.asyncio
|
||||
async def test(self) -> None:
|
||||
primary = _FakeProvider("primary", _error_response("primary fail"))
|
||||
fallback_a = _FakeProvider("a", _error_response("a fail"))
|
||||
fallback_b = _FakeProvider("b", _make_response("b ok"))
|
||||
factory = MagicMock(side_effect=[fallback_a, fallback_b])
|
||||
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a"), _fallback("fallback-b")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
assert result.content == "b ok"
|
||||
assert factory.call_count == 2
|
||||
factory.assert_any_call(_fallback("fallback-a"))
|
||||
factory.assert_any_call(_fallback("fallback-b"))
|
||||
|
||||
|
||||
class TestAllFallbacksFail:
|
||||
@pytest.mark.asyncio
|
||||
async def test(self) -> None:
|
||||
primary = _FakeProvider("primary", _error_response("primary fail"))
|
||||
fallback = _FakeProvider("fallback", _error_response("all fail"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
assert result.finish_reason == "error"
|
||||
assert "all fail" in result.content
|
||||
|
||||
|
||||
class TestFactoryExceptionSkipsModel:
|
||||
@pytest.mark.asyncio
|
||||
async def test(self) -> None:
|
||||
primary = _FakeProvider("primary", _error_response())
|
||||
fallback_b = _FakeProvider("b", _make_response("b ok"))
|
||||
factory = MagicMock(side_effect=[ValueError("no key"), fallback_b])
|
||||
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a"), _fallback("fallback-b")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
assert result.content == "b ok"
|
||||
assert factory.call_count == 2
|
||||
|
||||
|
||||
class TestFallbackModelParameter:
|
||||
@pytest.mark.asyncio
|
||||
async def test(self) -> None:
|
||||
"""Fallback calls should use the fallback model name."""
|
||||
primary = _FakeProvider("primary", _error_response())
|
||||
fallback = _FakeProvider("fallback", _make_response("ok"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-model")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model")
|
||||
assert fallback.chat_calls[0]["model"] == "fallback-model"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_fallback_generation_fields(self) -> None:
|
||||
primary = _FakeProvider("primary", _error_response())
|
||||
fallback = _FakeProvider("fallback", _make_response("ok"))
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[
|
||||
_fallback(
|
||||
"fallback-model",
|
||||
max_tokens=1234,
|
||||
temperature=0.4,
|
||||
reasoning_effort=None,
|
||||
)
|
||||
],
|
||||
provider_factory=MagicMock(return_value=fallback),
|
||||
)
|
||||
|
||||
await fb.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="primary-model",
|
||||
max_tokens=8192,
|
||||
temperature=0.1,
|
||||
reasoning_effort="high",
|
||||
)
|
||||
|
||||
assert fallback.chat_calls[0]["model"] == "fallback-model"
|
||||
assert fallback.chat_calls[0]["max_tokens"] == 1234
|
||||
assert fallback.chat_calls[0]["temperature"] == 0.4
|
||||
assert "reasoning_effort" not in fallback.chat_calls[0]
|
||||
|
||||
|
||||
class TestNoFallbackWhenEmptyList:
|
||||
@pytest.mark.asyncio
|
||||
async def test(self) -> None:
|
||||
primary = _FakeProvider("primary", _error_response())
|
||||
factory = MagicMock()
|
||||
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
assert result.finish_reason == "error"
|
||||
factory.assert_not_called()
|
||||
|
||||
|
||||
class TestChatStreamFailover:
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_succeeds(self) -> None:
|
||||
# Use empty content so on_content_delta is not triggered on the error
|
||||
primary = _FakeProvider("primary", _error_response(""))
|
||||
fallback = _FakeProvider("fallback", _make_response("stream ok"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
result = await fb.chat_stream(messages=[{"role": "user", "content": "hi"}])
|
||||
assert result.content == "stream ok"
|
||||
assert result.finish_reason == "stop"
|
||||
|
||||
|
||||
class TestGetDefaultModel:
|
||||
def test(self) -> None:
|
||||
primary = _FakeProvider("primary")
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("a")],
|
||||
provider_factory=MagicMock(),
|
||||
)
|
||||
assert fb.get_default_model() == "primary/model"
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_primary_after_three_failures(self) -> None:
|
||||
primary = _FakeProvider("primary", _error_response())
|
||||
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
# 3 failures — primary should still be called each time
|
||||
for _ in range(3):
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
assert result.content == "fallback ok"
|
||||
|
||||
assert len(primary.chat_calls) == 3
|
||||
|
||||
# 4th call — primary circuit is open, should be skipped
|
||||
primary.chat_calls.clear()
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
assert result.content == "fallback ok"
|
||||
assert len(primary.chat_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resets_on_success(self) -> None:
|
||||
primary = _FakeProvider("primary", _error_response())
|
||||
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
# 2 failures
|
||||
for _ in range(2):
|
||||
await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
|
||||
# 3rd call: primary succeeds — circuit resets
|
||||
primary._response = _make_response("primary ok")
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
assert result.content == "primary ok"
|
||||
|
||||
# 4th call: primary fails again — should still be called (counter reset)
|
||||
primary._response = _error_response()
|
||||
primary.chat_calls.clear()
|
||||
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
assert result.content == "fallback ok"
|
||||
assert len(primary.chat_calls) == 1
|
||||
|
||||
|
||||
class TestGenerationForwarded:
|
||||
def test(self) -> None:
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
primary = _FakeProvider("primary")
|
||||
primary.generation = GenerationSettings(temperature=0.5, max_tokens=1024)
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("a")],
|
||||
provider_factory=MagicMock(),
|
||||
)
|
||||
assert fb.generation.temperature == 0.5
|
||||
assert fb.generation.max_tokens == 1024
|
||||
Loading…
x
Reference in New Issue
Block a user