mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
feat(providers): add extra_query config for OpenAI-compatible providers
Adds ProviderConfig.extra_query, threaded into AsyncOpenAI(default_query) so that Azure-style gateways requiring query params like api-version can be configured without URL hacks. Also updates provider_signature to track extra_query changes so per-turn refresh rebuilds the provider when the value changes. Addresses the extra_query portion of #4204. The max_completion_tokens model-awareness enhancement is intentionally left separate.
This commit is contained in:
parent
9c81280300
commit
28f3a20d64
@ -183,6 +183,7 @@ class ProviderConfig(Base):
|
||||
api_type: Literal["auto", "chat_completions", "responses"] = "auto" # Request API surface
|
||||
extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix)
|
||||
extra_body: dict[str, Any] | None = None # Extra provider request fields; shape depends on provider/API surface
|
||||
extra_query: dict[str, str] | None = None # Extra query params (e.g. api-version for Azure-style gateways)
|
||||
|
||||
|
||||
class BedrockProviderConfig(ProviderConfig):
|
||||
|
||||
@ -99,6 +99,7 @@ def _make_provider_core(
|
||||
spec=spec,
|
||||
extra_body=p.extra_body if p else None,
|
||||
api_type=p.api_type if p and provider_name == "openai" else "auto",
|
||||
extra_query=p.extra_query if p else None,
|
||||
)
|
||||
|
||||
provider.generation = resolved.to_generation_settings()
|
||||
@ -185,6 +186,7 @@ def provider_signature(
|
||||
fp.extra_headers if fp else None,
|
||||
fp.extra_body if fp else None,
|
||||
fp.api_type if fp else "auto",
|
||||
fp.extra_query if fp else None,
|
||||
getattr(fp, "region", None) if fp else None,
|
||||
getattr(fp, "profile", None) if fp else None,
|
||||
fallback.max_tokens,
|
||||
@ -202,6 +204,7 @@ def provider_signature(
|
||||
p.extra_headers if p else None,
|
||||
p.extra_body if p else None,
|
||||
p.api_type if p else "auto",
|
||||
p.extra_query if p else None,
|
||||
getattr(p, "region", None) if p else None,
|
||||
getattr(p, "profile", None) if p else None,
|
||||
resolved.max_tokens,
|
||||
|
||||
@ -331,6 +331,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
spec: ProviderSpec | None = None,
|
||||
extra_body: dict[str, Any] | None = None,
|
||||
api_type: str = "auto",
|
||||
extra_query: dict[str, str] | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
@ -338,6 +339,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
self._spec = spec
|
||||
self._extra_body = extra_body or {}
|
||||
self._api_type = api_type if spec and spec.name == "openai" else "auto"
|
||||
self._extra_query = extra_query or {}
|
||||
|
||||
if api_key and spec and spec.env_key:
|
||||
self._setup_env(api_key, api_base)
|
||||
@ -386,6 +388,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
api_key=self._api_key_for_client,
|
||||
base_url=self._effective_base,
|
||||
default_headers=self._default_headers,
|
||||
default_query=self._extra_query or None,
|
||||
max_retries=0,
|
||||
timeout=timeout_s,
|
||||
http_client=http_client,
|
||||
|
||||
@ -241,7 +241,7 @@ def test_inline_fallback_reasoning_effort_does_not_inherit_primary() -> None:
|
||||
signature = provider_signature(config)
|
||||
fallback_signatures = signature[-1]
|
||||
|
||||
assert fallback_signatures[0][12] is None
|
||||
assert fallback_signatures[0][13] is None
|
||||
|
||||
|
||||
# -- FallbackProvider tests --
|
||||
|
||||
101
tests/providers/test_extra_query_config.py
Normal file
101
tests/providers/test_extra_query_config.py
Normal file
@ -0,0 +1,101 @@
|
||||
"""Tests for provider extra_query config injection into client defaults."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from nanobot.config.schema import Config, ProviderConfig
|
||||
from nanobot.providers.factory import provider_signature
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
|
||||
class TestExtraQuerySchema:
|
||||
"""Verify ProviderConfig accepts extra_query."""
|
||||
|
||||
def test_default_is_none(self) -> None:
|
||||
config = ProviderConfig()
|
||||
assert config.extra_query is None
|
||||
|
||||
def test_accepts_dict(self) -> None:
|
||||
config = ProviderConfig(extra_query={"api-version": "2024-02-01"})
|
||||
assert config.extra_query == {"api-version": "2024-02-01"}
|
||||
|
||||
|
||||
class TestExtraQueryInit:
|
||||
"""Verify the provider stores extra_query from config."""
|
||||
|
||||
def test_default_is_empty(self) -> None:
|
||||
provider = OpenAICompatProvider(api_key="test")
|
||||
assert provider._extra_query == {}
|
||||
|
||||
def test_none_becomes_empty(self) -> None:
|
||||
provider = OpenAICompatProvider(api_key="test", extra_query=None)
|
||||
assert provider._extra_query == {}
|
||||
|
||||
def test_dict_stored(self) -> None:
|
||||
query = {"api-version": "v1"}
|
||||
provider = OpenAICompatProvider(api_key="test", extra_query=query)
|
||||
assert provider._extra_query == query
|
||||
|
||||
|
||||
class TestExtraQueryBuildClient:
|
||||
"""Verify extra_query flows into AsyncOpenAI default_query."""
|
||||
|
||||
def test_build_client_passes_default_query(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch(
|
||||
"nanobot.providers.openai_compat_provider.AsyncOpenAI",
|
||||
return_value=mock_client,
|
||||
) as mock_async_openai:
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test",
|
||||
extra_query={"api-version": "v1"},
|
||||
)
|
||||
provider._build_client()
|
||||
|
||||
assert provider._client is mock_client
|
||||
assert mock_async_openai.call_args.kwargs["default_query"] == {"api-version": "v1"}
|
||||
|
||||
def test_build_client_passes_no_default_query_when_empty(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch(
|
||||
"nanobot.providers.openai_compat_provider.AsyncOpenAI",
|
||||
return_value=mock_client,
|
||||
) as mock_async_openai:
|
||||
provider = OpenAICompatProvider(api_key="test")
|
||||
provider._build_client()
|
||||
|
||||
assert provider._client is mock_client
|
||||
kwargs = mock_async_openai.call_args.kwargs
|
||||
assert "default_query" not in kwargs or kwargs["default_query"] is None
|
||||
|
||||
|
||||
class TestProviderSignatureIncludesExtraQuery:
|
||||
"""Verify provider_signature tracks provider extra_query changes."""
|
||||
|
||||
def test_provider_signature_tracks_extra_query(self) -> None:
|
||||
base = {
|
||||
"agents": {"defaults": {"modelPreset": "fast"}},
|
||||
"modelPresets": {
|
||||
"fast": {"model": "custom/test-model", "provider": "custom"},
|
||||
},
|
||||
"providers": {
|
||||
"custom": {
|
||||
"apiKey": "test-key",
|
||||
"extra_query": None,
|
||||
},
|
||||
},
|
||||
}
|
||||
changed_query = {
|
||||
**base,
|
||||
"providers": {
|
||||
"custom": {
|
||||
"apiKey": "test-key",
|
||||
"extra_query": {"api-version": "v1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
signature = provider_signature(Config.model_validate(base))
|
||||
|
||||
assert signature != provider_signature(Config.model_validate(changed_query))
|
||||
Loading…
x
Reference in New Issue
Block a user