fix: preserve dynamic custom provider semantics

maintainer edit: treat arbitrary custom provider names as direct OpenAI-compatible providers, validate their api_type consistently, and avoid Pydantic instance-field warnings in fallback routing.
This commit is contained in:
chengyongru 2026-06-11 13:19:15 +08:00 committed by Xubin Ren
parent e9e1489cee
commit 68c6844c0b
4 changed files with 71 additions and 14 deletions

View File

@ -241,6 +241,15 @@ class ProvidersConfig(Base):
qianfan: ProviderConfig = Field(default_factory=ProviderConfig) # Qianfan (百度千帆) qianfan: ProviderConfig = Field(default_factory=ProviderConfig) # Qianfan (百度千帆)
nvidia: ProviderConfig = Field(default_factory=ProviderConfig) # NVIDIA NIM (nvapi- keys) nvidia: ProviderConfig = Field(default_factory=ProviderConfig) # NVIDIA NIM (nvapi- keys)
@model_validator(mode="after")
def convert_extra_providers(self):
"""Convert extra fields (custom providers) to ProviderConfig objects."""
if self.model_extra:
for key, value in self.model_extra.items():
if isinstance(value, dict):
self.model_extra[key] = ProviderConfig.model_validate(value)
return self
@model_validator(mode="after") @model_validator(mode="after")
def _validate_api_type_scope(self) -> "ProvidersConfig": def _validate_api_type_scope(self) -> "ProvidersConfig":
for name in self.__class__.model_fields: for name in self.__class__.model_fields:
@ -249,15 +258,9 @@ class ProvidersConfig(Base):
provider = getattr(self, name, None) provider = getattr(self, name, None)
if isinstance(provider, ProviderConfig) and provider.api_type != "auto": if isinstance(provider, ProviderConfig) and provider.api_type != "auto":
raise ValueError("providers.<name>.api_type is only supported for providers.openai") raise ValueError("providers.<name>.api_type is only supported for providers.openai")
return self for provider in (self.model_extra or {}).values():
if isinstance(provider, ProviderConfig) and provider.api_type != "auto":
@model_validator(mode="after") raise ValueError("providers.<name>.api_type is only supported for providers.openai")
def convert_extra_providers(self):
"""Convert extra fields (custom providers) to ProviderConfig objects."""
if self.model_extra:
for key, value in self.model_extra.items():
if isinstance(value, dict):
self.model_extra[key] = ProviderConfig.model_validate(value)
return self return self
@ -478,10 +481,7 @@ class Config(BaseSettings):
return p, spec.name return p, spec.name
# Final fallback: check for any configured custom provider # Final fallback: check for any configured custom provider
for attr_name in dir(self.providers): for attr_name, p in (self.providers.model_extra or {}).items():
if attr_name.startswith("_"):
continue
p = getattr(self.providers, attr_name, None)
if isinstance(p, ProviderConfig) and p.api_base: if isinstance(p, ProviderConfig) and p.api_base:
return p, attr_name return p, attr_name

View File

@ -8,7 +8,7 @@ from pathlib import Path
from nanobot.config.schema import Config, InlineFallbackConfig, ModelPresetConfig from nanobot.config.schema import Config, InlineFallbackConfig, ModelPresetConfig
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.providers.fallback_provider import FallbackProvider from nanobot.providers.fallback_provider import FallbackProvider
from nanobot.providers.registry import find_by_name from nanobot.providers.registry import create_dynamic_spec, find_by_name
@dataclass(frozen=True) @dataclass(frozen=True)
@ -41,6 +41,8 @@ def _make_provider_core(
provider_name = config.get_provider_name(model, preset=resolved) provider_name = config.get_provider_name(model, preset=resolved)
p = config.get_provider(model, preset=resolved) p = config.get_provider(model, preset=resolved)
spec = find_by_name(provider_name) if provider_name else None spec = find_by_name(provider_name) if provider_name else None
if provider_name and not spec and p:
spec = create_dynamic_spec(provider_name)
if spec and spec.is_transcription_only: if spec and spec.is_transcription_only:
raise ValueError(f"Provider '{provider_name}' only supports transcription.") raise ValueError(f"Provider '{provider_name}' only supports transcription.")
backend = spec.backend if spec else "openai_compat" backend = spec.backend if spec else "openai_compat"

View File

@ -664,6 +664,30 @@ def test_make_provider_passes_extra_headers_to_custom_provider():
assert kwargs["default_headers"]["x-session-affinity"] == "sticky-session" assert kwargs["default_headers"]["x-session-affinity"] == "sticky-session"
def test_make_provider_treats_dynamic_custom_provider_as_direct():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "my-company-api", "model": "gpt-4o-mini"}},
"providers": {
"my-company-api": {
"apiBase": "https://example.com/v1",
}
},
}
)
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai:
provider = make_provider(config)
asyncio.run(provider._ensure_client())
assert provider.get_default_model() == "gpt-4o-mini"
assert provider._spec.name == "my_company_api"
assert provider._spec.is_direct is True
kwargs = mock_async_openai.call_args.kwargs
assert kwargs["api_key"] == "no-key"
assert kwargs["base_url"] == "https://example.com/v1"
@pytest.fixture @pytest.fixture
def mock_agent_runtime(tmp_path): def mock_agent_runtime(tmp_path):
"""Mock agent command dependencies for focused CLI tests.""" """Mock agent command dependencies for focused CLI tests."""

View File

@ -1,3 +1,5 @@
import warnings
import pytest import pytest
from nanobot.config.schema import Config from nanobot.config.schema import Config
@ -47,6 +49,35 @@ def test_provider_api_type_is_openai_only() -> None:
} }
}) })
with pytest.raises(ValueError, match="only supported"):
Config.model_validate({
"providers": {
"my-company-api": {
"apiBase": "https://example.test/v1",
"apiType": "responses",
}
}
})
def test_custom_provider_fallback_uses_model_extra_without_pydantic_warnings() -> None:
config = Config.model_validate({
"agents": {
"defaults": {
"model": "unmatched-model",
}
},
"providers": {
"my-company-api": {
"apiBase": "https://example.test/v1",
}
},
})
with warnings.catch_warnings():
warnings.simplefilter("error")
assert config.get_provider_name() == "my-company-api"
def test_legacy_defaults_config_without_presets_still_resolves() -> None: def test_legacy_defaults_config_without_presets_still_resolves() -> None:
config = Config.model_validate({ config = Config.model_validate({