mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
Refine fallback routing on model presets
This commit is contained in:
parent
2e5930e355
commit
7c270577e1
@ -16,6 +16,7 @@ from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.tools.ask import AskUserInterrupt
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.failover import ModelCandidate, ModelRouter
|
||||
from nanobot.utils.helpers import (
|
||||
build_assistant_message,
|
||||
estimate_message_tokens,
|
||||
@ -621,24 +622,14 @@ class AgentRunner:
|
||||
messages,
|
||||
tools=spec.tools.get_definitions(),
|
||||
)
|
||||
response = await self._call_provider(self.provider, kwargs, hook, context, spec, timeout_s)
|
||||
|
||||
if response.finish_reason == "error" and spec.fallback_models:
|
||||
for fb_model in spec.fallback_models:
|
||||
logger.warning(
|
||||
"Primary model {} failed, trying fallback: {}",
|
||||
spec.model,
|
||||
fb_model,
|
||||
)
|
||||
fb_provider, resolved_model = self._resolve_fallback_provider(fb_model)
|
||||
fb_kwargs = dict(kwargs, model=resolved_model)
|
||||
response = await self._call_provider(
|
||||
fb_provider, fb_kwargs, hook, context, spec, timeout_s,
|
||||
)
|
||||
if response.finish_reason != "error":
|
||||
break
|
||||
|
||||
return response
|
||||
provider: LLMProvider = self.provider
|
||||
request_timeout = timeout_s
|
||||
if spec.fallback_models:
|
||||
provider = self._build_model_router(spec, timeout_s)
|
||||
# ModelRouter applies the same timeout per candidate, preserving
|
||||
# fallback on primary timeouts instead of timing out the whole chain.
|
||||
request_timeout = None
|
||||
return await self._call_provider(provider, kwargs, hook, context, spec, request_timeout)
|
||||
|
||||
async def _call_provider(
|
||||
self,
|
||||
@ -716,6 +707,25 @@ class AgentRunner:
|
||||
return provider, provider.get_default_model()
|
||||
return self.provider, model
|
||||
|
||||
def _build_model_router(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
timeout_s: float | None,
|
||||
) -> ModelRouter:
|
||||
candidates = [
|
||||
ModelCandidate(
|
||||
label=model,
|
||||
resolver=lambda m=model: self._resolve_fallback_provider(m),
|
||||
)
|
||||
for model in spec.fallback_models
|
||||
]
|
||||
return ModelRouter(
|
||||
primary_provider=self.provider,
|
||||
primary_model=spec.model,
|
||||
fallback_candidates=candidates,
|
||||
per_candidate_timeout_s=timeout_s,
|
||||
)
|
||||
|
||||
async def _request_finalization_retry(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
|
||||
@ -527,8 +527,7 @@ def _make_cli_provider_factory(config: Config):
|
||||
def factory(model_or_preset: str):
|
||||
preset = presets.get(model_or_preset)
|
||||
actual_model = preset.model if preset else model_or_preset
|
||||
provider_name = config.get_provider_name(actual_model)
|
||||
key = provider_name or actual_model
|
||||
key = actual_model
|
||||
if key not in cache:
|
||||
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
||||
return cache[key]
|
||||
|
||||
@ -220,8 +220,7 @@ def _make_provider_factory(config: Any):
|
||||
def factory(model_or_preset: str):
|
||||
preset = presets.get(model_or_preset)
|
||||
actual_model = preset.model if preset else model_or_preset
|
||||
provider_name = config.get_provider_name(actual_model)
|
||||
key = provider_name or actual_model
|
||||
key = actual_model
|
||||
if key not in cache:
|
||||
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
||||
return cache[key]
|
||||
|
||||
276
nanobot/providers/failover.py
Normal file
276
nanobot/providers/failover.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""Provider-like failover router used after provider-local retry is exhausted."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelCandidate:
|
||||
"""A lazily resolved model/provider candidate."""
|
||||
|
||||
label: str
|
||||
resolver: Callable[[], tuple[LLMProvider, str]]
|
||||
|
||||
|
||||
class ModelRouter(LLMProvider):
|
||||
"""Try fallback model candidates for eligible transient final errors."""
|
||||
|
||||
supports_progress_deltas = False
|
||||
|
||||
_BLOCKED_STATUS_CODES = frozenset({400, 401, 403, 404, 422})
|
||||
_QUOTA_MARKERS = (
|
||||
"insufficient_quota",
|
||||
"insufficient quota",
|
||||
"quota exceeded",
|
||||
"quota_exceeded",
|
||||
"quota exhausted",
|
||||
"quota_exhausted",
|
||||
"billing hard limit",
|
||||
"billing_hard_limit_reached",
|
||||
"billing not active",
|
||||
"insufficient balance",
|
||||
"insufficient_balance",
|
||||
"credit balance too low",
|
||||
"payment required",
|
||||
"out of credits",
|
||||
"out of quota",
|
||||
"exceeded your current quota",
|
||||
)
|
||||
_NON_FAILOVER_MARKERS = (
|
||||
"context length",
|
||||
"context_length",
|
||||
"maximum context",
|
||||
"max context",
|
||||
"token budget",
|
||||
"too many tokens",
|
||||
"schema",
|
||||
"invalid request",
|
||||
"invalid_request",
|
||||
"invalid parameter",
|
||||
"invalid_parameter",
|
||||
"unsupported",
|
||||
"unauthorized",
|
||||
"authentication",
|
||||
"permission",
|
||||
"forbidden",
|
||||
"refusal",
|
||||
"content policy",
|
||||
"content_filter",
|
||||
"policy violation",
|
||||
"safety",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
primary_provider: LLMProvider,
|
||||
primary_model: str,
|
||||
fallback_candidates: list[ModelCandidate],
|
||||
per_candidate_timeout_s: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
api_key=getattr(primary_provider, "api_key", None),
|
||||
api_base=getattr(primary_provider, "api_base", None),
|
||||
)
|
||||
self.primary_provider = primary_provider
|
||||
self.primary_model = primary_model
|
||||
self.fallback_candidates = list(fallback_candidates)
|
||||
self.per_candidate_timeout_s = per_candidate_timeout_s
|
||||
self.generation = getattr(primary_provider, "generation", GenerationSettings())
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.primary_model
|
||||
|
||||
async def chat(self, **kwargs: Any) -> LLMResponse:
|
||||
return await self.primary_provider.chat(**kwargs)
|
||||
|
||||
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||
return await self.primary_provider.chat_stream(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _is_quota_error(cls, response: LLMResponse) -> bool:
|
||||
tokens = {
|
||||
cls._normalize_error_token(response.error_type),
|
||||
cls._normalize_error_token(response.error_code),
|
||||
}
|
||||
if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in tokens if token):
|
||||
return True
|
||||
content = (response.content or "").lower()
|
||||
return any(marker in content for marker in cls._QUOTA_MARKERS)
|
||||
|
||||
@classmethod
|
||||
def _is_blocked_error(cls, response: LLMResponse) -> bool:
|
||||
status = response.error_status_code
|
||||
if status is not None and int(status) in cls._BLOCKED_STATUS_CODES:
|
||||
return True
|
||||
if response.finish_reason in {"refusal", "content_filter"}:
|
||||
return True
|
||||
content = (response.content or "").lower()
|
||||
return any(marker in content for marker in cls._NON_FAILOVER_MARKERS)
|
||||
|
||||
@classmethod
|
||||
def _should_failover(cls, response: LLMResponse) -> bool:
|
||||
if response.finish_reason != "error":
|
||||
return False
|
||||
if cls._is_blocked_error(response):
|
||||
return False
|
||||
if cls._is_quota_error(response):
|
||||
return False
|
||||
return cls._is_transient_response(response)
|
||||
|
||||
async def _with_timeout(self, coro: Awaitable[LLMResponse]) -> LLMResponse:
|
||||
timeout_s = self.per_candidate_timeout_s
|
||||
if timeout_s is None:
|
||||
return await coro
|
||||
try:
|
||||
return await asyncio.wait_for(coro, timeout=timeout_s)
|
||||
except asyncio.TimeoutError:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: timed out after {timeout_s:g}s",
|
||||
finish_reason="error",
|
||||
error_kind="timeout",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolver_error(candidate: ModelCandidate, exc: Exception) -> LLMResponse:
|
||||
logger.warning("Failed to resolve fallback model {}: {}", candidate.label, exc)
|
||||
return LLMResponse(
|
||||
content=f"Error configuring fallback model {candidate.label}: {exc}",
|
||||
finish_reason="error",
|
||||
error_kind="configuration",
|
||||
error_should_retry=False,
|
||||
)
|
||||
|
||||
def _candidate_chain(self) -> list[ModelCandidate]:
|
||||
return [
|
||||
ModelCandidate(
|
||||
label=self.primary_model,
|
||||
resolver=lambda: (self.primary_provider, self.primary_model),
|
||||
),
|
||||
*self.fallback_candidates,
|
||||
]
|
||||
|
||||
async def _route(
|
||||
self,
|
||||
call: Callable[[LLMProvider, str, Callable[[str], Awaitable[None]] | None], Awaitable[LLMResponse]],
|
||||
*,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
last_response: LLMResponse | None = None
|
||||
chain = self._candidate_chain()
|
||||
for index, candidate in enumerate(chain):
|
||||
try:
|
||||
provider, model = candidate.resolver()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
response = self._resolver_error(candidate, exc)
|
||||
else:
|
||||
response = await self._with_timeout(call(provider, model, on_content_delta))
|
||||
|
||||
if response.finish_reason != "error":
|
||||
if index > 0:
|
||||
logger.info("LLM failover selected model={}", candidate.label)
|
||||
return response
|
||||
|
||||
last_response = response
|
||||
if not self._should_failover(response):
|
||||
return response
|
||||
if index + 1 >= len(chain):
|
||||
logger.warning("LLM failover exhausted after model={}", candidate.label)
|
||||
return response
|
||||
logger.warning(
|
||||
"LLM failover model={} next_model={} status={} kind={}",
|
||||
candidate.label,
|
||||
chain[index + 1].label,
|
||||
response.error_status_code,
|
||||
response.error_kind or response.error_type or response.error_code or "unknown",
|
||||
)
|
||||
|
||||
return last_response or LLMResponse(
|
||||
content="No available fallback model candidate.",
|
||||
finish_reason="error",
|
||||
error_kind="configuration",
|
||||
error_should_retry=False,
|
||||
)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: object = LLMProvider._SENTINEL,
|
||||
temperature: object = LLMProvider._SENTINEL,
|
||||
reasoning_effort: object = LLMProvider._SENTINEL,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
retry_mode: str = "standard",
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
async def call(
|
||||
provider: LLMProvider,
|
||||
candidate_model: str,
|
||||
_delta: Callable[[str], Awaitable[None]] | None,
|
||||
) -> LLMResponse:
|
||||
return await provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=candidate_model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
retry_mode=retry_mode,
|
||||
on_retry_wait=on_retry_wait,
|
||||
)
|
||||
|
||||
return await self._route(call)
|
||||
|
||||
async def chat_stream_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: object = LLMProvider._SENTINEL,
|
||||
temperature: object = LLMProvider._SENTINEL,
|
||||
reasoning_effort: object = LLMProvider._SENTINEL,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
retry_mode: str = "standard",
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
async def call(
|
||||
provider: LLMProvider,
|
||||
candidate_model: str,
|
||||
external_delta: Callable[[str], Awaitable[None]] | None,
|
||||
) -> LLMResponse:
|
||||
buffered: list[str] = []
|
||||
|
||||
async def buffer_delta(delta: str) -> None:
|
||||
buffered.append(delta)
|
||||
|
||||
response = await provider.chat_stream_with_retry(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=candidate_model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
on_content_delta=buffer_delta if external_delta else None,
|
||||
retry_mode=retry_mode,
|
||||
on_retry_wait=on_retry_wait,
|
||||
)
|
||||
if response.finish_reason != "error" and external_delta:
|
||||
for delta in buffered:
|
||||
await external_delta(delta)
|
||||
return response
|
||||
|
||||
return await self._route(call, on_content_delta=on_content_delta)
|
||||
@ -6,7 +6,8 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
@ -24,6 +25,10 @@ def _make_provider(*, model_response: LLMResponse | None = None):
|
||||
return p
|
||||
|
||||
|
||||
def _transient_error(content: str = "server unavailable") -> LLMResponse:
|
||||
return LLMResponse(content=content, finish_reason="error", error_status_code=503)
|
||||
|
||||
|
||||
def _base_spec(**overrides) -> AgentRunSpec:
|
||||
defaults = dict(
|
||||
initial_messages=[
|
||||
@ -56,7 +61,7 @@ async def test_no_fallback_when_primary_succeeds():
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_triggered_on_primary_error():
|
||||
"""Primary fails -> first fallback succeeds."""
|
||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||
err = _transient_error()
|
||||
ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
@ -76,12 +81,12 @@ async def test_fallback_triggered_on_primary_error():
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_fallbacks_fail_returns_last_error():
|
||||
"""Primary + all fallbacks fail -> return last error response."""
|
||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||
err = _transient_error()
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
fb1 = _make_provider(model_response=err)
|
||||
fb2 = _make_provider(model_response=LLMResponse(
|
||||
content="last-error", finish_reason="error", usage={},
|
||||
content="last-error", finish_reason="error", error_status_code=500, usage={},
|
||||
))
|
||||
|
||||
providers = {"fb-1": fb1, "fb-2": fb2}
|
||||
@ -110,7 +115,7 @@ async def test_empty_fallback_list_no_retry():
|
||||
@pytest.mark.asyncio
|
||||
async def test_cross_provider_fallback():
|
||||
"""Fallback uses a different provider instance (cross-provider)."""
|
||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||
err = _transient_error()
|
||||
ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
@ -132,7 +137,7 @@ async def test_cross_provider_fallback():
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_skips_to_second_on_first_error():
|
||||
"""First fallback also fails -> second fallback succeeds."""
|
||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||
err = _transient_error()
|
||||
ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
@ -158,7 +163,7 @@ async def test_fallback_reuses_same_provider_without_factory():
|
||||
async def chat_with_retry(*, messages, model, **kw):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content=None, finish_reason="error", usage={})
|
||||
return _transient_error()
|
||||
return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={})
|
||||
|
||||
primary = MagicMock()
|
||||
@ -173,7 +178,7 @@ async def test_fallback_reuses_same_provider_without_factory():
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_provider_cached():
|
||||
"""Provider factory is called once per unique provider, not per attempt."""
|
||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||
err = _transient_error()
|
||||
ok = LLMResponse(content="cached-ok", tool_calls=[], usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
@ -188,3 +193,77 @@ async def test_fallback_provider_cached():
|
||||
result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"]))
|
||||
|
||||
assert result.final_content == "cached-ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_transient_error_does_not_fallback():
|
||||
"""Auth/config-style errors should surface instead of hiding bugs via fallback."""
|
||||
primary = _make_provider(model_response=LLMResponse(
|
||||
content="401 unauthorized",
|
||||
finish_reason="error",
|
||||
error_status_code=401,
|
||||
))
|
||||
fallback = _make_provider(model_response=LLMResponse(content="fallback-ok"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
|
||||
|
||||
factory.assert_not_called()
|
||||
assert result.error is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_error_does_not_fallback_by_default():
|
||||
"""Quota/billing/payment 429s should not route by default."""
|
||||
primary = _make_provider(model_response=LLMResponse(
|
||||
content="insufficient quota",
|
||||
finish_reason="error",
|
||||
error_status_code=429,
|
||||
error_code="insufficient_quota",
|
||||
))
|
||||
fallback = _make_provider(model_response=LLMResponse(content="fallback-ok"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
|
||||
|
||||
factory.assert_not_called()
|
||||
assert result.error is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_fallback_discards_failed_primary_deltas():
|
||||
"""Buffered streaming prevents primary partial output from leaking on fallback."""
|
||||
streamed: list[str] = []
|
||||
|
||||
class StreamingHook(AgentHook):
|
||||
def wants_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
streamed.append(delta)
|
||||
|
||||
async def primary_stream(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("bad partial")
|
||||
return _transient_error()
|
||||
|
||||
async def fallback_stream(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("good")
|
||||
await on_content_delta(" answer")
|
||||
return LLMResponse(content="good answer", tool_calls=[], usage={})
|
||||
|
||||
primary = MagicMock()
|
||||
primary.chat_stream_with_retry = primary_stream
|
||||
fallback = MagicMock()
|
||||
fallback.chat_stream_with_retry = fallback_stream
|
||||
factory = MagicMock(return_value=fallback)
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(
|
||||
fallback_models=["fb-model"],
|
||||
hook=StreamingHook(),
|
||||
))
|
||||
|
||||
assert result.final_content == "good answer"
|
||||
assert streamed == ["good", " answer"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user