fix: validate named custom provider endpoints

This commit is contained in:
chengyongru 2026-06-11 16:41:31 +08:00 committed by Xubin Ren
parent a9308eb8e2
commit 09d24e6c25
6 changed files with 164 additions and 12 deletions

View File

@ -406,16 +406,24 @@ class Config(BaseSettings):
resolved = preset or self.resolve_preset() resolved = preset or self.resolve_preset()
forced = resolved.provider forced = resolved.provider
def _custom_provider_by_name(name: str) -> tuple[ProviderConfig, str] | None:
normalized = name.replace("-", "_").lower()
for attr_name, provider in (self.providers.model_extra or {}).items():
if not isinstance(provider, ProviderConfig):
continue
if attr_name.replace("-", "_").lower() == normalized:
return provider, attr_name
return None
if forced != "auto": if forced != "auto":
spec = find_by_name(forced) spec = find_by_name(forced)
if spec: if spec:
p = getattr(self.providers, spec.name, None) p = getattr(self.providers, spec.name, None)
return (p, spec.name) if p else (None, None) return (p, spec.name) if p else (None, None)
# Check for custom provider by name (try both original and normalized) custom = _custom_provider_by_name(forced)
for name_to_try in (forced, forced.replace("-", "_")): if custom is not None:
p = getattr(self.providers, name_to_try, None) return custom
if p and isinstance(p, ProviderConfig):
return p, name_to_try
return None, None return None, None
model_lower = (model or resolved.model).lower() model_lower = (model or resolved.model).lower()
@ -436,13 +444,14 @@ class Config(BaseSettings):
if spec.is_oauth or spec.is_local or spec.is_direct or p.api_key: if spec.is_oauth or spec.is_local or spec.is_direct or p.api_key:
return p, spec.name return p, spec.name
# Check for custom provider by prefix (e.g., "myprovider/gpt-4") # Check for custom provider by prefix (e.g., "companyProxy/gpt-4").
# Try both original prefix and normalized (snake_case) prefix # Return the matching provider even when apiBase is missing, so a
# malformed explicit prefix fails instead of falling through to a
# different custom provider.
if model_prefix: if model_prefix:
for prefix_to_try in (model_prefix, normalized_prefix): custom = _custom_provider_by_name(normalized_prefix)
p = getattr(self.providers, prefix_to_try, None) if custom is not None:
if p and isinstance(p, ProviderConfig) and p.api_base: return custom
return p, prefix_to_try
# Match by keyword (order follows PROVIDERS registry) # Match by keyword (order follows PROVIDERS registry)
for spec in PROVIDERS: for spec in PROVIDERS:

View File

@ -42,6 +42,8 @@ def _make_provider_core(
p = config.get_provider(model, preset=resolved) p = config.get_provider(model, preset=resolved)
spec = find_by_name(provider_name) if provider_name else None spec = find_by_name(provider_name) if provider_name else None
if provider_name and not spec and p: if provider_name and not spec and p:
if not p.api_base:
raise ValueError(f"Provider '{provider_name}' requires api_base in config.")
spec = create_dynamic_spec(provider_name) spec = create_dynamic_spec(provider_name)
if spec and spec.is_transcription_only: if spec and spec.is_transcription_only:
raise ValueError(f"Provider '{provider_name}' only supports transcription.") raise ValueError(f"Provider '{provider_name}' only supports transcription.")
@ -50,6 +52,14 @@ def _make_provider_core(
if backend == "azure_openai": if backend == "azure_openai":
if not p or not p.api_base: if not p or not p.api_base:
raise ValueError("Azure OpenAI requires api_base in config.") raise ValueError("Azure OpenAI requires api_base in config.")
elif (
backend == "openai_compat"
and spec
and spec.is_direct
and not spec.default_api_base
and not (p and p.api_base)
):
raise ValueError(f"Provider '{provider_name}' requires api_base in config.")
elif backend == "openai_compat" and not model.startswith("bedrock/"): elif backend == "openai_compat" and not model.startswith("bedrock/"):
needs_key = not (p and p.api_key) needs_key = not (p and p.api_key)
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct) exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)

View File

@ -270,6 +270,12 @@ def _provider_requires_api_key(spec: Any) -> bool:
return True return True
def _provider_requires_api_base(spec: Any) -> bool:
if spec.name == "azure_openai":
return True
return bool(spec.backend == "openai_compat" and spec.is_direct and not spec.default_api_base)
def _oauth_provider_status(spec: Any) -> dict[str, Any]: def _oauth_provider_status(spec: Any) -> dict[str, Any]:
if not getattr(spec, "is_oauth", False): if not getattr(spec, "is_oauth", False):
return {"configured": False, "account": None, "expires_at": None, "login_supported": False} return {"configured": False, "account": None, "expires_at": None, "login_supported": False}
@ -321,7 +327,7 @@ def _oauth_provider_status(spec: Any) -> dict[str, Any]:
def _provider_configured_for_settings(spec: Any, provider_config: Any) -> bool: def _provider_configured_for_settings(spec: Any, provider_config: Any) -> bool:
if spec.is_oauth: if spec.is_oauth:
return bool(_oauth_provider_status(spec)["configured"]) return bool(_oauth_provider_status(spec)["configured"])
if spec.name == "azure_openai": if _provider_requires_api_base(spec):
return bool(provider_config.api_base) return bool(provider_config.api_base)
if _provider_requires_api_key(spec): if _provider_requires_api_key(spec):
return bool(provider_config.api_key) return bool(provider_config.api_key)

View File

@ -688,6 +688,41 @@ def test_make_provider_treats_dynamic_custom_provider_as_direct():
assert kwargs["base_url"] == "https://example.com/v1" assert kwargs["base_url"] == "https://example.com/v1"
def test_make_provider_rejects_dynamic_custom_provider_without_api_base():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "my-company-api", "model": "gpt-4o-mini"}},
"providers": {
"my-company-api": {
"apiKey": "sk-test",
}
},
}
)
with pytest.raises(ValueError, match="Provider 'my-company-api' requires api_base"):
make_provider(config)
def test_make_provider_rejects_auto_dynamic_custom_prefix_without_api_base():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "companyProxy/gpt-4o"}},
"providers": {
"otherProxy": {
"apiBase": "https://other.example.test/v1",
},
"companyProxy": {
"apiKey": "sk-company",
},
},
}
)
with pytest.raises(ValueError, match="Provider 'companyProxy' requires api_base"):
make_provider(config)
@pytest.fixture @pytest.fixture
def mock_agent_runtime(tmp_path): def mock_agent_runtime(tmp_path):
"""Mock agent command dependencies for focused CLI tests.""" """Mock agent command dependencies for focused CLI tests."""

View File

@ -79,6 +79,50 @@ def test_custom_provider_fallback_uses_model_extra_without_pydantic_warnings() -
assert config.get_provider_name() == "my-company-api" assert config.get_provider_name() == "my-company-api"
def test_dynamic_custom_provider_prefix_matches_camel_case_key() -> None:
config = Config.model_validate({
"agents": {
"defaults": {
"provider": "auto",
"model": "companyProxy/gpt-4o-mini",
}
},
"providers": {
"otherProxy": {
"apiBase": "https://other.example.test/v1",
},
"companyProxy": {
"apiBase": "https://company.example.test/v1",
},
},
})
assert config.get_provider_name() == "companyProxy"
assert config.get_api_base() == "https://company.example.test/v1"
def test_dynamic_custom_provider_prefix_does_not_fall_through_when_base_missing() -> None:
config = Config.model_validate({
"agents": {
"defaults": {
"provider": "auto",
"model": "companyProxy/gpt-4o-mini",
}
},
"providers": {
"otherProxy": {
"apiBase": "https://other.example.test/v1",
},
"companyProxy": {
"apiKey": "sk-company",
},
},
})
assert config.get_provider_name() == "companyProxy"
assert config.get_api_base() is None
def test_legacy_defaults_config_without_presets_still_resolves() -> None: def test_legacy_defaults_config_without_presets_still_resolves() -> None:
config = Config.model_validate({ config = Config.model_validate({
"agents": { "agents": {

View File

@ -113,6 +113,31 @@ def test_create_model_configuration_accepts_dynamic_custom_provider(
assert saved.model_presets["tenant-model"].model == "gpt-4o-mini" assert saved.model_presets["tenant-model"].model == "gpt-4o-mini"
def test_create_model_configuration_rejects_dynamic_custom_provider_without_api_base(
tmp_path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
config_path = tmp_path / "config.json"
config = Config.model_validate({
"providers": {
DYNAMIC_PROVIDER_NAME: {
"apiKey": "sk-test",
}
}
})
save_config(config, config_path)
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
with pytest.raises(WebUISettingsError, match="provider is not configured"):
create_model_configuration(
{
"label": ["Tenant model"],
"provider": [DYNAMIC_PROVIDER_NAME],
"model": ["gpt-4o-mini"],
}
)
def test_create_model_configuration_rejects_unconfigured_provider( def test_create_model_configuration_rejects_unconfigured_provider(
tmp_path, tmp_path,
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
@ -315,6 +340,29 @@ def test_settings_payload_includes_dynamic_custom_provider(
assert providers[DYNAMIC_PROVIDER_NAME]["api_base"] == DYNAMIC_PROVIDER_API_BASE assert providers[DYNAMIC_PROVIDER_NAME]["api_base"] == DYNAMIC_PROVIDER_API_BASE
def test_settings_payload_marks_dynamic_custom_provider_without_api_base_unconfigured(
tmp_path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
config_path = tmp_path / "config.json"
config = Config.model_validate({
"providers": {
DYNAMIC_PROVIDER_NAME: {
"apiKey": "sk-test",
}
}
})
save_config(config, config_path)
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
payload = settings_payload()
providers = {row["name"]: row for row in payload["providers"]}
assert providers[DYNAMIC_PROVIDER_NAME]["configured"] is False
assert providers[DYNAMIC_PROVIDER_NAME]["api_key_hint"] == "••••"
assert providers[DYNAMIC_PROVIDER_NAME]["api_base"] is None
def test_settings_payload_includes_network_safety_fields( def test_settings_payload_includes_network_safety_fields(
tmp_path, tmp_path,
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,