diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 4a25b5ca3..532b4da0d 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -16,6 +16,7 @@ from nanobot.agent.hook import AgentHook, AgentHookContext from nanobot.agent.tools.ask import AskUserInterrupt from nanobot.agent.tools.registry import ToolRegistry from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest +from nanobot.providers.failover import ModelCandidate, ModelRouter from nanobot.utils.helpers import ( build_assistant_message, estimate_message_tokens, @@ -621,24 +622,14 @@ class AgentRunner: messages, tools=spec.tools.get_definitions(), ) - response = await self._call_provider(self.provider, kwargs, hook, context, spec, timeout_s) - - if response.finish_reason == "error" and spec.fallback_models: - for fb_model in spec.fallback_models: - logger.warning( - "Primary model {} failed, trying fallback: {}", - spec.model, - fb_model, - ) - fb_provider, resolved_model = self._resolve_fallback_provider(fb_model) - fb_kwargs = dict(kwargs, model=resolved_model) - response = await self._call_provider( - fb_provider, fb_kwargs, hook, context, spec, timeout_s, - ) - if response.finish_reason != "error": - break - - return response + provider: LLMProvider = self.provider + request_timeout = timeout_s + if spec.fallback_models: + provider = self._build_model_router(spec, timeout_s) + # ModelRouter applies the same timeout per candidate, preserving + # fallback on primary timeouts instead of timing out the whole chain. + request_timeout = None + return await self._call_provider(provider, kwargs, hook, context, spec, request_timeout) async def _call_provider( self, @@ -716,6 +707,25 @@ class AgentRunner: return provider, provider.get_default_model() return self.provider, model + def _build_model_router( + self, + spec: AgentRunSpec, + timeout_s: float | None, + ) -> ModelRouter: + candidates = [ + ModelCandidate( + label=model, + resolver=lambda m=model: self._resolve_fallback_provider(m), + ) + for model in spec.fallback_models + ] + return ModelRouter( + primary_provider=self.provider, + primary_model=spec.model, + fallback_candidates=candidates, + per_candidate_timeout_s=timeout_s, + ) + async def _request_finalization_retry( self, spec: AgentRunSpec, diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index a3f5e24e8..34abb8fc5 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -527,8 +527,7 @@ def _make_cli_provider_factory(config: Config): def factory(model_or_preset: str): preset = presets.get(model_or_preset) actual_model = preset.model if preset else model_or_preset - provider_name = config.get_provider_name(actual_model) - key = provider_name or actual_model + key = actual_model if key not in cache: cache[key] = _make_provider_for_model(config, actual_model, preset=preset) return cache[key] diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index e976eae29..1eaef6636 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -224,8 +224,7 @@ def _make_provider_factory(config: Any): def factory(model_or_preset: str): preset = presets.get(model_or_preset) actual_model = preset.model if preset else model_or_preset - provider_name = config.get_provider_name(actual_model) - key = provider_name or actual_model + key = actual_model if key not in cache: cache[key] = _make_provider_for_model(config, actual_model, preset=preset) return cache[key] diff --git a/nanobot/providers/failover.py b/nanobot/providers/failover.py new file mode 100644 index 000000000..660652fd4 --- /dev/null +++ b/nanobot/providers/failover.py @@ -0,0 +1,276 @@ +"""Provider-like failover router used after provider-local retry is exhausted.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from loguru import logger + +from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse + + +@dataclass(frozen=True) +class ModelCandidate: + """A lazily resolved model/provider candidate.""" + + label: str + resolver: Callable[[], tuple[LLMProvider, str]] + + +class ModelRouter(LLMProvider): + """Try fallback model candidates for eligible transient final errors.""" + + supports_progress_deltas = False + + _BLOCKED_STATUS_CODES = frozenset({400, 401, 403, 404, 422}) + _QUOTA_MARKERS = ( + "insufficient_quota", + "insufficient quota", + "quota exceeded", + "quota_exceeded", + "quota exhausted", + "quota_exhausted", + "billing hard limit", + "billing_hard_limit_reached", + "billing not active", + "insufficient balance", + "insufficient_balance", + "credit balance too low", + "payment required", + "out of credits", + "out of quota", + "exceeded your current quota", + ) + _NON_FAILOVER_MARKERS = ( + "context length", + "context_length", + "maximum context", + "max context", + "token budget", + "too many tokens", + "schema", + "invalid request", + "invalid_request", + "invalid parameter", + "invalid_parameter", + "unsupported", + "unauthorized", + "authentication", + "permission", + "forbidden", + "refusal", + "content policy", + "content_filter", + "policy violation", + "safety", + ) + + def __init__( + self, + *, + primary_provider: LLMProvider, + primary_model: str, + fallback_candidates: list[ModelCandidate], + per_candidate_timeout_s: float | None = None, + ) -> None: + super().__init__( + api_key=getattr(primary_provider, "api_key", None), + api_base=getattr(primary_provider, "api_base", None), + ) + self.primary_provider = primary_provider + self.primary_model = primary_model + self.fallback_candidates = list(fallback_candidates) + self.per_candidate_timeout_s = per_candidate_timeout_s + self.generation = getattr(primary_provider, "generation", GenerationSettings()) + + def get_default_model(self) -> str: + return self.primary_model + + async def chat(self, **kwargs: Any) -> LLMResponse: + return await self.primary_provider.chat(**kwargs) + + async def chat_stream(self, **kwargs: Any) -> LLMResponse: + return await self.primary_provider.chat_stream(**kwargs) + + @classmethod + def _is_quota_error(cls, response: LLMResponse) -> bool: + tokens = { + cls._normalize_error_token(response.error_type), + cls._normalize_error_token(response.error_code), + } + if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in tokens if token): + return True + content = (response.content or "").lower() + return any(marker in content for marker in cls._QUOTA_MARKERS) + + @classmethod + def _is_blocked_error(cls, response: LLMResponse) -> bool: + status = response.error_status_code + if status is not None and int(status) in cls._BLOCKED_STATUS_CODES: + return True + if response.finish_reason in {"refusal", "content_filter"}: + return True + content = (response.content or "").lower() + return any(marker in content for marker in cls._NON_FAILOVER_MARKERS) + + @classmethod + def _should_failover(cls, response: LLMResponse) -> bool: + if response.finish_reason != "error": + return False + if cls._is_blocked_error(response): + return False + if cls._is_quota_error(response): + return False + return cls._is_transient_response(response) + + async def _with_timeout(self, coro: Awaitable[LLMResponse]) -> LLMResponse: + timeout_s = self.per_candidate_timeout_s + if timeout_s is None: + return await coro + try: + return await asyncio.wait_for(coro, timeout=timeout_s) + except asyncio.TimeoutError: + return LLMResponse( + content=f"Error calling LLM: timed out after {timeout_s:g}s", + finish_reason="error", + error_kind="timeout", + ) + + @staticmethod + def _resolver_error(candidate: ModelCandidate, exc: Exception) -> LLMResponse: + logger.warning("Failed to resolve fallback model {}: {}", candidate.label, exc) + return LLMResponse( + content=f"Error configuring fallback model {candidate.label}: {exc}", + finish_reason="error", + error_kind="configuration", + error_should_retry=False, + ) + + def _candidate_chain(self) -> list[ModelCandidate]: + return [ + ModelCandidate( + label=self.primary_model, + resolver=lambda: (self.primary_provider, self.primary_model), + ), + *self.fallback_candidates, + ] + + async def _route( + self, + call: Callable[[LLMProvider, str, Callable[[str], Awaitable[None]] | None], Awaitable[LLMResponse]], + *, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + last_response: LLMResponse | None = None + chain = self._candidate_chain() + for index, candidate in enumerate(chain): + try: + provider, model = candidate.resolver() + except asyncio.CancelledError: + raise + except Exception as exc: + response = self._resolver_error(candidate, exc) + else: + response = await self._with_timeout(call(provider, model, on_content_delta)) + + if response.finish_reason != "error": + if index > 0: + logger.info("LLM failover selected model={}", candidate.label) + return response + + last_response = response + if not self._should_failover(response): + return response + if index + 1 >= len(chain): + logger.warning("LLM failover exhausted after model={}", candidate.label) + return response + logger.warning( + "LLM failover model={} next_model={} status={} kind={}", + candidate.label, + chain[index + 1].label, + response.error_status_code, + response.error_kind or response.error_type or response.error_code or "unknown", + ) + + return last_response or LLMResponse( + content="No available fallback model candidate.", + finish_reason="error", + error_kind="configuration", + error_should_retry=False, + ) + + async def chat_with_retry( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: object = LLMProvider._SENTINEL, + temperature: object = LLMProvider._SENTINEL, + reasoning_effort: object = LLMProvider._SENTINEL, + tool_choice: str | dict[str, Any] | None = None, + retry_mode: str = "standard", + on_retry_wait: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + async def call( + provider: LLMProvider, + candidate_model: str, + _delta: Callable[[str], Awaitable[None]] | None, + ) -> LLMResponse: + return await provider.chat_with_retry( + messages=messages, + tools=tools, + model=candidate_model, + max_tokens=max_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + tool_choice=tool_choice, + retry_mode=retry_mode, + on_retry_wait=on_retry_wait, + ) + + return await self._route(call) + + async def chat_stream_with_retry( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: object = LLMProvider._SENTINEL, + temperature: object = LLMProvider._SENTINEL, + reasoning_effort: object = LLMProvider._SENTINEL, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + retry_mode: str = "standard", + on_retry_wait: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + async def call( + provider: LLMProvider, + candidate_model: str, + external_delta: Callable[[str], Awaitable[None]] | None, + ) -> LLMResponse: + buffered: list[str] = [] + + async def buffer_delta(delta: str) -> None: + buffered.append(delta) + + response = await provider.chat_stream_with_retry( + messages=messages, + tools=tools, + model=candidate_model, + max_tokens=max_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + tool_choice=tool_choice, + on_content_delta=buffer_delta if external_delta else None, + retry_mode=retry_mode, + on_retry_wait=on_retry_wait, + ) + if response.finish_reason != "error" and external_delta: + for delta in buffered: + await external_delta(delta) + return response + + return await self._route(call, on_content_delta=on_content_delta) diff --git a/tests/agent/test_runner_fallback.py b/tests/agent/test_runner_fallback.py index bfd0bc225..3ade85041 100644 --- a/tests/agent/test_runner_fallback.py +++ b/tests/agent/test_runner_fallback.py @@ -6,7 +6,8 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from nanobot.agent.runner import AgentRunSpec, AgentRunner +from nanobot.agent.hook import AgentHook, AgentHookContext +from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.providers.base import LLMResponse @@ -24,6 +25,10 @@ def _make_provider(*, model_response: LLMResponse | None = None): return p +def _transient_error(content: str = "server unavailable") -> LLMResponse: + return LLMResponse(content=content, finish_reason="error", error_status_code=503) + + def _base_spec(**overrides) -> AgentRunSpec: defaults = dict( initial_messages=[ @@ -56,7 +61,7 @@ async def test_no_fallback_when_primary_succeeds(): @pytest.mark.asyncio async def test_fallback_triggered_on_primary_error(): """Primary fails -> first fallback succeeds.""" - err = LLMResponse(content=None, finish_reason="error", usage={}) + err = _transient_error() ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={}) primary = _make_provider(model_response=err) @@ -76,12 +81,12 @@ async def test_fallback_triggered_on_primary_error(): @pytest.mark.asyncio async def test_all_fallbacks_fail_returns_last_error(): """Primary + all fallbacks fail -> return last error response.""" - err = LLMResponse(content=None, finish_reason="error", usage={}) + err = _transient_error() primary = _make_provider(model_response=err) fb1 = _make_provider(model_response=err) fb2 = _make_provider(model_response=LLMResponse( - content="last-error", finish_reason="error", usage={}, + content="last-error", finish_reason="error", error_status_code=500, usage={}, )) providers = {"fb-1": fb1, "fb-2": fb2} @@ -110,7 +115,7 @@ async def test_empty_fallback_list_no_retry(): @pytest.mark.asyncio async def test_cross_provider_fallback(): """Fallback uses a different provider instance (cross-provider).""" - err = LLMResponse(content=None, finish_reason="error", usage={}) + err = _transient_error() ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={}) primary = _make_provider(model_response=err) @@ -132,7 +137,7 @@ async def test_cross_provider_fallback(): @pytest.mark.asyncio async def test_fallback_skips_to_second_on_first_error(): """First fallback also fails -> second fallback succeeds.""" - err = LLMResponse(content=None, finish_reason="error", usage={}) + err = _transient_error() ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={}) primary = _make_provider(model_response=err) @@ -158,7 +163,7 @@ async def test_fallback_reuses_same_provider_without_factory(): async def chat_with_retry(*, messages, model, **kw): call_count["n"] += 1 if call_count["n"] == 1: - return LLMResponse(content=None, finish_reason="error", usage={}) + return _transient_error() return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={}) primary = MagicMock() @@ -173,7 +178,7 @@ async def test_fallback_reuses_same_provider_without_factory(): @pytest.mark.asyncio async def test_fallback_provider_cached(): """Provider factory is called once per unique provider, not per attempt.""" - err = LLMResponse(content=None, finish_reason="error", usage={}) + err = _transient_error() ok = LLMResponse(content="cached-ok", tool_calls=[], usage={}) primary = _make_provider(model_response=err) @@ -188,3 +193,77 @@ async def test_fallback_provider_cached(): result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"])) assert result.final_content == "cached-ok" + + +@pytest.mark.asyncio +async def test_non_transient_error_does_not_fallback(): + """Auth/config-style errors should surface instead of hiding bugs via fallback.""" + primary = _make_provider(model_response=LLMResponse( + content="401 unauthorized", + finish_reason="error", + error_status_code=401, + )) + fallback = _make_provider(model_response=LLMResponse(content="fallback-ok")) + factory = MagicMock(return_value=fallback) + + runner = AgentRunner(primary, provider_factory=factory) + result = await runner.run(_base_spec(fallback_models=["fb-model"])) + + factory.assert_not_called() + assert result.error is not None + + +@pytest.mark.asyncio +async def test_quota_error_does_not_fallback_by_default(): + """Quota/billing/payment 429s should not route by default.""" + primary = _make_provider(model_response=LLMResponse( + content="insufficient quota", + finish_reason="error", + error_status_code=429, + error_code="insufficient_quota", + )) + fallback = _make_provider(model_response=LLMResponse(content="fallback-ok")) + factory = MagicMock(return_value=fallback) + + runner = AgentRunner(primary, provider_factory=factory) + result = await runner.run(_base_spec(fallback_models=["fb-model"])) + + factory.assert_not_called() + assert result.error is not None + + +@pytest.mark.asyncio +async def test_streaming_fallback_discards_failed_primary_deltas(): + """Buffered streaming prevents primary partial output from leaking on fallback.""" + streamed: list[str] = [] + + class StreamingHook(AgentHook): + def wants_streaming(self) -> bool: + return True + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + streamed.append(delta) + + async def primary_stream(*, on_content_delta, **kwargs): + await on_content_delta("bad partial") + return _transient_error() + + async def fallback_stream(*, on_content_delta, **kwargs): + await on_content_delta("good") + await on_content_delta(" answer") + return LLMResponse(content="good answer", tool_calls=[], usage={}) + + primary = MagicMock() + primary.chat_stream_with_retry = primary_stream + fallback = MagicMock() + fallback.chat_stream_with_retry = fallback_stream + factory = MagicMock(return_value=fallback) + + runner = AgentRunner(primary, provider_factory=factory) + result = await runner.run(_base_spec( + fallback_models=["fb-model"], + hook=StreamingHook(), + )) + + assert result.final_content == "good answer" + assert streamed == ["good", " answer"]