mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
Refine fallback routing on model presets
This commit is contained in:
parent
2e5930e355
commit
7c270577e1
@ -16,6 +16,7 @@ from nanobot.agent.hook import AgentHook, AgentHookContext
|
|||||||
from nanobot.agent.tools.ask import AskUserInterrupt
|
from nanobot.agent.tools.ask import AskUserInterrupt
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
from nanobot.providers.failover import ModelCandidate, ModelRouter
|
||||||
from nanobot.utils.helpers import (
|
from nanobot.utils.helpers import (
|
||||||
build_assistant_message,
|
build_assistant_message,
|
||||||
estimate_message_tokens,
|
estimate_message_tokens,
|
||||||
@ -621,24 +622,14 @@ class AgentRunner:
|
|||||||
messages,
|
messages,
|
||||||
tools=spec.tools.get_definitions(),
|
tools=spec.tools.get_definitions(),
|
||||||
)
|
)
|
||||||
response = await self._call_provider(self.provider, kwargs, hook, context, spec, timeout_s)
|
provider: LLMProvider = self.provider
|
||||||
|
request_timeout = timeout_s
|
||||||
if response.finish_reason == "error" and spec.fallback_models:
|
if spec.fallback_models:
|
||||||
for fb_model in spec.fallback_models:
|
provider = self._build_model_router(spec, timeout_s)
|
||||||
logger.warning(
|
# ModelRouter applies the same timeout per candidate, preserving
|
||||||
"Primary model {} failed, trying fallback: {}",
|
# fallback on primary timeouts instead of timing out the whole chain.
|
||||||
spec.model,
|
request_timeout = None
|
||||||
fb_model,
|
return await self._call_provider(provider, kwargs, hook, context, spec, request_timeout)
|
||||||
)
|
|
||||||
fb_provider, resolved_model = self._resolve_fallback_provider(fb_model)
|
|
||||||
fb_kwargs = dict(kwargs, model=resolved_model)
|
|
||||||
response = await self._call_provider(
|
|
||||||
fb_provider, fb_kwargs, hook, context, spec, timeout_s,
|
|
||||||
)
|
|
||||||
if response.finish_reason != "error":
|
|
||||||
break
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def _call_provider(
|
async def _call_provider(
|
||||||
self,
|
self,
|
||||||
@ -716,6 +707,25 @@ class AgentRunner:
|
|||||||
return provider, provider.get_default_model()
|
return provider, provider.get_default_model()
|
||||||
return self.provider, model
|
return self.provider, model
|
||||||
|
|
||||||
|
def _build_model_router(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
timeout_s: float | None,
|
||||||
|
) -> ModelRouter:
|
||||||
|
candidates = [
|
||||||
|
ModelCandidate(
|
||||||
|
label=model,
|
||||||
|
resolver=lambda m=model: self._resolve_fallback_provider(m),
|
||||||
|
)
|
||||||
|
for model in spec.fallback_models
|
||||||
|
]
|
||||||
|
return ModelRouter(
|
||||||
|
primary_provider=self.provider,
|
||||||
|
primary_model=spec.model,
|
||||||
|
fallback_candidates=candidates,
|
||||||
|
per_candidate_timeout_s=timeout_s,
|
||||||
|
)
|
||||||
|
|
||||||
async def _request_finalization_retry(
|
async def _request_finalization_retry(
|
||||||
self,
|
self,
|
||||||
spec: AgentRunSpec,
|
spec: AgentRunSpec,
|
||||||
|
|||||||
@ -527,8 +527,7 @@ def _make_cli_provider_factory(config: Config):
|
|||||||
def factory(model_or_preset: str):
|
def factory(model_or_preset: str):
|
||||||
preset = presets.get(model_or_preset)
|
preset = presets.get(model_or_preset)
|
||||||
actual_model = preset.model if preset else model_or_preset
|
actual_model = preset.model if preset else model_or_preset
|
||||||
provider_name = config.get_provider_name(actual_model)
|
key = actual_model
|
||||||
key = provider_name or actual_model
|
|
||||||
if key not in cache:
|
if key not in cache:
|
||||||
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
||||||
return cache[key]
|
return cache[key]
|
||||||
|
|||||||
@ -220,8 +220,7 @@ def _make_provider_factory(config: Any):
|
|||||||
def factory(model_or_preset: str):
|
def factory(model_or_preset: str):
|
||||||
preset = presets.get(model_or_preset)
|
preset = presets.get(model_or_preset)
|
||||||
actual_model = preset.model if preset else model_or_preset
|
actual_model = preset.model if preset else model_or_preset
|
||||||
provider_name = config.get_provider_name(actual_model)
|
key = actual_model
|
||||||
key = provider_name or actual_model
|
|
||||||
if key not in cache:
|
if key not in cache:
|
||||||
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
||||||
return cache[key]
|
return cache[key]
|
||||||
|
|||||||
276
nanobot/providers/failover.py
Normal file
276
nanobot/providers/failover.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
"""Provider-like failover router used after provider-local retry is exhausted."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ModelCandidate:
|
||||||
|
"""A lazily resolved model/provider candidate."""
|
||||||
|
|
||||||
|
label: str
|
||||||
|
resolver: Callable[[], tuple[LLMProvider, str]]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRouter(LLMProvider):
|
||||||
|
"""Try fallback model candidates for eligible transient final errors."""
|
||||||
|
|
||||||
|
supports_progress_deltas = False
|
||||||
|
|
||||||
|
_BLOCKED_STATUS_CODES = frozenset({400, 401, 403, 404, 422})
|
||||||
|
_QUOTA_MARKERS = (
|
||||||
|
"insufficient_quota",
|
||||||
|
"insufficient quota",
|
||||||
|
"quota exceeded",
|
||||||
|
"quota_exceeded",
|
||||||
|
"quota exhausted",
|
||||||
|
"quota_exhausted",
|
||||||
|
"billing hard limit",
|
||||||
|
"billing_hard_limit_reached",
|
||||||
|
"billing not active",
|
||||||
|
"insufficient balance",
|
||||||
|
"insufficient_balance",
|
||||||
|
"credit balance too low",
|
||||||
|
"payment required",
|
||||||
|
"out of credits",
|
||||||
|
"out of quota",
|
||||||
|
"exceeded your current quota",
|
||||||
|
)
|
||||||
|
_NON_FAILOVER_MARKERS = (
|
||||||
|
"context length",
|
||||||
|
"context_length",
|
||||||
|
"maximum context",
|
||||||
|
"max context",
|
||||||
|
"token budget",
|
||||||
|
"too many tokens",
|
||||||
|
"schema",
|
||||||
|
"invalid request",
|
||||||
|
"invalid_request",
|
||||||
|
"invalid parameter",
|
||||||
|
"invalid_parameter",
|
||||||
|
"unsupported",
|
||||||
|
"unauthorized",
|
||||||
|
"authentication",
|
||||||
|
"permission",
|
||||||
|
"forbidden",
|
||||||
|
"refusal",
|
||||||
|
"content policy",
|
||||||
|
"content_filter",
|
||||||
|
"policy violation",
|
||||||
|
"safety",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
primary_provider: LLMProvider,
|
||||||
|
primary_model: str,
|
||||||
|
fallback_candidates: list[ModelCandidate],
|
||||||
|
per_candidate_timeout_s: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
api_key=getattr(primary_provider, "api_key", None),
|
||||||
|
api_base=getattr(primary_provider, "api_base", None),
|
||||||
|
)
|
||||||
|
self.primary_provider = primary_provider
|
||||||
|
self.primary_model = primary_model
|
||||||
|
self.fallback_candidates = list(fallback_candidates)
|
||||||
|
self.per_candidate_timeout_s = per_candidate_timeout_s
|
||||||
|
self.generation = getattr(primary_provider, "generation", GenerationSettings())
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return self.primary_model
|
||||||
|
|
||||||
|
async def chat(self, **kwargs: Any) -> LLMResponse:
|
||||||
|
return await self.primary_provider.chat(**kwargs)
|
||||||
|
|
||||||
|
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||||
|
return await self.primary_provider.chat_stream(**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _is_quota_error(cls, response: LLMResponse) -> bool:
|
||||||
|
tokens = {
|
||||||
|
cls._normalize_error_token(response.error_type),
|
||||||
|
cls._normalize_error_token(response.error_code),
|
||||||
|
}
|
||||||
|
if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in tokens if token):
|
||||||
|
return True
|
||||||
|
content = (response.content or "").lower()
|
||||||
|
return any(marker in content for marker in cls._QUOTA_MARKERS)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _is_blocked_error(cls, response: LLMResponse) -> bool:
|
||||||
|
status = response.error_status_code
|
||||||
|
if status is not None and int(status) in cls._BLOCKED_STATUS_CODES:
|
||||||
|
return True
|
||||||
|
if response.finish_reason in {"refusal", "content_filter"}:
|
||||||
|
return True
|
||||||
|
content = (response.content or "").lower()
|
||||||
|
return any(marker in content for marker in cls._NON_FAILOVER_MARKERS)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _should_failover(cls, response: LLMResponse) -> bool:
|
||||||
|
if response.finish_reason != "error":
|
||||||
|
return False
|
||||||
|
if cls._is_blocked_error(response):
|
||||||
|
return False
|
||||||
|
if cls._is_quota_error(response):
|
||||||
|
return False
|
||||||
|
return cls._is_transient_response(response)
|
||||||
|
|
||||||
|
async def _with_timeout(self, coro: Awaitable[LLMResponse]) -> LLMResponse:
|
||||||
|
timeout_s = self.per_candidate_timeout_s
|
||||||
|
if timeout_s is None:
|
||||||
|
return await coro
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(coro, timeout=timeout_s)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Error calling LLM: timed out after {timeout_s:g}s",
|
||||||
|
finish_reason="error",
|
||||||
|
error_kind="timeout",
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolver_error(candidate: ModelCandidate, exc: Exception) -> LLMResponse:
|
||||||
|
logger.warning("Failed to resolve fallback model {}: {}", candidate.label, exc)
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Error configuring fallback model {candidate.label}: {exc}",
|
||||||
|
finish_reason="error",
|
||||||
|
error_kind="configuration",
|
||||||
|
error_should_retry=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _candidate_chain(self) -> list[ModelCandidate]:
|
||||||
|
return [
|
||||||
|
ModelCandidate(
|
||||||
|
label=self.primary_model,
|
||||||
|
resolver=lambda: (self.primary_provider, self.primary_model),
|
||||||
|
),
|
||||||
|
*self.fallback_candidates,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _route(
|
||||||
|
self,
|
||||||
|
call: Callable[[LLMProvider, str, Callable[[str], Awaitable[None]] | None], Awaitable[LLMResponse]],
|
||||||
|
*,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
last_response: LLMResponse | None = None
|
||||||
|
chain = self._candidate_chain()
|
||||||
|
for index, candidate in enumerate(chain):
|
||||||
|
try:
|
||||||
|
provider, model = candidate.resolver()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
response = self._resolver_error(candidate, exc)
|
||||||
|
else:
|
||||||
|
response = await self._with_timeout(call(provider, model, on_content_delta))
|
||||||
|
|
||||||
|
if response.finish_reason != "error":
|
||||||
|
if index > 0:
|
||||||
|
logger.info("LLM failover selected model={}", candidate.label)
|
||||||
|
return response
|
||||||
|
|
||||||
|
last_response = response
|
||||||
|
if not self._should_failover(response):
|
||||||
|
return response
|
||||||
|
if index + 1 >= len(chain):
|
||||||
|
logger.warning("LLM failover exhausted after model={}", candidate.label)
|
||||||
|
return response
|
||||||
|
logger.warning(
|
||||||
|
"LLM failover model={} next_model={} status={} kind={}",
|
||||||
|
candidate.label,
|
||||||
|
chain[index + 1].label,
|
||||||
|
response.error_status_code,
|
||||||
|
response.error_kind or response.error_type or response.error_code or "unknown",
|
||||||
|
)
|
||||||
|
|
||||||
|
return last_response or LLMResponse(
|
||||||
|
content="No available fallback model candidate.",
|
||||||
|
finish_reason="error",
|
||||||
|
error_kind="configuration",
|
||||||
|
error_should_retry=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_with_retry(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: object = LLMProvider._SENTINEL,
|
||||||
|
temperature: object = LLMProvider._SENTINEL,
|
||||||
|
reasoning_effort: object = LLMProvider._SENTINEL,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
retry_mode: str = "standard",
|
||||||
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
async def call(
|
||||||
|
provider: LLMProvider,
|
||||||
|
candidate_model: str,
|
||||||
|
_delta: Callable[[str], Awaitable[None]] | None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
return await provider.chat_with_retry(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model=candidate_model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
reasoning_effort=reasoning_effort,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
retry_mode=retry_mode,
|
||||||
|
on_retry_wait=on_retry_wait,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self._route(call)
|
||||||
|
|
||||||
|
async def chat_stream_with_retry(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: object = LLMProvider._SENTINEL,
|
||||||
|
temperature: object = LLMProvider._SENTINEL,
|
||||||
|
reasoning_effort: object = LLMProvider._SENTINEL,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
retry_mode: str = "standard",
|
||||||
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
async def call(
|
||||||
|
provider: LLMProvider,
|
||||||
|
candidate_model: str,
|
||||||
|
external_delta: Callable[[str], Awaitable[None]] | None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
buffered: list[str] = []
|
||||||
|
|
||||||
|
async def buffer_delta(delta: str) -> None:
|
||||||
|
buffered.append(delta)
|
||||||
|
|
||||||
|
response = await provider.chat_stream_with_retry(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model=candidate_model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
reasoning_effort=reasoning_effort,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
on_content_delta=buffer_delta if external_delta else None,
|
||||||
|
retry_mode=retry_mode,
|
||||||
|
on_retry_wait=on_retry_wait,
|
||||||
|
)
|
||||||
|
if response.finish_reason != "error" and external_delta:
|
||||||
|
for delta in buffered:
|
||||||
|
await external_delta(delta)
|
||||||
|
return response
|
||||||
|
|
||||||
|
return await self._route(call, on_content_delta=on_content_delta)
|
||||||
@ -6,7 +6,8 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
from nanobot.providers.base import LLMResponse
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
|
||||||
@ -24,6 +25,10 @@ def _make_provider(*, model_response: LLMResponse | None = None):
|
|||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def _transient_error(content: str = "server unavailable") -> LLMResponse:
|
||||||
|
return LLMResponse(content=content, finish_reason="error", error_status_code=503)
|
||||||
|
|
||||||
|
|
||||||
def _base_spec(**overrides) -> AgentRunSpec:
|
def _base_spec(**overrides) -> AgentRunSpec:
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
initial_messages=[
|
initial_messages=[
|
||||||
@ -56,7 +61,7 @@ async def test_no_fallback_when_primary_succeeds():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fallback_triggered_on_primary_error():
|
async def test_fallback_triggered_on_primary_error():
|
||||||
"""Primary fails -> first fallback succeeds."""
|
"""Primary fails -> first fallback succeeds."""
|
||||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
err = _transient_error()
|
||||||
ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={})
|
ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={})
|
||||||
|
|
||||||
primary = _make_provider(model_response=err)
|
primary = _make_provider(model_response=err)
|
||||||
@ -76,12 +81,12 @@ async def test_fallback_triggered_on_primary_error():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_all_fallbacks_fail_returns_last_error():
|
async def test_all_fallbacks_fail_returns_last_error():
|
||||||
"""Primary + all fallbacks fail -> return last error response."""
|
"""Primary + all fallbacks fail -> return last error response."""
|
||||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
err = _transient_error()
|
||||||
|
|
||||||
primary = _make_provider(model_response=err)
|
primary = _make_provider(model_response=err)
|
||||||
fb1 = _make_provider(model_response=err)
|
fb1 = _make_provider(model_response=err)
|
||||||
fb2 = _make_provider(model_response=LLMResponse(
|
fb2 = _make_provider(model_response=LLMResponse(
|
||||||
content="last-error", finish_reason="error", usage={},
|
content="last-error", finish_reason="error", error_status_code=500, usage={},
|
||||||
))
|
))
|
||||||
|
|
||||||
providers = {"fb-1": fb1, "fb-2": fb2}
|
providers = {"fb-1": fb1, "fb-2": fb2}
|
||||||
@ -110,7 +115,7 @@ async def test_empty_fallback_list_no_retry():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cross_provider_fallback():
|
async def test_cross_provider_fallback():
|
||||||
"""Fallback uses a different provider instance (cross-provider)."""
|
"""Fallback uses a different provider instance (cross-provider)."""
|
||||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
err = _transient_error()
|
||||||
ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={})
|
ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={})
|
||||||
|
|
||||||
primary = _make_provider(model_response=err)
|
primary = _make_provider(model_response=err)
|
||||||
@ -132,7 +137,7 @@ async def test_cross_provider_fallback():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fallback_skips_to_second_on_first_error():
|
async def test_fallback_skips_to_second_on_first_error():
|
||||||
"""First fallback also fails -> second fallback succeeds."""
|
"""First fallback also fails -> second fallback succeeds."""
|
||||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
err = _transient_error()
|
||||||
ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={})
|
ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={})
|
||||||
|
|
||||||
primary = _make_provider(model_response=err)
|
primary = _make_provider(model_response=err)
|
||||||
@ -158,7 +163,7 @@ async def test_fallback_reuses_same_provider_without_factory():
|
|||||||
async def chat_with_retry(*, messages, model, **kw):
|
async def chat_with_retry(*, messages, model, **kw):
|
||||||
call_count["n"] += 1
|
call_count["n"] += 1
|
||||||
if call_count["n"] == 1:
|
if call_count["n"] == 1:
|
||||||
return LLMResponse(content=None, finish_reason="error", usage={})
|
return _transient_error()
|
||||||
return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={})
|
return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={})
|
||||||
|
|
||||||
primary = MagicMock()
|
primary = MagicMock()
|
||||||
@ -173,7 +178,7 @@ async def test_fallback_reuses_same_provider_without_factory():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fallback_provider_cached():
|
async def test_fallback_provider_cached():
|
||||||
"""Provider factory is called once per unique provider, not per attempt."""
|
"""Provider factory is called once per unique provider, not per attempt."""
|
||||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
err = _transient_error()
|
||||||
ok = LLMResponse(content="cached-ok", tool_calls=[], usage={})
|
ok = LLMResponse(content="cached-ok", tool_calls=[], usage={})
|
||||||
|
|
||||||
primary = _make_provider(model_response=err)
|
primary = _make_provider(model_response=err)
|
||||||
@ -188,3 +193,77 @@ async def test_fallback_provider_cached():
|
|||||||
result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"]))
|
result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"]))
|
||||||
|
|
||||||
assert result.final_content == "cached-ok"
|
assert result.final_content == "cached-ok"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_transient_error_does_not_fallback():
|
||||||
|
"""Auth/config-style errors should surface instead of hiding bugs via fallback."""
|
||||||
|
primary = _make_provider(model_response=LLMResponse(
|
||||||
|
content="401 unauthorized",
|
||||||
|
finish_reason="error",
|
||||||
|
error_status_code=401,
|
||||||
|
))
|
||||||
|
fallback = _make_provider(model_response=LLMResponse(content="fallback-ok"))
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
|
||||||
|
runner = AgentRunner(primary, provider_factory=factory)
|
||||||
|
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
|
||||||
|
|
||||||
|
factory.assert_not_called()
|
||||||
|
assert result.error is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_quota_error_does_not_fallback_by_default():
|
||||||
|
"""Quota/billing/payment 429s should not route by default."""
|
||||||
|
primary = _make_provider(model_response=LLMResponse(
|
||||||
|
content="insufficient quota",
|
||||||
|
finish_reason="error",
|
||||||
|
error_status_code=429,
|
||||||
|
error_code="insufficient_quota",
|
||||||
|
))
|
||||||
|
fallback = _make_provider(model_response=LLMResponse(content="fallback-ok"))
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
|
||||||
|
runner = AgentRunner(primary, provider_factory=factory)
|
||||||
|
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
|
||||||
|
|
||||||
|
factory.assert_not_called()
|
||||||
|
assert result.error is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_fallback_discards_failed_primary_deltas():
|
||||||
|
"""Buffered streaming prevents primary partial output from leaking on fallback."""
|
||||||
|
streamed: list[str] = []
|
||||||
|
|
||||||
|
class StreamingHook(AgentHook):
|
||||||
|
def wants_streaming(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||||
|
streamed.append(delta)
|
||||||
|
|
||||||
|
async def primary_stream(*, on_content_delta, **kwargs):
|
||||||
|
await on_content_delta("bad partial")
|
||||||
|
return _transient_error()
|
||||||
|
|
||||||
|
async def fallback_stream(*, on_content_delta, **kwargs):
|
||||||
|
await on_content_delta("good")
|
||||||
|
await on_content_delta(" answer")
|
||||||
|
return LLMResponse(content="good answer", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
primary = MagicMock()
|
||||||
|
primary.chat_stream_with_retry = primary_stream
|
||||||
|
fallback = MagicMock()
|
||||||
|
fallback.chat_stream_with_retry = fallback_stream
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
|
||||||
|
runner = AgentRunner(primary, provider_factory=factory)
|
||||||
|
result = await runner.run(_base_spec(
|
||||||
|
fallback_models=["fb-model"],
|
||||||
|
hook=StreamingHook(),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "good answer"
|
||||||
|
assert streamed == ["good", " answer"]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user