mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
feat: add fallback_models support for automatic model failover
When the primary model fails (finish_reason="error" after exhausting provider-level retries), automatically try each model in the configured fallback_models list. Supports cross-provider fallback via a cached provider_factory that resolves the correct provider for each model string. Config: agents.defaults.fallback_models: ["model-b", "provider/model-c"] Changes: - AgentDefaults: add fallback_models field - AgentRunSpec: add fallback_models field - AgentRunner: add provider_factory, _call_provider, _resolve_fallback_provider - AgentLoop: accept and forward fallback_models + provider_factory - nanobot.py: extract _make_provider_for_model, add _make_provider_factory - cli/commands.py: add _make_cli_provider_factory, wire all AgentLoop sites - tests/agent/test_runner_fallback.py: 8 test cases covering primary success, single/multi fallback, cross-provider, no-factory reuse, caching Made-with: Cursor
This commit is contained in:
parent
83f437a088
commit
2e5930e355
@ -200,6 +200,8 @@ class AgentLoop:
|
|||||||
max_tool_result_chars: int | None = None,
|
max_tool_result_chars: int | None = None,
|
||||||
provider_retry_mode: str = "standard",
|
provider_retry_mode: str = "standard",
|
||||||
tool_hint_max_length: int | None = None,
|
tool_hint_max_length: int | None = None,
|
||||||
|
fallback_models: list[str] | None = None,
|
||||||
|
provider_factory: Any | None = None,
|
||||||
web_config: WebToolsConfig | None = None,
|
web_config: WebToolsConfig | None = None,
|
||||||
exec_config: ExecToolConfig | None = None,
|
exec_config: ExecToolConfig | None = None,
|
||||||
cron_service: CronService | None = None,
|
cron_service: CronService | None = None,
|
||||||
@ -250,6 +252,7 @@ class AgentLoop:
|
|||||||
tool_hint_max_length if tool_hint_max_length is not None
|
tool_hint_max_length if tool_hint_max_length is not None
|
||||||
else defaults.tool_hint_max_length
|
else defaults.tool_hint_max_length
|
||||||
)
|
)
|
||||||
|
self.fallback_models = fallback_models or []
|
||||||
self.web_config = web_config or WebToolsConfig()
|
self.web_config = web_config or WebToolsConfig()
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.cron_service = cron_service
|
self.cron_service = cron_service
|
||||||
@ -263,7 +266,7 @@ class AgentLoop:
|
|||||||
# One file-read/write tracker per logical session. The tool registry is
|
# One file-read/write tracker per logical session. The tool registry is
|
||||||
# shared by this loop, so tools resolve the active state via contextvars.
|
# shared by this loop, so tools resolve the active state via contextvars.
|
||||||
self._file_state_store = FileStateStore()
|
self._file_state_store = FileStateStore()
|
||||||
self.runner = AgentRunner(provider)
|
self.runner = AgentRunner(provider, provider_factory=provider_factory)
|
||||||
self.subagents = SubagentManager(
|
self.subagents = SubagentManager(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
@ -681,6 +684,7 @@ class AgentLoop:
|
|||||||
context_window_tokens=self.context_window_tokens,
|
context_window_tokens=self.context_window_tokens,
|
||||||
context_block_limit=self.context_block_limit,
|
context_block_limit=self.context_block_limit,
|
||||||
provider_retry_mode=self.provider_retry_mode,
|
provider_retry_mode=self.provider_retry_mode,
|
||||||
|
fallback_models=self.fallback_models,
|
||||||
progress_callback=on_progress,
|
progress_callback=on_progress,
|
||||||
stream_progress_deltas=on_stream is not None,
|
stream_progress_deltas=on_stream is not None,
|
||||||
retry_wait_callback=on_retry_wait,
|
retry_wait_callback=on_retry_wait,
|
||||||
|
|||||||
@ -75,6 +75,7 @@ class AgentRunSpec:
|
|||||||
context_window_tokens: int | None = None
|
context_window_tokens: int | None = None
|
||||||
context_block_limit: int | None = None
|
context_block_limit: int | None = None
|
||||||
provider_retry_mode: str = "standard"
|
provider_retry_mode: str = "standard"
|
||||||
|
fallback_models: list[str] = field(default_factory=list)
|
||||||
progress_callback: Any | None = None
|
progress_callback: Any | None = None
|
||||||
stream_progress_deltas: bool = True
|
stream_progress_deltas: bool = True
|
||||||
retry_wait_callback: Any | None = None
|
retry_wait_callback: Any | None = None
|
||||||
@ -97,11 +98,21 @@ class AgentRunResult:
|
|||||||
had_injections: bool = False
|
had_injections: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
ProviderFactory = Any # Callable[[str], LLMProvider] — avoids circular import
|
||||||
|
|
||||||
|
|
||||||
class AgentRunner:
|
class AgentRunner:
|
||||||
"""Run a tool-capable LLM loop without product-layer concerns."""
|
"""Run a tool-capable LLM loop without product-layer concerns."""
|
||||||
|
|
||||||
def __init__(self, provider: LLMProvider):
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider: LLMProvider,
|
||||||
|
*,
|
||||||
|
provider_factory: ProviderFactory | None = None,
|
||||||
|
):
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
|
self._provider_factory = provider_factory
|
||||||
|
self._fallback_providers: dict[str, LLMProvider] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||||
@ -594,12 +605,9 @@ class AgentRunner:
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
hook: AgentHook,
|
hook: AgentHook,
|
||||||
context: AgentHookContext,
|
context: AgentHookContext,
|
||||||
):
|
) -> LLMResponse:
|
||||||
timeout_s: float | None = spec.llm_timeout_s
|
timeout_s: float | None = spec.llm_timeout_s
|
||||||
if timeout_s is None:
|
if timeout_s is None:
|
||||||
# Default to a finite timeout to avoid per-session lock starvation when an LLM
|
|
||||||
# request hangs indefinitely (e.g. gateway/network stall).
|
|
||||||
# Set NANOBOT_LLM_TIMEOUT_S=0 to disable.
|
|
||||||
raw = os.environ.get("NANOBOT_LLM_TIMEOUT_S", "300").strip()
|
raw = os.environ.get("NANOBOT_LLM_TIMEOUT_S", "300").strip()
|
||||||
try:
|
try:
|
||||||
timeout_s = float(raw)
|
timeout_s = float(raw)
|
||||||
@ -613,12 +621,40 @@ class AgentRunner:
|
|||||||
messages,
|
messages,
|
||||||
tools=spec.tools.get_definitions(),
|
tools=spec.tools.get_definitions(),
|
||||||
)
|
)
|
||||||
|
response = await self._call_provider(self.provider, kwargs, hook, context, spec, timeout_s)
|
||||||
|
|
||||||
|
if response.finish_reason == "error" and spec.fallback_models:
|
||||||
|
for fb_model in spec.fallback_models:
|
||||||
|
logger.warning(
|
||||||
|
"Primary model {} failed, trying fallback: {}",
|
||||||
|
spec.model,
|
||||||
|
fb_model,
|
||||||
|
)
|
||||||
|
fb_provider, resolved_model = self._resolve_fallback_provider(fb_model)
|
||||||
|
fb_kwargs = dict(kwargs, model=resolved_model)
|
||||||
|
response = await self._call_provider(
|
||||||
|
fb_provider, fb_kwargs, hook, context, spec, timeout_s,
|
||||||
|
)
|
||||||
|
if response.finish_reason != "error":
|
||||||
|
break
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _call_provider(
|
||||||
|
self,
|
||||||
|
provider: LLMProvider,
|
||||||
|
kwargs: dict[str, Any],
|
||||||
|
hook: AgentHook,
|
||||||
|
context: AgentHookContext,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
timeout_s: float | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
wants_streaming = hook.wants_streaming()
|
wants_streaming = hook.wants_streaming()
|
||||||
wants_progress_streaming = (
|
wants_progress_streaming = (
|
||||||
not wants_streaming
|
not wants_streaming
|
||||||
and spec.stream_progress_deltas
|
and spec.stream_progress_deltas
|
||||||
and spec.progress_callback is not None
|
and spec.progress_callback is not None
|
||||||
and getattr(self.provider, "supports_progress_deltas", False) is True
|
and getattr(provider, "supports_progress_deltas", False) is True
|
||||||
)
|
)
|
||||||
|
|
||||||
if wants_streaming:
|
if wants_streaming:
|
||||||
@ -627,7 +663,7 @@ class AgentRunner:
|
|||||||
context.streamed_content = True
|
context.streamed_content = True
|
||||||
await hook.on_stream(context, delta)
|
await hook.on_stream(context, delta)
|
||||||
|
|
||||||
coro = self.provider.chat_stream_with_retry(
|
coro = provider.chat_stream_with_retry(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
on_content_delta=_stream,
|
on_content_delta=_stream,
|
||||||
)
|
)
|
||||||
@ -646,12 +682,12 @@ class AgentRunner:
|
|||||||
context.streamed_content = True
|
context.streamed_content = True
|
||||||
await spec.progress_callback(incremental)
|
await spec.progress_callback(incremental)
|
||||||
|
|
||||||
coro = self.provider.chat_stream_with_retry(
|
coro = provider.chat_stream_with_retry(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
on_content_delta=_stream_progress,
|
on_content_delta=_stream_progress,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
coro = self.provider.chat_with_retry(**kwargs)
|
coro = provider.chat_with_retry(**kwargs)
|
||||||
|
|
||||||
if timeout_s is None:
|
if timeout_s is None:
|
||||||
return await coro
|
return await coro
|
||||||
@ -664,6 +700,22 @@ class AgentRunner:
|
|||||||
error_kind="timeout",
|
error_kind="timeout",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _resolve_fallback_provider(self, model: str) -> tuple[LLMProvider, str]:
|
||||||
|
"""Return (provider, actual_model_name) for a fallback model.
|
||||||
|
|
||||||
|
When a provider_factory is available (and the model string may be a
|
||||||
|
preset name), the factory resolves the actual model; otherwise the
|
||||||
|
primary provider is reused with the raw model string.
|
||||||
|
"""
|
||||||
|
if model in self._fallback_providers:
|
||||||
|
p = self._fallback_providers[model]
|
||||||
|
return p, p.get_default_model()
|
||||||
|
if self._provider_factory:
|
||||||
|
provider = self._provider_factory(model)
|
||||||
|
self._fallback_providers[model] = provider
|
||||||
|
return provider, provider.get_default_model()
|
||||||
|
return self.provider, model
|
||||||
|
|
||||||
async def _request_finalization_retry(
|
async def _request_finalization_retry(
|
||||||
self,
|
self,
|
||||||
spec: AgentRunSpec,
|
spec: AgentRunSpec,
|
||||||
|
|||||||
@ -513,6 +513,29 @@ def _make_provider(config: Config):
|
|||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def _make_cli_provider_factory(config: Config):
|
||||||
|
"""Build a cached factory for fallback model providers (CLI side).
|
||||||
|
|
||||||
|
Supports preset names: if a model string matches a preset, the preset's
|
||||||
|
full config is used for provider creation.
|
||||||
|
"""
|
||||||
|
from nanobot.nanobot import _make_provider_for_model
|
||||||
|
|
||||||
|
cache: dict[str, Any] = {}
|
||||||
|
presets = getattr(config, "model_presets", {}) or {}
|
||||||
|
|
||||||
|
def factory(model_or_preset: str):
|
||||||
|
preset = presets.get(model_or_preset)
|
||||||
|
actual_model = preset.model if preset else model_or_preset
|
||||||
|
provider_name = config.get_provider_name(actual_model)
|
||||||
|
key = provider_name or actual_model
|
||||||
|
if key not in cache:
|
||||||
|
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
||||||
|
return cache[key]
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||||
"""Load config and optionally override the active workspace."""
|
"""Load config and optionally override the active workspace."""
|
||||||
from nanobot.config.loader import load_config, resolve_config_env_vars, set_config_path
|
from nanobot.config.loader import load_config, resolve_config_env_vars, set_config_path
|
||||||
@ -608,6 +631,8 @@ def serve(
|
|||||||
sync_workspace_templates(runtime_config.workspace_path)
|
sync_workspace_templates(runtime_config.workspace_path)
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = _make_provider(runtime_config)
|
provider = _make_provider(runtime_config)
|
||||||
|
defaults = runtime_config.agents.defaults
|
||||||
|
pf = _make_cli_provider_factory(runtime_config) if defaults.fallback_models else None
|
||||||
session_manager = SessionManager(runtime_config.workspace_path)
|
session_manager = SessionManager(runtime_config.workspace_path)
|
||||||
_resolved = runtime_config.resolve_preset()
|
_resolved = runtime_config.resolve_preset()
|
||||||
agent_loop = AgentLoop(
|
agent_loop = AgentLoop(
|
||||||
@ -615,12 +640,13 @@ def serve(
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=runtime_config.workspace_path,
|
workspace=runtime_config.workspace_path,
|
||||||
model=_resolved.model,
|
model=_resolved.model,
|
||||||
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
|
max_iterations=defaults.max_tool_iterations,
|
||||||
context_window_tokens=_resolved.context_window_tokens,
|
context_window_tokens=_resolved.context_window_tokens,
|
||||||
context_block_limit=runtime_config.agents.defaults.context_block_limit,
|
context_block_limit=defaults.context_block_limit,
|
||||||
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
|
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||||
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
|
provider_retry_mode=defaults.provider_retry_mode,
|
||||||
tool_hint_max_length=runtime_config.agents.defaults.tool_hint_max_length,
|
fallback_models=defaults.fallback_models,
|
||||||
|
provider_factory=pf,
|
||||||
web_config=runtime_config.tools.web,
|
web_config=runtime_config.tools.web,
|
||||||
exec_config=runtime_config.tools.exec,
|
exec_config=runtime_config.tools.exec,
|
||||||
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
|
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
|
||||||
@ -639,7 +665,7 @@ def serve(
|
|||||||
)
|
)
|
||||||
|
|
||||||
model_name = _resolved.model
|
model_name = _resolved.model
|
||||||
preset_name = runtime_config.agents.defaults.model_preset
|
preset_name = defaults.model_preset
|
||||||
preset_tag = f" (preset: {preset_name})" if preset_name else ""
|
preset_tag = f" (preset: {preset_name})" if preset_name else ""
|
||||||
console.print(f"{__logo__} Starting OpenAI-compatible API server")
|
console.print(f"{__logo__} Starting OpenAI-compatible API server")
|
||||||
console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
|
console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
|
||||||
@ -721,12 +747,14 @@ def _run_gateway(
|
|||||||
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
|
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
|
||||||
sync_workspace_templates(config.workspace_path)
|
sync_workspace_templates(config.workspace_path)
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
|
provider = _make_provider(config)
|
||||||
|
gw_defaults = config.agents.defaults
|
||||||
|
gw_pf = _make_cli_provider_factory(config) if gw_defaults.fallback_models else None
|
||||||
try:
|
try:
|
||||||
provider_snapshot = build_provider_snapshot(config)
|
provider_snapshot = build_provider_snapshot(config)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
console.print(f"[red]Error: {exc}[/red]")
|
console.print(f"[red]Error: {exc}[/red]")
|
||||||
raise typer.Exit(1) from exc
|
raise typer.Exit(1) from exc
|
||||||
provider = provider_snapshot.provider
|
|
||||||
session_manager = SessionManager(config.workspace_path)
|
session_manager = SessionManager(config.workspace_path)
|
||||||
|
|
||||||
# Preserve existing single-workspace installs, but keep custom workspaces clean.
|
# Preserve existing single-workspace installs, but keep custom workspaces clean.
|
||||||
@ -744,13 +772,14 @@ def _run_gateway(
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
model=_resolved.model,
|
model=_resolved.model,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=gw_defaults.max_tool_iterations,
|
||||||
context_window_tokens=_resolved.context_window_tokens,
|
context_window_tokens=_resolved.context_window_tokens,
|
||||||
web_config=config.tools.web,
|
web_config=config.tools.web,
|
||||||
context_block_limit=config.agents.defaults.context_block_limit,
|
context_block_limit=gw_defaults.context_block_limit,
|
||||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
max_tool_result_chars=gw_defaults.max_tool_result_chars,
|
||||||
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
provider_retry_mode=gw_defaults.provider_retry_mode,
|
||||||
tool_hint_max_length=config.agents.defaults.tool_hint_max_length,
|
fallback_models=gw_defaults.fallback_models,
|
||||||
|
provider_factory=gw_pf,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
@ -1120,6 +1149,8 @@ def agent(
|
|||||||
|
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = _make_provider(config)
|
provider = _make_provider(config)
|
||||||
|
chat_defaults = config.agents.defaults
|
||||||
|
chat_pf = _make_cli_provider_factory(config) if chat_defaults.fallback_models else None
|
||||||
|
|
||||||
# Preserve existing single-workspace installs, but keep custom workspaces clean.
|
# Preserve existing single-workspace installs, but keep custom workspaces clean.
|
||||||
if is_default_workspace(config.workspace_path):
|
if is_default_workspace(config.workspace_path):
|
||||||
@ -1140,13 +1171,14 @@ def agent(
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
model=_resolved.model,
|
model=_resolved.model,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=chat_defaults.max_tool_iterations,
|
||||||
context_window_tokens=_resolved.context_window_tokens,
|
context_window_tokens=_resolved.context_window_tokens,
|
||||||
web_config=config.tools.web,
|
web_config=config.tools.web,
|
||||||
context_block_limit=config.agents.defaults.context_block_limit,
|
context_block_limit=chat_defaults.context_block_limit,
|
||||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
max_tool_result_chars=chat_defaults.max_tool_result_chars,
|
||||||
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
provider_retry_mode=chat_defaults.provider_retry_mode,
|
||||||
tool_hint_max_length=config.agents.defaults.tool_hint_max_length,
|
fallback_models=chat_defaults.fallback_models,
|
||||||
|
provider_factory=chat_pf,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
|
|||||||
@ -105,6 +105,7 @@ class AgentDefaults(Base):
|
|||||||
serialization_alias="toolHintMaxLength",
|
serialization_alias="toolHintMaxLength",
|
||||||
) # Max characters for tool hint display (e.g. "$ cd …/project && npm test")
|
) # Max characters for tool hint display (e.g. "$ cd …/project && npm test")
|
||||||
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
|
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
|
||||||
|
fallback_models: list[str] = Field(default_factory=list)
|
||||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||||
unified_session: bool = False # Share one session across all channels (single-user multi-device)
|
unified_session: bool = False # Share one session across all channels (single-user multi-device)
|
||||||
disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"])
|
disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"])
|
||||||
|
|||||||
@ -66,6 +66,7 @@ class Nanobot:
|
|||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
defaults = config.agents.defaults
|
defaults = config.agents.defaults
|
||||||
_resolved = config.resolve_preset()
|
_resolved = config.resolve_preset()
|
||||||
|
pf = _make_provider_factory(config) if defaults.fallback_models else None
|
||||||
|
|
||||||
loop = AgentLoop(
|
loop = AgentLoop(
|
||||||
bus=bus,
|
bus=bus,
|
||||||
@ -78,6 +79,8 @@ class Nanobot:
|
|||||||
max_tool_result_chars=defaults.max_tool_result_chars,
|
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||||
provider_retry_mode=defaults.provider_retry_mode,
|
provider_retry_mode=defaults.provider_retry_mode,
|
||||||
tool_hint_max_length=defaults.tool_hint_max_length,
|
tool_hint_max_length=defaults.tool_hint_max_length,
|
||||||
|
fallback_models=defaults.fallback_models,
|
||||||
|
provider_factory=pf,
|
||||||
web_config=config.tools.web,
|
web_config=config.tools.web,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
@ -127,14 +130,22 @@ class Nanobot:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_provider(config: Any) -> Any:
|
def _make_provider_for_model(
|
||||||
"""Create the LLM provider from config (extracted from CLI)."""
|
config: Any,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
preset: Any | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Create an LLM provider instance for a specific model string.
|
||||||
|
|
||||||
|
When *preset* is given, its generation settings (temperature, max_tokens,
|
||||||
|
reasoning_effort) override the active preset defaults.
|
||||||
|
"""
|
||||||
from nanobot.providers.base import GenerationSettings
|
from nanobot.providers.base import GenerationSettings
|
||||||
from nanobot.providers.factory import make_provider
|
from nanobot.providers.factory import make_provider
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
resolved = config.resolve_preset()
|
gen_src = preset or config.resolve_preset()
|
||||||
model = resolved.model
|
|
||||||
provider_name = config.get_provider_name(model)
|
provider_name = config.get_provider_name(model)
|
||||||
p = config.get_provider(model)
|
p = config.get_provider(model)
|
||||||
spec = find_by_name(provider_name) if provider_name else None
|
spec = find_by_name(provider_name) if provider_name else None
|
||||||
@ -185,8 +196,34 @@ def _make_provider(config: Any) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
provider.generation = GenerationSettings(
|
provider.generation = GenerationSettings(
|
||||||
temperature=resolved.temperature,
|
temperature=gen_src.temperature,
|
||||||
max_tokens=resolved.max_tokens,
|
max_tokens=gen_src.max_tokens,
|
||||||
reasoning_effort=resolved.reasoning_effort,
|
reasoning_effort=gen_src.reasoning_effort,
|
||||||
)
|
)
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def _make_provider(config: Any) -> Any:
|
||||||
|
"""Create the LLM provider for the primary model from config."""
|
||||||
|
return _make_provider_for_model(config, config.resolve_preset().model)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_provider_factory(config: Any):
|
||||||
|
"""Build a cached factory that creates providers for arbitrary model strings.
|
||||||
|
|
||||||
|
If a model string matches a preset name in ``config.model_presets``, the
|
||||||
|
preset's full config (model, temperature, max_tokens, …) is used.
|
||||||
|
"""
|
||||||
|
cache: dict[str, Any] = {}
|
||||||
|
presets = getattr(config, "model_presets", {}) or {}
|
||||||
|
|
||||||
|
def factory(model_or_preset: str):
|
||||||
|
preset = presets.get(model_or_preset)
|
||||||
|
actual_model = preset.model if preset else model_or_preset
|
||||||
|
provider_name = config.get_provider_name(actual_model)
|
||||||
|
key = provider_name or actual_model
|
||||||
|
if key not in cache:
|
||||||
|
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
|
||||||
|
return cache[key]
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|||||||
190
tests/agent/test_runner_fallback.py
Normal file
190
tests/agent/test_runner_fallback.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
"""Tests for the provider fallback models feature in AgentRunner."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tools():
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="ok")
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
def _make_provider(*, model_response: LLMResponse | None = None):
|
||||||
|
p = MagicMock()
|
||||||
|
if model_response is not None:
|
||||||
|
p.chat_with_retry = AsyncMock(return_value=model_response)
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def _base_spec(**overrides) -> AgentRunSpec:
|
||||||
|
defaults = dict(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
],
|
||||||
|
tools=_make_tools(),
|
||||||
|
model="primary-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=8000,
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
return AgentRunSpec(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_fallback_when_primary_succeeds():
|
||||||
|
"""Primary succeeds -> fallback list never consulted."""
|
||||||
|
ok = LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
provider = _make_provider(model_response=ok)
|
||||||
|
factory = MagicMock()
|
||||||
|
|
||||||
|
runner = AgentRunner(provider, provider_factory=factory)
|
||||||
|
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
factory.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_triggered_on_primary_error():
|
||||||
|
"""Primary fails -> first fallback succeeds."""
|
||||||
|
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||||
|
ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
primary = _make_provider(model_response=err)
|
||||||
|
|
||||||
|
fb_provider = MagicMock()
|
||||||
|
fb_provider.chat_with_retry = AsyncMock(return_value=ok)
|
||||||
|
factory = MagicMock(return_value=fb_provider)
|
||||||
|
|
||||||
|
runner = AgentRunner(primary, provider_factory=factory)
|
||||||
|
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
|
||||||
|
|
||||||
|
assert result.final_content == "fallback-ok"
|
||||||
|
factory.assert_called_once_with("fb-model")
|
||||||
|
fb_provider.chat_with_retry.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_fallbacks_fail_returns_last_error():
|
||||||
|
"""Primary + all fallbacks fail -> return last error response."""
|
||||||
|
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||||
|
|
||||||
|
primary = _make_provider(model_response=err)
|
||||||
|
fb1 = _make_provider(model_response=err)
|
||||||
|
fb2 = _make_provider(model_response=LLMResponse(
|
||||||
|
content="last-error", finish_reason="error", usage={},
|
||||||
|
))
|
||||||
|
|
||||||
|
providers = {"fb-1": fb1, "fb-2": fb2}
|
||||||
|
factory = MagicMock(side_effect=lambda m: providers[m])
|
||||||
|
|
||||||
|
runner = AgentRunner(primary, provider_factory=factory)
|
||||||
|
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
|
||||||
|
|
||||||
|
assert result.error is not None or result.final_content is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_fallback_list_no_retry():
|
||||||
|
"""Empty fallback_models -> no fallback attempted."""
|
||||||
|
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||||
|
primary = _make_provider(model_response=err)
|
||||||
|
factory = MagicMock()
|
||||||
|
|
||||||
|
runner = AgentRunner(primary, provider_factory=factory)
|
||||||
|
result = await runner.run(_base_spec(fallback_models=[]))
|
||||||
|
|
||||||
|
factory.assert_not_called()
|
||||||
|
assert result.error is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cross_provider_fallback():
|
||||||
|
"""Fallback uses a different provider instance (cross-provider)."""
|
||||||
|
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||||
|
ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
primary = _make_provider(model_response=err)
|
||||||
|
anthropic_provider = MagicMock()
|
||||||
|
anthropic_provider.chat_with_retry = AsyncMock(return_value=ok)
|
||||||
|
|
||||||
|
def cross_factory(model: str):
|
||||||
|
if model == "anthropic/claude-sonnet":
|
||||||
|
return anthropic_provider
|
||||||
|
raise ValueError(f"unexpected model: {model}")
|
||||||
|
|
||||||
|
runner = AgentRunner(primary, provider_factory=cross_factory)
|
||||||
|
result = await runner.run(_base_spec(fallback_models=["anthropic/claude-sonnet"]))
|
||||||
|
|
||||||
|
assert result.final_content == "cross-provider-ok"
|
||||||
|
anthropic_provider.chat_with_retry.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_skips_to_second_on_first_error():
|
||||||
|
"""First fallback also fails -> second fallback succeeds."""
|
||||||
|
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||||
|
ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
primary = _make_provider(model_response=err)
|
||||||
|
fb1 = _make_provider(model_response=err)
|
||||||
|
fb2 = MagicMock()
|
||||||
|
fb2.chat_with_retry = AsyncMock(return_value=ok)
|
||||||
|
|
||||||
|
providers = {"fb-1": fb1, "fb-2": fb2}
|
||||||
|
factory = MagicMock(side_effect=lambda m: providers[m])
|
||||||
|
|
||||||
|
runner = AgentRunner(primary, provider_factory=factory)
|
||||||
|
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
|
||||||
|
|
||||||
|
assert result.final_content == "second-fb-ok"
|
||||||
|
assert factory.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_reuses_same_provider_without_factory():
|
||||||
|
"""No provider_factory -> fallback reuses primary provider with different model."""
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, model, **kw):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(content=None, finish_reason="error", usage={})
|
||||||
|
return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
primary = MagicMock()
|
||||||
|
primary.chat_with_retry = chat_with_retry
|
||||||
|
|
||||||
|
runner = AgentRunner(primary, provider_factory=None)
|
||||||
|
result = await runner.run(_base_spec(fallback_models=["fallback-model"]))
|
||||||
|
|
||||||
|
assert result.final_content == "ok-via-fallback-model"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_provider_cached():
|
||||||
|
"""Provider factory is called once per unique provider, not per attempt."""
|
||||||
|
err = LLMResponse(content=None, finish_reason="error", usage={})
|
||||||
|
ok = LLMResponse(content="cached-ok", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
primary = _make_provider(model_response=err)
|
||||||
|
|
||||||
|
fb_provider = MagicMock()
|
||||||
|
call_seq = [err, ok]
|
||||||
|
fb_provider.chat_with_retry = AsyncMock(side_effect=call_seq)
|
||||||
|
|
||||||
|
factory = MagicMock(return_value=fb_provider)
|
||||||
|
|
||||||
|
runner = AgentRunner(primary, provider_factory=factory)
|
||||||
|
result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"]))
|
||||||
|
|
||||||
|
assert result.final_content == "cached-ok"
|
||||||
Loading…
x
Reference in New Issue
Block a user