diff --git a/nanobot/providers/image_generation.py b/nanobot/providers/image_generation.py index 6cab279b1..1316f1d43 100644 --- a/nanobot/providers/image_generation.py +++ b/nanobot/providers/image_generation.py @@ -139,6 +139,11 @@ _IMAGE_GEN_PROVIDERS: dict[str, type[ImageGenerationProvider]] = {} def register_image_gen_provider(cls: type[ImageGenerationProvider]) -> None: + """Register an image provider at import time only. + + The registry is populated by module side effects so provider discovery + stays lazy and consistent across the process. + """ name = cls.provider_name if not name: raise ValueError(f"{cls.__name__} must set provider_name") @@ -229,7 +234,10 @@ class ImageGenerationProvider(ABC): *, headers: dict[str, str], body: dict[str, Any], + client: httpx.AsyncClient | None = None, ) -> httpx.Response: + if client is not None: + return await client.post(url, headers=headers, json=body) if self._client is not None: return await self._client.post(url, headers=headers, json=body) async with httpx.AsyncClient(timeout=self.timeout) as c: @@ -400,10 +408,11 @@ class AIHubMixImageGenerationClient(ImageGenerationProvider): model_path = _aihubmix_model_path(model) url = f"{self.api_base}/models/{model_path}/predictions" try: - response = await client.post( + response = await self._http_post( url, headers={**headers, "Content-Type": "application/json"}, - json=body, + body=body, + client=client, ) except httpx.TimeoutException as exc: raise ImageGenerationError("AIHubMix image generation timed out") from exc @@ -585,9 +594,9 @@ class GeminiImageGenerationClient(ImageGenerationProvider): return "https://generativelanguage.googleapis.com/v1beta" def _resolve_base_url(self, api_base: str | None) -> str: - # The Gemini provider's registry default_api_base is the OpenAI-compat - # shim (.../v1beta/openai/), which has no image endpoints. - # Skip the registry lookup and use the native API base directly. + # Gemini chat completions use the registry's OpenAI-compatible shim. + # Image generation must hit the native Generative Language API, so we + # intentionally bypass the shared registry lookup here. if api_base: return api_base.rstrip("/") return self._default_base_url() @@ -849,22 +858,16 @@ class MiniMaxImageGenerationClient(ImageGenerationProvider): body.update(self.extra_body) - client = self._client or httpx.AsyncClient(timeout=self.timeout) - try: - return await self._generate_with_client(client, body, headers) - finally: - if self._client is None: - await client.aclose() + return await self._generate_with_client(body, headers) async def _generate_with_client( self, - client: httpx.AsyncClient, body: dict[str, Any], headers: dict[str, str], ) -> GeneratedImageResponse: url = f"{self.api_base}/image_generation" try: - response = await client.post(url, headers=headers, json=body) + response = await self._http_post(url, headers=headers, body=body) except httpx.TimeoutException as exc: raise ImageGenerationError("MiniMax image generation timed out") from exc except httpx.RequestError as exc: diff --git a/tests/providers/test_image_generation.py b/tests/providers/test_image_generation.py index f3ca1459c..77025895c 100644 --- a/tests/providers/test_image_generation.py +++ b/tests/providers/test_image_generation.py @@ -410,6 +410,11 @@ async def test_gemini_requires_api_key() -> None: await client.generate(prompt="draw", model="imagen-4.0-generate-001") +def test_gemini_image_client_uses_native_api_base_by_default() -> None: + client = GeminiImageGenerationClient(api_key="AIza-test") + assert client.api_base == "https://generativelanguage.googleapis.com/v1beta" + + @pytest.mark.asyncio async def test_gemini_no_images_raises() -> None: fake = FakeClient(FakeResponse({"candidates": [{"content": {"parts": [{"text": "sorry"}]}}]})) @@ -452,6 +457,17 @@ async def test_minimax_payload_and_response_with_reference_image(tmp_path: Path) assert body["subject_reference"][0]["image_file"].startswith("data:image/png;base64,") +@pytest.mark.asyncio +async def test_minimax_base64_response_uses_detected_mime() -> None: + raw_b64 = base64.b64encode(JPEG_BYTES).decode("ascii") + fake = FakeClient(FakeResponse({"data": {"image_base64": [raw_b64]}})) + client = MiniMaxImageGenerationClient(api_key="sk-mm-test", client=fake) # type: ignore[arg-type] + + response = await client.generate(prompt="draw", model="image-01") + + assert response.images == [f"data:image/jpeg;base64,{raw_b64}"] + + # --------------------------------------------------------------------------- # StepFun (阶跃星辰) # --------------------------------------------------------------------------- diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 924ee0060..6e00cba19 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -458,6 +458,12 @@ def test_gemma_routes_to_gemini_provider() -> None: assert "gemma" in spec.keywords +def test_gemini_spec_keeps_openai_compat_base() -> None: + spec = find_by_name("gemini") + assert spec is not None + assert spec.default_api_base == "https://generativelanguage.googleapis.com/v1beta/openai/" + + async def test_openrouter_sets_default_attribution_headers() -> None: spec = find_by_name("openrouter") with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_client_cls: