mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-24 18:42:35 +00:00
fix Gemini image base and provider docs
This commit is contained in:
parent
72f999f8f7
commit
a7b34422f3
@ -129,6 +129,11 @@ _IMAGE_GEN_PROVIDERS: dict[str, type[ImageGenerationProvider]] = {}
|
|||||||
|
|
||||||
|
|
||||||
def register_image_gen_provider(cls: type[ImageGenerationProvider]) -> None:
|
def register_image_gen_provider(cls: type[ImageGenerationProvider]) -> None:
|
||||||
|
"""Register an image provider at import time only.
|
||||||
|
|
||||||
|
The registry is populated by module side effects so provider discovery
|
||||||
|
stays lazy and consistent across the process.
|
||||||
|
"""
|
||||||
name = cls.provider_name
|
name = cls.provider_name
|
||||||
if not name:
|
if not name:
|
||||||
raise ValueError(f"{cls.__name__} must set provider_name")
|
raise ValueError(f"{cls.__name__} must set provider_name")
|
||||||
@ -204,7 +209,7 @@ class ImageGenerationProvider(ABC):
|
|||||||
image_size: str | None = None,
|
image_size: str | None = None,
|
||||||
) -> GeneratedImageResponse: ...
|
) -> GeneratedImageResponse: ...
|
||||||
|
|
||||||
def _require_images(self, images: list[str], data: dict[str, Any]) -> None:
|
def _ensure_images(self, images: list[str], data: dict[str, Any]) -> None:
|
||||||
if images:
|
if images:
|
||||||
return
|
return
|
||||||
provider_error = data.get("error") if isinstance(data, dict) else None
|
provider_error = data.get("error") if isinstance(data, dict) else None
|
||||||
@ -311,7 +316,7 @@ class OpenRouterImageGenerationClient(ImageGenerationProvider):
|
|||||||
if isinstance(url_value, str) and url_value.startswith("data:image/"):
|
if isinstance(url_value, str) and url_value.startswith("data:image/"):
|
||||||
images.append(url_value)
|
images.append(url_value)
|
||||||
|
|
||||||
self._require_images(images, data)
|
self._ensure_images(images, data)
|
||||||
|
|
||||||
return GeneratedImageResponse(
|
return GeneratedImageResponse(
|
||||||
images=images,
|
images=images,
|
||||||
@ -413,7 +418,7 @@ class AIHubMixImageGenerationClient(ImageGenerationProvider):
|
|||||||
payload = response.json()
|
payload = response.json()
|
||||||
images = await _aihubmix_images_from_payload(client, payload)
|
images = await _aihubmix_images_from_payload(client, payload)
|
||||||
|
|
||||||
self._require_images(images, payload)
|
self._ensure_images(images, payload)
|
||||||
|
|
||||||
return GeneratedImageResponse(images=images, content="", raw=payload)
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
||||||
|
|
||||||
@ -446,9 +451,9 @@ class GeminiImageGenerationClient(ImageGenerationProvider):
|
|||||||
return "https://generativelanguage.googleapis.com/v1beta"
|
return "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
|
||||||
def _resolve_base_url(self, api_base: str | None) -> str:
|
def _resolve_base_url(self, api_base: str | None) -> str:
|
||||||
# The Gemini provider's registry default_api_base is the OpenAI-compat
|
# Gemini chat completions use the registry's OpenAI-compatible shim.
|
||||||
# shim (.../v1beta/openai/), which has no image endpoints.
|
# Image generation must hit the native Generative Language API, so we
|
||||||
# Skip the registry lookup and use the native API base directly.
|
# intentionally bypass the shared registry lookup here.
|
||||||
if api_base:
|
if api_base:
|
||||||
return api_base.rstrip("/")
|
return api_base.rstrip("/")
|
||||||
return self._default_base_url()
|
return self._default_base_url()
|
||||||
@ -522,7 +527,7 @@ class GeminiImageGenerationClient(ImageGenerationProvider):
|
|||||||
if isinstance(b64, str) and b64:
|
if isinstance(b64, str) and b64:
|
||||||
images.append(f"data:{mime};base64,{b64}")
|
images.append(f"data:{mime};base64,{b64}")
|
||||||
|
|
||||||
self._require_images(images, data)
|
self._ensure_images(images, data)
|
||||||
|
|
||||||
return GeneratedImageResponse(images=images, content="", raw=data)
|
return GeneratedImageResponse(images=images, content="", raw=data)
|
||||||
|
|
||||||
@ -580,7 +585,7 @@ class GeminiImageGenerationClient(ImageGenerationProvider):
|
|||||||
if b64:
|
if b64:
|
||||||
images.append(f"data:{mime};base64,{b64}")
|
images.append(f"data:{mime};base64,{b64}")
|
||||||
|
|
||||||
self._require_images(images, data)
|
self._ensure_images(images, data)
|
||||||
|
|
||||||
return GeneratedImageResponse(
|
return GeneratedImageResponse(
|
||||||
images=images,
|
images=images,
|
||||||
@ -734,7 +739,7 @@ class MiniMaxImageGenerationClient(ImageGenerationProvider):
|
|||||||
payload = response.json()
|
payload = response.json()
|
||||||
images = _minimax_images_from_payload(payload)
|
images = _minimax_images_from_payload(payload)
|
||||||
|
|
||||||
self._require_images(images, payload)
|
self._ensure_images(images, payload)
|
||||||
|
|
||||||
return GeneratedImageResponse(images=images, content="", raw=payload)
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
||||||
|
|
||||||
@ -840,7 +845,7 @@ class StepFunImageGenerationClient(ImageGenerationProvider):
|
|||||||
payload = response.json()
|
payload = response.json()
|
||||||
images = _stepfun_images_from_payload(payload)
|
images = _stepfun_images_from_payload(payload)
|
||||||
|
|
||||||
self._require_images(images, payload)
|
self._ensure_images(images, payload)
|
||||||
|
|
||||||
return GeneratedImageResponse(images=images, content="", raw=payload)
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
||||||
|
|
||||||
|
|||||||
@ -348,6 +348,11 @@ async def test_gemini_requires_api_key() -> None:
|
|||||||
await client.generate(prompt="draw", model="imagen-4.0-generate-001")
|
await client.generate(prompt="draw", model="imagen-4.0-generate-001")
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemini_image_client_uses_native_api_base_by_default() -> None:
|
||||||
|
client = GeminiImageGenerationClient(api_key="AIza-test")
|
||||||
|
assert client.api_base == "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_no_images_raises() -> None:
|
async def test_gemini_no_images_raises() -> None:
|
||||||
fake = FakeClient(FakeResponse({"candidates": [{"content": {"parts": [{"text": "sorry"}]}}]}))
|
fake = FakeClient(FakeResponse({"candidates": [{"content": {"parts": [{"text": "sorry"}]}}]}))
|
||||||
|
|||||||
@ -449,6 +449,12 @@ def test_gemma_routes_to_gemini_provider() -> None:
|
|||||||
assert "gemma" in spec.keywords
|
assert "gemma" in spec.keywords
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemini_spec_keeps_openai_compat_base() -> None:
|
||||||
|
spec = find_by_name("gemini")
|
||||||
|
assert spec is not None
|
||||||
|
assert spec.default_api_base == "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||||
|
|
||||||
|
|
||||||
async def test_openrouter_sets_default_attribution_headers() -> None:
|
async def test_openrouter_sets_default_attribution_headers() -> None:
|
||||||
spec = find_by_name("openrouter")
|
spec = find_by_name("openrouter")
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_client_cls:
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_client_cls:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user