mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
fix(image-generation): align media delivery and mime handling
This commit is contained in:
parent
d7a73093a8
commit
44b7bba9bd
@ -31,8 +31,8 @@ from nanobot.config.paths import get_workspace_path
|
||||
media=ArraySchema(
|
||||
StringSchema(""),
|
||||
description=(
|
||||
"Optional list of existing file paths to attach for proactive or cross-channel delivery. "
|
||||
"Do not use this to resend generate_image outputs in the current chat."
|
||||
"Optional list of existing file paths to attach. "
|
||||
"Use artifact paths returned by generate_image here when delivering generated images."
|
||||
),
|
||||
),
|
||||
buttons=ArraySchema(
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@ -67,8 +68,16 @@ def image_path_to_inline_data(path: str | Path) -> dict[str, str]:
|
||||
return {"mimeType": mime, "data": encoded}
|
||||
|
||||
|
||||
def _b64_png_data_url(value: str) -> str:
|
||||
return f"data:image/png;base64,{value}"
|
||||
def _b64_image_data_url(value: str) -> str:
|
||||
encoded = "".join(value.split())
|
||||
try:
|
||||
raw = base64.b64decode(encoded, validate=True)
|
||||
except binascii.Error as exc:
|
||||
raise ImageGenerationError("generated image payload was not valid base64") from exc
|
||||
mime = detect_image_mime(raw)
|
||||
if mime is None:
|
||||
raise ImageGenerationError("generated image payload was not a supported image")
|
||||
return f"data:{mime};base64,{encoded}"
|
||||
|
||||
|
||||
def _aihubmix_size(aspect_ratio: str | None, image_size: str | None) -> str:
|
||||
@ -598,13 +607,13 @@ async def _aihubmix_images_from_payload(
|
||||
|
||||
b64_json = value.get("b64_json")
|
||||
if isinstance(b64_json, str) and b64_json:
|
||||
images.append(_b64_png_data_url(b64_json))
|
||||
images.append(_b64_image_data_url(b64_json))
|
||||
elif b64_json is not None:
|
||||
await collect(b64_json)
|
||||
|
||||
bytes_base64 = value.get("bytesBase64") or value.get("bytes_base64") or value.get("base64")
|
||||
if isinstance(bytes_base64, str) and bytes_base64:
|
||||
images.append(_b64_png_data_url(bytes_base64))
|
||||
images.append(_b64_image_data_url(bytes_base64))
|
||||
|
||||
image_url = value.get("image_url") or value.get("imageUrl")
|
||||
if isinstance(image_url, dict):
|
||||
@ -738,7 +747,7 @@ def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
||||
return images
|
||||
for b64 in data.get("image_base64") or []:
|
||||
if isinstance(b64, str) and b64:
|
||||
images.append(_b64_png_data_url(b64))
|
||||
images.append(_b64_image_data_url(b64))
|
||||
return images
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -11,6 +12,7 @@ from nanobot.providers.image_generation import (
|
||||
GeminiImageGenerationClient,
|
||||
GeneratedImageResponse,
|
||||
ImageGenerationError,
|
||||
MiniMaxImageGenerationClient,
|
||||
OpenRouterImageGenerationClient,
|
||||
)
|
||||
|
||||
@ -24,6 +26,7 @@ PNG_DATA_URL = (
|
||||
"data:image/png;base64,"
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+/p9sAAAAASUVORK5CYII="
|
||||
)
|
||||
JPEG_BYTES = b"\xff\xd8\xff\xe0" + b"0" * 12
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
@ -205,6 +208,20 @@ async def test_aihubmix_image_generation_downloads_url_response() -> None:
|
||||
assert fake.get_calls[0]["url"] == "https://cdn.example/image.png"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aihubmix_base64_response_uses_detected_mime() -> None:
|
||||
raw_b64 = base64.b64encode(JPEG_BYTES).decode("ascii")
|
||||
fake = FakeClient(FakeResponse({"output": {"b64_json": raw_b64}}))
|
||||
client = AIHubMixImageGenerationClient(
|
||||
api_key="sk-ahm-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(prompt="draw", model="gpt-image-2-free")
|
||||
|
||||
assert response.images == [f"data:image/jpeg;base64,{raw_b64}"]
|
||||
|
||||
|
||||
RAW_B64 = PNG_DATA_URL.removeprefix("data:image/png;base64,")
|
||||
|
||||
|
||||
@ -337,3 +354,36 @@ async def test_gemini_no_images_raises() -> None:
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="returned no images"):
|
||||
await client.generate(prompt="draw", model="gemini-2.0-flash-preview-image-generation")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_minimax_payload_and_response_with_reference_image(tmp_path: Path) -> None:
|
||||
ref = tmp_path / "ref.png"
|
||||
ref.write_bytes(PNG_BYTES)
|
||||
fake = FakeClient(FakeResponse({"data": {"image_base64": [RAW_B64]}}))
|
||||
client = MiniMaxImageGenerationClient(
|
||||
api_key="sk-mm-test",
|
||||
api_base="https://api.minimaxi.com/v1/",
|
||||
extra_headers={"X-Test": "1"},
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(
|
||||
prompt="draw a character",
|
||||
model="image-01",
|
||||
reference_images=[str(ref)],
|
||||
aspect_ratio="21:9",
|
||||
)
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
call = fake.calls[0]
|
||||
assert call["url"] == "https://api.minimaxi.com/v1/image_generation"
|
||||
assert call["headers"]["Authorization"] == "Bearer sk-mm-test"
|
||||
assert call["headers"]["X-Test"] == "1"
|
||||
body = call["json"]
|
||||
assert body["model"] == "image-01"
|
||||
assert body["prompt"] == "draw a character"
|
||||
assert body["response_format"] == "base64"
|
||||
assert body["aspect_ratio"] == "21:9"
|
||||
assert body["subject_reference"][0]["type"] == "character"
|
||||
assert body["subject_reference"][0]["image_file"].startswith("data:image/png;base64,")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user