diff --git a/docs/configuration.md b/docs/configuration.md index 0123017d2..3f7f39709 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -672,7 +672,8 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age "maxTokens": 8192, "contextWindowTokens": 128000, "temperature": 0.1, - "modelPreset": null + "modelPreset": "fast", + "fallbackModels": ["deep"] } }, "modelPresets": { @@ -708,6 +709,40 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age `default` is reserved and always means the implicit preset built from `agents.defaults.*`; do not define `modelPresets.default`. Use `/model default` to switch back to `agents.defaults.*`. +### Model Fallbacks + +`agents.defaults.fallbackModels` defines an ordered failover chain for the active model configuration. The primary model is still selected by `agents.defaults.modelPreset` (or the implicit default config when no preset is active). + +Each fallback candidate can be either: + +- A preset name from `modelPresets`, such as `"deep"`. The preset's full model, provider, generation, and context-window config is used. +- An inline fallback object with at least `provider` and `model`. Optional `maxTokens`, `contextWindowTokens`, and `temperature` fields inherit from the active primary config when omitted. `reasoningEffort` does not inherit; omit it to leave reasoning off for that fallback, or set it explicitly for models that support reasoning. + +```json +{ + "agents": { + "defaults": { + "modelPreset": "fast", + "fallbackModels": [ + "deep", + { + "provider": "deepseek", + "model": "deepseek-v4-pro", + "maxTokens": 4096, + "contextWindowTokens": 262144 + } + ] + } + } +} +``` + +String entries are preset names, not raw model names. If you want to use a model that is not already a preset, use the inline object form. + +Failover only runs when the primary provider returns a retryable model/provider error before any answer text has been streamed. Typical fallback cases include timeouts, connection errors, 5xx server errors, 429 rate limits, overloads, and quota/balance exhaustion. It does not run for malformed requests, authentication/permission errors, content filtering/refusals, or context-length/message-format errors. + +If fallback candidates use smaller `contextWindowTokens` values, nanobot builds context using the smallest window in the active chain so every candidate can receive the same prompt. + Set `agents.defaults.modelPreset` to start with a named preset: ```json diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index ff7454d71..c8556ec9f 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -74,6 +74,20 @@ class DreamConfig(Base): return f"every {hours}h" +class InlineFallbackConfig(Base): + """One inline fallback model configuration.""" + + model: str + provider: str + max_tokens: int | None = None + context_window_tokens: int | None = None + temperature: float | None = None + reasoning_effort: str | None = None + + +FallbackCandidate = str | InlineFallbackConfig + + class ModelPresetConfig(Base): """A named set of model + generation parameters for quick switching.""" @@ -106,6 +120,7 @@ class AgentDefaults(Base): context_window_tokens: int = 65_536 context_block_limit: int | None = None temperature: float = 0.1 + fallback_models: list[FallbackCandidate] = Field(default_factory=list) max_tool_iterations: int = 200 max_concurrent_subagents: int = Field(default=1, ge=1) max_tool_result_chars: int = 16_000 @@ -287,6 +302,9 @@ class Config(BaseSettings): name = self.agents.defaults.model_preset if name and name != "default" and name not in self.model_presets: raise ValueError(f"model_preset {name!r} not found in model_presets") + for fallback in self.agents.defaults.fallback_models: + if isinstance(fallback, str) and fallback not in self.model_presets: + raise ValueError(f"fallback_models entry {fallback!r} not found in model_presets") return self def resolve_default_preset(self) -> ModelPresetConfig: diff --git a/nanobot/providers/factory.py b/nanobot/providers/factory.py index 3473afff3..288611392 100644 --- a/nanobot/providers/factory.py +++ b/nanobot/providers/factory.py @@ -5,8 +5,9 @@ from __future__ import annotations from dataclasses import dataclass from pathlib import Path -from nanobot.config.schema import Config, ModelPresetConfig +from nanobot.config.schema import Config, InlineFallbackConfig, ModelPresetConfig from nanobot.providers.base import LLMProvider +from nanobot.providers.fallback_provider import FallbackProvider from nanobot.providers.registry import find_by_name @@ -27,15 +28,16 @@ def _resolve_model_preset( return preset if preset is not None else config.resolve_preset(preset_name) -def make_provider( +def _make_provider_core( config: Config, *, preset_name: str | None = None, preset: ModelPresetConfig | None = None, + model: str | None = None, ) -> LLMProvider: - """Create the LLM provider implied by config.""" + """Create a plain LLM provider without failover wrapping.""" resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) - model = resolved.model + model = model or resolved.model provider_name = config.get_provider_name(model, preset=resolved) p = config.get_provider(model, preset=resolved) spec = find_by_name(provider_name) if provider_name else None @@ -102,15 +104,93 @@ def make_provider( return provider +def _inline_fallback_preset( + primary: ModelPresetConfig, + fallback: InlineFallbackConfig, +) -> ModelPresetConfig: + return ModelPresetConfig( + model=fallback.model, + provider=fallback.provider, + max_tokens=fallback.max_tokens if fallback.max_tokens is not None else primary.max_tokens, + context_window_tokens=( + fallback.context_window_tokens + if fallback.context_window_tokens is not None + else primary.context_window_tokens + ), + temperature=( + fallback.temperature if fallback.temperature is not None else primary.temperature + ), + reasoning_effort=fallback.reasoning_effort, + ) + + +def _resolve_fallback_presets(config: Config, primary: ModelPresetConfig) -> list[ModelPresetConfig]: + presets: list[ModelPresetConfig] = [] + for fallback in config.agents.defaults.fallback_models: + if isinstance(fallback, str): + presets.append(config.model_presets[fallback]) + else: + presets.append(_inline_fallback_preset(primary, fallback)) + return presets + + +def make_provider( + config: Config, + *, + preset_name: str | None = None, + preset: ModelPresetConfig | None = None, + model: str | None = None, +) -> LLMProvider: + """Create the LLM provider implied by config. + + When *model* is given, it overrides the resolved/preset model — used by + the failover path to create providers for fallback models. + """ + resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) + provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model) + fallback_presets = _resolve_fallback_presets(config, resolved) + + if fallback_presets: + provider = FallbackProvider( + primary=provider, + fallback_presets=fallback_presets, + provider_factory=lambda fb: _make_provider_core( + config, preset_name=preset_name, preset=fb + ), + ) + + return provider + + def provider_signature( config: Config, *, preset_name: str | None = None, preset: ModelPresetConfig | None = None, ) -> tuple[object, ...]: - """Return the config fields that affect the primary LLM provider.""" + """Return the config fields that affect the active provider chain.""" resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) p = config.get_provider(resolved.model, preset=resolved) + fallback_presets = _resolve_fallback_presets(config, resolved) + + def _fallback_signature(fallback: ModelPresetConfig) -> tuple[object, ...]: + fp = config.get_provider(fallback.model, preset=fallback) + return ( + fallback.model, + fallback.provider, + config.get_provider_name(fallback.model, preset=fallback), + config.get_api_key(fallback.model, preset=fallback), + config.get_api_base(fallback.model, preset=fallback), + fp.extra_headers if fp else None, + fp.extra_body if fp else None, + getattr(fp, "region", None) if fp else None, + getattr(fp, "profile", None) if fp else None, + fallback.max_tokens, + fallback.temperature, + fallback.reasoning_effort, + fallback.context_window_tokens, + ) + return ( resolved.model, resolved.provider, @@ -125,6 +205,7 @@ def provider_signature( resolved.temperature, resolved.reasoning_effort, resolved.context_window_tokens, + tuple(_fallback_signature(fallback) for fallback in fallback_presets), ) @@ -135,10 +216,14 @@ def build_provider_snapshot( preset: ModelPresetConfig | None = None, ) -> ProviderSnapshot: resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) + fallback_windows = [ + fallback.context_window_tokens + for fallback in _resolve_fallback_presets(config, resolved) + ] return ProviderSnapshot( provider=make_provider(config, preset=resolved), model=resolved.model, - context_window_tokens=resolved.context_window_tokens, + context_window_tokens=min([resolved.context_window_tokens, *fallback_windows]), signature=provider_signature(config, preset=resolved), ) diff --git a/nanobot/providers/fallback_provider.py b/nanobot/providers/fallback_provider.py new file mode 100644 index 000000000..c082c2361 --- /dev/null +++ b/nanobot/providers/fallback_provider.py @@ -0,0 +1,273 @@ +"""Provider wrapper that transparently fails over to fallback models on error.""" + +from __future__ import annotations + +import time +from collections.abc import Awaitable, Callable +from typing import Any + +from loguru import logger + +from nanobot.providers.base import LLMProvider, LLMResponse + +# Circuit breaker tuned to match OpenAICompatProvider's Responses API breaker. +_PRIMARY_FAILURE_THRESHOLD = 3 +_PRIMARY_COOLDOWN_S = 60 +_MISSING = object() +_FALLBACK_ERROR_KINDS = frozenset({ + "timeout", + "connection", + "server_error", + "rate_limit", + "overloaded", +}) +_NON_FALLBACK_ERROR_KINDS = frozenset({ + "authentication", + "auth", + "permission", + "content_filter", + "refusal", + "context_length", + "invalid_request", +}) +_FALLBACK_ERROR_TOKENS = ( + "rate_limit", + "rate limit", + "too_many_requests", + "too many requests", + "overloaded", + "server_error", + "server error", + "temporarily unavailable", + "timeout", + "timed out", + "connection", + "insufficient_quota", + "insufficient quota", + "quota_exceeded", + "quota exceeded", + "quota_exhausted", + "quota exhausted", + "billing_hard_limit", + "insufficient_balance", + "balance", + "out of credits", +) + + +class FallbackProvider(LLMProvider): + """Wrap a primary provider and transparently failover to fallback models. + + When the primary model returns an error and no content has been streamed yet, + the wrapper tries each fallback model in order. Each fallback model may + reside on a different provider — a factory callable creates the underlying + provider on-the-fly. + + Key design: + - Failover is request-scoped (the wrapper itself is stateless between turns). + - Skipped when content was already streamed to avoid duplicate output. + - Recursive failover is prevented by the factory returning plain providers. + - Primary provider is circuit-broken after repeated failures to avoid + wasting requests on a known-bad endpoint. + """ + + def __init__( + self, + primary: LLMProvider, + fallback_presets: list[Any], + provider_factory: Callable[[Any], LLMProvider], + ): + self._primary = primary + self._fallback_presets = list(fallback_presets) + self._provider_factory = provider_factory + self._has_fallbacks = bool(fallback_presets) + self._primary_failures = 0 + self._primary_tripped_at: float | None = None + + @property + def generation(self): + return self._primary.generation + + @generation.setter + def generation(self, value): + self._primary.generation = value + + def get_default_model(self) -> str: + return self._primary.get_default_model() + + @property + def supports_progress_deltas(self) -> bool: + return bool(getattr(self._primary, "supports_progress_deltas", False)) + + def _primary_available(self) -> bool: + """Return True if the primary provider is not currently tripped.""" + if self._primary_tripped_at is None: + return True + if time.monotonic() - self._primary_tripped_at >= _PRIMARY_COOLDOWN_S: + # Half-open: allow one probe attempt. + return True + return False + + async def chat(self, **kwargs: Any) -> LLMResponse: + if not self._has_fallbacks: + return await self._primary.chat(**kwargs) + return await self._try_with_fallback( + lambda p, kw: p.chat(**kw), kwargs, has_streamed=None + ) + + async def chat_stream(self, **kwargs: Any) -> LLMResponse: + if not self._has_fallbacks: + return await self._primary.chat_stream(**kwargs) + + has_streamed: list[bool] = [False] + original_delta = kwargs.get("on_content_delta") + + async def _tracking_delta(text: str) -> None: + if text: + has_streamed[0] = True + if original_delta: + await original_delta(text) + + kwargs["on_content_delta"] = _tracking_delta + return await self._try_with_fallback( + lambda p, kw: p.chat_stream(**kw), kwargs, has_streamed=has_streamed + ) + + async def _try_with_fallback( + self, + call: Callable[[LLMProvider, dict[str, Any]], Awaitable[LLMResponse]], + kwargs: dict[str, Any], + has_streamed: list[bool] | None, + ) -> LLMResponse: + primary_model = kwargs.get("model") or self._primary.get_default_model() + + if self._primary_available(): + response = await call(self._primary, kwargs) + if response.finish_reason != "error": + self._primary_failures = 0 + self._primary_tripped_at = None + return response + + if has_streamed is not None and has_streamed[0]: + logger.warning( + "Primary model error but content already streamed; skipping failover" + ) + return response + + if not self._should_fallback(response): + logger.warning( + "Primary model '{}' returned non-fallbackable error: {}", + primary_model, + (response.content or "")[:120], + ) + return response + + self._primary_failures += 1 + if self._primary_failures >= _PRIMARY_FAILURE_THRESHOLD: + self._primary_tripped_at = time.monotonic() + logger.warning( + "Primary model '{}' circuit open after {} consecutive failures", + primary_model, self._primary_failures, + ) + else: + logger.debug("Primary model '{}' circuit open; skipping", primary_model) + + last_response: LLMResponse | None = None + primary_skipped = not self._primary_available() + for idx, fallback in enumerate(self._fallback_presets): + fallback_model = fallback.model + if has_streamed is not None and has_streamed[0]: + break + if idx == 0 and primary_skipped: + logger.info( + "Primary model '{}' circuit open, trying fallback '{}'", + primary_model, fallback_model, + ) + elif idx == 0: + logger.info( + "Primary model '{}' failed, trying fallback '{}'", + primary_model, fallback_model, + ) + else: + logger.info( + "Fallback '{}' also failed, trying next fallback '{}'", + self._fallback_presets[idx - 1].model, fallback_model, + ) + try: + fallback_provider = self._provider_factory(fallback) + except Exception as exc: + logger.warning( + "Failed to create provider for fallback '{}': {}", fallback_model, exc + ) + continue + + original_values = { + name: kwargs.get(name, _MISSING) + for name in ("model", "max_tokens", "temperature", "reasoning_effort") + } + kwargs["model"] = fallback_model + kwargs["max_tokens"] = fallback.max_tokens + kwargs["temperature"] = fallback.temperature + if fallback.reasoning_effort is None: + kwargs.pop("reasoning_effort", None) + else: + kwargs["reasoning_effort"] = fallback.reasoning_effort + try: + fallback_response = await call(fallback_provider, kwargs) + finally: + for name, value in original_values.items(): + if value is _MISSING: + kwargs.pop(name, None) + else: + kwargs[name] = value + + if fallback_response.finish_reason != "error": + logger.info( + "Fallback '{}' succeeded after primary '{}' failed", + fallback_model, primary_model, + ) + return fallback_response + + last_response = fallback_response + logger.warning( + "Fallback '{}' also failed: {}", + fallback_model, + (fallback_response.content or "")[:120], + ) + + logger.warning( + "All {} fallback model(s) failed", + len(self._fallback_presets), + ) + # Return the last error response we saw (primary or last fallback). + if last_response is not None: + return last_response + # Primary was tripped and we have no fallbacks — synthesize an error. + return LLMResponse( + content=f"Primary model '{primary_model}' circuit open and no fallbacks available", + finish_reason="error", + ) + + @staticmethod + def _should_fallback(response: LLMResponse) -> bool: + if response.error_should_retry is False: + return False + status = response.error_status_code + kind = (response.error_kind or "").lower() + error_type = (response.error_type or "").lower() + code = (response.error_code or "").lower() + text = (response.content or "").lower() + + if status in {400, 401, 403, 404, 422}: + return False + if kind in _NON_FALLBACK_ERROR_KINDS: + return False + if any(token in value for value in (kind, error_type, code) for token in _NON_FALLBACK_ERROR_KINDS): + return False + if response.error_should_retry is True: + return True + if status is not None and (status in {408, 409, 429} or 500 <= status <= 599): + return True + if kind in _FALLBACK_ERROR_KINDS: + return True + return any(token in value for value in (kind, error_type, code, text) for token in _FALLBACK_ERROR_TOKENS) diff --git a/tests/agent/test_runner_fallback.py b/tests/agent/test_runner_fallback.py new file mode 100644 index 000000000..0e36fb02a --- /dev/null +++ b/tests/agent/test_runner_fallback.py @@ -0,0 +1,613 @@ +"""Tests for FallbackProvider model failover.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from nanobot.config.schema import ModelPresetConfig +from nanobot.providers.base import LLMProvider, LLMResponse +from nanobot.providers.fallback_provider import FallbackProvider + + +def _make_response( + content: str = "ok", + finish_reason: str = "stop", + *, + error_kind: str | None = None, + error_status_code: int | None = None, + error_type: str | None = None, + error_code: str | None = None, + error_should_retry: bool | None = None, +) -> LLMResponse: + return LLMResponse( + content=content, + finish_reason=finish_reason, + error_kind=error_kind, + error_status_code=error_status_code, + error_type=error_type, + error_code=error_code, + error_should_retry=error_should_retry, + ) + + +def _error_response(content: str = "api error") -> LLMResponse: + return _make_response(content, finish_reason="error", error_kind="server_error") + + +def _fallback( + model: str, + provider: str = "custom", + *, + max_tokens: int = 8192, + context_window_tokens: int = 65_536, + temperature: float = 0.1, + reasoning_effort: str | None = None, +) -> ModelPresetConfig: + return ModelPresetConfig( + model=model, + provider=provider, + max_tokens=max_tokens, + context_window_tokens=context_window_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + ) + + +class _FakeProvider(LLMProvider): + """Fake provider for testing.""" + + def __init__(self, name: str = "fake", response: LLMResponse | None = None): + super().__init__() + self.name = name + self._response = response or _make_response() + self.chat_calls: list[dict[str, Any]] = [] + self.chat_stream_calls: list[dict[str, Any]] = [] + + def get_default_model(self) -> str: + return f"{self.name}/model" + + async def chat(self, **kwargs: Any) -> LLMResponse: + self.chat_calls.append(dict(kwargs)) + return self._response + + async def chat_stream(self, **kwargs: Any) -> LLMResponse: + self.chat_stream_calls.append(dict(kwargs)) + on_delta = kwargs.get("on_content_delta") + if on_delta and self._response.content: + await on_delta(self._response.content) + return self._response + + +# -- config-level tests -- + + +def test_fallback_models_default_empty() -> None: + from nanobot.config.schema import AgentDefaults + + defaults = AgentDefaults() + + assert defaults.fallback_models == [] + + +def test_fallback_models_accept_preset_refs_and_inline_configs() -> None: + from nanobot.config.schema import Config, InlineFallbackConfig + + config = Config.model_validate({ + "agents": { + "defaults": { + "fallbackModels": [ + "deep", + { + "provider": "openai", + "model": "gpt-4.1", + "maxTokens": 4096, + }, + ] + } + }, + "modelPresets": { + "deep": {"provider": "anthropic", "model": "claude-opus-4-7"} + }, + }) + + assert config.agents.defaults.fallback_models[0] == "deep" + assert config.agents.defaults.fallback_models[1] == InlineFallbackConfig( + provider="openai", + model="gpt-4.1", + max_tokens=4096, + ) + + +def test_fallback_model_preset_ref_must_exist() -> None: + from nanobot.config.schema import Config + + with pytest.raises(ValueError, match="fallback_models.*not found"): + Config.model_validate({ + "agents": {"defaults": {"fallbackModels": ["missing"]}}, + "modelPresets": {}, + }) + + +def test_provider_signature_tracks_fallback_presets_and_provider_config() -> None: + from nanobot.config.schema import Config + from nanobot.providers.factory import provider_signature + + base = { + "agents": { + "defaults": { + "modelPreset": "fast", + "fallbackModels": ["deep"], + } + }, + "modelPresets": { + "fast": {"model": "openai/gpt-4.1", "provider": "openai"}, + "deep": {"model": "anthropic/claude-sonnet-4-6", "provider": "anthropic"}, + }, + "providers": { + "openai": {"apiKey": "primary-key"}, + "anthropic": {"apiKey": "fallback-key"}, + }, + } + changed_fallback = { + **base, + "agents": {"defaults": {"modelPreset": "fast", "fallbackModels": ["backup"]}}, + "modelPresets": { + **base["modelPresets"], + "backup": {"model": "deepseek/deepseek-chat", "provider": "deepseek"}, + }, + "providers": { + **base["providers"], + "deepseek": {"apiKey": "deepseek-key"}, + }, + } + changed_key = { + **base, + "providers": { + "openai": {"apiKey": "primary-key"}, + "anthropic": {"apiKey": "new-fallback-key"}, + }, + } + + signature = provider_signature(Config.model_validate(base)) + + assert signature != provider_signature(Config.model_validate(changed_fallback)) + assert signature != provider_signature(Config.model_validate(changed_key)) + + +def test_provider_snapshot_uses_smallest_fallback_context_window() -> None: + from nanobot.config.schema import Config + from nanobot.providers.factory import build_provider_snapshot + + config = Config.model_validate({ + "agents": { + "defaults": { + "modelPreset": "fast", + "fallbackModels": ["deep"], + } + }, + "modelPresets": { + "fast": { + "model": "openai/gpt-4.1", + "provider": "openai", + "contextWindowTokens": 128000, + }, + "deep": { + "model": "deepseek/deepseek-chat", + "provider": "deepseek", + "contextWindowTokens": 64000, + }, + }, + "providers": { + "openai": {"apiKey": "primary-key"}, + "deepseek": {"apiKey": "fallback-key"}, + }, + }) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + snapshot = build_provider_snapshot(config) + + assert snapshot.context_window_tokens == 64000 + + +def test_inline_fallback_reasoning_effort_does_not_inherit_primary() -> None: + from nanobot.config.schema import Config + from nanobot.providers.factory import provider_signature + + config = Config.model_validate({ + "agents": { + "defaults": { + "modelPreset": "fast", + "fallbackModels": [ + {"provider": "openai", "model": "gpt-4.1"} + ], + } + }, + "modelPresets": { + "fast": { + "model": "anthropic/claude-opus-4-5", + "provider": "anthropic", + "reasoningEffort": "high", + } + }, + "providers": { + "anthropic": {"apiKey": "primary-key"}, + "openai": {"apiKey": "fallback-key"}, + }, + }) + + signature = provider_signature(config) + fallback_signatures = signature[-1] + + assert fallback_signatures[0][11] is None + + +# -- FallbackProvider tests -- + + +class TestNoFallbackWhenPrimarySucceeds: + @pytest.mark.asyncio + async def test(self) -> None: + primary = _FakeProvider("primary", _make_response("primary ok")) + factory = MagicMock() + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "primary ok" + assert result.finish_reason == "stop" + factory.assert_not_called() + + +class TestFallbackOnPrimaryError: + @pytest.mark.asyncio + async def test_first_fallback_succeeds(self) -> None: + primary = _FakeProvider("primary", _error_response()) + fallback = _FakeProvider("fallback", _make_response("fallback ok")) + factory = MagicMock(return_value=fallback) + + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model") + assert result.content == "fallback ok" + assert result.finish_reason == "stop" + factory.assert_called_once_with(_fallback("fallback-a")) + assert primary.chat_calls[0]["model"] == "primary-model" + assert fallback.chat_calls[0]["model"] == "fallback-a" + + +class TestNoFallbackWhenContentStreamed: + @pytest.mark.asyncio + async def test(self) -> None: + primary = _FakeProvider("primary", _error_response()) + factory = MagicMock() + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + async def _delta(text: str) -> None: + pass + + result = await fb.chat_stream( + messages=[{"role": "user", "content": "hi"}], + on_content_delta=_delta, + ) + # Primary returns error but content was "streamed" (FakeProvider calls delta) + # so failover should be skipped + assert result.finish_reason == "error" + factory.assert_not_called() + + +class TestFailoverOnTransientError: + @pytest.mark.asyncio + async def test_rate_limit(self) -> None: + primary = _FakeProvider("primary", _error_response("rate limit exceeded")) + fallback = _FakeProvider("fallback", _make_response("fallback ok")) + factory = MagicMock(return_value=fallback) + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "fallback ok" + assert result.finish_reason == "stop" + factory.assert_called_once_with(_fallback("fallback-a")) + + +class TestNoFallbackOnNonRetryableError: + @pytest.mark.asyncio + async def test_bad_request(self) -> None: + primary = _FakeProvider( + "primary", + _make_response( + "invalid request", + finish_reason="error", + error_status_code=400, + error_kind="invalid_request", + ), + ) + factory = MagicMock() + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + + assert result.finish_reason == "error" + factory.assert_not_called() + + @pytest.mark.asyncio + async def test_auth_error(self) -> None: + primary = _FakeProvider( + "primary", + _make_response( + "unauthorized", + finish_reason="error", + error_status_code=401, + error_kind="authentication", + ), + ) + factory = MagicMock() + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + + assert result.finish_reason == "error" + factory.assert_not_called() + + @pytest.mark.asyncio + async def test_timeout(self) -> None: + primary = _FakeProvider( + "primary", + _make_response("timed out", finish_reason="error", error_kind="timeout"), + ) + fallback = _FakeProvider("fallback", _make_response("fallback ok")) + factory = MagicMock(return_value=fallback) + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "fallback ok" + assert result.finish_reason == "stop" + factory.assert_called_once_with(_fallback("fallback-a")) + + +class TestFallbackTriesModelsInOrder: + @pytest.mark.asyncio + async def test(self) -> None: + primary = _FakeProvider("primary", _error_response("primary fail")) + fallback_a = _FakeProvider("a", _error_response("a fail")) + fallback_b = _FakeProvider("b", _make_response("b ok")) + factory = MagicMock(side_effect=[fallback_a, fallback_b]) + + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a"), _fallback("fallback-b")], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "b ok" + assert factory.call_count == 2 + factory.assert_any_call(_fallback("fallback-a")) + factory.assert_any_call(_fallback("fallback-b")) + + +class TestAllFallbacksFail: + @pytest.mark.asyncio + async def test(self) -> None: + primary = _FakeProvider("primary", _error_response("primary fail")) + fallback = _FakeProvider("fallback", _error_response("all fail")) + factory = MagicMock(return_value=fallback) + + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.finish_reason == "error" + assert "all fail" in result.content + + +class TestFactoryExceptionSkipsModel: + @pytest.mark.asyncio + async def test(self) -> None: + primary = _FakeProvider("primary", _error_response()) + fallback_b = _FakeProvider("b", _make_response("b ok")) + factory = MagicMock(side_effect=[ValueError("no key"), fallback_b]) + + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a"), _fallback("fallback-b")], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "b ok" + assert factory.call_count == 2 + + +class TestFallbackModelParameter: + @pytest.mark.asyncio + async def test(self) -> None: + """Fallback calls should use the fallback model name.""" + primary = _FakeProvider("primary", _error_response()) + fallback = _FakeProvider("fallback", _make_response("ok")) + factory = MagicMock(return_value=fallback) + + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-model")], + provider_factory=factory, + ) + + await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model") + assert fallback.chat_calls[0]["model"] == "fallback-model" + + @pytest.mark.asyncio + async def test_uses_fallback_generation_fields(self) -> None: + primary = _FakeProvider("primary", _error_response()) + fallback = _FakeProvider("fallback", _make_response("ok")) + fb = FallbackProvider( + primary=primary, + fallback_presets=[ + _fallback( + "fallback-model", + max_tokens=1234, + temperature=0.4, + reasoning_effort=None, + ) + ], + provider_factory=MagicMock(return_value=fallback), + ) + + await fb.chat( + messages=[{"role": "user", "content": "hi"}], + model="primary-model", + max_tokens=8192, + temperature=0.1, + reasoning_effort="high", + ) + + assert fallback.chat_calls[0]["model"] == "fallback-model" + assert fallback.chat_calls[0]["max_tokens"] == 1234 + assert fallback.chat_calls[0]["temperature"] == 0.4 + assert "reasoning_effort" not in fallback.chat_calls[0] + + +class TestNoFallbackWhenEmptyList: + @pytest.mark.asyncio + async def test(self) -> None: + primary = _FakeProvider("primary", _error_response()) + factory = MagicMock() + + fb = FallbackProvider( + primary=primary, + fallback_presets=[], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.finish_reason == "error" + factory.assert_not_called() + + +class TestChatStreamFailover: + @pytest.mark.asyncio + async def test_fallback_succeeds(self) -> None: + # Use empty content so on_content_delta is not triggered on the error + primary = _FakeProvider("primary", _error_response("")) + fallback = _FakeProvider("fallback", _make_response("stream ok")) + factory = MagicMock(return_value=fallback) + + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + result = await fb.chat_stream(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "stream ok" + assert result.finish_reason == "stop" + + +class TestGetDefaultModel: + def test(self) -> None: + primary = _FakeProvider("primary") + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("a")], + provider_factory=MagicMock(), + ) + assert fb.get_default_model() == "primary/model" + + +class TestCircuitBreaker: + @pytest.mark.asyncio + async def test_skips_primary_after_three_failures(self) -> None: + primary = _FakeProvider("primary", _error_response()) + fallback = _FakeProvider("fallback", _make_response("fallback ok")) + factory = MagicMock(return_value=fallback) + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + # 3 failures — primary should still be called each time + for _ in range(3): + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "fallback ok" + + assert len(primary.chat_calls) == 3 + + # 4th call — primary circuit is open, should be skipped + primary.chat_calls.clear() + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "fallback ok" + assert len(primary.chat_calls) == 0 + + @pytest.mark.asyncio + async def test_resets_on_success(self) -> None: + primary = _FakeProvider("primary", _error_response()) + fallback = _FakeProvider("fallback", _make_response("fallback ok")) + factory = MagicMock(return_value=fallback) + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + # 2 failures + for _ in range(2): + await fb.chat(messages=[{"role": "user", "content": "hi"}]) + + # 3rd call: primary succeeds — circuit resets + primary._response = _make_response("primary ok") + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "primary ok" + + # 4th call: primary fails again — should still be called (counter reset) + primary._response = _error_response() + primary.chat_calls.clear() + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "fallback ok" + assert len(primary.chat_calls) == 1 + + +class TestGenerationForwarded: + def test(self) -> None: + from nanobot.providers.base import GenerationSettings + primary = _FakeProvider("primary") + primary.generation = GenerationSettings(temperature=0.5, max_tokens=1024) + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("a")], + provider_factory=MagicMock(), + ) + assert fb.generation.temperature == 0.5 + assert fb.generation.max_tokens == 1024