feat(providers): add extra_query config for OpenAI-compatible providers

Adds ProviderConfig.extra_query, threaded into AsyncOpenAI(default_query)
so that Azure-style gateways requiring query params like api-version can
be configured without URL hacks.

Also updates provider_signature to track extra_query changes so per-turn
refresh rebuilds the provider when the value changes.

Addresses the extra_query portion of #4204. The max_completion_tokens
model-awareness enhancement is intentionally left separate.
This commit is contained in:
axelray-dev 2026-06-06 14:54:37 +08:00 committed by Xubin Ren
parent 9c81280300
commit 28f3a20d64
5 changed files with 109 additions and 1 deletions

View File

@ -183,6 +183,7 @@ class ProviderConfig(Base):
api_type: Literal["auto", "chat_completions", "responses"] = "auto" # Request API surface api_type: Literal["auto", "chat_completions", "responses"] = "auto" # Request API surface
extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix) extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix)
extra_body: dict[str, Any] | None = None # Extra provider request fields; shape depends on provider/API surface extra_body: dict[str, Any] | None = None # Extra provider request fields; shape depends on provider/API surface
extra_query: dict[str, str] | None = None # Extra query params (e.g. api-version for Azure-style gateways)
class BedrockProviderConfig(ProviderConfig): class BedrockProviderConfig(ProviderConfig):

View File

@ -99,6 +99,7 @@ def _make_provider_core(
spec=spec, spec=spec,
extra_body=p.extra_body if p else None, extra_body=p.extra_body if p else None,
api_type=p.api_type if p and provider_name == "openai" else "auto", api_type=p.api_type if p and provider_name == "openai" else "auto",
extra_query=p.extra_query if p else None,
) )
provider.generation = resolved.to_generation_settings() provider.generation = resolved.to_generation_settings()
@ -185,6 +186,7 @@ def provider_signature(
fp.extra_headers if fp else None, fp.extra_headers if fp else None,
fp.extra_body if fp else None, fp.extra_body if fp else None,
fp.api_type if fp else "auto", fp.api_type if fp else "auto",
fp.extra_query if fp else None,
getattr(fp, "region", None) if fp else None, getattr(fp, "region", None) if fp else None,
getattr(fp, "profile", None) if fp else None, getattr(fp, "profile", None) if fp else None,
fallback.max_tokens, fallback.max_tokens,
@ -202,6 +204,7 @@ def provider_signature(
p.extra_headers if p else None, p.extra_headers if p else None,
p.extra_body if p else None, p.extra_body if p else None,
p.api_type if p else "auto", p.api_type if p else "auto",
p.extra_query if p else None,
getattr(p, "region", None) if p else None, getattr(p, "region", None) if p else None,
getattr(p, "profile", None) if p else None, getattr(p, "profile", None) if p else None,
resolved.max_tokens, resolved.max_tokens,

View File

@ -331,6 +331,7 @@ class OpenAICompatProvider(LLMProvider):
spec: ProviderSpec | None = None, spec: ProviderSpec | None = None,
extra_body: dict[str, Any] | None = None, extra_body: dict[str, Any] | None = None,
api_type: str = "auto", api_type: str = "auto",
extra_query: dict[str, str] | None = None,
): ):
super().__init__(api_key, api_base) super().__init__(api_key, api_base)
self.default_model = default_model self.default_model = default_model
@ -338,6 +339,7 @@ class OpenAICompatProvider(LLMProvider):
self._spec = spec self._spec = spec
self._extra_body = extra_body or {} self._extra_body = extra_body or {}
self._api_type = api_type if spec and spec.name == "openai" else "auto" self._api_type = api_type if spec and spec.name == "openai" else "auto"
self._extra_query = extra_query or {}
if api_key and spec and spec.env_key: if api_key and spec and spec.env_key:
self._setup_env(api_key, api_base) self._setup_env(api_key, api_base)
@ -386,6 +388,7 @@ class OpenAICompatProvider(LLMProvider):
api_key=self._api_key_for_client, api_key=self._api_key_for_client,
base_url=self._effective_base, base_url=self._effective_base,
default_headers=self._default_headers, default_headers=self._default_headers,
default_query=self._extra_query or None,
max_retries=0, max_retries=0,
timeout=timeout_s, timeout=timeout_s,
http_client=http_client, http_client=http_client,

View File

@ -241,7 +241,7 @@ def test_inline_fallback_reasoning_effort_does_not_inherit_primary() -> None:
signature = provider_signature(config) signature = provider_signature(config)
fallback_signatures = signature[-1] fallback_signatures = signature[-1]
assert fallback_signatures[0][12] is None assert fallback_signatures[0][13] is None
# -- FallbackProvider tests -- # -- FallbackProvider tests --

View File

@ -0,0 +1,101 @@
"""Tests for provider extra_query config injection into client defaults."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from nanobot.config.schema import Config, ProviderConfig
from nanobot.providers.factory import provider_signature
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
class TestExtraQuerySchema:
"""Verify ProviderConfig accepts extra_query."""
def test_default_is_none(self) -> None:
config = ProviderConfig()
assert config.extra_query is None
def test_accepts_dict(self) -> None:
config = ProviderConfig(extra_query={"api-version": "2024-02-01"})
assert config.extra_query == {"api-version": "2024-02-01"}
class TestExtraQueryInit:
"""Verify the provider stores extra_query from config."""
def test_default_is_empty(self) -> None:
provider = OpenAICompatProvider(api_key="test")
assert provider._extra_query == {}
def test_none_becomes_empty(self) -> None:
provider = OpenAICompatProvider(api_key="test", extra_query=None)
assert provider._extra_query == {}
def test_dict_stored(self) -> None:
query = {"api-version": "v1"}
provider = OpenAICompatProvider(api_key="test", extra_query=query)
assert provider._extra_query == query
class TestExtraQueryBuildClient:
"""Verify extra_query flows into AsyncOpenAI default_query."""
def test_build_client_passes_default_query(self) -> None:
mock_client = MagicMock()
with patch(
"nanobot.providers.openai_compat_provider.AsyncOpenAI",
return_value=mock_client,
) as mock_async_openai:
provider = OpenAICompatProvider(
api_key="test",
extra_query={"api-version": "v1"},
)
provider._build_client()
assert provider._client is mock_client
assert mock_async_openai.call_args.kwargs["default_query"] == {"api-version": "v1"}
def test_build_client_passes_no_default_query_when_empty(self) -> None:
mock_client = MagicMock()
with patch(
"nanobot.providers.openai_compat_provider.AsyncOpenAI",
return_value=mock_client,
) as mock_async_openai:
provider = OpenAICompatProvider(api_key="test")
provider._build_client()
assert provider._client is mock_client
kwargs = mock_async_openai.call_args.kwargs
assert "default_query" not in kwargs or kwargs["default_query"] is None
class TestProviderSignatureIncludesExtraQuery:
"""Verify provider_signature tracks provider extra_query changes."""
def test_provider_signature_tracks_extra_query(self) -> None:
base = {
"agents": {"defaults": {"modelPreset": "fast"}},
"modelPresets": {
"fast": {"model": "custom/test-model", "provider": "custom"},
},
"providers": {
"custom": {
"apiKey": "test-key",
"extra_query": None,
},
},
}
changed_query = {
**base,
"providers": {
"custom": {
"apiKey": "test-key",
"extra_query": {"api-version": "v1"},
},
},
}
signature = provider_signature(Config.model_validate(base))
assert signature != provider_signature(Config.model_validate(changed_query))