mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
feat(image-generation): add Gemini provider support
Adds GeminiImageGenerationClient covering both Imagen 4 (:predict) and Gemini Flash (:generateContent), wires the gemini ProviderConfig through the SDK, API server, and gateway entry points, and updates the image-generation docs and skill. Errors from the Gemini endpoints are logged and surface with the HTTP status and parsed message instead of an empty string. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
4e0d872588
commit
7367741ac1
@ -48,6 +48,28 @@ AIHubMix example:
|
||||
}
|
||||
```
|
||||
|
||||
Gemini example (Imagen 4):
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"gemini": {
|
||||
"apiKey": "${GEMINI_API_KEY}"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"imageGeneration": {
|
||||
"enabled": true,
|
||||
"provider": "gemini",
|
||||
"model": "imagen-4.0-generate-001",
|
||||
"defaultAspectRatio": "1:1"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
For Gemini Flash (which supports reference-image edits) see the [Gemini](#gemini) section below.
|
||||
|
||||
> [!TIP]
|
||||
> Prefer environment variables for API keys. nanobot resolves `${VAR_NAME}` values from the environment at startup.
|
||||
|
||||
@ -69,7 +91,7 @@ The WebUI hides provider storage details from the user. The agent sees the saved
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `tools.imageGeneration.enabled` | boolean | `false` | Register the `generate_image` tool |
|
||||
| `tools.imageGeneration.provider` | string | `"openrouter"` | Image provider name. Currently `openrouter` and `aihubmix` are supported |
|
||||
| `tools.imageGeneration.provider` | string | `"openrouter"` | Image provider name. Supported values: `openrouter`, `aihubmix`, `gemini` |
|
||||
| `tools.imageGeneration.model` | string | `"openai/gpt-5.4-image-2"` | Provider model name |
|
||||
| `tools.imageGeneration.defaultAspectRatio` | string | `"1:1"` | Default ratio when the prompt/tool call does not specify one |
|
||||
| `tools.imageGeneration.defaultImageSize` | string | `"1K"` | Default size hint, for example `1K`, `2K`, `4K`, or `1024x1024` |
|
||||
@ -139,6 +161,36 @@ Configure:
|
||||
|
||||
`quality: low` is optional. It can make free image models faster and less likely to time out, but it is not required for correctness.
|
||||
|
||||
### Gemini
|
||||
|
||||
nanobot supports two Gemini image generation model families via Google's Generative Language API:
|
||||
|
||||
| Model | Endpoint | Reference images |
|
||||
|-------|----------|-----------------|
|
||||
| `imagen-4.0-generate-001` | `:predict` | Not supported by this integration |
|
||||
| `gemini-2.5-flash-image` | `:generateContent` | Supported |
|
||||
|
||||
For reference-image edits, use a Gemini Flash image model:
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"gemini": {
|
||||
"apiKey": "${GEMINI_API_KEY}"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"imageGeneration": {
|
||||
"enabled": true,
|
||||
"provider": "gemini",
|
||||
"model": "gemini-2.5-flash-image"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Imagen 4 supports the aspect ratios `1:1`, `9:16`, `16:9`, `3:4`, and `4:3`. Unsupported ratios are ignored and the model uses its default. The `defaultImageSize` setting has no effect on Gemini models; sizing is controlled by `defaultAspectRatio` only. Reference images passed with an Imagen model are ignored (with a warning logged).
|
||||
|
||||
## Artifacts
|
||||
|
||||
Generated images are stored under the active nanobot instance's media directory:
|
||||
@ -193,7 +245,7 @@ Use the reference image. Keep the same robot and composition, change the palette
|
||||
|---------|-------|
|
||||
| `generate_image` is not available | Set `tools.imageGeneration.enabled` to `true` and restart the gateway |
|
||||
| Missing API key error | Configure `providers.<provider>.apiKey`; if using `${VAR_NAME}`, confirm the environment variable is visible to the gateway process |
|
||||
| `unsupported image generation provider` | Use `openrouter` or `aihubmix` |
|
||||
| `unsupported image generation provider` | Use `openrouter`, `aihubmix`, or `gemini` |
|
||||
| AIHubMix says `Incorrect model ID` | Use `model: "gpt-image-2-free"`; nanobot expands it to the required `openai/gpt-image-2-free` model path internally |
|
||||
| Generation times out | Try a smaller/default image size, set AIHubMix `extraBody.quality` to `"low"`, or retry later |
|
||||
| Reference image rejected | Reference image paths must be inside the workspace or nanobot media directory and must be valid image files |
|
||||
|
||||
@ -18,6 +18,7 @@ from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
from nanobot.providers.image_generation import (
|
||||
AIHubMixImageGenerationClient,
|
||||
GeminiImageGenerationClient,
|
||||
ImageGenerationError,
|
||||
MiniMaxImageGenerationClient,
|
||||
OpenRouterImageGenerationClient,
|
||||
@ -120,7 +121,7 @@ class ImageGenerationTool(Tool):
|
||||
|
||||
def _provider_client(
|
||||
self,
|
||||
) -> OpenRouterImageGenerationClient | AIHubMixImageGenerationClient | MiniMaxImageGenerationClient | None:
|
||||
) -> OpenRouterImageGenerationClient | AIHubMixImageGenerationClient | MiniMaxImageGenerationClient | GeminiImageGenerationClient | None:
|
||||
provider = self._provider_config()
|
||||
kwargs = {
|
||||
"api_key": provider.api_key if provider else None,
|
||||
@ -134,6 +135,8 @@ class ImageGenerationTool(Tool):
|
||||
return AIHubMixImageGenerationClient(**kwargs)
|
||||
if self.config.provider == "minimax":
|
||||
return MiniMaxImageGenerationClient(**kwargs)
|
||||
if self.config.provider == "gemini":
|
||||
return GeminiImageGenerationClient(**kwargs)
|
||||
return None
|
||||
|
||||
def _missing_api_key_error(self) -> str:
|
||||
@ -144,6 +147,8 @@ class ImageGenerationTool(Tool):
|
||||
return "Error: AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
|
||||
if provider == "minimax":
|
||||
return "Error: MiniMax API key is not configured. Set providers.minimax.apiKey."
|
||||
if provider == "gemini":
|
||||
return "Error: Gemini API key is not configured. Set providers.gemini.apiKey."
|
||||
return f"Error: {provider} API key is not configured."
|
||||
|
||||
def _resolve_reference_image(self, value: str) -> str:
|
||||
|
||||
@ -643,6 +643,7 @@ def serve(
|
||||
"openrouter": runtime_config.providers.openrouter,
|
||||
"aihubmix": runtime_config.providers.aihubmix,
|
||||
"minimax": runtime_config.providers.minimax,
|
||||
"gemini": runtime_config.providers.gemini,
|
||||
},
|
||||
)
|
||||
except ValueError as exc:
|
||||
@ -757,6 +758,7 @@ def _run_gateway(
|
||||
"openrouter": config.providers.openrouter,
|
||||
"aihubmix": config.providers.aihubmix,
|
||||
"minimax": config.providers.minimax,
|
||||
"gemini": config.providers.gemini,
|
||||
},
|
||||
provider_snapshot_loader=load_provider_snapshot,
|
||||
runtime_model_publisher=lambda model, preset: publish_runtime_model_update(
|
||||
|
||||
@ -67,6 +67,7 @@ class Nanobot:
|
||||
"openrouter": config.providers.openrouter,
|
||||
"aihubmix": config.providers.aihubmix,
|
||||
"minimax": config.providers.minimax,
|
||||
"gemini": config.providers.gemini,
|
||||
},
|
||||
)
|
||||
return cls(loop)
|
||||
|
||||
@ -8,6 +8,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.registry import find_by_name
|
||||
from nanobot.utils.helpers import detect_image_mime
|
||||
@ -26,6 +27,8 @@ _AIHUBMIX_ASPECT_RATIO_SIZES = {
|
||||
"4:3": "1536x1024",
|
||||
"16:9": "1536x1024",
|
||||
}
|
||||
_GEMINI_DEFAULT_TIMEOUT_S = 120.0
|
||||
_GEMINI_IMAGEN_ASPECT_RATIOS = {"1:1", "9:16", "16:9", "3:4", "4:3"}
|
||||
|
||||
|
||||
class ImageGenerationError(RuntimeError):
|
||||
@ -50,17 +53,28 @@ def _provider_base_url(provider: str, api_base: str | None, fallback: str) -> st
|
||||
return fallback
|
||||
|
||||
|
||||
def image_path_to_data_url(path: str | Path) -> str:
|
||||
"""Convert a local image path to an image data URL."""
|
||||
def _read_image_b64(path: str | Path) -> tuple[str, str]:
|
||||
"""Return ``(mime, base64)`` for the image at ``path``."""
|
||||
p = Path(path).expanduser()
|
||||
raw = p.read_bytes()
|
||||
mime = detect_image_mime(raw)
|
||||
if mime is None:
|
||||
raise ImageGenerationError(f"unsupported reference image: {p}")
|
||||
encoded = base64.b64encode(raw).decode("ascii")
|
||||
return mime, base64.b64encode(raw).decode("ascii")
|
||||
|
||||
|
||||
def image_path_to_data_url(path: str | Path) -> str:
|
||||
"""Convert a local image path to an image data URL."""
|
||||
mime, encoded = _read_image_b64(path)
|
||||
return f"data:{mime};base64,{encoded}"
|
||||
|
||||
|
||||
def image_path_to_inline_data(path: str | Path) -> dict[str, str]:
|
||||
"""Convert a local image path to a Gemini ``inlineData`` payload dict."""
|
||||
mime, encoded = _read_image_b64(path)
|
||||
return {"mimeType": mime, "data": encoded}
|
||||
|
||||
|
||||
def _b64_png_data_url(value: str) -> str:
|
||||
return f"data:image/png;base64,{value}"
|
||||
|
||||
@ -341,6 +355,203 @@ class AIHubMixImageGenerationClient:
|
||||
return GeneratedImageResponse(images=images, content="", raw=payload)
|
||||
|
||||
|
||||
def _http_error_detail(response: httpx.Response) -> str:
|
||||
"""Extract a readable error message from an HTTP error response."""
|
||||
try:
|
||||
data = response.json()
|
||||
if isinstance(data, dict):
|
||||
err = data.get("error")
|
||||
if isinstance(err, dict):
|
||||
return err.get("message") or str(err)
|
||||
if err:
|
||||
return str(err)
|
||||
except Exception:
|
||||
pass
|
||||
return response.text[:500] or "<empty response body>"
|
||||
|
||||
|
||||
class GeminiImageGenerationClient:
|
||||
"""Async client for Gemini/Imagen image generation via the Generative Language API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict[str, Any] | None = None,
|
||||
timeout: float = _GEMINI_DEFAULT_TIMEOUT_S,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
# The Gemini provider's registry default_api_base is the OpenAI-compat
|
||||
# shim (.../v1beta/openai/), which has no image endpoints. Image
|
||||
# generation needs the native Generative Language API base, so we don't
|
||||
# use _provider_base_url() here.
|
||||
self.api_base = (
|
||||
api_base or "https://generativelanguage.googleapis.com/v1beta"
|
||||
).rstrip("/")
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.extra_body = extra_body or {}
|
||||
self.timeout = timeout
|
||||
self._client = client
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
reference_images: list[str] | None = None,
|
||||
aspect_ratio: str | None = None,
|
||||
image_size: str | None = None,
|
||||
) -> GeneratedImageResponse:
|
||||
if not self.api_key:
|
||||
raise ImageGenerationError(
|
||||
"Gemini API key is not configured. Set providers.gemini.apiKey."
|
||||
)
|
||||
if "imagen" in model.lower():
|
||||
if reference_images:
|
||||
logger.warning(
|
||||
"Imagen models do not support reference images; "
|
||||
"ignoring {} reference image(s) for {}",
|
||||
len(reference_images),
|
||||
model,
|
||||
)
|
||||
return await self._generate_imagen(
|
||||
prompt=prompt, model=model, aspect_ratio=aspect_ratio
|
||||
)
|
||||
return await self._generate_gemini_flash(
|
||||
prompt=prompt, model=model, reference_images=reference_images or []
|
||||
)
|
||||
|
||||
async def _generate_imagen(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
aspect_ratio: str | None,
|
||||
) -> GeneratedImageResponse:
|
||||
parameters: dict[str, Any] = {"sampleCount": 1}
|
||||
if aspect_ratio in _GEMINI_IMAGEN_ASPECT_RATIOS:
|
||||
parameters["aspectRatio"] = aspect_ratio
|
||||
body: dict[str, Any] = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": parameters,
|
||||
}
|
||||
body.update(self.extra_body)
|
||||
|
||||
url = f"{self.api_base}/models/{model}:predict"
|
||||
headers = {
|
||||
"x-goog-api-key": self.api_key or "",
|
||||
"Content-Type": "application/json",
|
||||
**self.extra_headers,
|
||||
}
|
||||
|
||||
if self._client is not None:
|
||||
response = await self._client.post(url, headers=headers, json=body)
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
detail = _http_error_detail(response)
|
||||
logger.error("Gemini Imagen generation failed (HTTP {}): {}", response.status_code, detail)
|
||||
raise ImageGenerationError(
|
||||
f"Gemini Imagen generation failed (HTTP {response.status_code}): {detail}"
|
||||
) from exc
|
||||
|
||||
data = response.json()
|
||||
images: list[str] = []
|
||||
for prediction in data.get("predictions") or []:
|
||||
if not isinstance(prediction, dict):
|
||||
continue
|
||||
b64 = prediction.get("bytesBase64Encoded")
|
||||
mime = prediction.get("mimeType", "image/png")
|
||||
if isinstance(b64, str) and b64:
|
||||
images.append(f"data:{mime};base64,{b64}")
|
||||
|
||||
if not images:
|
||||
provider_error = data.get("error") if isinstance(data, dict) else None
|
||||
if provider_error:
|
||||
raise ImageGenerationError(f"Gemini Imagen returned no images: {provider_error}")
|
||||
raise ImageGenerationError("Gemini Imagen returned no images for this request")
|
||||
|
||||
return GeneratedImageResponse(images=images, content="", raw=data)
|
||||
|
||||
async def _generate_gemini_flash(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
reference_images: list[str],
|
||||
) -> GeneratedImageResponse:
|
||||
parts: list[dict[str, Any]] = [
|
||||
{"inlineData": image_path_to_inline_data(path)} for path in reference_images
|
||||
]
|
||||
parts.append({"text": prompt})
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"contents": [{"role": "user", "parts": parts}],
|
||||
"generationConfig": {"responseModalities": ["TEXT", "IMAGE"]},
|
||||
}
|
||||
body.update(self.extra_body)
|
||||
|
||||
url = f"{self.api_base}/models/{model}:generateContent"
|
||||
headers = {
|
||||
"x-goog-api-key": self.api_key or "",
|
||||
"Content-Type": "application/json",
|
||||
**self.extra_headers,
|
||||
}
|
||||
|
||||
if self._client is not None:
|
||||
response = await self._client.post(url, headers=headers, json=body)
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
detail = _http_error_detail(response)
|
||||
logger.error("Gemini image generation failed (HTTP {}): {}", response.status_code, detail)
|
||||
raise ImageGenerationError(
|
||||
f"Gemini image generation failed (HTTP {response.status_code}): {detail}"
|
||||
) from exc
|
||||
|
||||
data = response.json()
|
||||
images: list[str] = []
|
||||
text_parts: list[str] = []
|
||||
for candidate in data.get("candidates") or []:
|
||||
if not isinstance(candidate, dict):
|
||||
continue
|
||||
content = candidate.get("content") or {}
|
||||
for part in content.get("parts") or []:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
if "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
inline = part.get("inlineData")
|
||||
if isinstance(inline, dict):
|
||||
mime = inline.get("mimeType", "image/png")
|
||||
b64 = inline.get("data", "")
|
||||
if b64:
|
||||
images.append(f"data:{mime};base64,{b64}")
|
||||
|
||||
if not images:
|
||||
provider_error = data.get("error") if isinstance(data, dict) else None
|
||||
if provider_error:
|
||||
raise ImageGenerationError(f"Gemini returned no images: {provider_error}")
|
||||
raise ImageGenerationError("Gemini returned no images for this request")
|
||||
|
||||
return GeneratedImageResponse(
|
||||
images=images,
|
||||
content="\n".join(t for t in text_parts if t).strip(),
|
||||
raw=data,
|
||||
)
|
||||
|
||||
|
||||
async def _aihubmix_images_from_payload(
|
||||
client: httpx.AsyncClient,
|
||||
payload: dict[str, Any],
|
||||
|
||||
@ -88,6 +88,27 @@ AIHubMix `gpt-image-2-free` uses AIHubMix's unified predictions endpoint interna
|
||||
|
||||
`providers.aihubmix.extraBody` can be used for provider-specific options. For example, `"extraBody": {"quality": "low"}` is optional but can make `gpt-image-2-free` faster and less likely to time out.
|
||||
|
||||
For Gemini, the image tool supports two model families. Imagen 4 (`imagen-4.0-generate-001`) supports text-to-image only. Gemini Flash (`gemini-2.5-flash-image`) also supports reference-image edits. Configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"gemini": {
|
||||
"apiKey": "AIza..."
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"imageGeneration": {
|
||||
"enabled": true,
|
||||
"provider": "gemini",
|
||||
"model": "imagen-4.0-generate-001"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
For Gemini models, `defaultImageSize` has no effect; use `defaultAspectRatio` instead. Imagen 4 supports `1:1`, `9:16`, `16:9`, `3:4`, and `4:3`.
|
||||
|
||||
## Examples
|
||||
|
||||
Generate a new image:
|
||||
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
|
||||
from nanobot.providers.image_generation import (
|
||||
AIHubMixImageGenerationClient,
|
||||
GeminiImageGenerationClient,
|
||||
GeneratedImageResponse,
|
||||
ImageGenerationError,
|
||||
OpenRouterImageGenerationClient,
|
||||
@ -202,3 +203,137 @@ async def test_aihubmix_image_generation_downloads_url_response() -> None:
|
||||
|
||||
assert response.images[0].startswith("data:image/png;base64,")
|
||||
assert fake.get_calls[0]["url"] == "https://cdn.example/image.png"
|
||||
|
||||
|
||||
RAW_B64 = PNG_DATA_URL.removeprefix("data:image/png;base64,")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_imagen_payload_and_response() -> None:
|
||||
fake = FakeClient(
|
||||
FakeResponse({"predictions": [{"bytesBase64Encoded": RAW_B64, "mimeType": "image/png"}]})
|
||||
)
|
||||
client = GeminiImageGenerationClient(
|
||||
api_key="AIza-test",
|
||||
api_base="https://generativelanguage.googleapis.com/v1beta",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(
|
||||
prompt="a sunset",
|
||||
model="imagen-4.0-generate-001",
|
||||
aspect_ratio="16:9",
|
||||
)
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
assert response.content == ""
|
||||
call = fake.calls[0]
|
||||
assert call["url"].endswith(":predict")
|
||||
assert call["headers"]["x-goog-api-key"] == "AIza-test"
|
||||
assert "params" not in call
|
||||
body = call["json"]
|
||||
assert body["instances"] == [{"prompt": "a sunset"}]
|
||||
assert body["parameters"]["sampleCount"] == 1
|
||||
assert body["parameters"]["aspectRatio"] == "16:9"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_imagen_ignores_unsupported_aspect_ratio() -> None:
|
||||
fake = FakeClient(
|
||||
FakeResponse({"predictions": [{"bytesBase64Encoded": RAW_B64, "mimeType": "image/png"}]})
|
||||
)
|
||||
client = GeminiImageGenerationClient(api_key="AIza-test", client=fake) # type: ignore[arg-type]
|
||||
|
||||
await client.generate(prompt="a sunset", model="imagen-4.0-generate-001", aspect_ratio="2:3")
|
||||
|
||||
body = fake.calls[0]["json"]
|
||||
assert "aspectRatio" not in body["parameters"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_flash_payload_and_response() -> None:
|
||||
fake = FakeClient(
|
||||
FakeResponse(
|
||||
{
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{"text": "here is your image"},
|
||||
{"inlineData": {"mimeType": "image/png", "data": RAW_B64}},
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
)
|
||||
client = GeminiImageGenerationClient(
|
||||
api_key="AIza-test",
|
||||
api_base="https://generativelanguage.googleapis.com/v1beta",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(
|
||||
prompt="draw a cat",
|
||||
model="gemini-2.0-flash-preview-image-generation",
|
||||
)
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
assert response.content == "here is your image"
|
||||
call = fake.calls[0]
|
||||
assert call["url"].endswith(":generateContent")
|
||||
assert call["headers"]["x-goog-api-key"] == "AIza-test"
|
||||
assert "params" not in call
|
||||
body = call["json"]
|
||||
assert body["generationConfig"]["responseModalities"] == ["TEXT", "IMAGE"]
|
||||
assert body["contents"][0]["parts"][-1] == {"text": "draw a cat"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_flash_reference_images(tmp_path: Path) -> None:
|
||||
ref = tmp_path / "ref.png"
|
||||
ref.write_bytes(PNG_BYTES)
|
||||
fake = FakeClient(
|
||||
FakeResponse(
|
||||
{
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [{"inlineData": {"mimeType": "image/png", "data": RAW_B64}}]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
)
|
||||
client = GeminiImageGenerationClient(api_key="AIza-test", client=fake) # type: ignore[arg-type]
|
||||
|
||||
response = await client.generate(
|
||||
prompt="edit this",
|
||||
model="gemini-2.0-flash-preview-image-generation",
|
||||
reference_images=[str(ref)],
|
||||
)
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
parts = fake.calls[0]["json"]["contents"][0]["parts"]
|
||||
assert parts[0]["inlineData"]["mimeType"] == "image/png"
|
||||
assert parts[0]["inlineData"]["data"].startswith("iVBOR")
|
||||
assert parts[1] == {"text": "edit this"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_requires_api_key() -> None:
|
||||
client = GeminiImageGenerationClient(api_key=None)
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="API key"):
|
||||
await client.generate(prompt="draw", model="imagen-4.0-generate-001")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_no_images_raises() -> None:
|
||||
fake = FakeClient(FakeResponse({"candidates": [{"content": {"parts": [{"text": "sorry"}]}}]}))
|
||||
client = GeminiImageGenerationClient(api_key="AIza-test", client=fake) # type: ignore[arg-type]
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="returned no images"):
|
||||
await client.generate(prompt="draw", model="gemini-2.0-flash-preview-image-generation")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user