mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
Add MiniMaxImageGenerationClient with support for: - Text-to-image generation via MiniMax image-01 model - Reference image support (subject_reference) - Aspect ratio selection - Proper error handling aligned with existing providers Wire up MiniMax provider config in ImageGenerationTool, gateway, serve, and Nanobot class.
529 lines
17 KiB
Python
529 lines
17 KiB
Python
"""Image generation provider helpers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from nanobot.providers.registry import find_by_name
|
|
from nanobot.utils.helpers import detect_image_mime
|
|
|
|
_OPENROUTER_ATTRIBUTION_HEADERS = {
|
|
"HTTP-Referer": "https://github.com/HKUDS/nanobot",
|
|
"X-OpenRouter-Title": "nanobot",
|
|
"X-OpenRouter-Categories": "cli-agent,personal-agent",
|
|
}
|
|
_DEFAULT_TIMEOUT_S = 120.0
|
|
_AIHUBMIX_TIMEOUT_S = 300.0
|
|
_AIHUBMIX_ASPECT_RATIO_SIZES = {
|
|
"1:1": "1024x1024",
|
|
"3:4": "1024x1536",
|
|
"9:16": "1024x1536",
|
|
"4:3": "1536x1024",
|
|
"16:9": "1536x1024",
|
|
}
|
|
|
|
|
|
class ImageGenerationError(RuntimeError):
|
|
"""Raised when the image generation provider cannot return images."""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GeneratedImageResponse:
|
|
"""Images and optional text returned by the provider."""
|
|
|
|
images: list[str]
|
|
content: str
|
|
raw: dict[str, Any]
|
|
|
|
|
|
def _provider_base_url(provider: str, api_base: str | None, fallback: str) -> str:
|
|
if api_base:
|
|
return api_base.rstrip("/")
|
|
spec = find_by_name(provider)
|
|
if spec and spec.default_api_base:
|
|
return spec.default_api_base.rstrip("/")
|
|
return fallback
|
|
|
|
|
|
def image_path_to_data_url(path: str | Path) -> str:
|
|
"""Convert a local image path to an image data URL."""
|
|
p = Path(path).expanduser()
|
|
raw = p.read_bytes()
|
|
mime = detect_image_mime(raw)
|
|
if mime is None:
|
|
raise ImageGenerationError(f"unsupported reference image: {p}")
|
|
encoded = base64.b64encode(raw).decode("ascii")
|
|
return f"data:{mime};base64,{encoded}"
|
|
|
|
|
|
def _b64_png_data_url(value: str) -> str:
|
|
return f"data:image/png;base64,{value}"
|
|
|
|
|
|
def _aihubmix_size(aspect_ratio: str | None, image_size: str | None) -> str:
|
|
"""Return an OpenAI Images API size string for AIHubMix.
|
|
|
|
The WebUI emits compact size hints like ``1K`` for OpenRouter. AIHubMix's
|
|
Images API expects OpenAI-style dimensions or ``auto``, so only pass
|
|
through explicit dimension strings and otherwise derive the closest
|
|
supported orientation from aspect ratio.
|
|
"""
|
|
if image_size and "x" in image_size.lower():
|
|
return image_size
|
|
if aspect_ratio in _AIHUBMIX_ASPECT_RATIO_SIZES:
|
|
return _AIHUBMIX_ASPECT_RATIO_SIZES[aspect_ratio]
|
|
return "auto"
|
|
|
|
|
|
def _aihubmix_model_path(model: str) -> str:
|
|
if "/" in model:
|
|
return model
|
|
if model.startswith(("gpt-image-", "dall-e-")):
|
|
return f"openai/{model}"
|
|
return model
|
|
|
|
|
|
async def _download_image_data_url(
|
|
client: httpx.AsyncClient,
|
|
url: str,
|
|
) -> str:
|
|
response = await client.get(url)
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:500]
|
|
raise ImageGenerationError(f"failed to download generated image: {detail}") from exc
|
|
raw = response.content
|
|
mime = detect_image_mime(raw)
|
|
if mime is None:
|
|
raise ImageGenerationError("generated image URL did not return a supported image")
|
|
encoded = base64.b64encode(raw).decode("ascii")
|
|
return f"data:{mime};base64,{encoded}"
|
|
|
|
|
|
class OpenRouterImageGenerationClient:
|
|
"""Small async client for OpenRouter Chat Completions image generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str | None,
|
|
api_base: str | None = None,
|
|
extra_headers: dict[str, str] | None = None,
|
|
extra_body: dict[str, Any] | None = None,
|
|
timeout: float = _DEFAULT_TIMEOUT_S,
|
|
client: httpx.AsyncClient | None = None,
|
|
) -> None:
|
|
self.api_key = api_key
|
|
self.api_base = _provider_base_url(
|
|
"openrouter",
|
|
api_base,
|
|
"https://openrouter.ai/api/v1",
|
|
)
|
|
self.extra_headers = extra_headers or {}
|
|
self.extra_body = extra_body or {}
|
|
self.timeout = timeout
|
|
self._client = client
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
if not self.api_key:
|
|
raise ImageGenerationError(
|
|
"OpenRouter API key is not configured. Set providers.openrouter.apiKey."
|
|
)
|
|
|
|
content: str | list[dict[str, Any]]
|
|
references = list(reference_images or [])
|
|
if references:
|
|
blocks: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
|
|
blocks.extend(
|
|
{"type": "image_url", "image_url": {"url": image_path_to_data_url(path)}}
|
|
for path in references
|
|
)
|
|
content = blocks
|
|
else:
|
|
content = prompt
|
|
|
|
body: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"modalities": ["image", "text"],
|
|
"stream": False,
|
|
}
|
|
image_config: dict[str, str] = {}
|
|
if aspect_ratio:
|
|
image_config["aspect_ratio"] = aspect_ratio
|
|
if image_size:
|
|
image_config["image_size"] = image_size
|
|
if image_config:
|
|
body["image_config"] = image_config
|
|
body.update(self.extra_body)
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
**_OPENROUTER_ATTRIBUTION_HEADERS,
|
|
**self.extra_headers,
|
|
}
|
|
url = f"{self.api_base}/chat/completions"
|
|
|
|
if self._client is not None:
|
|
response = await self._client.post(url, headers=headers, json=body)
|
|
else:
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
response = await client.post(url, headers=headers, json=body)
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:500]
|
|
raise ImageGenerationError(f"OpenRouter image generation failed: {detail}") from exc
|
|
|
|
data = response.json()
|
|
images: list[str] = []
|
|
text_parts: list[str] = []
|
|
for choice in data.get("choices") or []:
|
|
if not isinstance(choice, dict):
|
|
continue
|
|
message = choice.get("message") or {}
|
|
if isinstance(message.get("content"), str):
|
|
text_parts.append(message["content"])
|
|
for image in message.get("images") or []:
|
|
if not isinstance(image, dict):
|
|
continue
|
|
image_url = image.get("image_url") or image.get("imageUrl") or {}
|
|
url_value = image_url.get("url") if isinstance(image_url, dict) else None
|
|
if isinstance(url_value, str) and url_value.startswith("data:image/"):
|
|
images.append(url_value)
|
|
|
|
if not images:
|
|
provider_error = data.get("error") if isinstance(data, dict) else None
|
|
if provider_error:
|
|
raise ImageGenerationError(f"OpenRouter returned no images: {provider_error}")
|
|
raise ImageGenerationError("OpenRouter returned no images for this request")
|
|
|
|
return GeneratedImageResponse(
|
|
images=images,
|
|
content="\n".join(part for part in text_parts if part).strip(),
|
|
raw=data,
|
|
)
|
|
|
|
|
|
class AIHubMixImageGenerationClient:
|
|
"""Small async client for AIHubMix unified image generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str | None,
|
|
api_base: str | None = None,
|
|
extra_headers: dict[str, str] | None = None,
|
|
extra_body: dict[str, Any] | None = None,
|
|
timeout: float = _AIHUBMIX_TIMEOUT_S,
|
|
client: httpx.AsyncClient | None = None,
|
|
) -> None:
|
|
self.api_key = api_key
|
|
self.api_base = _provider_base_url(
|
|
"aihubmix",
|
|
api_base,
|
|
"https://aihubmix.com/v1",
|
|
)
|
|
self.extra_headers = extra_headers or {}
|
|
self.extra_body = extra_body or {}
|
|
self.timeout = timeout
|
|
self._client = client
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
if not self.api_key:
|
|
raise ImageGenerationError(
|
|
"AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
|
|
)
|
|
|
|
refs = list(reference_images or [])
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
**self.extra_headers,
|
|
}
|
|
size = _aihubmix_size(aspect_ratio, image_size)
|
|
|
|
if self._client is not None:
|
|
return await self._generate_with_client(
|
|
self._client,
|
|
prompt=prompt,
|
|
model=model,
|
|
reference_images=refs,
|
|
size=size,
|
|
headers=headers,
|
|
)
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
return await self._generate_with_client(
|
|
client,
|
|
prompt=prompt,
|
|
model=model,
|
|
reference_images=refs,
|
|
size=size,
|
|
headers=headers,
|
|
)
|
|
|
|
async def _generate_with_client(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str],
|
|
size: str,
|
|
headers: dict[str, str],
|
|
) -> GeneratedImageResponse:
|
|
image_input: str | list[str] | None = None
|
|
if reference_images:
|
|
image_refs = [image_path_to_data_url(path) for path in reference_images]
|
|
image_input = image_refs[0] if len(image_refs) == 1 else image_refs
|
|
|
|
input_body: dict[str, Any] = {
|
|
"prompt": prompt,
|
|
"n": 1,
|
|
"size": size,
|
|
}
|
|
if image_input is not None:
|
|
input_body["image"] = image_input
|
|
input_body.update(self.extra_body)
|
|
|
|
body = {"input": input_body}
|
|
model_path = _aihubmix_model_path(model)
|
|
url = f"{self.api_base}/models/{model_path}/predictions"
|
|
try:
|
|
response = await client.post(
|
|
url,
|
|
headers={**headers, "Content-Type": "application/json"},
|
|
json=body,
|
|
)
|
|
except httpx.TimeoutException as exc:
|
|
raise ImageGenerationError("AIHubMix image generation timed out") from exc
|
|
except httpx.RequestError as exc:
|
|
raise ImageGenerationError(f"AIHubMix image generation request failed: {exc}") from exc
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:500]
|
|
raise ImageGenerationError(f"AIHubMix image generation failed: {detail}") from exc
|
|
|
|
payload = response.json()
|
|
images = await _aihubmix_images_from_payload(client, payload)
|
|
|
|
if not images:
|
|
provider_error = payload.get("error") if isinstance(payload, dict) else None
|
|
if provider_error:
|
|
raise ImageGenerationError(f"AIHubMix returned no images: {provider_error}")
|
|
raise ImageGenerationError("AIHubMix returned no images for this request")
|
|
|
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
|
|
|
|
|
async def _aihubmix_images_from_payload(
|
|
client: httpx.AsyncClient,
|
|
payload: dict[str, Any],
|
|
) -> list[str]:
|
|
images: list[str] = []
|
|
candidates: list[Any] = []
|
|
if "data" in payload:
|
|
candidates.append(payload["data"])
|
|
if "output" in payload:
|
|
candidates.append(payload["output"])
|
|
|
|
async def collect(value: Any) -> None:
|
|
if isinstance(value, list):
|
|
for item in value:
|
|
await collect(item)
|
|
return
|
|
if isinstance(value, str):
|
|
if value.startswith("data:image/"):
|
|
images.append(value)
|
|
elif value.startswith(("http://", "https://")):
|
|
images.append(await _download_image_data_url(client, value))
|
|
return
|
|
if not isinstance(value, dict):
|
|
return
|
|
|
|
b64_json = value.get("b64_json")
|
|
if isinstance(b64_json, str) and b64_json:
|
|
images.append(_b64_png_data_url(b64_json))
|
|
elif b64_json is not None:
|
|
await collect(b64_json)
|
|
|
|
bytes_base64 = value.get("bytesBase64") or value.get("bytes_base64") or value.get("base64")
|
|
if isinstance(bytes_base64, str) and bytes_base64:
|
|
images.append(_b64_png_data_url(bytes_base64))
|
|
|
|
image_url = value.get("image_url") or value.get("imageUrl")
|
|
if isinstance(image_url, dict):
|
|
await collect(image_url.get("url"))
|
|
elif image_url is not None:
|
|
await collect(image_url)
|
|
|
|
url_value = value.get("url")
|
|
if url_value is not None:
|
|
await collect(url_value)
|
|
|
|
for key in ("images", "image", "output"):
|
|
if key in value:
|
|
await collect(value[key])
|
|
|
|
for candidate in candidates:
|
|
await collect(candidate)
|
|
return images
|
|
|
|
|
|
_MINIMAX_TIMEOUT_S = 300.0
|
|
|
|
_MINIMAX_ASPECT_RATIO_SIZES = {
|
|
"1:1": "1:1",
|
|
"16:9": "16:9",
|
|
"4:3": "4:3",
|
|
"3:2": "3:2",
|
|
"2:3": "2:3",
|
|
"3:4": "3:4",
|
|
"9:16": "9:16",
|
|
"21:9": "21:9",
|
|
}
|
|
|
|
|
|
class MiniMaxImageGenerationClient:
|
|
"""Async client for MiniMax image generation API."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str | None,
|
|
api_base: str | None = None,
|
|
extra_headers: dict[str, str] | None = None,
|
|
extra_body: dict[str, Any] | None = None,
|
|
timeout: float = _MINIMAX_TIMEOUT_S,
|
|
client: httpx.AsyncClient | None = None,
|
|
) -> None:
|
|
self.api_key = api_key
|
|
self.api_base = _provider_base_url(
|
|
"minimax",
|
|
api_base,
|
|
"https://api.minimaxi.com/v1",
|
|
)
|
|
self.extra_headers = extra_headers or {}
|
|
self.extra_body = extra_body or {}
|
|
self.timeout = timeout
|
|
self._client = client
|
|
|
|
def _resolve_aspect_ratio(self, aspect_ratio: str | None) -> str:
|
|
if aspect_ratio and aspect_ratio in _MINIMAX_ASPECT_RATIO_SIZES:
|
|
return _MINIMAX_ASPECT_RATIO_SIZES[aspect_ratio]
|
|
return "1:1"
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
if not self.api_key:
|
|
raise ImageGenerationError(
|
|
"MiniMax API key is not configured. Set providers.minimax.apiKey."
|
|
)
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
**self.extra_headers,
|
|
}
|
|
|
|
body: dict[str, Any] = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"response_format": "base64",
|
|
}
|
|
|
|
resolved_ratio = self._resolve_aspect_ratio(aspect_ratio)
|
|
body["aspect_ratio"] = resolved_ratio
|
|
|
|
refs = list(reference_images or [])
|
|
if refs:
|
|
image_refs = [image_path_to_data_url(path) for path in refs]
|
|
body["subject_reference"] = [
|
|
{"type": "character", "image_file": ref} for ref in image_refs
|
|
]
|
|
|
|
body.update(self.extra_body)
|
|
|
|
if self._client is not None:
|
|
return await self._generate_with_client(self._client, body, headers)
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
return await self._generate_with_client(client, body, headers)
|
|
|
|
async def _generate_with_client(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
body: dict[str, Any],
|
|
headers: dict[str, str],
|
|
) -> GeneratedImageResponse:
|
|
url = f"{self.api_base}/image_generation"
|
|
try:
|
|
response = await client.post(url, headers=headers, json=body)
|
|
except httpx.TimeoutException as exc:
|
|
raise ImageGenerationError("MiniMax image generation timed out") from exc
|
|
except httpx.RequestError as exc:
|
|
raise ImageGenerationError(f"MiniMax image generation request failed: {exc}") from exc
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:500]
|
|
raise ImageGenerationError(f"MiniMax image generation failed: {detail}") from exc
|
|
|
|
payload = response.json()
|
|
images = _minimax_images_from_payload(payload)
|
|
|
|
if not images:
|
|
provider_error = payload.get("error") if isinstance(payload, dict) else None
|
|
if provider_error:
|
|
raise ImageGenerationError(f"MiniMax returned no images: {provider_error}")
|
|
raise ImageGenerationError("MiniMax returned no images for this request")
|
|
|
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
|
|
|
|
|
def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
|
"""Extract base64 images from MiniMax API response.
|
|
|
|
MiniMax returns images in ``data.image_base64`` (list of base64 strings).
|
|
"""
|
|
images: list[str] = []
|
|
data = payload.get("data")
|
|
if not isinstance(data, dict):
|
|
return images
|
|
for b64 in data.get("image_base64") or []:
|
|
if isinstance(b64, str) and b64:
|
|
images.append(_b64_png_data_url(b64))
|
|
return images
|