feat(providers): add extra_query config for OpenAI-compatible providers

Adds ProviderConfig.extra_query, threaded into AsyncOpenAI(default_query)
so that Azure-style gateways requiring query params like api-version can
be configured without URL hacks.

Also updates provider_signature to track extra_query changes so per-turn
refresh rebuilds the provider when the value changes.

Addresses the extra_query portion of #4204. The max_completion_tokens
model-awareness enhancement is intentionally left separate.
This commit is contained in:
axelray-dev 2026-06-06 14:54:37 +08:00 committed by Xubin Ren
parent 9c81280300
commit 28f3a20d64
5 changed files with 109 additions and 1 deletions

View File

@ -183,6 +183,7 @@ class ProviderConfig(Base):
api_type: Literal["auto", "chat_completions", "responses"] = "auto" # Request API surface
extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix)
extra_body: dict[str, Any] | None = None # Extra provider request fields; shape depends on provider/API surface
extra_query: dict[str, str] | None = None # Extra query params (e.g. api-version for Azure-style gateways)
class BedrockProviderConfig(ProviderConfig):

View File

@ -99,6 +99,7 @@ def _make_provider_core(
spec=spec,
extra_body=p.extra_body if p else None,
api_type=p.api_type if p and provider_name == "openai" else "auto",
extra_query=p.extra_query if p else None,
)
provider.generation = resolved.to_generation_settings()
@ -185,6 +186,7 @@ def provider_signature(
fp.extra_headers if fp else None,
fp.extra_body if fp else None,
fp.api_type if fp else "auto",
fp.extra_query if fp else None,
getattr(fp, "region", None) if fp else None,
getattr(fp, "profile", None) if fp else None,
fallback.max_tokens,
@ -202,6 +204,7 @@ def provider_signature(
p.extra_headers if p else None,
p.extra_body if p else None,
p.api_type if p else "auto",
p.extra_query if p else None,
getattr(p, "region", None) if p else None,
getattr(p, "profile", None) if p else None,
resolved.max_tokens,

View File

@ -331,6 +331,7 @@ class OpenAICompatProvider(LLMProvider):
spec: ProviderSpec | None = None,
extra_body: dict[str, Any] | None = None,
api_type: str = "auto",
extra_query: dict[str, str] | None = None,
):
super().__init__(api_key, api_base)
self.default_model = default_model
@ -338,6 +339,7 @@ class OpenAICompatProvider(LLMProvider):
self._spec = spec
self._extra_body = extra_body or {}
self._api_type = api_type if spec and spec.name == "openai" else "auto"
self._extra_query = extra_query or {}
if api_key and spec and spec.env_key:
self._setup_env(api_key, api_base)
@ -386,6 +388,7 @@ class OpenAICompatProvider(LLMProvider):
api_key=self._api_key_for_client,
base_url=self._effective_base,
default_headers=self._default_headers,
default_query=self._extra_query or None,
max_retries=0,
timeout=timeout_s,
http_client=http_client,

View File

@ -241,7 +241,7 @@ def test_inline_fallback_reasoning_effort_does_not_inherit_primary() -> None:
signature = provider_signature(config)
fallback_signatures = signature[-1]
assert fallback_signatures[0][12] is None
assert fallback_signatures[0][13] is None
# -- FallbackProvider tests --

View File

@ -0,0 +1,101 @@
"""Tests for provider extra_query config injection into client defaults."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from nanobot.config.schema import Config, ProviderConfig
from nanobot.providers.factory import provider_signature
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
class TestExtraQuerySchema:
"""Verify ProviderConfig accepts extra_query."""
def test_default_is_none(self) -> None:
config = ProviderConfig()
assert config.extra_query is None
def test_accepts_dict(self) -> None:
config = ProviderConfig(extra_query={"api-version": "2024-02-01"})
assert config.extra_query == {"api-version": "2024-02-01"}
class TestExtraQueryInit:
"""Verify the provider stores extra_query from config."""
def test_default_is_empty(self) -> None:
provider = OpenAICompatProvider(api_key="test")
assert provider._extra_query == {}
def test_none_becomes_empty(self) -> None:
provider = OpenAICompatProvider(api_key="test", extra_query=None)
assert provider._extra_query == {}
def test_dict_stored(self) -> None:
query = {"api-version": "v1"}
provider = OpenAICompatProvider(api_key="test", extra_query=query)
assert provider._extra_query == query
class TestExtraQueryBuildClient:
"""Verify extra_query flows into AsyncOpenAI default_query."""
def test_build_client_passes_default_query(self) -> None:
mock_client = MagicMock()
with patch(
"nanobot.providers.openai_compat_provider.AsyncOpenAI",
return_value=mock_client,
) as mock_async_openai:
provider = OpenAICompatProvider(
api_key="test",
extra_query={"api-version": "v1"},
)
provider._build_client()
assert provider._client is mock_client
assert mock_async_openai.call_args.kwargs["default_query"] == {"api-version": "v1"}
def test_build_client_passes_no_default_query_when_empty(self) -> None:
mock_client = MagicMock()
with patch(
"nanobot.providers.openai_compat_provider.AsyncOpenAI",
return_value=mock_client,
) as mock_async_openai:
provider = OpenAICompatProvider(api_key="test")
provider._build_client()
assert provider._client is mock_client
kwargs = mock_async_openai.call_args.kwargs
assert "default_query" not in kwargs or kwargs["default_query"] is None
class TestProviderSignatureIncludesExtraQuery:
"""Verify provider_signature tracks provider extra_query changes."""
def test_provider_signature_tracks_extra_query(self) -> None:
base = {
"agents": {"defaults": {"modelPreset": "fast"}},
"modelPresets": {
"fast": {"model": "custom/test-model", "provider": "custom"},
},
"providers": {
"custom": {
"apiKey": "test-key",
"extra_query": None,
},
},
}
changed_query = {
**base,
"providers": {
"custom": {
"apiKey": "test-key",
"extra_query": {"api-version": "v1"},
},
},
}
signature = provider_signature(Config.model_validate(base))
assert signature != provider_signature(Config.model_validate(changed_query))