mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 15:24:06 +00:00
fix: preserve dynamic custom provider semantics
maintainer edit: treat arbitrary custom provider names as direct OpenAI-compatible providers, validate their api_type consistently, and avoid Pydantic instance-field warnings in fallback routing.
This commit is contained in:
parent
e9e1489cee
commit
68c6844c0b
@ -241,6 +241,15 @@ class ProvidersConfig(Base):
|
|||||||
qianfan: ProviderConfig = Field(default_factory=ProviderConfig) # Qianfan (百度千帆)
|
qianfan: ProviderConfig = Field(default_factory=ProviderConfig) # Qianfan (百度千帆)
|
||||||
nvidia: ProviderConfig = Field(default_factory=ProviderConfig) # NVIDIA NIM (nvapi- keys)
|
nvidia: ProviderConfig = Field(default_factory=ProviderConfig) # NVIDIA NIM (nvapi- keys)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def convert_extra_providers(self):
|
||||||
|
"""Convert extra fields (custom providers) to ProviderConfig objects."""
|
||||||
|
if self.model_extra:
|
||||||
|
for key, value in self.model_extra.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
self.model_extra[key] = ProviderConfig.model_validate(value)
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_api_type_scope(self) -> "ProvidersConfig":
|
def _validate_api_type_scope(self) -> "ProvidersConfig":
|
||||||
for name in self.__class__.model_fields:
|
for name in self.__class__.model_fields:
|
||||||
@ -249,15 +258,9 @@ class ProvidersConfig(Base):
|
|||||||
provider = getattr(self, name, None)
|
provider = getattr(self, name, None)
|
||||||
if isinstance(provider, ProviderConfig) and provider.api_type != "auto":
|
if isinstance(provider, ProviderConfig) and provider.api_type != "auto":
|
||||||
raise ValueError("providers.<name>.api_type is only supported for providers.openai")
|
raise ValueError("providers.<name>.api_type is only supported for providers.openai")
|
||||||
return self
|
for provider in (self.model_extra or {}).values():
|
||||||
|
if isinstance(provider, ProviderConfig) and provider.api_type != "auto":
|
||||||
@model_validator(mode="after")
|
raise ValueError("providers.<name>.api_type is only supported for providers.openai")
|
||||||
def convert_extra_providers(self):
|
|
||||||
"""Convert extra fields (custom providers) to ProviderConfig objects."""
|
|
||||||
if self.model_extra:
|
|
||||||
for key, value in self.model_extra.items():
|
|
||||||
if isinstance(value, dict):
|
|
||||||
self.model_extra[key] = ProviderConfig.model_validate(value)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@ -478,10 +481,7 @@ class Config(BaseSettings):
|
|||||||
return p, spec.name
|
return p, spec.name
|
||||||
|
|
||||||
# Final fallback: check for any configured custom provider
|
# Final fallback: check for any configured custom provider
|
||||||
for attr_name in dir(self.providers):
|
for attr_name, p in (self.providers.model_extra or {}).items():
|
||||||
if attr_name.startswith("_"):
|
|
||||||
continue
|
|
||||||
p = getattr(self.providers, attr_name, None)
|
|
||||||
if isinstance(p, ProviderConfig) and p.api_base:
|
if isinstance(p, ProviderConfig) and p.api_base:
|
||||||
return p, attr_name
|
return p, attr_name
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from pathlib import Path
|
|||||||
from nanobot.config.schema import Config, InlineFallbackConfig, ModelPresetConfig
|
from nanobot.config.schema import Config, InlineFallbackConfig, ModelPresetConfig
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.providers.fallback_provider import FallbackProvider
|
from nanobot.providers.fallback_provider import FallbackProvider
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import create_dynamic_spec, find_by_name
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -41,6 +41,8 @@ def _make_provider_core(
|
|||||||
provider_name = config.get_provider_name(model, preset=resolved)
|
provider_name = config.get_provider_name(model, preset=resolved)
|
||||||
p = config.get_provider(model, preset=resolved)
|
p = config.get_provider(model, preset=resolved)
|
||||||
spec = find_by_name(provider_name) if provider_name else None
|
spec = find_by_name(provider_name) if provider_name else None
|
||||||
|
if provider_name and not spec and p:
|
||||||
|
spec = create_dynamic_spec(provider_name)
|
||||||
if spec and spec.is_transcription_only:
|
if spec and spec.is_transcription_only:
|
||||||
raise ValueError(f"Provider '{provider_name}' only supports transcription.")
|
raise ValueError(f"Provider '{provider_name}' only supports transcription.")
|
||||||
backend = spec.backend if spec else "openai_compat"
|
backend = spec.backend if spec else "openai_compat"
|
||||||
|
|||||||
@ -664,6 +664,30 @@ def test_make_provider_passes_extra_headers_to_custom_provider():
|
|||||||
assert kwargs["default_headers"]["x-session-affinity"] == "sticky-session"
|
assert kwargs["default_headers"]["x-session-affinity"] == "sticky-session"
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_provider_treats_dynamic_custom_provider_as_direct():
|
||||||
|
config = Config.model_validate(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"provider": "my-company-api", "model": "gpt-4o-mini"}},
|
||||||
|
"providers": {
|
||||||
|
"my-company-api": {
|
||||||
|
"apiBase": "https://example.com/v1",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai:
|
||||||
|
provider = make_provider(config)
|
||||||
|
asyncio.run(provider._ensure_client())
|
||||||
|
|
||||||
|
assert provider.get_default_model() == "gpt-4o-mini"
|
||||||
|
assert provider._spec.name == "my_company_api"
|
||||||
|
assert provider._spec.is_direct is True
|
||||||
|
kwargs = mock_async_openai.call_args.kwargs
|
||||||
|
assert kwargs["api_key"] == "no-key"
|
||||||
|
assert kwargs["base_url"] == "https://example.com/v1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_agent_runtime(tmp_path):
|
def mock_agent_runtime(tmp_path):
|
||||||
"""Mock agent command dependencies for focused CLI tests."""
|
"""Mock agent command dependencies for focused CLI tests."""
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
@ -47,6 +49,35 @@ def test_provider_api_type_is_openai_only() -> None:
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="only supported"):
|
||||||
|
Config.model_validate({
|
||||||
|
"providers": {
|
||||||
|
"my-company-api": {
|
||||||
|
"apiBase": "https://example.test/v1",
|
||||||
|
"apiType": "responses",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_provider_fallback_uses_model_extra_without_pydantic_warnings() -> None:
|
||||||
|
config = Config.model_validate({
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"model": "unmatched-model",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"providers": {
|
||||||
|
"my-company-api": {
|
||||||
|
"apiBase": "https://example.test/v1",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("error")
|
||||||
|
assert config.get_provider_name() == "my-company-api"
|
||||||
|
|
||||||
|
|
||||||
def test_legacy_defaults_config_without_presets_still_resolves() -> None:
|
def test_legacy_defaults_config_without_presets_still_resolves() -> None:
|
||||||
config = Config.model_validate({
|
config = Config.model_validate({
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user