mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 15:24:06 +00:00
fix: validate named custom provider endpoints
This commit is contained in:
parent
a9308eb8e2
commit
09d24e6c25
@ -406,16 +406,24 @@ class Config(BaseSettings):
|
|||||||
|
|
||||||
resolved = preset or self.resolve_preset()
|
resolved = preset or self.resolve_preset()
|
||||||
forced = resolved.provider
|
forced = resolved.provider
|
||||||
|
|
||||||
|
def _custom_provider_by_name(name: str) -> tuple[ProviderConfig, str] | None:
|
||||||
|
normalized = name.replace("-", "_").lower()
|
||||||
|
for attr_name, provider in (self.providers.model_extra or {}).items():
|
||||||
|
if not isinstance(provider, ProviderConfig):
|
||||||
|
continue
|
||||||
|
if attr_name.replace("-", "_").lower() == normalized:
|
||||||
|
return provider, attr_name
|
||||||
|
return None
|
||||||
|
|
||||||
if forced != "auto":
|
if forced != "auto":
|
||||||
spec = find_by_name(forced)
|
spec = find_by_name(forced)
|
||||||
if spec:
|
if spec:
|
||||||
p = getattr(self.providers, spec.name, None)
|
p = getattr(self.providers, spec.name, None)
|
||||||
return (p, spec.name) if p else (None, None)
|
return (p, spec.name) if p else (None, None)
|
||||||
# Check for custom provider by name (try both original and normalized)
|
custom = _custom_provider_by_name(forced)
|
||||||
for name_to_try in (forced, forced.replace("-", "_")):
|
if custom is not None:
|
||||||
p = getattr(self.providers, name_to_try, None)
|
return custom
|
||||||
if p and isinstance(p, ProviderConfig):
|
|
||||||
return p, name_to_try
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
model_lower = (model or resolved.model).lower()
|
model_lower = (model or resolved.model).lower()
|
||||||
@ -436,13 +444,14 @@ class Config(BaseSettings):
|
|||||||
if spec.is_oauth or spec.is_local or spec.is_direct or p.api_key:
|
if spec.is_oauth or spec.is_local or spec.is_direct or p.api_key:
|
||||||
return p, spec.name
|
return p, spec.name
|
||||||
|
|
||||||
# Check for custom provider by prefix (e.g., "myprovider/gpt-4")
|
# Check for custom provider by prefix (e.g., "companyProxy/gpt-4").
|
||||||
# Try both original prefix and normalized (snake_case) prefix
|
# Return the matching provider even when apiBase is missing, so a
|
||||||
|
# malformed explicit prefix fails instead of falling through to a
|
||||||
|
# different custom provider.
|
||||||
if model_prefix:
|
if model_prefix:
|
||||||
for prefix_to_try in (model_prefix, normalized_prefix):
|
custom = _custom_provider_by_name(normalized_prefix)
|
||||||
p = getattr(self.providers, prefix_to_try, None)
|
if custom is not None:
|
||||||
if p and isinstance(p, ProviderConfig) and p.api_base:
|
return custom
|
||||||
return p, prefix_to_try
|
|
||||||
|
|
||||||
# Match by keyword (order follows PROVIDERS registry)
|
# Match by keyword (order follows PROVIDERS registry)
|
||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
|
|||||||
@ -42,6 +42,8 @@ def _make_provider_core(
|
|||||||
p = config.get_provider(model, preset=resolved)
|
p = config.get_provider(model, preset=resolved)
|
||||||
spec = find_by_name(provider_name) if provider_name else None
|
spec = find_by_name(provider_name) if provider_name else None
|
||||||
if provider_name and not spec and p:
|
if provider_name and not spec and p:
|
||||||
|
if not p.api_base:
|
||||||
|
raise ValueError(f"Provider '{provider_name}' requires api_base in config.")
|
||||||
spec = create_dynamic_spec(provider_name)
|
spec = create_dynamic_spec(provider_name)
|
||||||
if spec and spec.is_transcription_only:
|
if spec and spec.is_transcription_only:
|
||||||
raise ValueError(f"Provider '{provider_name}' only supports transcription.")
|
raise ValueError(f"Provider '{provider_name}' only supports transcription.")
|
||||||
@ -50,6 +52,14 @@ def _make_provider_core(
|
|||||||
if backend == "azure_openai":
|
if backend == "azure_openai":
|
||||||
if not p or not p.api_base:
|
if not p or not p.api_base:
|
||||||
raise ValueError("Azure OpenAI requires api_base in config.")
|
raise ValueError("Azure OpenAI requires api_base in config.")
|
||||||
|
elif (
|
||||||
|
backend == "openai_compat"
|
||||||
|
and spec
|
||||||
|
and spec.is_direct
|
||||||
|
and not spec.default_api_base
|
||||||
|
and not (p and p.api_base)
|
||||||
|
):
|
||||||
|
raise ValueError(f"Provider '{provider_name}' requires api_base in config.")
|
||||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||||
needs_key = not (p and p.api_key)
|
needs_key = not (p and p.api_key)
|
||||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||||
|
|||||||
@ -270,6 +270,12 @@ def _provider_requires_api_key(spec: Any) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_requires_api_base(spec: Any) -> bool:
|
||||||
|
if spec.name == "azure_openai":
|
||||||
|
return True
|
||||||
|
return bool(spec.backend == "openai_compat" and spec.is_direct and not spec.default_api_base)
|
||||||
|
|
||||||
|
|
||||||
def _oauth_provider_status(spec: Any) -> dict[str, Any]:
|
def _oauth_provider_status(spec: Any) -> dict[str, Any]:
|
||||||
if not getattr(spec, "is_oauth", False):
|
if not getattr(spec, "is_oauth", False):
|
||||||
return {"configured": False, "account": None, "expires_at": None, "login_supported": False}
|
return {"configured": False, "account": None, "expires_at": None, "login_supported": False}
|
||||||
@ -321,7 +327,7 @@ def _oauth_provider_status(spec: Any) -> dict[str, Any]:
|
|||||||
def _provider_configured_for_settings(spec: Any, provider_config: Any) -> bool:
|
def _provider_configured_for_settings(spec: Any, provider_config: Any) -> bool:
|
||||||
if spec.is_oauth:
|
if spec.is_oauth:
|
||||||
return bool(_oauth_provider_status(spec)["configured"])
|
return bool(_oauth_provider_status(spec)["configured"])
|
||||||
if spec.name == "azure_openai":
|
if _provider_requires_api_base(spec):
|
||||||
return bool(provider_config.api_base)
|
return bool(provider_config.api_base)
|
||||||
if _provider_requires_api_key(spec):
|
if _provider_requires_api_key(spec):
|
||||||
return bool(provider_config.api_key)
|
return bool(provider_config.api_key)
|
||||||
|
|||||||
@ -688,6 +688,41 @@ def test_make_provider_treats_dynamic_custom_provider_as_direct():
|
|||||||
assert kwargs["base_url"] == "https://example.com/v1"
|
assert kwargs["base_url"] == "https://example.com/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_provider_rejects_dynamic_custom_provider_without_api_base():
|
||||||
|
config = Config.model_validate(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"provider": "my-company-api", "model": "gpt-4o-mini"}},
|
||||||
|
"providers": {
|
||||||
|
"my-company-api": {
|
||||||
|
"apiKey": "sk-test",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Provider 'my-company-api' requires api_base"):
|
||||||
|
make_provider(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_provider_rejects_auto_dynamic_custom_prefix_without_api_base():
|
||||||
|
config = Config.model_validate(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"provider": "auto", "model": "companyProxy/gpt-4o"}},
|
||||||
|
"providers": {
|
||||||
|
"otherProxy": {
|
||||||
|
"apiBase": "https://other.example.test/v1",
|
||||||
|
},
|
||||||
|
"companyProxy": {
|
||||||
|
"apiKey": "sk-company",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Provider 'companyProxy' requires api_base"):
|
||||||
|
make_provider(config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_agent_runtime(tmp_path):
|
def mock_agent_runtime(tmp_path):
|
||||||
"""Mock agent command dependencies for focused CLI tests."""
|
"""Mock agent command dependencies for focused CLI tests."""
|
||||||
|
|||||||
@ -79,6 +79,50 @@ def test_custom_provider_fallback_uses_model_extra_without_pydantic_warnings() -
|
|||||||
assert config.get_provider_name() == "my-company-api"
|
assert config.get_provider_name() == "my-company-api"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_custom_provider_prefix_matches_camel_case_key() -> None:
|
||||||
|
config = Config.model_validate({
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"provider": "auto",
|
||||||
|
"model": "companyProxy/gpt-4o-mini",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"providers": {
|
||||||
|
"otherProxy": {
|
||||||
|
"apiBase": "https://other.example.test/v1",
|
||||||
|
},
|
||||||
|
"companyProxy": {
|
||||||
|
"apiBase": "https://company.example.test/v1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert config.get_provider_name() == "companyProxy"
|
||||||
|
assert config.get_api_base() == "https://company.example.test/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_custom_provider_prefix_does_not_fall_through_when_base_missing() -> None:
|
||||||
|
config = Config.model_validate({
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"provider": "auto",
|
||||||
|
"model": "companyProxy/gpt-4o-mini",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"providers": {
|
||||||
|
"otherProxy": {
|
||||||
|
"apiBase": "https://other.example.test/v1",
|
||||||
|
},
|
||||||
|
"companyProxy": {
|
||||||
|
"apiKey": "sk-company",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert config.get_provider_name() == "companyProxy"
|
||||||
|
assert config.get_api_base() is None
|
||||||
|
|
||||||
|
|
||||||
def test_legacy_defaults_config_without_presets_still_resolves() -> None:
|
def test_legacy_defaults_config_without_presets_still_resolves() -> None:
|
||||||
config = Config.model_validate({
|
config = Config.model_validate({
|
||||||
"agents": {
|
"agents": {
|
||||||
|
|||||||
@ -113,6 +113,31 @@ def test_create_model_configuration_accepts_dynamic_custom_provider(
|
|||||||
assert saved.model_presets["tenant-model"].model == "gpt-4o-mini"
|
assert saved.model_presets["tenant-model"].model == "gpt-4o-mini"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_model_configuration_rejects_dynamic_custom_provider_without_api_base(
|
||||||
|
tmp_path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config = Config.model_validate({
|
||||||
|
"providers": {
|
||||||
|
DYNAMIC_PROVIDER_NAME: {
|
||||||
|
"apiKey": "sk-test",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
save_config(config, config_path)
|
||||||
|
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
|
||||||
|
|
||||||
|
with pytest.raises(WebUISettingsError, match="provider is not configured"):
|
||||||
|
create_model_configuration(
|
||||||
|
{
|
||||||
|
"label": ["Tenant model"],
|
||||||
|
"provider": [DYNAMIC_PROVIDER_NAME],
|
||||||
|
"model": ["gpt-4o-mini"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_model_configuration_rejects_unconfigured_provider(
|
def test_create_model_configuration_rejects_unconfigured_provider(
|
||||||
tmp_path,
|
tmp_path,
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
@ -315,6 +340,29 @@ def test_settings_payload_includes_dynamic_custom_provider(
|
|||||||
assert providers[DYNAMIC_PROVIDER_NAME]["api_base"] == DYNAMIC_PROVIDER_API_BASE
|
assert providers[DYNAMIC_PROVIDER_NAME]["api_base"] == DYNAMIC_PROVIDER_API_BASE
|
||||||
|
|
||||||
|
|
||||||
|
def test_settings_payload_marks_dynamic_custom_provider_without_api_base_unconfigured(
|
||||||
|
tmp_path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config = Config.model_validate({
|
||||||
|
"providers": {
|
||||||
|
DYNAMIC_PROVIDER_NAME: {
|
||||||
|
"apiKey": "sk-test",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
save_config(config, config_path)
|
||||||
|
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
|
||||||
|
|
||||||
|
payload = settings_payload()
|
||||||
|
providers = {row["name"]: row for row in payload["providers"]}
|
||||||
|
|
||||||
|
assert providers[DYNAMIC_PROVIDER_NAME]["configured"] is False
|
||||||
|
assert providers[DYNAMIC_PROVIDER_NAME]["api_key_hint"] == "••••"
|
||||||
|
assert providers[DYNAMIC_PROVIDER_NAME]["api_base"] is None
|
||||||
|
|
||||||
|
|
||||||
def test_settings_payload_includes_network_safety_fields(
|
def test_settings_payload_includes_network_safety_fields(
|
||||||
tmp_path,
|
tmp_path,
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user