From a7b34422f3c61c78e66cda21c5c520298808af57 Mon Sep 17 00:00:00 2001 From: Haisam Abbas Date: Wed, 20 May 2026 14:06:55 +0500 Subject: [PATCH] fix Gemini image base and provider docs --- nanobot/providers/image_generation.py | 25 ++++++++++++++---------- tests/providers/test_image_generation.py | 5 +++++ tests/providers/test_litellm_kwargs.py | 6 ++++++ 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/nanobot/providers/image_generation.py b/nanobot/providers/image_generation.py index b267f5c00..8f25195cf 100644 --- a/nanobot/providers/image_generation.py +++ b/nanobot/providers/image_generation.py @@ -129,6 +129,11 @@ _IMAGE_GEN_PROVIDERS: dict[str, type[ImageGenerationProvider]] = {} def register_image_gen_provider(cls: type[ImageGenerationProvider]) -> None: + """Register an image provider at import time only. + + The registry is populated by module side effects so provider discovery + stays lazy and consistent across the process. + """ name = cls.provider_name if not name: raise ValueError(f"{cls.__name__} must set provider_name") @@ -204,7 +209,7 @@ class ImageGenerationProvider(ABC): image_size: str | None = None, ) -> GeneratedImageResponse: ... - def _require_images(self, images: list[str], data: dict[str, Any]) -> None: + def _ensure_images(self, images: list[str], data: dict[str, Any]) -> None: if images: return provider_error = data.get("error") if isinstance(data, dict) else None @@ -311,7 +316,7 @@ class OpenRouterImageGenerationClient(ImageGenerationProvider): if isinstance(url_value, str) and url_value.startswith("data:image/"): images.append(url_value) - self._require_images(images, data) + self._ensure_images(images, data) return GeneratedImageResponse( images=images, @@ -413,7 +418,7 @@ class AIHubMixImageGenerationClient(ImageGenerationProvider): payload = response.json() images = await _aihubmix_images_from_payload(client, payload) - self._require_images(images, payload) + self._ensure_images(images, payload) return GeneratedImageResponse(images=images, content="", raw=payload) @@ -446,9 +451,9 @@ class GeminiImageGenerationClient(ImageGenerationProvider): return "https://generativelanguage.googleapis.com/v1beta" def _resolve_base_url(self, api_base: str | None) -> str: - # The Gemini provider's registry default_api_base is the OpenAI-compat - # shim (.../v1beta/openai/), which has no image endpoints. - # Skip the registry lookup and use the native API base directly. + # Gemini chat completions use the registry's OpenAI-compatible shim. + # Image generation must hit the native Generative Language API, so we + # intentionally bypass the shared registry lookup here. if api_base: return api_base.rstrip("/") return self._default_base_url() @@ -522,7 +527,7 @@ class GeminiImageGenerationClient(ImageGenerationProvider): if isinstance(b64, str) and b64: images.append(f"data:{mime};base64,{b64}") - self._require_images(images, data) + self._ensure_images(images, data) return GeneratedImageResponse(images=images, content="", raw=data) @@ -580,7 +585,7 @@ class GeminiImageGenerationClient(ImageGenerationProvider): if b64: images.append(f"data:{mime};base64,{b64}") - self._require_images(images, data) + self._ensure_images(images, data) return GeneratedImageResponse( images=images, @@ -734,7 +739,7 @@ class MiniMaxImageGenerationClient(ImageGenerationProvider): payload = response.json() images = _minimax_images_from_payload(payload) - self._require_images(images, payload) + self._ensure_images(images, payload) return GeneratedImageResponse(images=images, content="", raw=payload) @@ -840,7 +845,7 @@ class StepFunImageGenerationClient(ImageGenerationProvider): payload = response.json() images = _stepfun_images_from_payload(payload) - self._require_images(images, payload) + self._ensure_images(images, payload) return GeneratedImageResponse(images=images, content="", raw=payload) diff --git a/tests/providers/test_image_generation.py b/tests/providers/test_image_generation.py index eea3a3fe6..c42d947d5 100644 --- a/tests/providers/test_image_generation.py +++ b/tests/providers/test_image_generation.py @@ -348,6 +348,11 @@ async def test_gemini_requires_api_key() -> None: await client.generate(prompt="draw", model="imagen-4.0-generate-001") +def test_gemini_image_client_uses_native_api_base_by_default() -> None: + client = GeminiImageGenerationClient(api_key="AIza-test") + assert client.api_base == "https://generativelanguage.googleapis.com/v1beta" + + @pytest.mark.asyncio async def test_gemini_no_images_raises() -> None: fake = FakeClient(FakeResponse({"candidates": [{"content": {"parts": [{"text": "sorry"}]}}]})) diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 5f2ffec59..76414ad35 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -449,6 +449,12 @@ def test_gemma_routes_to_gemini_provider() -> None: assert "gemma" in spec.keywords +def test_gemini_spec_keeps_openai_compat_base() -> None: + spec = find_by_name("gemini") + assert spec is not None + assert spec.default_api_base == "https://generativelanguage.googleapis.com/v1beta/openai/" + + async def test_openrouter_sets_default_attribution_headers() -> None: spec = find_by_name("openrouter") with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_client_cls: