mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 15:24:06 +00:00
feat(providers): add extra_query config for OpenAI-compatible providers
Adds ProviderConfig.extra_query, threaded into AsyncOpenAI(default_query) so that Azure-style gateways requiring query params like api-version can be configured without URL hacks. Also updates provider_signature to track extra_query changes so per-turn refresh rebuilds the provider when the value changes. Addresses the extra_query portion of #4204. The max_completion_tokens model-awareness enhancement is intentionally left separate.
This commit is contained in:
parent
9c81280300
commit
28f3a20d64
@ -183,6 +183,7 @@ class ProviderConfig(Base):
|
|||||||
api_type: Literal["auto", "chat_completions", "responses"] = "auto" # Request API surface
|
api_type: Literal["auto", "chat_completions", "responses"] = "auto" # Request API surface
|
||||||
extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix)
|
extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix)
|
||||||
extra_body: dict[str, Any] | None = None # Extra provider request fields; shape depends on provider/API surface
|
extra_body: dict[str, Any] | None = None # Extra provider request fields; shape depends on provider/API surface
|
||||||
|
extra_query: dict[str, str] | None = None # Extra query params (e.g. api-version for Azure-style gateways)
|
||||||
|
|
||||||
|
|
||||||
class BedrockProviderConfig(ProviderConfig):
|
class BedrockProviderConfig(ProviderConfig):
|
||||||
|
|||||||
@ -99,6 +99,7 @@ def _make_provider_core(
|
|||||||
spec=spec,
|
spec=spec,
|
||||||
extra_body=p.extra_body if p else None,
|
extra_body=p.extra_body if p else None,
|
||||||
api_type=p.api_type if p and provider_name == "openai" else "auto",
|
api_type=p.api_type if p and provider_name == "openai" else "auto",
|
||||||
|
extra_query=p.extra_query if p else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider.generation = resolved.to_generation_settings()
|
provider.generation = resolved.to_generation_settings()
|
||||||
@ -185,6 +186,7 @@ def provider_signature(
|
|||||||
fp.extra_headers if fp else None,
|
fp.extra_headers if fp else None,
|
||||||
fp.extra_body if fp else None,
|
fp.extra_body if fp else None,
|
||||||
fp.api_type if fp else "auto",
|
fp.api_type if fp else "auto",
|
||||||
|
fp.extra_query if fp else None,
|
||||||
getattr(fp, "region", None) if fp else None,
|
getattr(fp, "region", None) if fp else None,
|
||||||
getattr(fp, "profile", None) if fp else None,
|
getattr(fp, "profile", None) if fp else None,
|
||||||
fallback.max_tokens,
|
fallback.max_tokens,
|
||||||
@ -202,6 +204,7 @@ def provider_signature(
|
|||||||
p.extra_headers if p else None,
|
p.extra_headers if p else None,
|
||||||
p.extra_body if p else None,
|
p.extra_body if p else None,
|
||||||
p.api_type if p else "auto",
|
p.api_type if p else "auto",
|
||||||
|
p.extra_query if p else None,
|
||||||
getattr(p, "region", None) if p else None,
|
getattr(p, "region", None) if p else None,
|
||||||
getattr(p, "profile", None) if p else None,
|
getattr(p, "profile", None) if p else None,
|
||||||
resolved.max_tokens,
|
resolved.max_tokens,
|
||||||
|
|||||||
@ -331,6 +331,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
spec: ProviderSpec | None = None,
|
spec: ProviderSpec | None = None,
|
||||||
extra_body: dict[str, Any] | None = None,
|
extra_body: dict[str, Any] | None = None,
|
||||||
api_type: str = "auto",
|
api_type: str = "auto",
|
||||||
|
extra_query: dict[str, str] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(api_key, api_base)
|
super().__init__(api_key, api_base)
|
||||||
self.default_model = default_model
|
self.default_model = default_model
|
||||||
@ -338,6 +339,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
self._spec = spec
|
self._spec = spec
|
||||||
self._extra_body = extra_body or {}
|
self._extra_body = extra_body or {}
|
||||||
self._api_type = api_type if spec and spec.name == "openai" else "auto"
|
self._api_type = api_type if spec and spec.name == "openai" else "auto"
|
||||||
|
self._extra_query = extra_query or {}
|
||||||
|
|
||||||
if api_key and spec and spec.env_key:
|
if api_key and spec and spec.env_key:
|
||||||
self._setup_env(api_key, api_base)
|
self._setup_env(api_key, api_base)
|
||||||
@ -386,6 +388,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
api_key=self._api_key_for_client,
|
api_key=self._api_key_for_client,
|
||||||
base_url=self._effective_base,
|
base_url=self._effective_base,
|
||||||
default_headers=self._default_headers,
|
default_headers=self._default_headers,
|
||||||
|
default_query=self._extra_query or None,
|
||||||
max_retries=0,
|
max_retries=0,
|
||||||
timeout=timeout_s,
|
timeout=timeout_s,
|
||||||
http_client=http_client,
|
http_client=http_client,
|
||||||
|
|||||||
@ -241,7 +241,7 @@ def test_inline_fallback_reasoning_effort_does_not_inherit_primary() -> None:
|
|||||||
signature = provider_signature(config)
|
signature = provider_signature(config)
|
||||||
fallback_signatures = signature[-1]
|
fallback_signatures = signature[-1]
|
||||||
|
|
||||||
assert fallback_signatures[0][12] is None
|
assert fallback_signatures[0][13] is None
|
||||||
|
|
||||||
|
|
||||||
# -- FallbackProvider tests --
|
# -- FallbackProvider tests --
|
||||||
|
|||||||
101
tests/providers/test_extra_query_config.py
Normal file
101
tests/providers/test_extra_query_config.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
"""Tests for provider extra_query config injection into client defaults."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from nanobot.config.schema import Config, ProviderConfig
|
||||||
|
from nanobot.providers.factory import provider_signature
|
||||||
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtraQuerySchema:
|
||||||
|
"""Verify ProviderConfig accepts extra_query."""
|
||||||
|
|
||||||
|
def test_default_is_none(self) -> None:
|
||||||
|
config = ProviderConfig()
|
||||||
|
assert config.extra_query is None
|
||||||
|
|
||||||
|
def test_accepts_dict(self) -> None:
|
||||||
|
config = ProviderConfig(extra_query={"api-version": "2024-02-01"})
|
||||||
|
assert config.extra_query == {"api-version": "2024-02-01"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtraQueryInit:
|
||||||
|
"""Verify the provider stores extra_query from config."""
|
||||||
|
|
||||||
|
def test_default_is_empty(self) -> None:
|
||||||
|
provider = OpenAICompatProvider(api_key="test")
|
||||||
|
assert provider._extra_query == {}
|
||||||
|
|
||||||
|
def test_none_becomes_empty(self) -> None:
|
||||||
|
provider = OpenAICompatProvider(api_key="test", extra_query=None)
|
||||||
|
assert provider._extra_query == {}
|
||||||
|
|
||||||
|
def test_dict_stored(self) -> None:
|
||||||
|
query = {"api-version": "v1"}
|
||||||
|
provider = OpenAICompatProvider(api_key="test", extra_query=query)
|
||||||
|
assert provider._extra_query == query
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtraQueryBuildClient:
|
||||||
|
"""Verify extra_query flows into AsyncOpenAI default_query."""
|
||||||
|
|
||||||
|
def test_build_client_passes_default_query(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch(
|
||||||
|
"nanobot.providers.openai_compat_provider.AsyncOpenAI",
|
||||||
|
return_value=mock_client,
|
||||||
|
) as mock_async_openai:
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="test",
|
||||||
|
extra_query={"api-version": "v1"},
|
||||||
|
)
|
||||||
|
provider._build_client()
|
||||||
|
|
||||||
|
assert provider._client is mock_client
|
||||||
|
assert mock_async_openai.call_args.kwargs["default_query"] == {"api-version": "v1"}
|
||||||
|
|
||||||
|
def test_build_client_passes_no_default_query_when_empty(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch(
|
||||||
|
"nanobot.providers.openai_compat_provider.AsyncOpenAI",
|
||||||
|
return_value=mock_client,
|
||||||
|
) as mock_async_openai:
|
||||||
|
provider = OpenAICompatProvider(api_key="test")
|
||||||
|
provider._build_client()
|
||||||
|
|
||||||
|
assert provider._client is mock_client
|
||||||
|
kwargs = mock_async_openai.call_args.kwargs
|
||||||
|
assert "default_query" not in kwargs or kwargs["default_query"] is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderSignatureIncludesExtraQuery:
|
||||||
|
"""Verify provider_signature tracks provider extra_query changes."""
|
||||||
|
|
||||||
|
def test_provider_signature_tracks_extra_query(self) -> None:
|
||||||
|
base = {
|
||||||
|
"agents": {"defaults": {"modelPreset": "fast"}},
|
||||||
|
"modelPresets": {
|
||||||
|
"fast": {"model": "custom/test-model", "provider": "custom"},
|
||||||
|
},
|
||||||
|
"providers": {
|
||||||
|
"custom": {
|
||||||
|
"apiKey": "test-key",
|
||||||
|
"extra_query": None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
changed_query = {
|
||||||
|
**base,
|
||||||
|
"providers": {
|
||||||
|
"custom": {
|
||||||
|
"apiKey": "test-key",
|
||||||
|
"extra_query": {"api-version": "v1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
signature = provider_signature(Config.model_validate(base))
|
||||||
|
|
||||||
|
assert signature != provider_signature(Config.model_validate(changed_query))
|
||||||
Loading…
x
Reference in New Issue
Block a user