mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-24 18:42:35 +00:00
1460 lines
48 KiB
Python
1460 lines
48 KiB
Python
"""Image generation provider helpers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import binascii
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from loguru import logger
|
|
|
|
from nanobot.providers.registry import find_by_name
|
|
from nanobot.utils.helpers import detect_image_mime
|
|
|
|
_OPENROUTER_ATTRIBUTION_HEADERS = {
|
|
"HTTP-Referer": "https://github.com/HKUDS/nanobot",
|
|
"X-OpenRouter-Title": "nanobot",
|
|
"X-OpenRouter-Categories": "cli-agent,personal-agent",
|
|
}
|
|
_DEFAULT_TIMEOUT_S = 120.0
|
|
_AIHUBMIX_TIMEOUT_S = 300.0
|
|
_AIHUBMIX_ASPECT_RATIO_SIZES = {
|
|
"1:1": "1024x1024",
|
|
"3:4": "1024x1536",
|
|
"9:16": "1024x1536",
|
|
"4:3": "1536x1024",
|
|
"16:9": "1536x1024",
|
|
}
|
|
_GEMINI_DEFAULT_TIMEOUT_S = 120.0
|
|
_GEMINI_IMAGEN_ASPECT_RATIOS = {"1:1", "9:16", "16:9", "3:4", "4:3"}
|
|
_OLLAMA_DEFAULT_SIDE = 1024
|
|
_OLLAMA_SIZE_PRESETS = {
|
|
"1K": 1024,
|
|
"2K": 2048,
|
|
"4K": 4096,
|
|
}
|
|
_OLLAMA_EXPLICIT_SIZE_RE = re.compile(r"^\s*(\d+)\s*[xX]\s*(\d+)\s*$")
|
|
_OLLAMA_ASPECT_RATIO_RE = re.compile(r"^\s*(\d+)\s*:\s*(\d+)\s*$")
|
|
|
|
|
|
class ImageGenerationError(RuntimeError):
|
|
"""Raised when the image generation provider cannot return images."""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GeneratedImageResponse:
|
|
"""Images and optional text returned by the provider."""
|
|
|
|
images: list[str]
|
|
content: str
|
|
raw: dict[str, Any]
|
|
|
|
|
|
def _read_image_b64(path: str | Path) -> tuple[str, str]:
|
|
"""Return ``(mime, base64)`` for the image at ``path``."""
|
|
p = Path(path).expanduser()
|
|
raw = p.read_bytes()
|
|
mime = detect_image_mime(raw)
|
|
if mime is None:
|
|
raise ImageGenerationError(f"unsupported reference image: {p}")
|
|
return mime, base64.b64encode(raw).decode("ascii")
|
|
|
|
|
|
def image_path_to_data_url(path: str | Path) -> str:
|
|
"""Convert a local image path to an image data URL."""
|
|
mime, encoded = _read_image_b64(path)
|
|
return f"data:{mime};base64,{encoded}"
|
|
|
|
|
|
def image_path_to_inline_data(path: str | Path) -> dict[str, str]:
|
|
"""Convert a local image path to a Gemini ``inlineData`` payload dict."""
|
|
mime, encoded = _read_image_b64(path)
|
|
return {"mimeType": mime, "data": encoded}
|
|
|
|
|
|
def _b64_image_data_url(value: str) -> str:
|
|
encoded = "".join(value.split())
|
|
try:
|
|
raw = base64.b64decode(encoded, validate=True)
|
|
except binascii.Error as exc:
|
|
raise ImageGenerationError("generated image payload was not valid base64") from exc
|
|
mime = detect_image_mime(raw)
|
|
if mime is None:
|
|
raise ImageGenerationError("generated image payload was not a supported image")
|
|
return f"data:{mime};base64,{encoded}"
|
|
|
|
|
|
def _aihubmix_size(aspect_ratio: str | None, image_size: str | None) -> str:
|
|
"""Return an OpenAI Images API size string for AIHubMix.
|
|
|
|
The WebUI emits compact size hints like ``1K`` for OpenRouter. AIHubMix's
|
|
Images API expects OpenAI-style dimensions or ``auto``, so only pass
|
|
through explicit dimension strings and otherwise derive the closest
|
|
supported orientation from aspect ratio.
|
|
"""
|
|
if image_size and "x" in image_size.lower():
|
|
return image_size
|
|
if aspect_ratio in _AIHUBMIX_ASPECT_RATIO_SIZES:
|
|
return _AIHUBMIX_ASPECT_RATIO_SIZES[aspect_ratio]
|
|
return "auto"
|
|
|
|
|
|
def _aihubmix_model_path(model: str) -> str:
|
|
if "/" in model:
|
|
return model
|
|
if model.startswith(("gpt-image-", "dall-e-")):
|
|
return f"openai/{model}"
|
|
return model
|
|
|
|
|
|
async def _download_image_data_url(
|
|
client: httpx.AsyncClient,
|
|
url: str,
|
|
) -> str:
|
|
response = await client.get(url)
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:500]
|
|
raise ImageGenerationError(f"failed to download generated image: {detail}") from exc
|
|
raw = response.content
|
|
mime = detect_image_mime(raw)
|
|
if mime is None:
|
|
raise ImageGenerationError("generated image URL did not return a supported image")
|
|
encoded = base64.b64encode(raw).decode("ascii")
|
|
return f"data:{mime};base64,{encoded}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Registry
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_IMAGE_GEN_PROVIDERS: dict[str, type[ImageGenerationProvider]] = {}
|
|
|
|
|
|
def register_image_gen_provider(cls: type[ImageGenerationProvider]) -> None:
|
|
"""Register an image provider at import time only.
|
|
|
|
The registry is populated by module side effects so provider discovery
|
|
stays lazy and consistent across the process.
|
|
"""
|
|
name = cls.provider_name
|
|
if not name:
|
|
raise ValueError(f"{cls.__name__} must set provider_name")
|
|
_IMAGE_GEN_PROVIDERS[name] = cls
|
|
|
|
|
|
def get_image_gen_provider(name: str) -> type[ImageGenerationProvider] | None:
|
|
return _IMAGE_GEN_PROVIDERS.get(name)
|
|
|
|
|
|
def image_gen_provider_names() -> tuple[str, ...]:
|
|
"""Return registered image generation provider names in registry order."""
|
|
return tuple(_IMAGE_GEN_PROVIDERS)
|
|
|
|
|
|
def image_gen_provider_configs(config: Any) -> dict[str, Any]:
|
|
providers_cfg = config.providers
|
|
return {
|
|
name: pc
|
|
for name in _IMAGE_GEN_PROVIDERS
|
|
if (pc := getattr(providers_cfg, name, None)) is not None
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Base class
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class ImageGenerationProvider(ABC):
|
|
"""Base class for image generation provider clients."""
|
|
|
|
provider_name: str = ""
|
|
missing_key_message: str = ""
|
|
default_timeout: float = _DEFAULT_TIMEOUT_S
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str | None,
|
|
api_base: str | None = None,
|
|
extra_headers: dict[str, str] | None = None,
|
|
extra_body: dict[str, Any] | None = None,
|
|
timeout: float | None = None,
|
|
client: httpx.AsyncClient | None = None,
|
|
) -> None:
|
|
self.api_key = api_key
|
|
self.api_base = self._resolve_base_url(api_base)
|
|
self.extra_headers = extra_headers or {}
|
|
self.extra_body = extra_body or {}
|
|
self.timeout = timeout if timeout is not None else self.default_timeout
|
|
self._client = client
|
|
|
|
def _resolve_base_url(self, api_base: str | None) -> str:
|
|
if api_base:
|
|
return api_base.rstrip("/")
|
|
spec = find_by_name(self.provider_name)
|
|
if spec and spec.default_api_base:
|
|
return spec.default_api_base.rstrip("/")
|
|
return self._default_base_url()
|
|
|
|
def _default_base_url(self) -> str:
|
|
return ""
|
|
|
|
@abstractmethod
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse: ...
|
|
|
|
def _require_images(self, images: list[str], data: dict[str, Any]) -> None:
|
|
if images:
|
|
return
|
|
provider_error = data.get("error") if isinstance(data, dict) else None
|
|
label = self.provider_name
|
|
if provider_error:
|
|
raise ImageGenerationError(f"{label} returned no images: {provider_error}")
|
|
raise ImageGenerationError(f"{label} returned no images for this request")
|
|
|
|
async def _http_post(
|
|
self,
|
|
url: str,
|
|
*,
|
|
headers: dict[str, str],
|
|
body: dict[str, Any],
|
|
client: httpx.AsyncClient | None = None,
|
|
) -> httpx.Response:
|
|
if client is not None:
|
|
return await client.post(url, headers=headers, json=body)
|
|
if self._client is not None:
|
|
return await self._client.post(url, headers=headers, json=body)
|
|
async with httpx.AsyncClient(timeout=self.timeout) as c:
|
|
return await c.post(url, headers=headers, json=body)
|
|
|
|
|
|
class OpenRouterImageGenerationClient(ImageGenerationProvider):
|
|
"""Small async client for OpenRouter Chat Completions image generation."""
|
|
|
|
provider_name = "openrouter"
|
|
missing_key_message = (
|
|
"OpenRouter API key is not configured. Set providers.openrouter.apiKey."
|
|
)
|
|
|
|
def _default_base_url(self) -> str:
|
|
return "https://openrouter.ai/api/v1"
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
if not self.api_key:
|
|
raise ImageGenerationError(self.missing_key_message)
|
|
|
|
content: str | list[dict[str, Any]]
|
|
references = list(reference_images or [])
|
|
if references:
|
|
blocks: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
|
|
blocks.extend(
|
|
{"type": "image_url", "image_url": {"url": image_path_to_data_url(path)}}
|
|
for path in references
|
|
)
|
|
content = blocks
|
|
else:
|
|
content = prompt
|
|
|
|
body: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"modalities": ["image", "text"],
|
|
"stream": False,
|
|
}
|
|
image_config: dict[str, str] = {}
|
|
if aspect_ratio:
|
|
image_config["aspect_ratio"] = aspect_ratio
|
|
if image_size:
|
|
image_config["image_size"] = image_size
|
|
if image_config:
|
|
body["image_config"] = image_config
|
|
body.update(self.extra_body)
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
**_OPENROUTER_ATTRIBUTION_HEADERS,
|
|
**self.extra_headers,
|
|
}
|
|
url = f"{self.api_base}/chat/completions"
|
|
response = await self._http_post(url, headers=headers, body=body)
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:500]
|
|
raise ImageGenerationError(f"OpenRouter image generation failed: {detail}") from exc
|
|
|
|
data = response.json()
|
|
images: list[str] = []
|
|
text_parts: list[str] = []
|
|
for choice in data.get("choices") or []:
|
|
if not isinstance(choice, dict):
|
|
continue
|
|
message = choice.get("message") or {}
|
|
if isinstance(message.get("content"), str):
|
|
text_parts.append(message["content"])
|
|
for image in message.get("images") or []:
|
|
if not isinstance(image, dict):
|
|
continue
|
|
image_url = image.get("image_url") or image.get("imageUrl") or {}
|
|
url_value = image_url.get("url") if isinstance(image_url, dict) else None
|
|
if isinstance(url_value, str) and url_value.startswith("data:image/"):
|
|
images.append(url_value)
|
|
|
|
self._require_images(images, data)
|
|
|
|
return GeneratedImageResponse(
|
|
images=images,
|
|
content="\n".join(part for part in text_parts if part).strip(),
|
|
raw=data,
|
|
)
|
|
|
|
|
|
class AIHubMixImageGenerationClient(ImageGenerationProvider):
|
|
"""Small async client for AIHubMix unified image generation."""
|
|
|
|
provider_name = "aihubmix"
|
|
missing_key_message = (
|
|
"AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
|
|
)
|
|
default_timeout = _AIHUBMIX_TIMEOUT_S
|
|
|
|
def _default_base_url(self) -> str:
|
|
return "https://aihubmix.com/v1"
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
if not self.api_key:
|
|
raise ImageGenerationError(self.missing_key_message)
|
|
|
|
refs = list(reference_images or [])
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
**self.extra_headers,
|
|
}
|
|
size = _aihubmix_size(aspect_ratio, image_size)
|
|
|
|
client = self._client or httpx.AsyncClient(timeout=self.timeout)
|
|
try:
|
|
return await self._generate_with_client(
|
|
client,
|
|
prompt=prompt,
|
|
model=model,
|
|
reference_images=refs,
|
|
size=size,
|
|
headers=headers,
|
|
)
|
|
finally:
|
|
if self._client is None:
|
|
await client.aclose()
|
|
|
|
async def _generate_with_client(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str],
|
|
size: str,
|
|
headers: dict[str, str],
|
|
) -> GeneratedImageResponse:
|
|
image_input: str | list[str] | None = None
|
|
if reference_images:
|
|
image_refs = [image_path_to_data_url(path) for path in reference_images]
|
|
image_input = image_refs[0] if len(image_refs) == 1 else image_refs
|
|
|
|
input_body: dict[str, Any] = {
|
|
"prompt": prompt,
|
|
"n": 1,
|
|
"size": size,
|
|
}
|
|
if image_input is not None:
|
|
input_body["image"] = image_input
|
|
input_body.update(self.extra_body)
|
|
|
|
body = {"input": input_body}
|
|
model_path = _aihubmix_model_path(model)
|
|
url = f"{self.api_base}/models/{model_path}/predictions"
|
|
try:
|
|
response = await self._http_post(
|
|
url,
|
|
headers={**headers, "Content-Type": "application/json"},
|
|
body=body,
|
|
client=client,
|
|
)
|
|
except httpx.TimeoutException as exc:
|
|
raise ImageGenerationError("AIHubMix image generation timed out") from exc
|
|
except httpx.RequestError as exc:
|
|
raise ImageGenerationError(f"AIHubMix image generation request failed: {exc}") from exc
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:500]
|
|
raise ImageGenerationError(f"AIHubMix image generation failed: {detail}") from exc
|
|
|
|
payload = response.json()
|
|
images = await _aihubmix_images_from_payload(client, payload)
|
|
|
|
self._require_images(images, payload)
|
|
|
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
|
|
|
|
|
def _http_error_detail(response: httpx.Response) -> str:
|
|
"""Extract a readable error message from an HTTP error response."""
|
|
try:
|
|
data = response.json()
|
|
if isinstance(data, dict):
|
|
err = data.get("error")
|
|
if isinstance(err, dict):
|
|
return err.get("message") or str(err)
|
|
if err:
|
|
return str(err)
|
|
except Exception:
|
|
pass
|
|
return response.text[:500] or "<empty response body>"
|
|
|
|
|
|
def _round_to_multiple(value: float, multiple: int = 8) -> int:
|
|
rounded = int(round(value / multiple) * multiple)
|
|
return max(multiple, rounded)
|
|
|
|
|
|
def _ollama_dimensions(aspect_ratio: str | None, image_size: str | None) -> tuple[int, int]:
|
|
if image_size:
|
|
size = image_size.strip()
|
|
explicit = _OLLAMA_EXPLICIT_SIZE_RE.fullmatch(size)
|
|
if explicit:
|
|
return int(explicit.group(1)), int(explicit.group(2))
|
|
long_side = _OLLAMA_SIZE_PRESETS.get(size.upper(), _OLLAMA_DEFAULT_SIDE)
|
|
else:
|
|
long_side = _OLLAMA_DEFAULT_SIDE
|
|
|
|
if not aspect_ratio:
|
|
return long_side, long_side
|
|
|
|
ratio = _OLLAMA_ASPECT_RATIO_RE.fullmatch(aspect_ratio.strip())
|
|
if ratio is None:
|
|
return long_side, long_side
|
|
|
|
width_ratio = int(ratio.group(1))
|
|
height_ratio = int(ratio.group(2))
|
|
if width_ratio <= 0 or height_ratio <= 0:
|
|
return long_side, long_side
|
|
|
|
if width_ratio >= height_ratio:
|
|
width = long_side
|
|
height = _round_to_multiple(long_side * height_ratio / width_ratio)
|
|
else:
|
|
height = long_side
|
|
width = _round_to_multiple(long_side * width_ratio / height_ratio)
|
|
return max(8, width), max(8, height)
|
|
|
|
|
|
def _ollama_image_data_url(value: str) -> str:
|
|
if value.startswith("data:image/"):
|
|
return value
|
|
return _b64_image_data_url(value)
|
|
|
|
|
|
def _ollama_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
|
images: list[str] = []
|
|
|
|
def collect(value: Any) -> None:
|
|
if isinstance(value, str) and value:
|
|
images.append(_ollama_image_data_url(value))
|
|
elif isinstance(value, list):
|
|
for item in value:
|
|
collect(item)
|
|
|
|
collect(payload.get("image"))
|
|
collect(payload.get("images"))
|
|
return images
|
|
|
|
|
|
class OllamaImageGenerationClient(ImageGenerationProvider):
|
|
"""Async client for Ollama native image generation models."""
|
|
|
|
provider_name = "ollama"
|
|
default_timeout = 300.0
|
|
|
|
def _default_base_url(self) -> str:
|
|
return "http://localhost:11434/api"
|
|
|
|
def _resolve_base_url(self, api_base: str | None) -> str:
|
|
if api_base:
|
|
base = api_base.rstrip("/")
|
|
if base.endswith("/v1"):
|
|
return f"{base[:-3]}/api"
|
|
return base
|
|
return self._default_base_url()
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
if reference_images:
|
|
raise ImageGenerationError(
|
|
"Ollama image generation does not support reference images"
|
|
)
|
|
|
|
width, height = _ollama_dimensions(aspect_ratio, image_size)
|
|
body: dict[str, Any] = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"width": width,
|
|
"height": height,
|
|
"steps": 0,
|
|
}
|
|
body.update(self.extra_body)
|
|
body["stream"] = False
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
**self.extra_headers,
|
|
}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
url = f"{self.api_base}/generate"
|
|
response = await self._http_post(url, headers=headers, body=body)
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = _http_error_detail(response)
|
|
logger.error(
|
|
"Ollama image generation failed (HTTP {}): {}",
|
|
response.status_code,
|
|
detail,
|
|
)
|
|
raise ImageGenerationError(
|
|
f"Ollama image generation failed (HTTP {response.status_code}): {detail}"
|
|
) from exc
|
|
|
|
data = response.json()
|
|
images = _ollama_images_from_payload(data)
|
|
|
|
self._require_images(images, data)
|
|
|
|
response_text = data.get("response")
|
|
content = response_text if isinstance(response_text, str) else ""
|
|
|
|
return GeneratedImageResponse(images=images, content=content, raw=data)
|
|
|
|
|
|
class GeminiImageGenerationClient(ImageGenerationProvider):
|
|
"""Async client for Gemini/Imagen image generation via the Generative Language API."""
|
|
|
|
provider_name = "gemini"
|
|
missing_key_message = (
|
|
"Gemini API key is not configured. Set providers.gemini.apiKey."
|
|
)
|
|
default_timeout = _GEMINI_DEFAULT_TIMEOUT_S
|
|
|
|
def _default_base_url(self) -> str:
|
|
return "https://generativelanguage.googleapis.com/v1beta"
|
|
|
|
def _resolve_base_url(self, api_base: str | None) -> str:
|
|
# Gemini chat completions use the registry's OpenAI-compatible shim.
|
|
# Image generation must hit the native Generative Language API, so we
|
|
# intentionally bypass the shared registry lookup here.
|
|
if api_base:
|
|
return api_base.rstrip("/")
|
|
return self._default_base_url()
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
if not self.api_key:
|
|
raise ImageGenerationError(self.missing_key_message)
|
|
if "imagen" in model.lower():
|
|
if reference_images:
|
|
logger.warning(
|
|
"Imagen models do not support reference images; "
|
|
"ignoring {} reference image(s) for {}",
|
|
len(reference_images),
|
|
model,
|
|
)
|
|
return await self._generate_imagen(
|
|
prompt=prompt, model=model, aspect_ratio=aspect_ratio
|
|
)
|
|
return await self._generate_gemini_flash(
|
|
prompt=prompt, model=model, reference_images=reference_images or []
|
|
)
|
|
|
|
async def _generate_imagen(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
aspect_ratio: str | None,
|
|
) -> GeneratedImageResponse:
|
|
parameters: dict[str, Any] = {"sampleCount": 1}
|
|
if aspect_ratio in _GEMINI_IMAGEN_ASPECT_RATIOS:
|
|
parameters["aspectRatio"] = aspect_ratio
|
|
body: dict[str, Any] = {
|
|
"instances": [{"prompt": prompt}],
|
|
"parameters": parameters,
|
|
}
|
|
body.update(self.extra_body)
|
|
|
|
url = f"{self.api_base}/models/{model}:predict"
|
|
headers = {
|
|
"x-goog-api-key": self.api_key or "",
|
|
"Content-Type": "application/json",
|
|
**self.extra_headers,
|
|
}
|
|
response = await self._http_post(url, headers=headers, body=body)
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = _http_error_detail(response)
|
|
logger.error("Gemini Imagen generation failed (HTTP {}): {}", response.status_code, detail)
|
|
raise ImageGenerationError(
|
|
f"Gemini Imagen generation failed (HTTP {response.status_code}): {detail}"
|
|
) from exc
|
|
|
|
data = response.json()
|
|
images: list[str] = []
|
|
for prediction in data.get("predictions") or []:
|
|
if not isinstance(prediction, dict):
|
|
continue
|
|
b64 = prediction.get("bytesBase64Encoded")
|
|
mime = prediction.get("mimeType", "image/png")
|
|
if isinstance(b64, str) and b64:
|
|
images.append(f"data:{mime};base64,{b64}")
|
|
|
|
self._require_images(images, data)
|
|
|
|
return GeneratedImageResponse(images=images, content="", raw=data)
|
|
|
|
async def _generate_gemini_flash(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str],
|
|
) -> GeneratedImageResponse:
|
|
parts: list[dict[str, Any]] = [
|
|
{"inlineData": image_path_to_inline_data(path)} for path in reference_images
|
|
]
|
|
parts.append({"text": prompt})
|
|
|
|
body: dict[str, Any] = {
|
|
"contents": [{"role": "user", "parts": parts}],
|
|
"generationConfig": {"responseModalities": ["TEXT", "IMAGE"]},
|
|
}
|
|
body.update(self.extra_body)
|
|
|
|
url = f"{self.api_base}/models/{model}:generateContent"
|
|
headers = {
|
|
"x-goog-api-key": self.api_key or "",
|
|
"Content-Type": "application/json",
|
|
**self.extra_headers,
|
|
}
|
|
response = await self._http_post(url, headers=headers, body=body)
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = _http_error_detail(response)
|
|
logger.error("Gemini image generation failed (HTTP {}): {}", response.status_code, detail)
|
|
raise ImageGenerationError(
|
|
f"Gemini image generation failed (HTTP {response.status_code}): {detail}"
|
|
) from exc
|
|
|
|
data = response.json()
|
|
images: list[str] = []
|
|
text_parts: list[str] = []
|
|
for candidate in data.get("candidates") or []:
|
|
if not isinstance(candidate, dict):
|
|
continue
|
|
content = candidate.get("content") or {}
|
|
for part in content.get("parts") or []:
|
|
if not isinstance(part, dict):
|
|
continue
|
|
if "text" in part:
|
|
text_parts.append(part["text"])
|
|
inline = part.get("inlineData")
|
|
if isinstance(inline, dict):
|
|
mime = inline.get("mimeType", "image/png")
|
|
b64 = inline.get("data", "")
|
|
if b64:
|
|
images.append(f"data:{mime};base64,{b64}")
|
|
|
|
self._require_images(images, data)
|
|
|
|
return GeneratedImageResponse(
|
|
images=images,
|
|
content="\n".join(t for t in text_parts if t).strip(),
|
|
raw=data,
|
|
)
|
|
|
|
|
|
async def _aihubmix_images_from_payload(
|
|
client: httpx.AsyncClient,
|
|
payload: dict[str, Any],
|
|
) -> list[str]:
|
|
images: list[str] = []
|
|
candidates: list[Any] = []
|
|
if "data" in payload:
|
|
candidates.append(payload["data"])
|
|
if "output" in payload:
|
|
candidates.append(payload["output"])
|
|
|
|
async def collect(value: Any) -> None:
|
|
if isinstance(value, list):
|
|
for item in value:
|
|
await collect(item)
|
|
return
|
|
if isinstance(value, str):
|
|
if value.startswith("data:image/"):
|
|
images.append(value)
|
|
elif value.startswith(("http://", "https://")):
|
|
images.append(await _download_image_data_url(client, value))
|
|
return
|
|
if not isinstance(value, dict):
|
|
return
|
|
|
|
b64_json = value.get("b64_json")
|
|
if isinstance(b64_json, str) and b64_json:
|
|
images.append(_b64_image_data_url(b64_json))
|
|
elif b64_json is not None:
|
|
await collect(b64_json)
|
|
|
|
bytes_base64 = value.get("bytesBase64") or value.get("bytes_base64") or value.get("base64")
|
|
if isinstance(bytes_base64, str) and bytes_base64:
|
|
images.append(_b64_image_data_url(bytes_base64))
|
|
|
|
image_url = value.get("image_url") or value.get("imageUrl")
|
|
if isinstance(image_url, dict):
|
|
await collect(image_url.get("url"))
|
|
elif image_url is not None:
|
|
await collect(image_url)
|
|
|
|
url_value = value.get("url")
|
|
if url_value is not None:
|
|
await collect(url_value)
|
|
|
|
for key in ("images", "image", "output"):
|
|
if key in value:
|
|
await collect(value[key])
|
|
|
|
for candidate in candidates:
|
|
await collect(candidate)
|
|
return images
|
|
|
|
|
|
_MINIMAX_TIMEOUT_S = 300.0
|
|
|
|
_MINIMAX_ASPECT_RATIO_SIZES = {
|
|
"1:1": "1:1",
|
|
"16:9": "16:9",
|
|
"4:3": "4:3",
|
|
"3:2": "3:2",
|
|
"2:3": "2:3",
|
|
"3:4": "3:4",
|
|
"9:16": "9:16",
|
|
"21:9": "21:9",
|
|
}
|
|
|
|
|
|
class MiniMaxImageGenerationClient(ImageGenerationProvider):
|
|
"""Async client for MiniMax image generation API."""
|
|
|
|
provider_name = "minimax"
|
|
missing_key_message = (
|
|
"MiniMax API key is not configured. Set providers.minimax.apiKey."
|
|
)
|
|
default_timeout = _MINIMAX_TIMEOUT_S
|
|
|
|
def _default_base_url(self) -> str:
|
|
return "https://api.minimaxi.com/v1"
|
|
|
|
def _resolve_aspect_ratio(self, aspect_ratio: str | None) -> str:
|
|
if aspect_ratio and aspect_ratio in _MINIMAX_ASPECT_RATIO_SIZES:
|
|
return _MINIMAX_ASPECT_RATIO_SIZES[aspect_ratio]
|
|
return "1:1"
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
if not self.api_key:
|
|
raise ImageGenerationError(self.missing_key_message)
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
**self.extra_headers,
|
|
}
|
|
|
|
body: dict[str, Any] = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"response_format": "base64",
|
|
}
|
|
|
|
resolved_ratio = self._resolve_aspect_ratio(aspect_ratio)
|
|
body["aspect_ratio"] = resolved_ratio
|
|
|
|
refs = list(reference_images or [])
|
|
if refs:
|
|
image_refs = [image_path_to_data_url(path) for path in refs]
|
|
body["subject_reference"] = [
|
|
{"type": "character", "image_file": ref} for ref in image_refs
|
|
]
|
|
|
|
body.update(self.extra_body)
|
|
|
|
return await self._generate_with_client(body, headers)
|
|
|
|
async def _generate_with_client(
|
|
self,
|
|
body: dict[str, Any],
|
|
headers: dict[str, str],
|
|
) -> GeneratedImageResponse:
|
|
url = f"{self.api_base}/image_generation"
|
|
try:
|
|
response = await self._http_post(url, headers=headers, body=body)
|
|
except httpx.TimeoutException as exc:
|
|
raise ImageGenerationError("MiniMax image generation timed out") from exc
|
|
except httpx.RequestError as exc:
|
|
raise ImageGenerationError(f"MiniMax image generation request failed: {exc}") from exc
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:500]
|
|
raise ImageGenerationError(f"MiniMax image generation failed: {detail}") from exc
|
|
|
|
payload = response.json()
|
|
images = _minimax_images_from_payload(payload)
|
|
|
|
self._require_images(images, payload)
|
|
|
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
|
|
|
|
|
def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
|
"""Extract base64 images from MiniMax API response.
|
|
|
|
MiniMax returns images in ``data.image_base64`` (list of base64 strings).
|
|
"""
|
|
images: list[str] = []
|
|
data = payload.get("data")
|
|
if not isinstance(data, dict):
|
|
return images
|
|
for b64 in data.get("image_base64") or []:
|
|
if isinstance(b64, str) and b64:
|
|
images.append(_b64_image_data_url(b64))
|
|
return images
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# OpenAI image generation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_OPENAI_DALLE2_SUPPORTED_SIZES = {"256x256", "512x512", "1024x1024"}
|
|
_OPENAI_DALLE3_SUPPORTED_SIZES = {"1024x1024", "1792x1024", "1024x1792"}
|
|
_OPENAI_GPT_IMAGE_SUPPORTED_SIZES = {
|
|
"1024x1024",
|
|
"1536x1024",
|
|
"1024x1536",
|
|
"auto",
|
|
}
|
|
_OPENAI_DALLE2_ASPECT_RATIO_SIZES = {
|
|
"1:1": "1024x1024",
|
|
"16:9": "1024x1024",
|
|
"9:16": "1024x1024",
|
|
"3:4": "1024x1024",
|
|
"4:3": "1024x1024",
|
|
}
|
|
_OPENAI_DALLE3_ASPECT_RATIO_SIZES = {
|
|
"1:1": "1024x1024",
|
|
"16:9": "1792x1024",
|
|
"9:16": "1024x1792",
|
|
"3:4": "1024x1792",
|
|
"4:3": "1792x1024",
|
|
}
|
|
_OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES = {
|
|
"1:1": "1024x1024",
|
|
"16:9": "1536x1024",
|
|
"9:16": "1024x1536",
|
|
"3:4": "1024x1536",
|
|
"4:3": "1536x1024",
|
|
}
|
|
|
|
|
|
class OpenAIImageGenerationClient(ImageGenerationProvider):
|
|
"""OpenAI Images API using an API key (``providers.openai.apiKey``)."""
|
|
|
|
provider_name = "openai"
|
|
missing_key_message = (
|
|
"OpenAI API key is not configured. Set providers.openai.apiKey."
|
|
)
|
|
|
|
def _default_base_url(self) -> str:
|
|
return "https://api.openai.com/v1"
|
|
|
|
@staticmethod
|
|
def _strip_model_prefix(model: str) -> str:
|
|
"""Remove ``openai/`` prefix if present (OpenRouter convention)."""
|
|
if model.startswith("openai/") or model.startswith("openai_codex/"):
|
|
return model.split("/", 1)[1]
|
|
return model
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
if not self.api_key:
|
|
raise ImageGenerationError(self.missing_key_message)
|
|
|
|
if reference_images:
|
|
logger.warning(
|
|
"DALL-E models do not support reference images; "
|
|
"ignoring {} reference image(s) for {}",
|
|
len(reference_images),
|
|
model,
|
|
)
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
**self.extra_headers,
|
|
}
|
|
|
|
clean_model = self._strip_model_prefix(model)
|
|
body: dict[str, Any] = {
|
|
"model": clean_model,
|
|
"prompt": prompt,
|
|
}
|
|
|
|
if not _openai_is_gpt_image_model(clean_model):
|
|
body["response_format"] = "b64_json"
|
|
body["n"] = 1
|
|
|
|
size = _openai_size(clean_model, aspect_ratio, image_size)
|
|
if size:
|
|
body["size"] = size
|
|
|
|
body.update(self.extra_body)
|
|
|
|
logger.info("OpenAI Images API request: POST {}/images/generations body={}", self.api_base, body)
|
|
|
|
response = await self._http_post(
|
|
f"{self.api_base}/images/generations",
|
|
headers=headers,
|
|
body=body,
|
|
)
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:1000]
|
|
logger.error("OpenAI Images API error ({}): {}", response.status_code, detail)
|
|
raise ImageGenerationError(
|
|
f"OpenAI image generation failed (HTTP {response.status_code}): {detail}"
|
|
) from exc
|
|
|
|
payload = response.json()
|
|
logger.info("OpenAI Images API response ({}): {}", response.status_code,
|
|
{k: v for k, v in payload.items() if k != "data"})
|
|
|
|
client = self._client
|
|
owns_client = client is None
|
|
if owns_client:
|
|
client = httpx.AsyncClient(timeout=self.timeout)
|
|
try:
|
|
images = await _openai_images_from_payload(client, payload)
|
|
finally:
|
|
if owns_client:
|
|
await client.aclose()
|
|
|
|
self._require_images(images, payload)
|
|
|
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# OpenAI Codex image generation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class CodexImageGenerationClient(ImageGenerationProvider):
|
|
"""OpenAI image generation via Codex subscription OAuth.
|
|
|
|
Uses the Codex Responses API with the ``image_generation`` tool
|
|
(the same mechanism ChatGPT uses internally). No API key required —
|
|
the Codex OAuth token from ``oauth_cli_kit`` is used instead.
|
|
"""
|
|
|
|
provider_name = "openai_codex"
|
|
missing_key_message = (
|
|
"Codex OAuth token is unavailable. "
|
|
"Log in with Codex subscription first."
|
|
)
|
|
|
|
def _default_base_url(self) -> str:
|
|
return "https://chatgpt.com/backend-api"
|
|
|
|
def _codex_model(self, model: str) -> str:
|
|
"""Strip the ``openai-codex/`` prefix if present."""
|
|
if model.startswith(("openai-codex/", "openai_codex/")):
|
|
return model.split("/", 1)[1]
|
|
return model
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
try:
|
|
from oauth_cli_kit import get_token as get_codex_token
|
|
except ImportError:
|
|
raise ImageGenerationError(self.missing_key_message)
|
|
|
|
try:
|
|
token = await asyncio.to_thread(get_codex_token)
|
|
except Exception as exc:
|
|
raise ImageGenerationError(self.missing_key_message) from exc
|
|
if not token or not token.access:
|
|
raise ImageGenerationError(self.missing_key_message)
|
|
|
|
logger.info(
|
|
"Using Codex OAuth token for image generation (account: {})",
|
|
token.account_id,
|
|
)
|
|
|
|
if reference_images:
|
|
logger.warning(
|
|
"Codex image generation does not support reference images; "
|
|
"ignoring {} reference image(s)",
|
|
len(reference_images),
|
|
)
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {token.access}",
|
|
"chatgpt-account-id": token.account_id,
|
|
"OpenAI-Beta": "responses=experimental",
|
|
"originator": "nanobot",
|
|
"User-Agent": "nanobot (python)",
|
|
"Content-Type": "application/json",
|
|
**self.extra_headers,
|
|
}
|
|
|
|
body: dict[str, Any] = {
|
|
"model": self._codex_model(model),
|
|
"instructions": "Generate an image based on the user's request.",
|
|
"input": [{"role": "user", "content": prompt}],
|
|
"tools": [{"type": "image_generation"}],
|
|
"tool_choice": "auto",
|
|
"stream": True,
|
|
"store": False,
|
|
}
|
|
body.update(self.extra_body)
|
|
|
|
logger.info("Codex Responses API request: POST {}/codex/responses body={}",
|
|
self.api_base, {k: v for k, v in body.items() if k != "input"})
|
|
|
|
response = await self._http_post(
|
|
f"{self.api_base}/codex/responses",
|
|
headers=headers,
|
|
body=body,
|
|
)
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:1000]
|
|
logger.error("Codex Responses API error ({}): {}", response.status_code, detail)
|
|
raise ImageGenerationError(
|
|
f"Codex image generation failed (HTTP {response.status_code}): {detail}"
|
|
) from exc
|
|
|
|
images, content_text = await _parse_codex_sse_images(response)
|
|
|
|
raw = {"status": "completed"}
|
|
self._require_images(images, raw)
|
|
|
|
return GeneratedImageResponse(images=images, content=content_text, raw=raw)
|
|
|
|
|
|
def _openai_size(
|
|
model: str,
|
|
aspect_ratio: str | None,
|
|
image_size: str | None,
|
|
) -> str:
|
|
"""Resolve aspect ratio or image_size to an OpenAI Images API size string."""
|
|
sizes, supported_sizes = _openai_size_options(model)
|
|
explicit_size = _normalize_openai_image_size(image_size)
|
|
if explicit_size and _openai_explicit_size_supported(
|
|
explicit_size,
|
|
supported_sizes=supported_sizes,
|
|
):
|
|
return explicit_size
|
|
if explicit_size:
|
|
logger.warning(
|
|
"OpenAI image size '{}' is not supported by {}; using aspect ratio/default size",
|
|
explicit_size,
|
|
model,
|
|
)
|
|
if aspect_ratio and aspect_ratio in sizes:
|
|
return sizes[aspect_ratio]
|
|
return "1024x1024"
|
|
|
|
|
|
def _openai_is_gpt_image_model(model: str) -> bool:
|
|
normalized = model.lower()
|
|
return normalized.startswith(("gpt-image", "chatgpt-image"))
|
|
|
|
|
|
def _openai_size_options(model: str) -> tuple[dict[str, str], set[str] | None]:
|
|
normalized = model.lower()
|
|
if normalized.startswith("dall-e-2"):
|
|
return _OPENAI_DALLE2_ASPECT_RATIO_SIZES, _OPENAI_DALLE2_SUPPORTED_SIZES
|
|
if normalized.startswith("dall-e-3"):
|
|
return _OPENAI_DALLE3_ASPECT_RATIO_SIZES, _OPENAI_DALLE3_SUPPORTED_SIZES
|
|
if normalized.startswith("gpt-image-2"):
|
|
return _OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, None
|
|
return _OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, _OPENAI_GPT_IMAGE_SUPPORTED_SIZES
|
|
|
|
|
|
def _normalize_openai_image_size(image_size: str | None) -> str | None:
|
|
if not image_size:
|
|
return None
|
|
normalized = image_size.strip().lower()
|
|
return normalized or None
|
|
|
|
|
|
def _openai_explicit_size_supported(
|
|
size: str,
|
|
*,
|
|
supported_sizes: set[str] | None,
|
|
) -> bool:
|
|
if supported_sizes is not None:
|
|
return size in supported_sizes
|
|
width, sep, height = size.partition("x")
|
|
return bool(sep and width.isdecimal() and height.isdecimal())
|
|
|
|
|
|
async def _openai_images_from_payload(
|
|
client: httpx.AsyncClient,
|
|
payload: dict[str, Any],
|
|
) -> list[str]:
|
|
"""Extract images from OpenAI Images API response.
|
|
|
|
Handles both ``b64_json`` (preferred) and ``url`` (downloaded) formats.
|
|
"""
|
|
images: list[str] = []
|
|
for item in payload.get("data") or []:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
b64 = item.get("b64_json")
|
|
if isinstance(b64, str) and b64:
|
|
images.append(_b64_image_data_url(b64))
|
|
continue
|
|
url = item.get("url")
|
|
if isinstance(url, str) and url:
|
|
images.append(await _download_image_data_url(client, url))
|
|
return images
|
|
|
|
|
|
def _codex_responses_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
|
"""Extract images from Codex Responses API ``image_generation_call`` output."""
|
|
images: list[str] = []
|
|
for item in payload.get("output") or []:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
if item.get("type") != "image_generation_call":
|
|
continue
|
|
result = item.get("result")
|
|
if isinstance(result, str):
|
|
images.append(result if result.startswith("data:image/") else _b64_image_data_url(result))
|
|
continue
|
|
if isinstance(result, dict):
|
|
image_url = result.get("image_url") or result.get("image") or ""
|
|
if isinstance(image_url, str):
|
|
images.append(image_url if image_url.startswith("data:image/") else _b64_image_data_url(image_url))
|
|
return images
|
|
|
|
|
|
async def _parse_codex_sse_images(
|
|
response: httpx.Response,
|
|
) -> tuple[list[str], str]:
|
|
"""Parse a Codex Responses API SSE stream for image generation output.
|
|
|
|
Returns ``(images, content_text)``.
|
|
"""
|
|
import json as _json
|
|
|
|
images: list[str] = []
|
|
text_parts: list[str] = []
|
|
|
|
buffer: list[str] = []
|
|
async for line_bytes in response.aiter_lines():
|
|
line = line_bytes.strip()
|
|
if line == "":
|
|
if buffer:
|
|
data_lines = []
|
|
for bl in buffer:
|
|
if bl.startswith("data:"):
|
|
data_lines.append(bl[5:].strip())
|
|
buffer.clear()
|
|
if data_lines:
|
|
raw = "".join(data_lines)
|
|
if raw == "[DONE]":
|
|
break
|
|
try:
|
|
event = _json.loads(raw)
|
|
except Exception:
|
|
continue
|
|
ev_type = event.get("type", "")
|
|
if ev_type in ("error", "response.failed"):
|
|
logger.error("Codex SSE failure: {}", raw[:2000])
|
|
_collect_images_from_sse_event(event, images)
|
|
_collect_text_from_sse_event(event, text_parts)
|
|
continue
|
|
buffer.append(line)
|
|
|
|
# flush remaining
|
|
if buffer:
|
|
data_lines = [bl[5:].strip() for bl in buffer if bl.startswith("data:")]
|
|
raw = "".join(data_lines)
|
|
if raw and raw != "[DONE]":
|
|
try:
|
|
event = _json.loads(raw)
|
|
except Exception:
|
|
pass
|
|
else:
|
|
_collect_images_from_sse_event(event, images)
|
|
_collect_text_from_sse_event(event, text_parts)
|
|
|
|
return images, "".join(text_parts).strip()
|
|
|
|
|
|
def _collect_images_from_sse_event(event: dict[str, Any], images: list[str]) -> None:
|
|
if event.get("type") != "response.output_item.done":
|
|
return
|
|
item = event.get("item") or {}
|
|
if item.get("type") != "image_generation_call":
|
|
return
|
|
result = item.get("result")
|
|
if isinstance(result, str):
|
|
if result.startswith("data:image/"):
|
|
images.append(result)
|
|
else:
|
|
images.append(_b64_image_data_url(result))
|
|
elif isinstance(result, dict):
|
|
image_url = result.get("image_url") or result.get("image") or ""
|
|
if isinstance(image_url, str):
|
|
if image_url.startswith("data:image/"):
|
|
images.append(image_url)
|
|
else:
|
|
images.append(_b64_image_data_url(image_url))
|
|
|
|
|
|
def _collect_text_from_sse_event(event: dict[str, Any], text_parts: list[str]) -> None:
|
|
if event.get("type") == "response.output_text.delta":
|
|
delta = event.get("delta")
|
|
if isinstance(delta, str) and delta:
|
|
text_parts.append(delta)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# StepFun (阶跃星辰) image generation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_STEPFUN_ASPECT_RATIO_SIZES = {
|
|
"1:1": "1024x1024",
|
|
"16:9": "1280x800",
|
|
"9:16": "800x1280",
|
|
"3:4": "768x1360",
|
|
"4:3": "1360x768",
|
|
}
|
|
|
|
|
|
class StepFunImageGenerationClient(ImageGenerationProvider):
|
|
"""Async client for StepFun (阶跃星辰) image generation.
|
|
|
|
Supports:
|
|
- Text-to-image via step-image-edit-2 (default model)
|
|
- Reference-image-guided generation via style_reference (step-1x-medium)
|
|
"""
|
|
|
|
provider_name = "stepfun"
|
|
missing_key_message = (
|
|
"StepFun API key is not configured. Set providers.stepfun.apiKey."
|
|
)
|
|
default_timeout = 120.0
|
|
|
|
def _default_base_url(self) -> str:
|
|
return "https://api.stepfun.com/v1"
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
if not self.api_key:
|
|
raise ImageGenerationError(self.missing_key_message)
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
**self.extra_headers,
|
|
}
|
|
|
|
body: dict[str, Any] = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"response_format": "b64_json",
|
|
"n": 1,
|
|
}
|
|
|
|
# Map aspect ratio / image_size to StepFun size string
|
|
size = _stepfun_size(aspect_ratio, image_size)
|
|
if size:
|
|
body["size"] = size
|
|
|
|
# step-1x-medium supports style_reference for reference-image-guided generation
|
|
refs = list(reference_images or [])
|
|
if refs and "1x" in model:
|
|
body["style_reference"] = {
|
|
"source_url": image_path_to_data_url(refs[0]),
|
|
}
|
|
|
|
body.update(self.extra_body)
|
|
|
|
response = await self._http_post(
|
|
f"{self.api_base}/images/generations",
|
|
headers=headers,
|
|
body=body,
|
|
)
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:500]
|
|
raise ImageGenerationError(
|
|
f"StepFun image generation failed: {detail}"
|
|
) from exc
|
|
|
|
payload = response.json()
|
|
images = _stepfun_images_from_payload(payload)
|
|
|
|
self._require_images(images, payload)
|
|
|
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
|
|
|
|
|
def _stepfun_size(
|
|
aspect_ratio: str | None,
|
|
image_size: str | None,
|
|
) -> str:
|
|
"""Resolve aspect ratio / image_size to StepFun size string.
|
|
|
|
StepFun expects ``WIDTHxHEIGHT`` (note: width x height, not the more
|
|
common ``HxW`` order used by other providers). The accepted sizes are
|
|
``1024x1024``, ``768x1360``, ``896x1184``, ``1360x768``, ``1184x896``.
|
|
"""
|
|
if image_size and "x" in image_size.lower():
|
|
return image_size
|
|
if aspect_ratio and aspect_ratio in _STEPFUN_ASPECT_RATIO_SIZES:
|
|
return _STEPFUN_ASPECT_RATIO_SIZES[aspect_ratio]
|
|
return "1024x1024"
|
|
|
|
|
|
def _stepfun_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
|
"""Extract base64 images from StepFun API response.
|
|
|
|
StepFun returns images in ``data[].b64_json`` (base64 strings).
|
|
"""
|
|
images: list[str] = []
|
|
for item in payload.get("data") or []:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
b64 = item.get("b64_json")
|
|
if isinstance(b64, str) and b64:
|
|
images.append(_b64_image_data_url(b64))
|
|
return images
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Provider registration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
register_image_gen_provider(AIHubMixImageGenerationClient)
|
|
register_image_gen_provider(CodexImageGenerationClient)
|
|
register_image_gen_provider(GeminiImageGenerationClient)
|
|
register_image_gen_provider(OllamaImageGenerationClient)
|
|
register_image_gen_provider(MiniMaxImageGenerationClient)
|
|
register_image_gen_provider(OpenAIImageGenerationClient)
|
|
register_image_gen_provider(OpenRouterImageGenerationClient)
|
|
register_image_gen_provider(StepFunImageGenerationClient)
|