diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 1ca13c4f2..0a19fbfd4 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -183,6 +183,7 @@ class ProviderConfig(Base): api_type: Literal["auto", "chat_completions", "responses"] = "auto" # Request API surface extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix) extra_body: dict[str, Any] | None = None # Extra provider request fields; shape depends on provider/API surface + extra_query: dict[str, str] | None = None # Extra query params (e.g. api-version for Azure-style gateways) class BedrockProviderConfig(ProviderConfig): diff --git a/nanobot/providers/factory.py b/nanobot/providers/factory.py index e8275f93a..2e6b68c7d 100644 --- a/nanobot/providers/factory.py +++ b/nanobot/providers/factory.py @@ -99,6 +99,7 @@ def _make_provider_core( spec=spec, extra_body=p.extra_body if p else None, api_type=p.api_type if p and provider_name == "openai" else "auto", + extra_query=p.extra_query if p else None, ) provider.generation = resolved.to_generation_settings() @@ -185,6 +186,7 @@ def provider_signature( fp.extra_headers if fp else None, fp.extra_body if fp else None, fp.api_type if fp else "auto", + fp.extra_query if fp else None, getattr(fp, "region", None) if fp else None, getattr(fp, "profile", None) if fp else None, fallback.max_tokens, @@ -202,6 +204,7 @@ def provider_signature( p.extra_headers if p else None, p.extra_body if p else None, p.api_type if p else "auto", + p.extra_query if p else None, getattr(p, "region", None) if p else None, getattr(p, "profile", None) if p else None, resolved.max_tokens, diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 6fe00b327..a0eb35176 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -331,6 +331,7 @@ class OpenAICompatProvider(LLMProvider): spec: ProviderSpec | None = None, extra_body: dict[str, Any] | None = None, api_type: str = "auto", + extra_query: dict[str, str] | None = None, ): super().__init__(api_key, api_base) self.default_model = default_model @@ -338,6 +339,7 @@ class OpenAICompatProvider(LLMProvider): self._spec = spec self._extra_body = extra_body or {} self._api_type = api_type if spec and spec.name == "openai" else "auto" + self._extra_query = extra_query or {} if api_key and spec and spec.env_key: self._setup_env(api_key, api_base) @@ -386,6 +388,7 @@ class OpenAICompatProvider(LLMProvider): api_key=self._api_key_for_client, base_url=self._effective_base, default_headers=self._default_headers, + default_query=self._extra_query or None, max_retries=0, timeout=timeout_s, http_client=http_client, diff --git a/tests/agent/test_runner_fallback.py b/tests/agent/test_runner_fallback.py index 4ae161e4a..a7a6f7c30 100644 --- a/tests/agent/test_runner_fallback.py +++ b/tests/agent/test_runner_fallback.py @@ -241,7 +241,7 @@ def test_inline_fallback_reasoning_effort_does_not_inherit_primary() -> None: signature = provider_signature(config) fallback_signatures = signature[-1] - assert fallback_signatures[0][12] is None + assert fallback_signatures[0][13] is None # -- FallbackProvider tests -- diff --git a/tests/providers/test_extra_query_config.py b/tests/providers/test_extra_query_config.py new file mode 100644 index 000000000..79e985261 --- /dev/null +++ b/tests/providers/test_extra_query_config.py @@ -0,0 +1,101 @@ +"""Tests for provider extra_query config injection into client defaults.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from nanobot.config.schema import Config, ProviderConfig +from nanobot.providers.factory import provider_signature +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + + +class TestExtraQuerySchema: + """Verify ProviderConfig accepts extra_query.""" + + def test_default_is_none(self) -> None: + config = ProviderConfig() + assert config.extra_query is None + + def test_accepts_dict(self) -> None: + config = ProviderConfig(extra_query={"api-version": "2024-02-01"}) + assert config.extra_query == {"api-version": "2024-02-01"} + + +class TestExtraQueryInit: + """Verify the provider stores extra_query from config.""" + + def test_default_is_empty(self) -> None: + provider = OpenAICompatProvider(api_key="test") + assert provider._extra_query == {} + + def test_none_becomes_empty(self) -> None: + provider = OpenAICompatProvider(api_key="test", extra_query=None) + assert provider._extra_query == {} + + def test_dict_stored(self) -> None: + query = {"api-version": "v1"} + provider = OpenAICompatProvider(api_key="test", extra_query=query) + assert provider._extra_query == query + + +class TestExtraQueryBuildClient: + """Verify extra_query flows into AsyncOpenAI default_query.""" + + def test_build_client_passes_default_query(self) -> None: + mock_client = MagicMock() + with patch( + "nanobot.providers.openai_compat_provider.AsyncOpenAI", + return_value=mock_client, + ) as mock_async_openai: + provider = OpenAICompatProvider( + api_key="test", + extra_query={"api-version": "v1"}, + ) + provider._build_client() + + assert provider._client is mock_client + assert mock_async_openai.call_args.kwargs["default_query"] == {"api-version": "v1"} + + def test_build_client_passes_no_default_query_when_empty(self) -> None: + mock_client = MagicMock() + with patch( + "nanobot.providers.openai_compat_provider.AsyncOpenAI", + return_value=mock_client, + ) as mock_async_openai: + provider = OpenAICompatProvider(api_key="test") + provider._build_client() + + assert provider._client is mock_client + kwargs = mock_async_openai.call_args.kwargs + assert "default_query" not in kwargs or kwargs["default_query"] is None + + +class TestProviderSignatureIncludesExtraQuery: + """Verify provider_signature tracks provider extra_query changes.""" + + def test_provider_signature_tracks_extra_query(self) -> None: + base = { + "agents": {"defaults": {"modelPreset": "fast"}}, + "modelPresets": { + "fast": {"model": "custom/test-model", "provider": "custom"}, + }, + "providers": { + "custom": { + "apiKey": "test-key", + "extra_query": None, + }, + }, + } + changed_query = { + **base, + "providers": { + "custom": { + "apiKey": "test-key", + "extra_query": {"api-version": "v1"}, + }, + }, + } + + signature = provider_signature(Config.model_validate(base)) + + assert signature != provider_signature(Config.model_validate(changed_query))