mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
feat: add fallback_models support for automatic model failover
When the primary model fails (finish_reason="error" after exhausting provider-level retries), automatically try each model in the configured fallback_models list. Supports cross-provider fallback via a cached provider_factory that resolves the correct provider for each model string. Config: agents.defaults.fallback_models: ["model-b", "provider/model-c"] Changes: - AgentDefaults: add fallback_models field - AgentRunSpec: add fallback_models field - AgentRunner: add provider_factory, _call_provider, _resolve_fallback_provider - AgentLoop: accept and forward fallback_models + provider_factory - nanobot.py: extract _make_provider_for_model, add _make_provider_factory - cli/commands.py: add _make_cli_provider_factory, wire all AgentLoop sites - tests/agent/test_runner_fallback.py: 8 test cases covering primary success, single/multi fallback, cross-provider, no-factory reuse, caching Made-with: Cursor
This commit is contained in:
parent
83f437a088
commit
2e5930e355
@ -200,6 +200,8 @@ class AgentLoop:
|
||||
max_tool_result_chars: int | None = None,
|
||||
provider_retry_mode: str = "standard",
|
||||
tool_hint_max_length: int | None = None,
|
||||
fallback_models: list[str] | None = None,
|
||||
provider_factory: Any | None = None,
|
||||
web_config: WebToolsConfig | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
cron_service: CronService | None = None,
|
||||
@ -250,6 +252,7 @@ class AgentLoop:
|
||||
tool_hint_max_length if tool_hint_max_length is not None
|
||||
else defaults.tool_hint_max_length
|
||||
)
|
||||
self.fallback_models = fallback_models or []
|
||||
self.web_config = web_config or WebToolsConfig()
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.cron_service = cron_service
|
||||
@ -263,7 +266,7 @@ class AgentLoop:
|
||||
# One file-read/write tracker per logical session. The tool registry is
|
||||
# shared by this loop, so tools resolve the active state via contextvars.
|
||||
self._file_state_store = FileStateStore()
|
||||
self.runner = AgentRunner(provider)
|
||||
self.runner = AgentRunner(provider, provider_factory=provider_factory)
|
||||
self.subagents = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=workspace,
|
||||
@ -681,6 +684,7 @@ class AgentLoop:
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
context_block_limit=self.context_block_limit,
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
fallback_models=self.fallback_models,
|
||||
progress_callback=on_progress,
|
||||
stream_progress_deltas=on_stream is not None,
|
||||
retry_wait_callback=on_retry_wait,
|
||||
|
||||
@ -75,6 +75,7 @@ class AgentRunSpec:
|
||||
context_window_tokens: int | None = None
|
||||
context_block_limit: int | None = None
|
||||
provider_retry_mode: str = "standard"
|
||||
fallback_models: list[str] = field(default_factory=list)
|
||||
progress_callback: Any | None = None
|
||||
stream_progress_deltas: bool = True
|
||||
retry_wait_callback: Any | None = None
|
||||
@ -97,11 +98,21 @@ class AgentRunResult:
|
||||
had_injections: bool = False
|
||||
|
||||
|
||||
ProviderFactory = Any # Callable[[str], LLMProvider] — avoids circular import
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
"""Run a tool-capable LLM loop without product-layer concerns."""
|
||||
|
||||
def __init__(self, provider: LLMProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
*,
|
||||
provider_factory: ProviderFactory | None = None,
|
||||
):
|
||||
self.provider = provider
|
||||
self._provider_factory = provider_factory
|
||||
self._fallback_providers: dict[str, LLMProvider] = {}
|
||||
|
||||
@staticmethod
|
||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||
@ -594,12 +605,9 @@ class AgentRunner:
|
||||
messages: list[dict[str, Any]],
|
||||
hook: AgentHook,
|
||||
context: AgentHookContext,
|
||||
):
|
||||
) -> LLMResponse:
|
||||
timeout_s: float | None = spec.llm_timeout_s
|
||||
if timeout_s is None:
|
||||
# Default to a finite timeout to avoid per-session lock starvation when an LLM
|
||||
# request hangs indefinitely (e.g. gateway/network stall).
|
||||
# Set NANOBOT_LLM_TIMEOUT_S=0 to disable.
|
||||
raw = os.environ.get("NANOBOT_LLM_TIMEOUT_S", "300").strip()
|
||||
try:
|
||||
timeout_s = float(raw)
|
||||
@ -613,12 +621,40 @@ class AgentRunner:
|
||||
messages,
|
||||
tools=spec.tools.get_definitions(),
|
||||
)
|
||||
response = await self._call_provider(self.provider, kwargs, hook, context, spec, timeout_s)
|
||||
|
||||
if response.finish_reason == "error" and spec.fallback_models:
|
||||
for fb_model in spec.fallback_models:
|
||||
logger.warning(
|
||||
"Primary model {} failed, trying fallback: {}",
|
||||
spec.model,
|
||||
fb_model,
|
||||
)
|
||||
fb_provider, resolved_model = self._resolve_fallback_provider(fb_model)
|
||||
fb_kwargs = dict(kwargs, model=resolved_model)
|
||||
response = await self._call_provider(
|
||||
fb_provider, fb_kwargs, hook, context, spec, timeout_s,
|
||||
)
|
||||
if response.finish_reason != "error":
|
||||
break
|
||||
|
||||
return response
|
||||
|
||||
async def _call_provider(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
kwargs: dict[str, Any],
|
||||
hook: AgentHook,
|
||||
context: AgentHookContext,
|
||||
spec: AgentRunSpec,
|
||||
timeout_s: float | None = None,
|
||||
) -> LLMResponse:
|
||||
wants_streaming = hook.wants_streaming()
|
||||
wants_progress_streaming = (
|
||||
not wants_streaming
|
||||
and spec.stream_progress_deltas
|
||||
and spec.progress_callback is not None
|
||||
and getattr(self.provider, "supports_progress_deltas", False) is True
|
||||
and getattr(provider, "supports_progress_deltas", False) is True
|
||||
)
|
||||
|
||||
if wants_streaming:
|
||||
@ -627,7 +663,7 @@ class AgentRunner:
|
||||
context.streamed_content = True
|
||||
await hook.on_stream(context, delta)
|
||||
|
||||
coro = self.provider.chat_stream_with_retry(
|
||||
coro = provider.chat_stream_with_retry(
|
||||
**kwargs,
|
||||
on_content_delta=_stream,
|
||||
)
|
||||
@ -646,12 +682,12 @@ class AgentRunner:
|
||||
context.streamed_content = True
|
||||
await spec.progress_callback(incremental)
|
||||
|
||||
coro = self.provider.chat_stream_with_retry(
|
||||
coro = provider.chat_stream_with_retry(
|
||||
**kwargs,
|
||||
on_content_delta=_stream_progress,
|
||||
)
|
||||
else:
|
||||
coro = self.provider.chat_with_retry(**kwargs)
|
||||
coro = provider.chat_with_retry(**kwargs)
|
||||
|
||||
if timeout_s is None:
|
||||
return await coro
|
||||
@ -664,6 +700,22 @@ class AgentRunner:
|
||||
error_kind="timeout",
|
||||
)
|
||||
|
||||
def _resolve_fallback_provider(self, model: str) -> tuple[LLMProvider, str]:
|
||||
"""Return (provider, actual_model_name) for a fallback model.
|
||||
|
||||
When a provider_factory is available (and the model string may be a
|
||||
preset name), the factory resolves the actual model; otherwise the
|
||||
primary provider is reused with the raw model string.
|
||||
"""
|
||||
if model in self._fallback_providers:
|
||||
p = self._fallback_providers[model]
|
||||
return p, p.get_default_model()
|
||||
if self._provider_factory:
|
||||
provider = self._provider_factory(model)
|
||||
self._fallback_providers[model] = provider
|
||||
return provider, provider.get_default_model()
|
||||
return self.provider, model
|
||||
|
||||
async def _request_finalization_retry(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
|
||||
@ -513,6 +513,29 @@ def _make_provider(config: Config):
|
||||
return provider
|
||||
|
||||
|
||||
def _make_cli_provider_factory(config: Config):
|
||||
"""Build a cached factory for fallback model providers (CLI side).
|
||||
|
||||
Supports preset names: if a model string matches a preset, the preset's
|
||||
full config is used for provider creation.
|
||||
"""
|
||||
from nanobot.nanobot import _make_provider_for_model
|
||||
|
||||
cache: dict[str, Any] = {}
|
||||
presets = getattr(config, "model_presets", {}) or {}
|
||||
|
||||
def factory(model_or_preset: str):
|
||||
preset = presets.get(model_or_preset)
|
||||
actual_model = preset.model if preset else model_or_preset
|
||||
provider_name = config.get_provider_name(actual_model)
|
||||
key = provider_name or actual_model
|
||||
if key not in cache:
|
||||
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
||||
return cache[key]
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||
"""Load config and optionally override the active workspace."""
|
||||
from nanobot.config.loader import load_config, resolve_config_env_vars, set_config_path
|
||||
@ -608,6 +631,8 @@ def serve(
|
||||
sync_workspace_templates(runtime_config.workspace_path)
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(runtime_config)
|
||||
defaults = runtime_config.agents.defaults
|
||||
pf = _make_cli_provider_factory(runtime_config) if defaults.fallback_models else None
|
||||
session_manager = SessionManager(runtime_config.workspace_path)
|
||||
_resolved = runtime_config.resolve_preset()
|
||||
agent_loop = AgentLoop(
|
||||
@ -615,12 +640,13 @@ def serve(
|
||||
provider=provider,
|
||||
workspace=runtime_config.workspace_path,
|
||||
model=_resolved.model,
|
||||
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
|
||||
max_iterations=defaults.max_tool_iterations,
|
||||
context_window_tokens=_resolved.context_window_tokens,
|
||||
context_block_limit=runtime_config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
|
||||
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
|
||||
tool_hint_max_length=runtime_config.agents.defaults.tool_hint_max_length,
|
||||
context_block_limit=defaults.context_block_limit,
|
||||
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||
provider_retry_mode=defaults.provider_retry_mode,
|
||||
fallback_models=defaults.fallback_models,
|
||||
provider_factory=pf,
|
||||
web_config=runtime_config.tools.web,
|
||||
exec_config=runtime_config.tools.exec,
|
||||
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
|
||||
@ -639,7 +665,7 @@ def serve(
|
||||
)
|
||||
|
||||
model_name = _resolved.model
|
||||
preset_name = runtime_config.agents.defaults.model_preset
|
||||
preset_name = defaults.model_preset
|
||||
preset_tag = f" (preset: {preset_name})" if preset_name else ""
|
||||
console.print(f"{__logo__} Starting OpenAI-compatible API server")
|
||||
console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
|
||||
@ -721,12 +747,14 @@ def _run_gateway(
|
||||
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
|
||||
sync_workspace_templates(config.workspace_path)
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(config)
|
||||
gw_defaults = config.agents.defaults
|
||||
gw_pf = _make_cli_provider_factory(config) if gw_defaults.fallback_models else None
|
||||
try:
|
||||
provider_snapshot = build_provider_snapshot(config)
|
||||
except ValueError as exc:
|
||||
console.print(f"[red]Error: {exc}[/red]")
|
||||
raise typer.Exit(1) from exc
|
||||
provider = provider_snapshot.provider
|
||||
session_manager = SessionManager(config.workspace_path)
|
||||
|
||||
# Preserve existing single-workspace installs, but keep custom workspaces clean.
|
||||
@ -744,13 +772,14 @@ def _run_gateway(
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=_resolved.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
max_iterations=gw_defaults.max_tool_iterations,
|
||||
context_window_tokens=_resolved.context_window_tokens,
|
||||
web_config=config.tools.web,
|
||||
context_block_limit=config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
||||
tool_hint_max_length=config.agents.defaults.tool_hint_max_length,
|
||||
context_block_limit=gw_defaults.context_block_limit,
|
||||
max_tool_result_chars=gw_defaults.max_tool_result_chars,
|
||||
provider_retry_mode=gw_defaults.provider_retry_mode,
|
||||
fallback_models=gw_defaults.fallback_models,
|
||||
provider_factory=gw_pf,
|
||||
exec_config=config.tools.exec,
|
||||
cron_service=cron,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
@ -1120,6 +1149,8 @@ def agent(
|
||||
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(config)
|
||||
chat_defaults = config.agents.defaults
|
||||
chat_pf = _make_cli_provider_factory(config) if chat_defaults.fallback_models else None
|
||||
|
||||
# Preserve existing single-workspace installs, but keep custom workspaces clean.
|
||||
if is_default_workspace(config.workspace_path):
|
||||
@ -1140,13 +1171,14 @@ def agent(
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=_resolved.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
max_iterations=chat_defaults.max_tool_iterations,
|
||||
context_window_tokens=_resolved.context_window_tokens,
|
||||
web_config=config.tools.web,
|
||||
context_block_limit=config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
||||
tool_hint_max_length=config.agents.defaults.tool_hint_max_length,
|
||||
context_block_limit=chat_defaults.context_block_limit,
|
||||
max_tool_result_chars=chat_defaults.max_tool_result_chars,
|
||||
provider_retry_mode=chat_defaults.provider_retry_mode,
|
||||
fallback_models=chat_defaults.fallback_models,
|
||||
provider_factory=chat_pf,
|
||||
exec_config=config.tools.exec,
|
||||
cron_service=cron,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
|
||||
@ -105,6 +105,7 @@ class AgentDefaults(Base):
|
||||
serialization_alias="toolHintMaxLength",
|
||||
) # Max characters for tool hint display (e.g. "$ cd …/project && npm test")
|
||||
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
|
||||
fallback_models: list[str] = Field(default_factory=list)
|
||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||
unified_session: bool = False # Share one session across all channels (single-user multi-device)
|
||||
disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"])
|
||||
|
||||
@ -66,6 +66,7 @@ class Nanobot:
|
||||
bus = MessageBus()
|
||||
defaults = config.agents.defaults
|
||||
_resolved = config.resolve_preset()
|
||||
pf = _make_provider_factory(config) if defaults.fallback_models else None
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
@ -78,6 +79,8 @@ class Nanobot:
|
||||
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||
provider_retry_mode=defaults.provider_retry_mode,
|
||||
tool_hint_max_length=defaults.tool_hint_max_length,
|
||||
fallback_models=defaults.fallback_models,
|
||||
provider_factory=pf,
|
||||
web_config=config.tools.web,
|
||||
exec_config=config.tools.exec,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
@ -127,14 +130,22 @@ class Nanobot:
|
||||
)
|
||||
|
||||
|
||||
def _make_provider(config: Any) -> Any:
|
||||
"""Create the LLM provider from config (extracted from CLI)."""
|
||||
def _make_provider_for_model(
|
||||
config: Any,
|
||||
model: str,
|
||||
*,
|
||||
preset: Any | None = None,
|
||||
) -> Any:
|
||||
"""Create an LLM provider instance for a specific model string.
|
||||
|
||||
When *preset* is given, its generation settings (temperature, max_tokens,
|
||||
reasoning_effort) override the active preset defaults.
|
||||
"""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.factory import make_provider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
resolved = config.resolve_preset()
|
||||
model = resolved.model
|
||||
gen_src = preset or config.resolve_preset()
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
@ -185,8 +196,34 @@ def _make_provider(config: Any) -> Any:
|
||||
)
|
||||
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=resolved.temperature,
|
||||
max_tokens=resolved.max_tokens,
|
||||
reasoning_effort=resolved.reasoning_effort,
|
||||
temperature=gen_src.temperature,
|
||||
max_tokens=gen_src.max_tokens,
|
||||
reasoning_effort=gen_src.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
def _make_provider(config: Any) -> Any:
|
||||
"""Create the LLM provider for the primary model from config."""
|
||||
return _make_provider_for_model(config, config.resolve_preset().model)
|
||||
|
||||
|
||||
def _make_provider_factory(config: Any):
|
||||
"""Build a cached factory that creates providers for arbitrary model strings.
|
||||
|
||||
If a model string matches a preset name in ``config.model_presets``, the
|
||||
preset's full config (model, temperature, max_tokens, …) is used.
|
||||
"""
|
||||
cache: dict[str, Any] = {}
|
||||
presets = getattr(config, "model_presets", {}) or {}
|
||||
|
||||
def factory(model_or_preset: str):
|
||||
preset = presets.get(model_or_preset)
|
||||
actual_model = preset.model if preset else model_or_preset
|
||||
provider_name = config.get_provider_name(actual_model)
|
||||
key = provider_name or actual_model
|
||||
if key not in cache:
|
||||
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
||||
return cache[key]
|
||||
|
||||
return factory
|
||||
|
||||
190
tests/agent/test_runner_fallback.py
Normal file
190
tests/agent/test_runner_fallback.py
Normal file
@ -0,0 +1,190 @@
|
||||
"""Tests for the provider fallback models feature in AgentRunner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_tools():
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="ok")
|
||||
return tools
|
||||
|
||||
|
||||
def _make_provider(*, model_response: LLMResponse | None = None):
|
||||
p = MagicMock()
|
||||
if model_response is not None:
|
||||
p.chat_with_retry = AsyncMock(return_value=model_response)
|
||||
return p
|
||||
|
||||
|
||||
def _base_spec(**overrides) -> AgentRunSpec:
|
||||
defaults = dict(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "hello"},
|
||||
],
|
||||
tools=_make_tools(),
|
||||
model="primary-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=8000,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return AgentRunSpec(**defaults)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_fallback_when_primary_succeeds():
|
||||
"""Primary succeeds -> fallback list never consulted."""
|
||||
ok = LLMResponse(content="done", tool_calls=[], usage={})
|
||||
provider = _make_provider(model_response=ok)
|
||||
factory = MagicMock()
|
||||
|
||||
runner = AgentRunner(provider, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
|
||||
|
||||
assert result.final_content == "done"
|
||||
factory.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_triggered_on_primary_error():
|
||||
"""Primary fails -> first fallback succeeds."""
|
||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||
ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
|
||||
fb_provider = MagicMock()
|
||||
fb_provider.chat_with_retry = AsyncMock(return_value=ok)
|
||||
factory = MagicMock(return_value=fb_provider)
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
|
||||
|
||||
assert result.final_content == "fallback-ok"
|
||||
factory.assert_called_once_with("fb-model")
|
||||
fb_provider.chat_with_retry.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_fallbacks_fail_returns_last_error():
|
||||
"""Primary + all fallbacks fail -> return last error response."""
|
||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
fb1 = _make_provider(model_response=err)
|
||||
fb2 = _make_provider(model_response=LLMResponse(
|
||||
content="last-error", finish_reason="error", usage={},
|
||||
))
|
||||
|
||||
providers = {"fb-1": fb1, "fb-2": fb2}
|
||||
factory = MagicMock(side_effect=lambda m: providers[m])
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
|
||||
|
||||
assert result.error is not None or result.final_content is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_fallback_list_no_retry():
|
||||
"""Empty fallback_models -> no fallback attempted."""
|
||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||
primary = _make_provider(model_response=err)
|
||||
factory = MagicMock()
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=[]))
|
||||
|
||||
factory.assert_not_called()
|
||||
assert result.error is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cross_provider_fallback():
|
||||
"""Fallback uses a different provider instance (cross-provider)."""
|
||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||
ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
anthropic_provider = MagicMock()
|
||||
anthropic_provider.chat_with_retry = AsyncMock(return_value=ok)
|
||||
|
||||
def cross_factory(model: str):
|
||||
if model == "anthropic/claude-sonnet":
|
||||
return anthropic_provider
|
||||
raise ValueError(f"unexpected model: {model}")
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=cross_factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["anthropic/claude-sonnet"]))
|
||||
|
||||
assert result.final_content == "cross-provider-ok"
|
||||
anthropic_provider.chat_with_retry.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_skips_to_second_on_first_error():
|
||||
"""First fallback also fails -> second fallback succeeds."""
|
||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||
ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
fb1 = _make_provider(model_response=err)
|
||||
fb2 = MagicMock()
|
||||
fb2.chat_with_retry = AsyncMock(return_value=ok)
|
||||
|
||||
providers = {"fb-1": fb1, "fb-2": fb2}
|
||||
factory = MagicMock(side_effect=lambda m: providers[m])
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
|
||||
|
||||
assert result.final_content == "second-fb-ok"
|
||||
assert factory.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_reuses_same_provider_without_factory():
|
||||
"""No provider_factory -> fallback reuses primary provider with different model."""
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, model, **kw):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content=None, finish_reason="error", usage={})
|
||||
return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={})
|
||||
|
||||
primary = MagicMock()
|
||||
primary.chat_with_retry = chat_with_retry
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=None)
|
||||
result = await runner.run(_base_spec(fallback_models=["fallback-model"]))
|
||||
|
||||
assert result.final_content == "ok-via-fallback-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_provider_cached():
|
||||
"""Provider factory is called once per unique provider, not per attempt."""
|
||||
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||
ok = LLMResponse(content="cached-ok", tool_calls=[], usage={})
|
||||
|
||||
primary = _make_provider(model_response=err)
|
||||
|
||||
fb_provider = MagicMock()
|
||||
call_seq = [err, ok]
|
||||
fb_provider.chat_with_retry = AsyncMock(side_effect=call_seq)
|
||||
|
||||
factory = MagicMock(return_value=fb_provider)
|
||||
|
||||
runner = AgentRunner(primary, provider_factory=factory)
|
||||
result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"]))
|
||||
|
||||
assert result.final_content == "cached-ok"
|
||||
Loading…
x
Reference in New Issue
Block a user