mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
fix(image-generation): align media delivery and mime handling
This commit is contained in:
parent
d7a73093a8
commit
44b7bba9bd
@ -31,8 +31,8 @@ from nanobot.config.paths import get_workspace_path
|
|||||||
media=ArraySchema(
|
media=ArraySchema(
|
||||||
StringSchema(""),
|
StringSchema(""),
|
||||||
description=(
|
description=(
|
||||||
"Optional list of existing file paths to attach for proactive or cross-channel delivery. "
|
"Optional list of existing file paths to attach. "
|
||||||
"Do not use this to resend generate_image outputs in the current chat."
|
"Use artifact paths returned by generate_image here when delivering generated images."
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
buttons=ArraySchema(
|
buttons=ArraySchema(
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import binascii
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -67,8 +68,16 @@ def image_path_to_inline_data(path: str | Path) -> dict[str, str]:
|
|||||||
return {"mimeType": mime, "data": encoded}
|
return {"mimeType": mime, "data": encoded}
|
||||||
|
|
||||||
|
|
||||||
def _b64_png_data_url(value: str) -> str:
|
def _b64_image_data_url(value: str) -> str:
|
||||||
return f"data:image/png;base64,{value}"
|
encoded = "".join(value.split())
|
||||||
|
try:
|
||||||
|
raw = base64.b64decode(encoded, validate=True)
|
||||||
|
except binascii.Error as exc:
|
||||||
|
raise ImageGenerationError("generated image payload was not valid base64") from exc
|
||||||
|
mime = detect_image_mime(raw)
|
||||||
|
if mime is None:
|
||||||
|
raise ImageGenerationError("generated image payload was not a supported image")
|
||||||
|
return f"data:{mime};base64,{encoded}"
|
||||||
|
|
||||||
|
|
||||||
def _aihubmix_size(aspect_ratio: str | None, image_size: str | None) -> str:
|
def _aihubmix_size(aspect_ratio: str | None, image_size: str | None) -> str:
|
||||||
@ -598,13 +607,13 @@ async def _aihubmix_images_from_payload(
|
|||||||
|
|
||||||
b64_json = value.get("b64_json")
|
b64_json = value.get("b64_json")
|
||||||
if isinstance(b64_json, str) and b64_json:
|
if isinstance(b64_json, str) and b64_json:
|
||||||
images.append(_b64_png_data_url(b64_json))
|
images.append(_b64_image_data_url(b64_json))
|
||||||
elif b64_json is not None:
|
elif b64_json is not None:
|
||||||
await collect(b64_json)
|
await collect(b64_json)
|
||||||
|
|
||||||
bytes_base64 = value.get("bytesBase64") or value.get("bytes_base64") or value.get("base64")
|
bytes_base64 = value.get("bytesBase64") or value.get("bytes_base64") or value.get("base64")
|
||||||
if isinstance(bytes_base64, str) and bytes_base64:
|
if isinstance(bytes_base64, str) and bytes_base64:
|
||||||
images.append(_b64_png_data_url(bytes_base64))
|
images.append(_b64_image_data_url(bytes_base64))
|
||||||
|
|
||||||
image_url = value.get("image_url") or value.get("imageUrl")
|
image_url = value.get("image_url") or value.get("imageUrl")
|
||||||
if isinstance(image_url, dict):
|
if isinstance(image_url, dict):
|
||||||
@ -738,7 +747,7 @@ def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
|||||||
return images
|
return images
|
||||||
for b64 in data.get("image_base64") or []:
|
for b64 in data.get("image_base64") or []:
|
||||||
if isinstance(b64, str) and b64:
|
if isinstance(b64, str) and b64:
|
||||||
images.append(_b64_png_data_url(b64))
|
images.append(_b64_image_data_url(b64))
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -11,6 +12,7 @@ from nanobot.providers.image_generation import (
|
|||||||
GeminiImageGenerationClient,
|
GeminiImageGenerationClient,
|
||||||
GeneratedImageResponse,
|
GeneratedImageResponse,
|
||||||
ImageGenerationError,
|
ImageGenerationError,
|
||||||
|
MiniMaxImageGenerationClient,
|
||||||
OpenRouterImageGenerationClient,
|
OpenRouterImageGenerationClient,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -24,6 +26,7 @@ PNG_DATA_URL = (
|
|||||||
"data:image/png;base64,"
|
"data:image/png;base64,"
|
||||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+/p9sAAAAASUVORK5CYII="
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+/p9sAAAAASUVORK5CYII="
|
||||||
)
|
)
|
||||||
|
JPEG_BYTES = b"\xff\xd8\xff\xe0" + b"0" * 12
|
||||||
|
|
||||||
|
|
||||||
class FakeResponse:
|
class FakeResponse:
|
||||||
@ -205,6 +208,20 @@ async def test_aihubmix_image_generation_downloads_url_response() -> None:
|
|||||||
assert fake.get_calls[0]["url"] == "https://cdn.example/image.png"
|
assert fake.get_calls[0]["url"] == "https://cdn.example/image.png"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aihubmix_base64_response_uses_detected_mime() -> None:
|
||||||
|
raw_b64 = base64.b64encode(JPEG_BYTES).decode("ascii")
|
||||||
|
fake = FakeClient(FakeResponse({"output": {"b64_json": raw_b64}}))
|
||||||
|
client = AIHubMixImageGenerationClient(
|
||||||
|
api_key="sk-ahm-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.generate(prompt="draw", model="gpt-image-2-free")
|
||||||
|
|
||||||
|
assert response.images == [f"data:image/jpeg;base64,{raw_b64}"]
|
||||||
|
|
||||||
|
|
||||||
RAW_B64 = PNG_DATA_URL.removeprefix("data:image/png;base64,")
|
RAW_B64 = PNG_DATA_URL.removeprefix("data:image/png;base64,")
|
||||||
|
|
||||||
|
|
||||||
@ -337,3 +354,36 @@ async def test_gemini_no_images_raises() -> None:
|
|||||||
|
|
||||||
with pytest.raises(ImageGenerationError, match="returned no images"):
|
with pytest.raises(ImageGenerationError, match="returned no images"):
|
||||||
await client.generate(prompt="draw", model="gemini-2.0-flash-preview-image-generation")
|
await client.generate(prompt="draw", model="gemini-2.0-flash-preview-image-generation")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_minimax_payload_and_response_with_reference_image(tmp_path: Path) -> None:
|
||||||
|
ref = tmp_path / "ref.png"
|
||||||
|
ref.write_bytes(PNG_BYTES)
|
||||||
|
fake = FakeClient(FakeResponse({"data": {"image_base64": [RAW_B64]}}))
|
||||||
|
client = MiniMaxImageGenerationClient(
|
||||||
|
api_key="sk-mm-test",
|
||||||
|
api_base="https://api.minimaxi.com/v1/",
|
||||||
|
extra_headers={"X-Test": "1"},
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.generate(
|
||||||
|
prompt="draw a character",
|
||||||
|
model="image-01",
|
||||||
|
reference_images=[str(ref)],
|
||||||
|
aspect_ratio="21:9",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.images == [PNG_DATA_URL]
|
||||||
|
call = fake.calls[0]
|
||||||
|
assert call["url"] == "https://api.minimaxi.com/v1/image_generation"
|
||||||
|
assert call["headers"]["Authorization"] == "Bearer sk-mm-test"
|
||||||
|
assert call["headers"]["X-Test"] == "1"
|
||||||
|
body = call["json"]
|
||||||
|
assert body["model"] == "image-01"
|
||||||
|
assert body["prompt"] == "draw a character"
|
||||||
|
assert body["response_format"] == "base64"
|
||||||
|
assert body["aspect_ratio"] == "21:9"
|
||||||
|
assert body["subject_reference"][0]["type"] == "character"
|
||||||
|
assert body["subject_reference"][0]["image_file"].startswith("data:image/png;base64,")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user