From 913b0774d864bf575c1f561f764930574a23d9ab Mon Sep 17 00:00:00 2001 From: chengyongru Date: Tue, 12 May 2026 16:51:48 +0800 Subject: [PATCH 1/4] feat(runner): add model failover with fallback_models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the primary model returns a non-transient error and no content has been streamed yet, the runner now tries each model listed in the active preset's fallback_models in order. Each fallback model may reside on a different provider — a temporary provider instance is created on-the-fly via make_provider(config, model=...). Key design: - Failover is request-scoped (does not affect subagents/dream/consolidator) - Provider is restored via try/finally after each fallback attempt - Skipped when content was already streamed to avoid duplicate output - Recursive failover prevented by clearing fallback_models on fallback spec - Circuit breaker trips open after 3 consecutive primary failures (60s cooldown) - Cross-provider routing: fallback model prefix (e.g. groq/) determines provider Fixes: cross-provider fallback was broken because the factory passed the original preset (with provider forced to primary's provider) when creating fallback providers. Now uses provider="auto" so the model string prefix correctly routes to the right provider. Also fixes: log messages now distinguish between primary-failed, previous-fallback-failed, and circuit-open scenarios. closes: https://github.com/HKUDS/nanobot/issues/3376 --- nanobot/config/schema.py | 1 + nanobot/providers/factory.py | 36 ++- nanobot/providers/fallback_provider.py | 186 +++++++++++++ tests/agent/test_runner_fallback.py | 364 +++++++++++++++++++++++++ 4 files changed, 584 insertions(+), 3 deletions(-) create mode 100644 nanobot/providers/fallback_provider.py create mode 100644 tests/agent/test_runner_fallback.py diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 0f1f06c69..1cab02763 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -82,6 +82,7 @@ class ModelPresetConfig(Base): context_window_tokens: int = 65_536 temperature: float = 0.1 reasoning_effort: str | None = None + fallback_models: list[str] = Field(default_factory=list) def to_generation_settings(self) -> Any: from nanobot.providers.base import GenerationSettings diff --git a/nanobot/providers/factory.py b/nanobot/providers/factory.py index 3473afff3..e4822b7f8 100644 --- a/nanobot/providers/factory.py +++ b/nanobot/providers/factory.py @@ -7,6 +7,7 @@ from pathlib import Path from nanobot.config.schema import Config, ModelPresetConfig from nanobot.providers.base import LLMProvider +from nanobot.providers.fallback_provider import FallbackProvider from nanobot.providers.registry import find_by_name @@ -27,15 +28,16 @@ def _resolve_model_preset( return preset if preset is not None else config.resolve_preset(preset_name) -def make_provider( +def _make_provider_core( config: Config, *, preset_name: str | None = None, preset: ModelPresetConfig | None = None, + model: str | None = None, ) -> LLMProvider: - """Create the LLM provider implied by config.""" + """Create a plain LLM provider without failover wrapping.""" resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) - model = resolved.model + model = model or resolved.model provider_name = config.get_provider_name(model, preset=resolved) p = config.get_provider(model, preset=resolved) spec = find_by_name(provider_name) if provider_name else None @@ -102,6 +104,34 @@ def make_provider( return provider +def make_provider( + config: Config, + *, + preset_name: str | None = None, + preset: ModelPresetConfig | None = None, + model: str | None = None, +) -> LLMProvider: + """Create the LLM provider implied by config. + + When *model* is given, it overrides the resolved/preset model — used by + the failover path to create providers for fallback models. + """ + resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) + provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model) + + if resolved.fallback_models: + fb_preset = resolved.model_copy(update={"provider": "auto", "fallback_models": []}) + provider = FallbackProvider( + primary=provider, + fallback_models=resolved.fallback_models, + provider_factory=lambda m: _make_provider_core( + config, preset_name=preset_name, preset=fb_preset, model=m + ), + ) + + return provider + + def provider_signature( config: Config, *, diff --git a/nanobot/providers/fallback_provider.py b/nanobot/providers/fallback_provider.py new file mode 100644 index 000000000..c0b137890 --- /dev/null +++ b/nanobot/providers/fallback_provider.py @@ -0,0 +1,186 @@ +"""Provider wrapper that transparently fails over to fallback models on error.""" + +from __future__ import annotations + +import time +from collections.abc import Awaitable, Callable +from typing import Any + +from loguru import logger + +from nanobot.providers.base import LLMProvider, LLMResponse + +# Circuit breaker tuned to match OpenAICompatProvider's Responses API breaker. +_PRIMARY_FAILURE_THRESHOLD = 3 +_PRIMARY_COOLDOWN_S = 60 + + +class FallbackProvider(LLMProvider): + """Wrap a primary provider and transparently failover to fallback models. + + When the primary model returns an error and no content has been streamed yet, + the wrapper tries each fallback model in order. Each fallback model may + reside on a different provider — a factory callable creates the underlying + provider on-the-fly. + + Key design: + - Failover is request-scoped (the wrapper itself is stateless between turns). + - Skipped when content was already streamed to avoid duplicate output. + - Recursive failover is prevented by the factory returning plain providers. + - Primary provider is circuit-broken after repeated failures to avoid + wasting requests on a known-bad endpoint. + """ + + def __init__( + self, + primary: LLMProvider, + fallback_models: list[str], + provider_factory: Callable[[str], LLMProvider], + ): + self._primary = primary + self._fallback_models = list(fallback_models) + self._provider_factory = provider_factory + self._has_fallbacks = bool(fallback_models) + self._primary_failures = 0 + self._primary_tripped_at: float | None = None + + @property + def generation(self): + return self._primary.generation + + @generation.setter + def generation(self, value): + self._primary.generation = value + + def get_default_model(self) -> str: + return self._primary.get_default_model() + + def _primary_available(self) -> bool: + """Return True if the primary provider is not currently tripped.""" + if self._primary_tripped_at is None: + return True + if time.monotonic() - self._primary_tripped_at >= _PRIMARY_COOLDOWN_S: + # Half-open: allow one probe attempt. + return True + return False + + async def chat(self, **kwargs: Any) -> LLMResponse: + if not self._has_fallbacks: + return await self._primary.chat(**kwargs) + return await self._try_with_fallback( + lambda p, kw: p.chat(**kw), kwargs, has_streamed=None + ) + + async def chat_stream(self, **kwargs: Any) -> LLMResponse: + if not self._has_fallbacks: + return await self._primary.chat_stream(**kwargs) + + has_streamed: list[bool] = [False] + original_delta = kwargs.get("on_content_delta") + + async def _tracking_delta(text: str) -> None: + if text: + has_streamed[0] = True + if original_delta: + await original_delta(text) + + kwargs["on_content_delta"] = _tracking_delta + return await self._try_with_fallback( + lambda p, kw: p.chat_stream(**kw), kwargs, has_streamed=has_streamed + ) + + async def _try_with_fallback( + self, + call: Callable[[LLMProvider, dict[str, Any]], Awaitable[LLMResponse]], + kwargs: dict[str, Any], + has_streamed: list[bool] | None, + ) -> LLMResponse: + primary_model = kwargs.get("model") or self._primary.get_default_model() + + if self._primary_available(): + response = await call(self._primary, kwargs) + if response.finish_reason != "error": + self._primary_failures = 0 + self._primary_tripped_at = None + return response + + if has_streamed is not None and has_streamed[0]: + logger.warning( + "Primary model error but content already streamed; skipping failover" + ) + return response + + self._primary_failures += 1 + if self._primary_failures >= _PRIMARY_FAILURE_THRESHOLD: + self._primary_tripped_at = time.monotonic() + logger.warning( + "Primary model '{}' circuit open after {} consecutive failures", + primary_model, self._primary_failures, + ) + else: + logger.debug("Primary model '{}' circuit open; skipping", primary_model) + + last_response: LLMResponse | None = None + primary_skipped = not self._primary_available() + for idx, fallback_model in enumerate(self._fallback_models): + if has_streamed is not None and has_streamed[0]: + break + if idx == 0 and primary_skipped: + logger.info( + "Primary model '{}' circuit open, trying fallback '{}'", + primary_model, fallback_model, + ) + elif idx == 0: + logger.info( + "Primary model '{}' failed, trying fallback '{}'", + primary_model, fallback_model, + ) + else: + logger.info( + "Fallback '{}' also failed, trying next fallback '{}'", + self._fallback_models[idx - 1], fallback_model, + ) + try: + fallback_provider = self._provider_factory(fallback_model) + except Exception as exc: + logger.warning( + "Failed to create provider for fallback '{}': {}", fallback_model, exc + ) + continue + + original_model = kwargs.get("model") + kwargs["model"] = fallback_model + try: + fallback_response = await call(fallback_provider, kwargs) + finally: + if original_model is not None: + kwargs["model"] = original_model + else: + kwargs.pop("model", None) + + if fallback_response.finish_reason != "error": + logger.info( + "Fallback '{}' succeeded after primary '{}' failed", + fallback_model, primary_model, + ) + return fallback_response + + last_response = fallback_response + logger.warning( + "Fallback '{}' also failed: {}", + fallback_model, + (fallback_response.content or "")[:120], + ) + + logger.warning( + "All {} fallback model(s) failed", + len(self._fallback_models), + ) + # Return the last error response we saw (primary or last fallback). + if last_response is not None: + return last_response + # Primary was tripped and we have no fallbacks — synthesize an error. + return LLMResponse( + content=f"Primary model '{primary_model}' circuit open and no fallbacks available", + finish_reason="error", + ) diff --git a/tests/agent/test_runner_fallback.py b/tests/agent/test_runner_fallback.py new file mode 100644 index 000000000..273bd6d6d --- /dev/null +++ b/tests/agent/test_runner_fallback.py @@ -0,0 +1,364 @@ +"""Tests for FallbackProvider model failover.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from nanobot.providers.base import LLMProvider, LLMResponse +from nanobot.providers.fallback_provider import FallbackProvider + + +def _make_response( + content: str = "ok", + finish_reason: str = "stop", + *, + error_kind: str | None = None, +) -> LLMResponse: + return LLMResponse(content=content, finish_reason=finish_reason, error_kind=error_kind) + + +def _error_response(content: str = "api error") -> LLMResponse: + return _make_response(content, finish_reason="error", error_kind="server_error") + + +class _FakeProvider(LLMProvider): + """Fake provider for testing.""" + + def __init__(self, name: str = "fake", response: LLMResponse | None = None): + super().__init__() + self.name = name + self._response = response or _make_response() + self.chat_calls: list[dict[str, Any]] = [] + self.chat_stream_calls: list[dict[str, Any]] = [] + + def get_default_model(self) -> str: + return f"{self.name}/model" + + async def chat(self, **kwargs: Any) -> LLMResponse: + self.chat_calls.append(dict(kwargs)) + return self._response + + async def chat_stream(self, **kwargs: Any) -> LLMResponse: + self.chat_stream_calls.append(dict(kwargs)) + on_delta = kwargs.get("on_content_delta") + if on_delta and self._response.content: + await on_delta(self._response.content) + return self._response + + +# -- config-level tests -- + + +def test_fallback_models_default_empty() -> None: + from nanobot.config.schema import ModelPresetConfig + p = ModelPresetConfig(model="test/model") + assert p.fallback_models == [] + + +def test_fallback_models_accepts_list() -> None: + from nanobot.config.schema import ModelPresetConfig + p = ModelPresetConfig(model="test/primary", fallback_models=["test/a", "test/b"]) + assert p.fallback_models == ["test/a", "test/b"] + + +def test_fallback_models_from_camel_case() -> None: + from nanobot.config.schema import ModelPresetConfig + p = ModelPresetConfig.model_validate({ + "model": "test/primary", + "fallbackModels": ["test/a"], + }) + assert p.fallback_models == ["test/a"] + + +# -- FallbackProvider tests -- + + +class TestNoFallbackWhenPrimarySucceeds: + @pytest.mark.asyncio + async def test(self) -> None: + primary = _FakeProvider("primary", _make_response("primary ok")) + factory = MagicMock() + fb = FallbackProvider( + primary=primary, + fallback_models=["fallback-a"], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "primary ok" + assert result.finish_reason == "stop" + factory.assert_not_called() + + +class TestFallbackOnPrimaryError: + @pytest.mark.asyncio + async def test_first_fallback_succeeds(self) -> None: + primary = _FakeProvider("primary", _error_response()) + fallback = _FakeProvider("fallback", _make_response("fallback ok")) + factory = MagicMock(return_value=fallback) + + fb = FallbackProvider( + primary=primary, + fallback_models=["fallback-a"], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model") + assert result.content == "fallback ok" + assert result.finish_reason == "stop" + factory.assert_called_once_with("fallback-a") + assert primary.chat_calls[0]["model"] == "primary-model" + assert fallback.chat_calls[0]["model"] == "fallback-a" + + +class TestNoFallbackWhenContentStreamed: + @pytest.mark.asyncio + async def test(self) -> None: + primary = _FakeProvider("primary", _error_response()) + factory = MagicMock() + fb = FallbackProvider( + primary=primary, + fallback_models=["fallback-a"], + provider_factory=factory, + ) + + async def _delta(text: str) -> None: + pass + + result = await fb.chat_stream( + messages=[{"role": "user", "content": "hi"}], + on_content_delta=_delta, + ) + # Primary returns error but content was "streamed" (FakeProvider calls delta) + # so failover should be skipped + assert result.finish_reason == "error" + factory.assert_not_called() + + +class TestFailoverOnTransientError: + @pytest.mark.asyncio + async def test_rate_limit(self) -> None: + primary = _FakeProvider("primary", _error_response("rate limit exceeded")) + fallback = _FakeProvider("fallback", _make_response("fallback ok")) + factory = MagicMock(return_value=fallback) + fb = FallbackProvider( + primary=primary, + fallback_models=["fallback-a"], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "fallback ok" + assert result.finish_reason == "stop" + factory.assert_called_once_with("fallback-a") + + @pytest.mark.asyncio + async def test_timeout(self) -> None: + primary = _FakeProvider( + "primary", + _make_response("timed out", finish_reason="error", error_kind="timeout"), + ) + fallback = _FakeProvider("fallback", _make_response("fallback ok")) + factory = MagicMock(return_value=fallback) + fb = FallbackProvider( + primary=primary, + fallback_models=["fallback-a"], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "fallback ok" + assert result.finish_reason == "stop" + factory.assert_called_once_with("fallback-a") + + +class TestFallbackTriesModelsInOrder: + @pytest.mark.asyncio + async def test(self) -> None: + primary = _FakeProvider("primary", _error_response("primary fail")) + fallback_a = _FakeProvider("a", _error_response("a fail")) + fallback_b = _FakeProvider("b", _make_response("b ok")) + factory = MagicMock(side_effect=[fallback_a, fallback_b]) + + fb = FallbackProvider( + primary=primary, + fallback_models=["fallback-a", "fallback-b"], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "b ok" + assert factory.call_count == 2 + factory.assert_any_call("fallback-a") + factory.assert_any_call("fallback-b") + + +class TestAllFallbacksFail: + @pytest.mark.asyncio + async def test(self) -> None: + primary = _FakeProvider("primary", _error_response("primary fail")) + fallback = _FakeProvider("fallback", _error_response("all fail")) + factory = MagicMock(return_value=fallback) + + fb = FallbackProvider( + primary=primary, + fallback_models=["fallback-a"], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.finish_reason == "error" + assert "all fail" in result.content + + +class TestFactoryExceptionSkipsModel: + @pytest.mark.asyncio + async def test(self) -> None: + primary = _FakeProvider("primary", _error_response()) + fallback_b = _FakeProvider("b", _make_response("b ok")) + factory = MagicMock(side_effect=[ValueError("no key"), fallback_b]) + + fb = FallbackProvider( + primary=primary, + fallback_models=["fallback-a", "fallback-b"], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "b ok" + assert factory.call_count == 2 + + +class TestFallbackModelParameter: + @pytest.mark.asyncio + async def test(self) -> None: + """Fallback calls should use the fallback model name.""" + primary = _FakeProvider("primary", _error_response()) + fallback = _FakeProvider("fallback", _make_response("ok")) + factory = MagicMock(return_value=fallback) + + fb = FallbackProvider( + primary=primary, + fallback_models=["fallback-model"], + provider_factory=factory, + ) + + await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model") + assert fallback.chat_calls[0]["model"] == "fallback-model" + + +class TestNoFallbackWhenEmptyList: + @pytest.mark.asyncio + async def test(self) -> None: + primary = _FakeProvider("primary", _error_response()) + factory = MagicMock() + + fb = FallbackProvider( + primary=primary, + fallback_models=[], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.finish_reason == "error" + factory.assert_not_called() + + +class TestChatStreamFailover: + @pytest.mark.asyncio + async def test_fallback_succeeds(self) -> None: + # Use empty content so on_content_delta is not triggered on the error + primary = _FakeProvider("primary", _error_response("")) + fallback = _FakeProvider("fallback", _make_response("stream ok")) + factory = MagicMock(return_value=fallback) + + fb = FallbackProvider( + primary=primary, + fallback_models=["fallback-a"], + provider_factory=factory, + ) + + result = await fb.chat_stream(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "stream ok" + assert result.finish_reason == "stop" + + +class TestGetDefaultModel: + def test(self) -> None: + primary = _FakeProvider("primary") + fb = FallbackProvider( + primary=primary, + fallback_models=["a"], + provider_factory=MagicMock(), + ) + assert fb.get_default_model() == "primary/model" + + +class TestCircuitBreaker: + @pytest.mark.asyncio + async def test_skips_primary_after_three_failures(self) -> None: + primary = _FakeProvider("primary", _error_response()) + fallback = _FakeProvider("fallback", _make_response("fallback ok")) + factory = MagicMock(return_value=fallback) + fb = FallbackProvider( + primary=primary, + fallback_models=["fallback-a"], + provider_factory=factory, + ) + + # 3 failures — primary should still be called each time + for _ in range(3): + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "fallback ok" + + assert len(primary.chat_calls) == 3 + + # 4th call — primary circuit is open, should be skipped + primary.chat_calls.clear() + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "fallback ok" + assert len(primary.chat_calls) == 0 + + @pytest.mark.asyncio + async def test_resets_on_success(self) -> None: + primary = _FakeProvider("primary", _error_response()) + fallback = _FakeProvider("fallback", _make_response("fallback ok")) + factory = MagicMock(return_value=fallback) + fb = FallbackProvider( + primary=primary, + fallback_models=["fallback-a"], + provider_factory=factory, + ) + + # 2 failures + for _ in range(2): + await fb.chat(messages=[{"role": "user", "content": "hi"}]) + + # 3rd call: primary succeeds — circuit resets + primary._response = _make_response("primary ok") + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "primary ok" + + # 4th call: primary fails again — should still be called (counter reset) + primary._response = _error_response() + primary.chat_calls.clear() + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "fallback ok" + assert len(primary.chat_calls) == 1 + + +class TestGenerationForwarded: + def test(self) -> None: + from nanobot.providers.base import GenerationSettings + primary = _FakeProvider("primary") + primary.generation = GenerationSettings(temperature=0.5, max_tokens=1024) + fb = FallbackProvider( + primary=primary, + fallback_models=["a"], + provider_factory=MagicMock(), + ) + assert fb.generation.temperature == 0.5 + assert fb.generation.max_tokens == 1024 From 02b059a616dc6dc82ad15282102c7b27a5a34e40 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 13 May 2026 13:57:30 +0000 Subject: [PATCH 2/4] feat(runner): support structured fallback models Bind fallback model chains to the active model configuration so defaults and presets do not inherit or merge fallback behavior implicitly. Require explicit fallback providers while preserving per-fallback generation overrides and context-window safety. Co-authored-by: Cursor --- docs/configuration.md | 62 +++++++- nanobot/config/schema.py | 15 +- nanobot/providers/factory.py | 61 +++++++- nanobot/providers/fallback_provider.py | 37 +++-- tests/agent/test_runner_fallback.py | 192 ++++++++++++++++++++++--- 5 files changed, 325 insertions(+), 42 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 0123017d2..e208212cf 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -672,6 +672,12 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age "maxTokens": 8192, "contextWindowTokens": 128000, "temperature": 0.1, + "fallbackModels": [ + { + "provider": "anthropic", + "model": "anthropic/claude-sonnet-4-6" + } + ], "modelPreset": null } }, @@ -682,7 +688,17 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age "maxTokens": 4096, "contextWindowTokens": 128000, "temperature": 0.2, - "reasoningEffort": "low" + "reasoningEffort": "low", + "fallbackModels": [ + { + "provider": "deepseek", + "model": "deepseek/deepseek-chat", + "maxTokens": 4096, + "contextWindowTokens": 64000, + "temperature": 0.1, + "reasoningEffort": null + } + ] }, "deep": { "model": "anthropic/claude-opus-4-5", @@ -705,9 +721,53 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age | `contextWindowTokens` | Context window size used by prompt building and consolidation decisions. | | `temperature` | Sampling temperature. | | `reasoningEffort` | Optional reasoning/thinking setting. Provider support varies. | +| `fallbackModels` | Optional ordered fallback models for this active configuration only. | `default` is reserved and always means the implicit preset built from `agents.defaults.*`; do not define `modelPresets.default`. Use `/model default` to switch back to `agents.defaults.*`. +### Model Fallbacks + +`fallbackModels` belongs to the currently active model configuration. If the active configuration is `agents.defaults`, only `agents.defaults.fallbackModels` is used. If the active configuration is `modelPresets.fast`, only `modelPresets.fast.fallbackModels` is used. nanobot does not inherit or merge fallbacks between defaults and presets. + +Each fallback entry must include at least `provider` and `model`. The other fields are optional; omitted values inherit from the active primary configuration for that request. + +```json +{ + "modelPresets": { + "fast": { + "model": "MiniMax-M2.7-highspeed", + "provider": "minimaxAnthropic", + "maxTokens": 4096, + "contextWindowTokens": 262144, + "temperature": 0.1, + "reasoningEffort": null, + "fallbackModels": [ + { + "provider": "deepseek", + "model": "deepseek-v4-pro", + "maxTokens": 4096, + "contextWindowTokens": 262144, + "temperature": 0.1, + "reasoningEffort": null + } + ] + }, + "deep": { + "model": "deepseek-v4-pro", + "provider": "deepseek", + "maxTokens": 4096, + "contextWindowTokens": 262144, + "temperature": 0.1, + "reasoningEffort": null + } + } +} +``` + +In this example, `/model fast` can fail over to DeepSeek, but `/model deep` has no fallback because the `deep` preset does not define `fallbackModels`. + +Failover only runs when the primary model returns an error before any answer text has been streamed. Fallback models are tried in order. If a fallback has a smaller `contextWindowTokens`, nanobot uses the smallest window in the active chain when building context so the fallback can receive the same prompt. + Set `agents.defaults.modelPreset` to start with a named preset: ```json diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index a112b932d..bdae26008 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -74,6 +74,17 @@ class DreamConfig(Base): return f"every {hours}h" +class ModelFallbackConfig(Base): + """A fallback model tied to one active model configuration.""" + + model: str + provider: str + max_tokens: int | None = None + context_window_tokens: int | None = None + temperature: float | None = None + reasoning_effort: str | None = None + + class ModelPresetConfig(Base): """A named set of model + generation parameters for quick switching.""" @@ -83,7 +94,7 @@ class ModelPresetConfig(Base): context_window_tokens: int = 65_536 temperature: float = 0.1 reasoning_effort: str | None = None - fallback_models: list[str] = Field(default_factory=list) + fallback_models: list[ModelFallbackConfig] = Field(default_factory=list) def to_generation_settings(self) -> Any: from nanobot.providers.base import GenerationSettings @@ -107,6 +118,7 @@ class AgentDefaults(Base): context_window_tokens: int = 65_536 context_block_limit: int | None = None temperature: float = 0.1 + fallback_models: list[ModelFallbackConfig] = Field(default_factory=list) max_tool_iterations: int = 200 max_concurrent_subagents: int = Field(default=1, ge=1) max_tool_result_chars: int = 16_000 @@ -297,6 +309,7 @@ class Config(BaseSettings): model=d.model, provider=d.provider, max_tokens=d.max_tokens, context_window_tokens=d.context_window_tokens, temperature=d.temperature, reasoning_effort=d.reasoning_effort, + fallback_models=d.fallback_models, ) def resolve_preset(self, name: str | None = None) -> ModelPresetConfig: diff --git a/nanobot/providers/factory.py b/nanobot/providers/factory.py index e4822b7f8..a3ae57daf 100644 --- a/nanobot/providers/factory.py +++ b/nanobot/providers/factory.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass from pathlib import Path -from nanobot.config.schema import Config, ModelPresetConfig +from nanobot.config.schema import Config, ModelFallbackConfig, ModelPresetConfig from nanobot.providers.base import LLMProvider from nanobot.providers.fallback_provider import FallbackProvider from nanobot.providers.registry import find_by_name @@ -104,6 +104,28 @@ def _make_provider_core( return provider +def _fallback_preset(primary: ModelPresetConfig, fallback: ModelFallbackConfig) -> ModelPresetConfig: + """Build the effective provider/generation config for one fallback model.""" + return ModelPresetConfig( + model=fallback.model, + provider=fallback.provider, + max_tokens=fallback.max_tokens if fallback.max_tokens is not None else primary.max_tokens, + context_window_tokens=( + fallback.context_window_tokens + if fallback.context_window_tokens is not None + else primary.context_window_tokens + ), + temperature=( + fallback.temperature if fallback.temperature is not None else primary.temperature + ), + reasoning_effort=( + fallback.reasoning_effort + if fallback.reasoning_effort is not None + else primary.reasoning_effort + ), + ) + + def make_provider( config: Config, *, @@ -120,12 +142,11 @@ def make_provider( provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model) if resolved.fallback_models: - fb_preset = resolved.model_copy(update={"provider": "auto", "fallback_models": []}) provider = FallbackProvider( primary=provider, fallback_models=resolved.fallback_models, - provider_factory=lambda m: _make_provider_core( - config, preset_name=preset_name, preset=fb_preset, model=m + provider_factory=lambda fb: _make_provider_core( + config, preset_name=preset_name, preset=_fallback_preset(resolved, fb) ), ) @@ -138,9 +159,32 @@ def provider_signature( preset_name: str | None = None, preset: ModelPresetConfig | None = None, ) -> tuple[object, ...]: - """Return the config fields that affect the primary LLM provider.""" + """Return the config fields that affect the active provider chain.""" resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) p = config.get_provider(resolved.model, preset=resolved) + + def _fallback_signature(fallback: ModelFallbackConfig) -> tuple[object, ...]: + fallback_preset = _fallback_preset(resolved, fallback) + fp = config.get_provider(fallback.model, preset=fallback_preset) + return ( + fallback.model, + fallback.provider, + fallback_preset.max_tokens, + fallback_preset.temperature, + fallback_preset.reasoning_effort, + fallback_preset.context_window_tokens, + config.get_provider_name(fallback.model, preset=fallback_preset), + config.get_api_key(fallback.model, preset=fallback_preset), + config.get_api_base(fallback.model, preset=fallback_preset), + fp.extra_headers if fp else None, + fp.extra_body if fp else None, + getattr(fp, "region", None) if fp else None, + getattr(fp, "profile", None) if fp else None, + ) + + fallback_signatures = tuple( + _fallback_signature(fallback) for fallback in resolved.fallback_models + ) return ( resolved.model, resolved.provider, @@ -155,6 +199,7 @@ def provider_signature( resolved.temperature, resolved.reasoning_effort, resolved.context_window_tokens, + fallback_signatures, ) @@ -165,10 +210,14 @@ def build_provider_snapshot( preset: ModelPresetConfig | None = None, ) -> ProviderSnapshot: resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) + fallback_windows = [ + _fallback_preset(resolved, fallback).context_window_tokens + for fallback in resolved.fallback_models + ] return ProviderSnapshot( provider=make_provider(config, preset=resolved), model=resolved.model, - context_window_tokens=resolved.context_window_tokens, + context_window_tokens=min([resolved.context_window_tokens, *fallback_windows]), signature=provider_signature(config, preset=resolved), ) diff --git a/nanobot/providers/fallback_provider.py b/nanobot/providers/fallback_provider.py index c0b137890..a62b619a0 100644 --- a/nanobot/providers/fallback_provider.py +++ b/nanobot/providers/fallback_provider.py @@ -24,7 +24,7 @@ class FallbackProvider(LLMProvider): provider on-the-fly. Key design: - - Failover is request-scoped (the wrapper itself is stateless between turns). + - Failover attempts are request-scoped; primary circuit state persists. - Skipped when content was already streamed to avoid duplicate output. - Recursive failover is prevented by the factory returning plain providers. - Primary provider is circuit-broken after repeated failures to avoid @@ -34,8 +34,8 @@ class FallbackProvider(LLMProvider): def __init__( self, primary: LLMProvider, - fallback_models: list[str], - provider_factory: Callable[[str], LLMProvider], + fallback_models: list[Any], + provider_factory: Callable[[Any], LLMProvider], ): self._primary = primary self._fallback_models = list(fallback_models) @@ -52,6 +52,10 @@ class FallbackProvider(LLMProvider): def generation(self, value): self._primary.generation = value + @property + def supports_progress_deltas(self) -> bool: + return bool(getattr(self._primary, "supports_progress_deltas", False)) + def get_default_model(self) -> str: return self._primary.get_default_model() @@ -122,7 +126,8 @@ class FallbackProvider(LLMProvider): last_response: LLMResponse | None = None primary_skipped = not self._primary_available() - for idx, fallback_model in enumerate(self._fallback_models): + for idx, fallback in enumerate(self._fallback_models): + fallback_model = fallback.model if has_streamed is not None and has_streamed[0]: break if idx == 0 and primary_skipped: @@ -138,25 +143,35 @@ class FallbackProvider(LLMProvider): else: logger.info( "Fallback '{}' also failed, trying next fallback '{}'", - self._fallback_models[idx - 1], fallback_model, + self._fallback_models[idx - 1].model, fallback_model, ) try: - fallback_provider = self._provider_factory(fallback_model) + fallback_provider = self._provider_factory(fallback) except Exception as exc: logger.warning( "Failed to create provider for fallback '{}': {}", fallback_model, exc ) continue - original_model = kwargs.get("model") + original_values = { + name: kwargs.get(name, LLMProvider._SENTINEL) + for name in ("model", "max_tokens", "temperature", "reasoning_effort") + } kwargs["model"] = fallback_model + if fallback.max_tokens is not None: + kwargs["max_tokens"] = fallback.max_tokens + if fallback.temperature is not None: + kwargs["temperature"] = fallback.temperature + if fallback.reasoning_effort is not None: + kwargs["reasoning_effort"] = fallback.reasoning_effort try: fallback_response = await call(fallback_provider, kwargs) finally: - if original_model is not None: - kwargs["model"] = original_model - else: - kwargs.pop("model", None) + for name, value in original_values.items(): + if value is LLMProvider._SENTINEL: + kwargs.pop(name, None) + else: + kwargs[name] = value if fallback_response.finish_reason != "error": logger.info( diff --git a/tests/agent/test_runner_fallback.py b/tests/agent/test_runner_fallback.py index 273bd6d6d..e15a29848 100644 --- a/tests/agent/test_runner_fallback.py +++ b/tests/agent/test_runner_fallback.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock import pytest +from nanobot.config.schema import ModelFallbackConfig from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.fallback_provider import FallbackProvider @@ -24,6 +25,25 @@ def _error_response(content: str = "api error") -> LLMResponse: return _make_response(content, finish_reason="error", error_kind="server_error") +def _fallback( + model: str, + provider: str = "fallback", + *, + max_tokens: int | None = None, + context_window_tokens: int | None = None, + temperature: float | None = None, + reasoning_effort: str | None = None, +) -> ModelFallbackConfig: + return ModelFallbackConfig( + model=model, + provider=provider, + max_tokens=max_tokens, + context_window_tokens=context_window_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + ) + + class _FakeProvider(LLMProvider): """Fake provider for testing.""" @@ -60,17 +80,113 @@ def test_fallback_models_default_empty() -> None: def test_fallback_models_accepts_list() -> None: from nanobot.config.schema import ModelPresetConfig - p = ModelPresetConfig(model="test/primary", fallback_models=["test/a", "test/b"]) - assert p.fallback_models == ["test/a", "test/b"] + p = ModelPresetConfig( + model="test/primary", + fallback_models=[{"provider": "test", "model": "test/a"}], + ) + assert p.fallback_models == [_fallback("test/a", provider="test")] def test_fallback_models_from_camel_case() -> None: from nanobot.config.schema import ModelPresetConfig p = ModelPresetConfig.model_validate({ "model": "test/primary", - "fallbackModels": ["test/a"], + "fallbackModels": [{"provider": "test", "model": "test/a"}], }) - assert p.fallback_models == ["test/a"] + assert p.fallback_models == [_fallback("test/a", provider="test")] + + +def test_provider_signature_tracks_fallback_models_and_provider_config() -> None: + from nanobot.config.schema import Config + from nanobot.providers.factory import provider_signature + + base = { + "modelPresets": { + "prod": { + "model": "openai/gpt-4.1", + "fallbackModels": [ + {"provider": "anthropic", "model": "anthropic/claude-sonnet-4-6"} + ], + } + }, + "providers": { + "openai": {"apiKey": "primary-key"}, + "anthropic": {"apiKey": "fallback-key"}, + }, + } + changed_fallback = { + **base, + "modelPresets": { + "prod": { + "model": "openai/gpt-4.1", + "fallbackModels": [{"provider": "deepseek", "model": "deepseek/deepseek-chat"}], + } + }, + "providers": { + **base["providers"], + "deepseek": {"apiKey": "deepseek-key"}, + }, + } + changed_key = { + **base, + "providers": { + "openai": {"apiKey": "primary-key"}, + "anthropic": {"apiKey": "new-fallback-key"}, + }, + } + + signature = provider_signature(Config.model_validate(base), preset_name="prod") + + assert signature != provider_signature(Config.model_validate(changed_fallback), preset_name="prod") + assert signature != provider_signature(Config.model_validate(changed_key), preset_name="prod") + + +def test_agent_defaults_can_define_fallback_models() -> None: + from nanobot.config.schema import Config + + config = Config.model_validate({ + "agents": { + "defaults": { + "model": "primary-model", + "provider": "custom", + "fallbackModels": [{"provider": "deepseek", "model": "deepseek-v4-pro"}], + } + } + }) + + assert config.resolve_preset().fallback_models == [ + _fallback("deepseek-v4-pro", provider="deepseek") + ] + + +def test_provider_snapshot_uses_smallest_fallback_context_window() -> None: + from nanobot.config.schema import Config + from nanobot.providers.factory import build_provider_snapshot + + config = Config.model_validate({ + "modelPresets": { + "prod": { + "model": "openai/gpt-4.1", + "provider": "openai", + "contextWindowTokens": 128000, + "fallbackModels": [ + { + "provider": "deepseek", + "model": "deepseek/deepseek-chat", + "contextWindowTokens": 64000, + } + ], + } + }, + "providers": { + "openai": {"apiKey": "primary-key"}, + "deepseek": {"apiKey": "fallback-key"}, + }, + }) + + snapshot = build_provider_snapshot(config, preset_name="prod") + + assert snapshot.context_window_tokens == 64000 # -- FallbackProvider tests -- @@ -83,7 +199,7 @@ class TestNoFallbackWhenPrimarySucceeds: factory = MagicMock() fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_models=[_fallback("fallback-a")], provider_factory=factory, ) @@ -102,14 +218,14 @@ class TestFallbackOnPrimaryError: fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_models=[_fallback("fallback-a")], provider_factory=factory, ) result = await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model") assert result.content == "fallback ok" assert result.finish_reason == "stop" - factory.assert_called_once_with("fallback-a") + factory.assert_called_once_with(_fallback("fallback-a")) assert primary.chat_calls[0]["model"] == "primary-model" assert fallback.chat_calls[0]["model"] == "fallback-a" @@ -121,7 +237,7 @@ class TestNoFallbackWhenContentStreamed: factory = MagicMock() fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_models=[_fallback("fallback-a")], provider_factory=factory, ) @@ -146,14 +262,14 @@ class TestFailoverOnTransientError: factory = MagicMock(return_value=fallback) fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_models=[_fallback("fallback-a")], provider_factory=factory, ) result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) assert result.content == "fallback ok" assert result.finish_reason == "stop" - factory.assert_called_once_with("fallback-a") + factory.assert_called_once_with(_fallback("fallback-a")) @pytest.mark.asyncio async def test_timeout(self) -> None: @@ -165,14 +281,14 @@ class TestFailoverOnTransientError: factory = MagicMock(return_value=fallback) fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_models=[_fallback("fallback-a")], provider_factory=factory, ) result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) assert result.content == "fallback ok" assert result.finish_reason == "stop" - factory.assert_called_once_with("fallback-a") + factory.assert_called_once_with(_fallback("fallback-a")) class TestFallbackTriesModelsInOrder: @@ -185,15 +301,15 @@ class TestFallbackTriesModelsInOrder: fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a", "fallback-b"], + fallback_models=[_fallback("fallback-a"), _fallback("fallback-b")], provider_factory=factory, ) result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) assert result.content == "b ok" assert factory.call_count == 2 - factory.assert_any_call("fallback-a") - factory.assert_any_call("fallback-b") + factory.assert_any_call(_fallback("fallback-a")) + factory.assert_any_call(_fallback("fallback-b")) class TestAllFallbacksFail: @@ -205,7 +321,7 @@ class TestAllFallbacksFail: fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_models=[_fallback("fallback-a")], provider_factory=factory, ) @@ -223,7 +339,7 @@ class TestFactoryExceptionSkipsModel: fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a", "fallback-b"], + fallback_models=[_fallback("fallback-a"), _fallback("fallback-b")], provider_factory=factory, ) @@ -242,13 +358,43 @@ class TestFallbackModelParameter: fb = FallbackProvider( primary=primary, - fallback_models=["fallback-model"], + fallback_models=[_fallback("fallback-model")], provider_factory=factory, ) await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model") assert fallback.chat_calls[0]["model"] == "fallback-model" + @pytest.mark.asyncio + async def test_overrides_generation_fields_when_configured(self) -> None: + primary = _FakeProvider("primary", _error_response()) + fallback = _FakeProvider("fallback", _make_response("ok")) + fb = FallbackProvider( + primary=primary, + fallback_models=[ + _fallback( + "fallback-model", + max_tokens=1234, + temperature=0.4, + reasoning_effort="low", + ) + ], + provider_factory=MagicMock(return_value=fallback), + ) + + await fb.chat( + messages=[{"role": "user", "content": "hi"}], + model="primary-model", + max_tokens=8192, + temperature=0.1, + reasoning_effort="high", + ) + + assert fallback.chat_calls[0]["model"] == "fallback-model" + assert fallback.chat_calls[0]["max_tokens"] == 1234 + assert fallback.chat_calls[0]["temperature"] == 0.4 + assert fallback.chat_calls[0]["reasoning_effort"] == "low" + class TestNoFallbackWhenEmptyList: @pytest.mark.asyncio @@ -277,7 +423,7 @@ class TestChatStreamFailover: fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_models=[_fallback("fallback-a")], provider_factory=factory, ) @@ -291,7 +437,7 @@ class TestGetDefaultModel: primary = _FakeProvider("primary") fb = FallbackProvider( primary=primary, - fallback_models=["a"], + fallback_models=[_fallback("a")], provider_factory=MagicMock(), ) assert fb.get_default_model() == "primary/model" @@ -305,7 +451,7 @@ class TestCircuitBreaker: factory = MagicMock(return_value=fallback) fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_models=[_fallback("fallback-a")], provider_factory=factory, ) @@ -329,7 +475,7 @@ class TestCircuitBreaker: factory = MagicMock(return_value=fallback) fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_models=[_fallback("fallback-a")], provider_factory=factory, ) @@ -357,7 +503,7 @@ class TestGenerationForwarded: primary.generation = GenerationSettings(temperature=0.5, max_tokens=1024) fb = FallbackProvider( primary=primary, - fallback_models=["a"], + fallback_models=[_fallback("a")], provider_factory=MagicMock(), ) assert fb.generation.temperature == 0.5 From 43db848db0f62305ade8353af380a2ffff296074 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 13 May 2026 14:11:08 +0000 Subject: [PATCH 3/4] Revert "feat(runner): support structured fallback models" This reverts commit 02b059a616dc6dc82ad15282102c7b27a5a34e40. --- docs/configuration.md | 62 +------- nanobot/config/schema.py | 15 +- nanobot/providers/factory.py | 61 +------- nanobot/providers/fallback_provider.py | 37 ++--- tests/agent/test_runner_fallback.py | 192 +++---------------------- 5 files changed, 42 insertions(+), 325 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index e208212cf..0123017d2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -672,12 +672,6 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age "maxTokens": 8192, "contextWindowTokens": 128000, "temperature": 0.1, - "fallbackModels": [ - { - "provider": "anthropic", - "model": "anthropic/claude-sonnet-4-6" - } - ], "modelPreset": null } }, @@ -688,17 +682,7 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age "maxTokens": 4096, "contextWindowTokens": 128000, "temperature": 0.2, - "reasoningEffort": "low", - "fallbackModels": [ - { - "provider": "deepseek", - "model": "deepseek/deepseek-chat", - "maxTokens": 4096, - "contextWindowTokens": 64000, - "temperature": 0.1, - "reasoningEffort": null - } - ] + "reasoningEffort": "low" }, "deep": { "model": "anthropic/claude-opus-4-5", @@ -721,53 +705,9 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age | `contextWindowTokens` | Context window size used by prompt building and consolidation decisions. | | `temperature` | Sampling temperature. | | `reasoningEffort` | Optional reasoning/thinking setting. Provider support varies. | -| `fallbackModels` | Optional ordered fallback models for this active configuration only. | `default` is reserved and always means the implicit preset built from `agents.defaults.*`; do not define `modelPresets.default`. Use `/model default` to switch back to `agents.defaults.*`. -### Model Fallbacks - -`fallbackModels` belongs to the currently active model configuration. If the active configuration is `agents.defaults`, only `agents.defaults.fallbackModels` is used. If the active configuration is `modelPresets.fast`, only `modelPresets.fast.fallbackModels` is used. nanobot does not inherit or merge fallbacks between defaults and presets. - -Each fallback entry must include at least `provider` and `model`. The other fields are optional; omitted values inherit from the active primary configuration for that request. - -```json -{ - "modelPresets": { - "fast": { - "model": "MiniMax-M2.7-highspeed", - "provider": "minimaxAnthropic", - "maxTokens": 4096, - "contextWindowTokens": 262144, - "temperature": 0.1, - "reasoningEffort": null, - "fallbackModels": [ - { - "provider": "deepseek", - "model": "deepseek-v4-pro", - "maxTokens": 4096, - "contextWindowTokens": 262144, - "temperature": 0.1, - "reasoningEffort": null - } - ] - }, - "deep": { - "model": "deepseek-v4-pro", - "provider": "deepseek", - "maxTokens": 4096, - "contextWindowTokens": 262144, - "temperature": 0.1, - "reasoningEffort": null - } - } -} -``` - -In this example, `/model fast` can fail over to DeepSeek, but `/model deep` has no fallback because the `deep` preset does not define `fallbackModels`. - -Failover only runs when the primary model returns an error before any answer text has been streamed. Fallback models are tried in order. If a fallback has a smaller `contextWindowTokens`, nanobot uses the smallest window in the active chain when building context so the fallback can receive the same prompt. - Set `agents.defaults.modelPreset` to start with a named preset: ```json diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index bdae26008..a112b932d 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -74,17 +74,6 @@ class DreamConfig(Base): return f"every {hours}h" -class ModelFallbackConfig(Base): - """A fallback model tied to one active model configuration.""" - - model: str - provider: str - max_tokens: int | None = None - context_window_tokens: int | None = None - temperature: float | None = None - reasoning_effort: str | None = None - - class ModelPresetConfig(Base): """A named set of model + generation parameters for quick switching.""" @@ -94,7 +83,7 @@ class ModelPresetConfig(Base): context_window_tokens: int = 65_536 temperature: float = 0.1 reasoning_effort: str | None = None - fallback_models: list[ModelFallbackConfig] = Field(default_factory=list) + fallback_models: list[str] = Field(default_factory=list) def to_generation_settings(self) -> Any: from nanobot.providers.base import GenerationSettings @@ -118,7 +107,6 @@ class AgentDefaults(Base): context_window_tokens: int = 65_536 context_block_limit: int | None = None temperature: float = 0.1 - fallback_models: list[ModelFallbackConfig] = Field(default_factory=list) max_tool_iterations: int = 200 max_concurrent_subagents: int = Field(default=1, ge=1) max_tool_result_chars: int = 16_000 @@ -309,7 +297,6 @@ class Config(BaseSettings): model=d.model, provider=d.provider, max_tokens=d.max_tokens, context_window_tokens=d.context_window_tokens, temperature=d.temperature, reasoning_effort=d.reasoning_effort, - fallback_models=d.fallback_models, ) def resolve_preset(self, name: str | None = None) -> ModelPresetConfig: diff --git a/nanobot/providers/factory.py b/nanobot/providers/factory.py index a3ae57daf..e4822b7f8 100644 --- a/nanobot/providers/factory.py +++ b/nanobot/providers/factory.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass from pathlib import Path -from nanobot.config.schema import Config, ModelFallbackConfig, ModelPresetConfig +from nanobot.config.schema import Config, ModelPresetConfig from nanobot.providers.base import LLMProvider from nanobot.providers.fallback_provider import FallbackProvider from nanobot.providers.registry import find_by_name @@ -104,28 +104,6 @@ def _make_provider_core( return provider -def _fallback_preset(primary: ModelPresetConfig, fallback: ModelFallbackConfig) -> ModelPresetConfig: - """Build the effective provider/generation config for one fallback model.""" - return ModelPresetConfig( - model=fallback.model, - provider=fallback.provider, - max_tokens=fallback.max_tokens if fallback.max_tokens is not None else primary.max_tokens, - context_window_tokens=( - fallback.context_window_tokens - if fallback.context_window_tokens is not None - else primary.context_window_tokens - ), - temperature=( - fallback.temperature if fallback.temperature is not None else primary.temperature - ), - reasoning_effort=( - fallback.reasoning_effort - if fallback.reasoning_effort is not None - else primary.reasoning_effort - ), - ) - - def make_provider( config: Config, *, @@ -142,11 +120,12 @@ def make_provider( provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model) if resolved.fallback_models: + fb_preset = resolved.model_copy(update={"provider": "auto", "fallback_models": []}) provider = FallbackProvider( primary=provider, fallback_models=resolved.fallback_models, - provider_factory=lambda fb: _make_provider_core( - config, preset_name=preset_name, preset=_fallback_preset(resolved, fb) + provider_factory=lambda m: _make_provider_core( + config, preset_name=preset_name, preset=fb_preset, model=m ), ) @@ -159,32 +138,9 @@ def provider_signature( preset_name: str | None = None, preset: ModelPresetConfig | None = None, ) -> tuple[object, ...]: - """Return the config fields that affect the active provider chain.""" + """Return the config fields that affect the primary LLM provider.""" resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) p = config.get_provider(resolved.model, preset=resolved) - - def _fallback_signature(fallback: ModelFallbackConfig) -> tuple[object, ...]: - fallback_preset = _fallback_preset(resolved, fallback) - fp = config.get_provider(fallback.model, preset=fallback_preset) - return ( - fallback.model, - fallback.provider, - fallback_preset.max_tokens, - fallback_preset.temperature, - fallback_preset.reasoning_effort, - fallback_preset.context_window_tokens, - config.get_provider_name(fallback.model, preset=fallback_preset), - config.get_api_key(fallback.model, preset=fallback_preset), - config.get_api_base(fallback.model, preset=fallback_preset), - fp.extra_headers if fp else None, - fp.extra_body if fp else None, - getattr(fp, "region", None) if fp else None, - getattr(fp, "profile", None) if fp else None, - ) - - fallback_signatures = tuple( - _fallback_signature(fallback) for fallback in resolved.fallback_models - ) return ( resolved.model, resolved.provider, @@ -199,7 +155,6 @@ def provider_signature( resolved.temperature, resolved.reasoning_effort, resolved.context_window_tokens, - fallback_signatures, ) @@ -210,14 +165,10 @@ def build_provider_snapshot( preset: ModelPresetConfig | None = None, ) -> ProviderSnapshot: resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) - fallback_windows = [ - _fallback_preset(resolved, fallback).context_window_tokens - for fallback in resolved.fallback_models - ] return ProviderSnapshot( provider=make_provider(config, preset=resolved), model=resolved.model, - context_window_tokens=min([resolved.context_window_tokens, *fallback_windows]), + context_window_tokens=resolved.context_window_tokens, signature=provider_signature(config, preset=resolved), ) diff --git a/nanobot/providers/fallback_provider.py b/nanobot/providers/fallback_provider.py index a62b619a0..c0b137890 100644 --- a/nanobot/providers/fallback_provider.py +++ b/nanobot/providers/fallback_provider.py @@ -24,7 +24,7 @@ class FallbackProvider(LLMProvider): provider on-the-fly. Key design: - - Failover attempts are request-scoped; primary circuit state persists. + - Failover is request-scoped (the wrapper itself is stateless between turns). - Skipped when content was already streamed to avoid duplicate output. - Recursive failover is prevented by the factory returning plain providers. - Primary provider is circuit-broken after repeated failures to avoid @@ -34,8 +34,8 @@ class FallbackProvider(LLMProvider): def __init__( self, primary: LLMProvider, - fallback_models: list[Any], - provider_factory: Callable[[Any], LLMProvider], + fallback_models: list[str], + provider_factory: Callable[[str], LLMProvider], ): self._primary = primary self._fallback_models = list(fallback_models) @@ -52,10 +52,6 @@ class FallbackProvider(LLMProvider): def generation(self, value): self._primary.generation = value - @property - def supports_progress_deltas(self) -> bool: - return bool(getattr(self._primary, "supports_progress_deltas", False)) - def get_default_model(self) -> str: return self._primary.get_default_model() @@ -126,8 +122,7 @@ class FallbackProvider(LLMProvider): last_response: LLMResponse | None = None primary_skipped = not self._primary_available() - for idx, fallback in enumerate(self._fallback_models): - fallback_model = fallback.model + for idx, fallback_model in enumerate(self._fallback_models): if has_streamed is not None and has_streamed[0]: break if idx == 0 and primary_skipped: @@ -143,35 +138,25 @@ class FallbackProvider(LLMProvider): else: logger.info( "Fallback '{}' also failed, trying next fallback '{}'", - self._fallback_models[idx - 1].model, fallback_model, + self._fallback_models[idx - 1], fallback_model, ) try: - fallback_provider = self._provider_factory(fallback) + fallback_provider = self._provider_factory(fallback_model) except Exception as exc: logger.warning( "Failed to create provider for fallback '{}': {}", fallback_model, exc ) continue - original_values = { - name: kwargs.get(name, LLMProvider._SENTINEL) - for name in ("model", "max_tokens", "temperature", "reasoning_effort") - } + original_model = kwargs.get("model") kwargs["model"] = fallback_model - if fallback.max_tokens is not None: - kwargs["max_tokens"] = fallback.max_tokens - if fallback.temperature is not None: - kwargs["temperature"] = fallback.temperature - if fallback.reasoning_effort is not None: - kwargs["reasoning_effort"] = fallback.reasoning_effort try: fallback_response = await call(fallback_provider, kwargs) finally: - for name, value in original_values.items(): - if value is LLMProvider._SENTINEL: - kwargs.pop(name, None) - else: - kwargs[name] = value + if original_model is not None: + kwargs["model"] = original_model + else: + kwargs.pop("model", None) if fallback_response.finish_reason != "error": logger.info( diff --git a/tests/agent/test_runner_fallback.py b/tests/agent/test_runner_fallback.py index e15a29848..273bd6d6d 100644 --- a/tests/agent/test_runner_fallback.py +++ b/tests/agent/test_runner_fallback.py @@ -7,7 +7,6 @@ from unittest.mock import MagicMock import pytest -from nanobot.config.schema import ModelFallbackConfig from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.fallback_provider import FallbackProvider @@ -25,25 +24,6 @@ def _error_response(content: str = "api error") -> LLMResponse: return _make_response(content, finish_reason="error", error_kind="server_error") -def _fallback( - model: str, - provider: str = "fallback", - *, - max_tokens: int | None = None, - context_window_tokens: int | None = None, - temperature: float | None = None, - reasoning_effort: str | None = None, -) -> ModelFallbackConfig: - return ModelFallbackConfig( - model=model, - provider=provider, - max_tokens=max_tokens, - context_window_tokens=context_window_tokens, - temperature=temperature, - reasoning_effort=reasoning_effort, - ) - - class _FakeProvider(LLMProvider): """Fake provider for testing.""" @@ -80,113 +60,17 @@ def test_fallback_models_default_empty() -> None: def test_fallback_models_accepts_list() -> None: from nanobot.config.schema import ModelPresetConfig - p = ModelPresetConfig( - model="test/primary", - fallback_models=[{"provider": "test", "model": "test/a"}], - ) - assert p.fallback_models == [_fallback("test/a", provider="test")] + p = ModelPresetConfig(model="test/primary", fallback_models=["test/a", "test/b"]) + assert p.fallback_models == ["test/a", "test/b"] def test_fallback_models_from_camel_case() -> None: from nanobot.config.schema import ModelPresetConfig p = ModelPresetConfig.model_validate({ "model": "test/primary", - "fallbackModels": [{"provider": "test", "model": "test/a"}], + "fallbackModels": ["test/a"], }) - assert p.fallback_models == [_fallback("test/a", provider="test")] - - -def test_provider_signature_tracks_fallback_models_and_provider_config() -> None: - from nanobot.config.schema import Config - from nanobot.providers.factory import provider_signature - - base = { - "modelPresets": { - "prod": { - "model": "openai/gpt-4.1", - "fallbackModels": [ - {"provider": "anthropic", "model": "anthropic/claude-sonnet-4-6"} - ], - } - }, - "providers": { - "openai": {"apiKey": "primary-key"}, - "anthropic": {"apiKey": "fallback-key"}, - }, - } - changed_fallback = { - **base, - "modelPresets": { - "prod": { - "model": "openai/gpt-4.1", - "fallbackModels": [{"provider": "deepseek", "model": "deepseek/deepseek-chat"}], - } - }, - "providers": { - **base["providers"], - "deepseek": {"apiKey": "deepseek-key"}, - }, - } - changed_key = { - **base, - "providers": { - "openai": {"apiKey": "primary-key"}, - "anthropic": {"apiKey": "new-fallback-key"}, - }, - } - - signature = provider_signature(Config.model_validate(base), preset_name="prod") - - assert signature != provider_signature(Config.model_validate(changed_fallback), preset_name="prod") - assert signature != provider_signature(Config.model_validate(changed_key), preset_name="prod") - - -def test_agent_defaults_can_define_fallback_models() -> None: - from nanobot.config.schema import Config - - config = Config.model_validate({ - "agents": { - "defaults": { - "model": "primary-model", - "provider": "custom", - "fallbackModels": [{"provider": "deepseek", "model": "deepseek-v4-pro"}], - } - } - }) - - assert config.resolve_preset().fallback_models == [ - _fallback("deepseek-v4-pro", provider="deepseek") - ] - - -def test_provider_snapshot_uses_smallest_fallback_context_window() -> None: - from nanobot.config.schema import Config - from nanobot.providers.factory import build_provider_snapshot - - config = Config.model_validate({ - "modelPresets": { - "prod": { - "model": "openai/gpt-4.1", - "provider": "openai", - "contextWindowTokens": 128000, - "fallbackModels": [ - { - "provider": "deepseek", - "model": "deepseek/deepseek-chat", - "contextWindowTokens": 64000, - } - ], - } - }, - "providers": { - "openai": {"apiKey": "primary-key"}, - "deepseek": {"apiKey": "fallback-key"}, - }, - }) - - snapshot = build_provider_snapshot(config, preset_name="prod") - - assert snapshot.context_window_tokens == 64000 + assert p.fallback_models == ["test/a"] # -- FallbackProvider tests -- @@ -199,7 +83,7 @@ class TestNoFallbackWhenPrimarySucceeds: factory = MagicMock() fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("fallback-a")], + fallback_models=["fallback-a"], provider_factory=factory, ) @@ -218,14 +102,14 @@ class TestFallbackOnPrimaryError: fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("fallback-a")], + fallback_models=["fallback-a"], provider_factory=factory, ) result = await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model") assert result.content == "fallback ok" assert result.finish_reason == "stop" - factory.assert_called_once_with(_fallback("fallback-a")) + factory.assert_called_once_with("fallback-a") assert primary.chat_calls[0]["model"] == "primary-model" assert fallback.chat_calls[0]["model"] == "fallback-a" @@ -237,7 +121,7 @@ class TestNoFallbackWhenContentStreamed: factory = MagicMock() fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("fallback-a")], + fallback_models=["fallback-a"], provider_factory=factory, ) @@ -262,14 +146,14 @@ class TestFailoverOnTransientError: factory = MagicMock(return_value=fallback) fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("fallback-a")], + fallback_models=["fallback-a"], provider_factory=factory, ) result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) assert result.content == "fallback ok" assert result.finish_reason == "stop" - factory.assert_called_once_with(_fallback("fallback-a")) + factory.assert_called_once_with("fallback-a") @pytest.mark.asyncio async def test_timeout(self) -> None: @@ -281,14 +165,14 @@ class TestFailoverOnTransientError: factory = MagicMock(return_value=fallback) fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("fallback-a")], + fallback_models=["fallback-a"], provider_factory=factory, ) result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) assert result.content == "fallback ok" assert result.finish_reason == "stop" - factory.assert_called_once_with(_fallback("fallback-a")) + factory.assert_called_once_with("fallback-a") class TestFallbackTriesModelsInOrder: @@ -301,15 +185,15 @@ class TestFallbackTriesModelsInOrder: fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("fallback-a"), _fallback("fallback-b")], + fallback_models=["fallback-a", "fallback-b"], provider_factory=factory, ) result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) assert result.content == "b ok" assert factory.call_count == 2 - factory.assert_any_call(_fallback("fallback-a")) - factory.assert_any_call(_fallback("fallback-b")) + factory.assert_any_call("fallback-a") + factory.assert_any_call("fallback-b") class TestAllFallbacksFail: @@ -321,7 +205,7 @@ class TestAllFallbacksFail: fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("fallback-a")], + fallback_models=["fallback-a"], provider_factory=factory, ) @@ -339,7 +223,7 @@ class TestFactoryExceptionSkipsModel: fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("fallback-a"), _fallback("fallback-b")], + fallback_models=["fallback-a", "fallback-b"], provider_factory=factory, ) @@ -358,43 +242,13 @@ class TestFallbackModelParameter: fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("fallback-model")], + fallback_models=["fallback-model"], provider_factory=factory, ) await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model") assert fallback.chat_calls[0]["model"] == "fallback-model" - @pytest.mark.asyncio - async def test_overrides_generation_fields_when_configured(self) -> None: - primary = _FakeProvider("primary", _error_response()) - fallback = _FakeProvider("fallback", _make_response("ok")) - fb = FallbackProvider( - primary=primary, - fallback_models=[ - _fallback( - "fallback-model", - max_tokens=1234, - temperature=0.4, - reasoning_effort="low", - ) - ], - provider_factory=MagicMock(return_value=fallback), - ) - - await fb.chat( - messages=[{"role": "user", "content": "hi"}], - model="primary-model", - max_tokens=8192, - temperature=0.1, - reasoning_effort="high", - ) - - assert fallback.chat_calls[0]["model"] == "fallback-model" - assert fallback.chat_calls[0]["max_tokens"] == 1234 - assert fallback.chat_calls[0]["temperature"] == 0.4 - assert fallback.chat_calls[0]["reasoning_effort"] == "low" - class TestNoFallbackWhenEmptyList: @pytest.mark.asyncio @@ -423,7 +277,7 @@ class TestChatStreamFailover: fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("fallback-a")], + fallback_models=["fallback-a"], provider_factory=factory, ) @@ -437,7 +291,7 @@ class TestGetDefaultModel: primary = _FakeProvider("primary") fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("a")], + fallback_models=["a"], provider_factory=MagicMock(), ) assert fb.get_default_model() == "primary/model" @@ -451,7 +305,7 @@ class TestCircuitBreaker: factory = MagicMock(return_value=fallback) fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("fallback-a")], + fallback_models=["fallback-a"], provider_factory=factory, ) @@ -475,7 +329,7 @@ class TestCircuitBreaker: factory = MagicMock(return_value=fallback) fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("fallback-a")], + fallback_models=["fallback-a"], provider_factory=factory, ) @@ -503,7 +357,7 @@ class TestGenerationForwarded: primary.generation = GenerationSettings(temperature=0.5, max_tokens=1024) fb = FallbackProvider( primary=primary, - fallback_models=[_fallback("a")], + fallback_models=["a"], provider_factory=MagicMock(), ) assert fb.generation.temperature == 0.5 From 5efd67919bf4e65f6ff9231e830e5b76567b6371 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 13 May 2026 15:34:03 +0000 Subject: [PATCH 4/4] feat(runner): support fallback candidates Resolve fallbackModels as preset references or explicit inline provider configs so failover uses complete model settings without exposing fallback logic to the agent loop. Co-authored-by: Cursor --- docs/configuration.md | 37 ++- nanobot/config/schema.py | 19 +- nanobot/providers/factory.py | 71 +++++- nanobot/providers/fallback_provider.py | 113 ++++++++- tests/agent/test_runner_fallback.py | 321 ++++++++++++++++++++++--- 5 files changed, 502 insertions(+), 59 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 0123017d2..3f7f39709 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -672,7 +672,8 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age "maxTokens": 8192, "contextWindowTokens": 128000, "temperature": 0.1, - "modelPreset": null + "modelPreset": "fast", + "fallbackModels": ["deep"] } }, "modelPresets": { @@ -708,6 +709,40 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age `default` is reserved and always means the implicit preset built from `agents.defaults.*`; do not define `modelPresets.default`. Use `/model default` to switch back to `agents.defaults.*`. +### Model Fallbacks + +`agents.defaults.fallbackModels` defines an ordered failover chain for the active model configuration. The primary model is still selected by `agents.defaults.modelPreset` (or the implicit default config when no preset is active). + +Each fallback candidate can be either: + +- A preset name from `modelPresets`, such as `"deep"`. The preset's full model, provider, generation, and context-window config is used. +- An inline fallback object with at least `provider` and `model`. Optional `maxTokens`, `contextWindowTokens`, and `temperature` fields inherit from the active primary config when omitted. `reasoningEffort` does not inherit; omit it to leave reasoning off for that fallback, or set it explicitly for models that support reasoning. + +```json +{ + "agents": { + "defaults": { + "modelPreset": "fast", + "fallbackModels": [ + "deep", + { + "provider": "deepseek", + "model": "deepseek-v4-pro", + "maxTokens": 4096, + "contextWindowTokens": 262144 + } + ] + } + } +} +``` + +String entries are preset names, not raw model names. If you want to use a model that is not already a preset, use the inline object form. + +Failover only runs when the primary provider returns a retryable model/provider error before any answer text has been streamed. Typical fallback cases include timeouts, connection errors, 5xx server errors, 429 rate limits, overloads, and quota/balance exhaustion. It does not run for malformed requests, authentication/permission errors, content filtering/refusals, or context-length/message-format errors. + +If fallback candidates use smaller `contextWindowTokens` values, nanobot builds context using the smallest window in the active chain so every candidate can receive the same prompt. + Set `agents.defaults.modelPreset` to start with a named preset: ```json diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index a112b932d..c8556ec9f 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -74,6 +74,20 @@ class DreamConfig(Base): return f"every {hours}h" +class InlineFallbackConfig(Base): + """One inline fallback model configuration.""" + + model: str + provider: str + max_tokens: int | None = None + context_window_tokens: int | None = None + temperature: float | None = None + reasoning_effort: str | None = None + + +FallbackCandidate = str | InlineFallbackConfig + + class ModelPresetConfig(Base): """A named set of model + generation parameters for quick switching.""" @@ -83,7 +97,6 @@ class ModelPresetConfig(Base): context_window_tokens: int = 65_536 temperature: float = 0.1 reasoning_effort: str | None = None - fallback_models: list[str] = Field(default_factory=list) def to_generation_settings(self) -> Any: from nanobot.providers.base import GenerationSettings @@ -107,6 +120,7 @@ class AgentDefaults(Base): context_window_tokens: int = 65_536 context_block_limit: int | None = None temperature: float = 0.1 + fallback_models: list[FallbackCandidate] = Field(default_factory=list) max_tool_iterations: int = 200 max_concurrent_subagents: int = Field(default=1, ge=1) max_tool_result_chars: int = 16_000 @@ -288,6 +302,9 @@ class Config(BaseSettings): name = self.agents.defaults.model_preset if name and name != "default" and name not in self.model_presets: raise ValueError(f"model_preset {name!r} not found in model_presets") + for fallback in self.agents.defaults.fallback_models: + if isinstance(fallback, str) and fallback not in self.model_presets: + raise ValueError(f"fallback_models entry {fallback!r} not found in model_presets") return self def resolve_default_preset(self) -> ModelPresetConfig: diff --git a/nanobot/providers/factory.py b/nanobot/providers/factory.py index e4822b7f8..288611392 100644 --- a/nanobot/providers/factory.py +++ b/nanobot/providers/factory.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass from pathlib import Path -from nanobot.config.schema import Config, ModelPresetConfig +from nanobot.config.schema import Config, InlineFallbackConfig, ModelPresetConfig from nanobot.providers.base import LLMProvider from nanobot.providers.fallback_provider import FallbackProvider from nanobot.providers.registry import find_by_name @@ -104,6 +104,36 @@ def _make_provider_core( return provider +def _inline_fallback_preset( + primary: ModelPresetConfig, + fallback: InlineFallbackConfig, +) -> ModelPresetConfig: + return ModelPresetConfig( + model=fallback.model, + provider=fallback.provider, + max_tokens=fallback.max_tokens if fallback.max_tokens is not None else primary.max_tokens, + context_window_tokens=( + fallback.context_window_tokens + if fallback.context_window_tokens is not None + else primary.context_window_tokens + ), + temperature=( + fallback.temperature if fallback.temperature is not None else primary.temperature + ), + reasoning_effort=fallback.reasoning_effort, + ) + + +def _resolve_fallback_presets(config: Config, primary: ModelPresetConfig) -> list[ModelPresetConfig]: + presets: list[ModelPresetConfig] = [] + for fallback in config.agents.defaults.fallback_models: + if isinstance(fallback, str): + presets.append(config.model_presets[fallback]) + else: + presets.append(_inline_fallback_preset(primary, fallback)) + return presets + + def make_provider( config: Config, *, @@ -118,14 +148,14 @@ def make_provider( """ resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model) + fallback_presets = _resolve_fallback_presets(config, resolved) - if resolved.fallback_models: - fb_preset = resolved.model_copy(update={"provider": "auto", "fallback_models": []}) + if fallback_presets: provider = FallbackProvider( primary=provider, - fallback_models=resolved.fallback_models, - provider_factory=lambda m: _make_provider_core( - config, preset_name=preset_name, preset=fb_preset, model=m + fallback_presets=fallback_presets, + provider_factory=lambda fb: _make_provider_core( + config, preset_name=preset_name, preset=fb ), ) @@ -138,9 +168,29 @@ def provider_signature( preset_name: str | None = None, preset: ModelPresetConfig | None = None, ) -> tuple[object, ...]: - """Return the config fields that affect the primary LLM provider.""" + """Return the config fields that affect the active provider chain.""" resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) p = config.get_provider(resolved.model, preset=resolved) + fallback_presets = _resolve_fallback_presets(config, resolved) + + def _fallback_signature(fallback: ModelPresetConfig) -> tuple[object, ...]: + fp = config.get_provider(fallback.model, preset=fallback) + return ( + fallback.model, + fallback.provider, + config.get_provider_name(fallback.model, preset=fallback), + config.get_api_key(fallback.model, preset=fallback), + config.get_api_base(fallback.model, preset=fallback), + fp.extra_headers if fp else None, + fp.extra_body if fp else None, + getattr(fp, "region", None) if fp else None, + getattr(fp, "profile", None) if fp else None, + fallback.max_tokens, + fallback.temperature, + fallback.reasoning_effort, + fallback.context_window_tokens, + ) + return ( resolved.model, resolved.provider, @@ -155,6 +205,7 @@ def provider_signature( resolved.temperature, resolved.reasoning_effort, resolved.context_window_tokens, + tuple(_fallback_signature(fallback) for fallback in fallback_presets), ) @@ -165,10 +216,14 @@ def build_provider_snapshot( preset: ModelPresetConfig | None = None, ) -> ProviderSnapshot: resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) + fallback_windows = [ + fallback.context_window_tokens + for fallback in _resolve_fallback_presets(config, resolved) + ] return ProviderSnapshot( provider=make_provider(config, preset=resolved), model=resolved.model, - context_window_tokens=resolved.context_window_tokens, + context_window_tokens=min([resolved.context_window_tokens, *fallback_windows]), signature=provider_signature(config, preset=resolved), ) diff --git a/nanobot/providers/fallback_provider.py b/nanobot/providers/fallback_provider.py index c0b137890..c082c2361 100644 --- a/nanobot/providers/fallback_provider.py +++ b/nanobot/providers/fallback_provider.py @@ -13,6 +13,46 @@ from nanobot.providers.base import LLMProvider, LLMResponse # Circuit breaker tuned to match OpenAICompatProvider's Responses API breaker. _PRIMARY_FAILURE_THRESHOLD = 3 _PRIMARY_COOLDOWN_S = 60 +_MISSING = object() +_FALLBACK_ERROR_KINDS = frozenset({ + "timeout", + "connection", + "server_error", + "rate_limit", + "overloaded", +}) +_NON_FALLBACK_ERROR_KINDS = frozenset({ + "authentication", + "auth", + "permission", + "content_filter", + "refusal", + "context_length", + "invalid_request", +}) +_FALLBACK_ERROR_TOKENS = ( + "rate_limit", + "rate limit", + "too_many_requests", + "too many requests", + "overloaded", + "server_error", + "server error", + "temporarily unavailable", + "timeout", + "timed out", + "connection", + "insufficient_quota", + "insufficient quota", + "quota_exceeded", + "quota exceeded", + "quota_exhausted", + "quota exhausted", + "billing_hard_limit", + "insufficient_balance", + "balance", + "out of credits", +) class FallbackProvider(LLMProvider): @@ -34,13 +74,13 @@ class FallbackProvider(LLMProvider): def __init__( self, primary: LLMProvider, - fallback_models: list[str], - provider_factory: Callable[[str], LLMProvider], + fallback_presets: list[Any], + provider_factory: Callable[[Any], LLMProvider], ): self._primary = primary - self._fallback_models = list(fallback_models) + self._fallback_presets = list(fallback_presets) self._provider_factory = provider_factory - self._has_fallbacks = bool(fallback_models) + self._has_fallbacks = bool(fallback_presets) self._primary_failures = 0 self._primary_tripped_at: float | None = None @@ -55,6 +95,10 @@ class FallbackProvider(LLMProvider): def get_default_model(self) -> str: return self._primary.get_default_model() + @property + def supports_progress_deltas(self) -> bool: + return bool(getattr(self._primary, "supports_progress_deltas", False)) + def _primary_available(self) -> bool: """Return True if the primary provider is not currently tripped.""" if self._primary_tripped_at is None: @@ -110,6 +154,14 @@ class FallbackProvider(LLMProvider): ) return response + if not self._should_fallback(response): + logger.warning( + "Primary model '{}' returned non-fallbackable error: {}", + primary_model, + (response.content or "")[:120], + ) + return response + self._primary_failures += 1 if self._primary_failures >= _PRIMARY_FAILURE_THRESHOLD: self._primary_tripped_at = time.monotonic() @@ -122,7 +174,8 @@ class FallbackProvider(LLMProvider): last_response: LLMResponse | None = None primary_skipped = not self._primary_available() - for idx, fallback_model in enumerate(self._fallback_models): + for idx, fallback in enumerate(self._fallback_presets): + fallback_model = fallback.model if has_streamed is not None and has_streamed[0]: break if idx == 0 and primary_skipped: @@ -138,25 +191,35 @@ class FallbackProvider(LLMProvider): else: logger.info( "Fallback '{}' also failed, trying next fallback '{}'", - self._fallback_models[idx - 1], fallback_model, + self._fallback_presets[idx - 1].model, fallback_model, ) try: - fallback_provider = self._provider_factory(fallback_model) + fallback_provider = self._provider_factory(fallback) except Exception as exc: logger.warning( "Failed to create provider for fallback '{}': {}", fallback_model, exc ) continue - original_model = kwargs.get("model") + original_values = { + name: kwargs.get(name, _MISSING) + for name in ("model", "max_tokens", "temperature", "reasoning_effort") + } kwargs["model"] = fallback_model + kwargs["max_tokens"] = fallback.max_tokens + kwargs["temperature"] = fallback.temperature + if fallback.reasoning_effort is None: + kwargs.pop("reasoning_effort", None) + else: + kwargs["reasoning_effort"] = fallback.reasoning_effort try: fallback_response = await call(fallback_provider, kwargs) finally: - if original_model is not None: - kwargs["model"] = original_model - else: - kwargs.pop("model", None) + for name, value in original_values.items(): + if value is _MISSING: + kwargs.pop(name, None) + else: + kwargs[name] = value if fallback_response.finish_reason != "error": logger.info( @@ -174,7 +237,7 @@ class FallbackProvider(LLMProvider): logger.warning( "All {} fallback model(s) failed", - len(self._fallback_models), + len(self._fallback_presets), ) # Return the last error response we saw (primary or last fallback). if last_response is not None: @@ -184,3 +247,27 @@ class FallbackProvider(LLMProvider): content=f"Primary model '{primary_model}' circuit open and no fallbacks available", finish_reason="error", ) + + @staticmethod + def _should_fallback(response: LLMResponse) -> bool: + if response.error_should_retry is False: + return False + status = response.error_status_code + kind = (response.error_kind or "").lower() + error_type = (response.error_type or "").lower() + code = (response.error_code or "").lower() + text = (response.content or "").lower() + + if status in {400, 401, 403, 404, 422}: + return False + if kind in _NON_FALLBACK_ERROR_KINDS: + return False + if any(token in value for value in (kind, error_type, code) for token in _NON_FALLBACK_ERROR_KINDS): + return False + if response.error_should_retry is True: + return True + if status is not None and (status in {408, 409, 429} or 500 <= status <= 599): + return True + if kind in _FALLBACK_ERROR_KINDS: + return True + return any(token in value for value in (kind, error_type, code, text) for token in _FALLBACK_ERROR_TOKENS) diff --git a/tests/agent/test_runner_fallback.py b/tests/agent/test_runner_fallback.py index 273bd6d6d..0e36fb02a 100644 --- a/tests/agent/test_runner_fallback.py +++ b/tests/agent/test_runner_fallback.py @@ -3,10 +3,11 @@ from __future__ import annotations from typing import Any -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +from nanobot.config.schema import ModelPresetConfig from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.fallback_provider import FallbackProvider @@ -16,14 +17,45 @@ def _make_response( finish_reason: str = "stop", *, error_kind: str | None = None, + error_status_code: int | None = None, + error_type: str | None = None, + error_code: str | None = None, + error_should_retry: bool | None = None, ) -> LLMResponse: - return LLMResponse(content=content, finish_reason=finish_reason, error_kind=error_kind) + return LLMResponse( + content=content, + finish_reason=finish_reason, + error_kind=error_kind, + error_status_code=error_status_code, + error_type=error_type, + error_code=error_code, + error_should_retry=error_should_retry, + ) def _error_response(content: str = "api error") -> LLMResponse: return _make_response(content, finish_reason="error", error_kind="server_error") +def _fallback( + model: str, + provider: str = "custom", + *, + max_tokens: int = 8192, + context_window_tokens: int = 65_536, + temperature: float = 0.1, + reasoning_effort: str | None = None, +) -> ModelPresetConfig: + return ModelPresetConfig( + model=model, + provider=provider, + max_tokens=max_tokens, + context_window_tokens=context_window_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + ) + + class _FakeProvider(LLMProvider): """Fake provider for testing.""" @@ -53,24 +85,163 @@ class _FakeProvider(LLMProvider): def test_fallback_models_default_empty() -> None: - from nanobot.config.schema import ModelPresetConfig - p = ModelPresetConfig(model="test/model") - assert p.fallback_models == [] + from nanobot.config.schema import AgentDefaults + + defaults = AgentDefaults() + + assert defaults.fallback_models == [] -def test_fallback_models_accepts_list() -> None: - from nanobot.config.schema import ModelPresetConfig - p = ModelPresetConfig(model="test/primary", fallback_models=["test/a", "test/b"]) - assert p.fallback_models == ["test/a", "test/b"] +def test_fallback_models_accept_preset_refs_and_inline_configs() -> None: + from nanobot.config.schema import Config, InlineFallbackConfig - -def test_fallback_models_from_camel_case() -> None: - from nanobot.config.schema import ModelPresetConfig - p = ModelPresetConfig.model_validate({ - "model": "test/primary", - "fallbackModels": ["test/a"], + config = Config.model_validate({ + "agents": { + "defaults": { + "fallbackModels": [ + "deep", + { + "provider": "openai", + "model": "gpt-4.1", + "maxTokens": 4096, + }, + ] + } + }, + "modelPresets": { + "deep": {"provider": "anthropic", "model": "claude-opus-4-7"} + }, }) - assert p.fallback_models == ["test/a"] + + assert config.agents.defaults.fallback_models[0] == "deep" + assert config.agents.defaults.fallback_models[1] == InlineFallbackConfig( + provider="openai", + model="gpt-4.1", + max_tokens=4096, + ) + + +def test_fallback_model_preset_ref_must_exist() -> None: + from nanobot.config.schema import Config + + with pytest.raises(ValueError, match="fallback_models.*not found"): + Config.model_validate({ + "agents": {"defaults": {"fallbackModels": ["missing"]}}, + "modelPresets": {}, + }) + + +def test_provider_signature_tracks_fallback_presets_and_provider_config() -> None: + from nanobot.config.schema import Config + from nanobot.providers.factory import provider_signature + + base = { + "agents": { + "defaults": { + "modelPreset": "fast", + "fallbackModels": ["deep"], + } + }, + "modelPresets": { + "fast": {"model": "openai/gpt-4.1", "provider": "openai"}, + "deep": {"model": "anthropic/claude-sonnet-4-6", "provider": "anthropic"}, + }, + "providers": { + "openai": {"apiKey": "primary-key"}, + "anthropic": {"apiKey": "fallback-key"}, + }, + } + changed_fallback = { + **base, + "agents": {"defaults": {"modelPreset": "fast", "fallbackModels": ["backup"]}}, + "modelPresets": { + **base["modelPresets"], + "backup": {"model": "deepseek/deepseek-chat", "provider": "deepseek"}, + }, + "providers": { + **base["providers"], + "deepseek": {"apiKey": "deepseek-key"}, + }, + } + changed_key = { + **base, + "providers": { + "openai": {"apiKey": "primary-key"}, + "anthropic": {"apiKey": "new-fallback-key"}, + }, + } + + signature = provider_signature(Config.model_validate(base)) + + assert signature != provider_signature(Config.model_validate(changed_fallback)) + assert signature != provider_signature(Config.model_validate(changed_key)) + + +def test_provider_snapshot_uses_smallest_fallback_context_window() -> None: + from nanobot.config.schema import Config + from nanobot.providers.factory import build_provider_snapshot + + config = Config.model_validate({ + "agents": { + "defaults": { + "modelPreset": "fast", + "fallbackModels": ["deep"], + } + }, + "modelPresets": { + "fast": { + "model": "openai/gpt-4.1", + "provider": "openai", + "contextWindowTokens": 128000, + }, + "deep": { + "model": "deepseek/deepseek-chat", + "provider": "deepseek", + "contextWindowTokens": 64000, + }, + }, + "providers": { + "openai": {"apiKey": "primary-key"}, + "deepseek": {"apiKey": "fallback-key"}, + }, + }) + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + snapshot = build_provider_snapshot(config) + + assert snapshot.context_window_tokens == 64000 + + +def test_inline_fallback_reasoning_effort_does_not_inherit_primary() -> None: + from nanobot.config.schema import Config + from nanobot.providers.factory import provider_signature + + config = Config.model_validate({ + "agents": { + "defaults": { + "modelPreset": "fast", + "fallbackModels": [ + {"provider": "openai", "model": "gpt-4.1"} + ], + } + }, + "modelPresets": { + "fast": { + "model": "anthropic/claude-opus-4-5", + "provider": "anthropic", + "reasoningEffort": "high", + } + }, + "providers": { + "anthropic": {"apiKey": "primary-key"}, + "openai": {"apiKey": "fallback-key"}, + }, + }) + + signature = provider_signature(config) + fallback_signatures = signature[-1] + + assert fallback_signatures[0][11] is None # -- FallbackProvider tests -- @@ -83,7 +254,7 @@ class TestNoFallbackWhenPrimarySucceeds: factory = MagicMock() fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_presets=[_fallback("fallback-a")], provider_factory=factory, ) @@ -102,14 +273,14 @@ class TestFallbackOnPrimaryError: fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_presets=[_fallback("fallback-a")], provider_factory=factory, ) result = await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model") assert result.content == "fallback ok" assert result.finish_reason == "stop" - factory.assert_called_once_with("fallback-a") + factory.assert_called_once_with(_fallback("fallback-a")) assert primary.chat_calls[0]["model"] == "primary-model" assert fallback.chat_calls[0]["model"] == "fallback-a" @@ -121,7 +292,7 @@ class TestNoFallbackWhenContentStreamed: factory = MagicMock() fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_presets=[_fallback("fallback-a")], provider_factory=factory, ) @@ -146,14 +317,62 @@ class TestFailoverOnTransientError: factory = MagicMock(return_value=fallback) fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_presets=[_fallback("fallback-a")], provider_factory=factory, ) result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) assert result.content == "fallback ok" assert result.finish_reason == "stop" - factory.assert_called_once_with("fallback-a") + factory.assert_called_once_with(_fallback("fallback-a")) + + +class TestNoFallbackOnNonRetryableError: + @pytest.mark.asyncio + async def test_bad_request(self) -> None: + primary = _FakeProvider( + "primary", + _make_response( + "invalid request", + finish_reason="error", + error_status_code=400, + error_kind="invalid_request", + ), + ) + factory = MagicMock() + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + + assert result.finish_reason == "error" + factory.assert_not_called() + + @pytest.mark.asyncio + async def test_auth_error(self) -> None: + primary = _FakeProvider( + "primary", + _make_response( + "unauthorized", + finish_reason="error", + error_status_code=401, + error_kind="authentication", + ), + ) + factory = MagicMock() + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) + + assert result.finish_reason == "error" + factory.assert_not_called() @pytest.mark.asyncio async def test_timeout(self) -> None: @@ -165,14 +384,14 @@ class TestFailoverOnTransientError: factory = MagicMock(return_value=fallback) fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_presets=[_fallback("fallback-a")], provider_factory=factory, ) result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) assert result.content == "fallback ok" assert result.finish_reason == "stop" - factory.assert_called_once_with("fallback-a") + factory.assert_called_once_with(_fallback("fallback-a")) class TestFallbackTriesModelsInOrder: @@ -185,15 +404,15 @@ class TestFallbackTriesModelsInOrder: fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a", "fallback-b"], + fallback_presets=[_fallback("fallback-a"), _fallback("fallback-b")], provider_factory=factory, ) result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) assert result.content == "b ok" assert factory.call_count == 2 - factory.assert_any_call("fallback-a") - factory.assert_any_call("fallback-b") + factory.assert_any_call(_fallback("fallback-a")) + factory.assert_any_call(_fallback("fallback-b")) class TestAllFallbacksFail: @@ -205,7 +424,7 @@ class TestAllFallbacksFail: fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_presets=[_fallback("fallback-a")], provider_factory=factory, ) @@ -223,7 +442,7 @@ class TestFactoryExceptionSkipsModel: fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a", "fallback-b"], + fallback_presets=[_fallback("fallback-a"), _fallback("fallback-b")], provider_factory=factory, ) @@ -242,13 +461,43 @@ class TestFallbackModelParameter: fb = FallbackProvider( primary=primary, - fallback_models=["fallback-model"], + fallback_presets=[_fallback("fallback-model")], provider_factory=factory, ) await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model") assert fallback.chat_calls[0]["model"] == "fallback-model" + @pytest.mark.asyncio + async def test_uses_fallback_generation_fields(self) -> None: + primary = _FakeProvider("primary", _error_response()) + fallback = _FakeProvider("fallback", _make_response("ok")) + fb = FallbackProvider( + primary=primary, + fallback_presets=[ + _fallback( + "fallback-model", + max_tokens=1234, + temperature=0.4, + reasoning_effort=None, + ) + ], + provider_factory=MagicMock(return_value=fallback), + ) + + await fb.chat( + messages=[{"role": "user", "content": "hi"}], + model="primary-model", + max_tokens=8192, + temperature=0.1, + reasoning_effort="high", + ) + + assert fallback.chat_calls[0]["model"] == "fallback-model" + assert fallback.chat_calls[0]["max_tokens"] == 1234 + assert fallback.chat_calls[0]["temperature"] == 0.4 + assert "reasoning_effort" not in fallback.chat_calls[0] + class TestNoFallbackWhenEmptyList: @pytest.mark.asyncio @@ -258,7 +507,7 @@ class TestNoFallbackWhenEmptyList: fb = FallbackProvider( primary=primary, - fallback_models=[], + fallback_presets=[], provider_factory=factory, ) @@ -277,7 +526,7 @@ class TestChatStreamFailover: fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_presets=[_fallback("fallback-a")], provider_factory=factory, ) @@ -291,7 +540,7 @@ class TestGetDefaultModel: primary = _FakeProvider("primary") fb = FallbackProvider( primary=primary, - fallback_models=["a"], + fallback_presets=[_fallback("a")], provider_factory=MagicMock(), ) assert fb.get_default_model() == "primary/model" @@ -305,7 +554,7 @@ class TestCircuitBreaker: factory = MagicMock(return_value=fallback) fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_presets=[_fallback("fallback-a")], provider_factory=factory, ) @@ -329,7 +578,7 @@ class TestCircuitBreaker: factory = MagicMock(return_value=fallback) fb = FallbackProvider( primary=primary, - fallback_models=["fallback-a"], + fallback_presets=[_fallback("fallback-a")], provider_factory=factory, ) @@ -357,7 +606,7 @@ class TestGenerationForwarded: primary.generation = GenerationSettings(temperature=0.5, max_tokens=1024) fb = FallbackProvider( primary=primary, - fallback_models=["a"], + fallback_presets=[_fallback("a")], provider_factory=MagicMock(), ) assert fb.generation.temperature == 0.5