mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 16:42:25 +00:00
396 lines
13 KiB
Python
396 lines
13 KiB
Python
"""Image generation provider helpers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from nanobot.providers.registry import find_by_name
|
|
from nanobot.utils.helpers import detect_image_mime
|
|
|
|
_OPENROUTER_ATTRIBUTION_HEADERS = {
|
|
"HTTP-Referer": "https://github.com/HKUDS/nanobot",
|
|
"X-OpenRouter-Title": "nanobot",
|
|
"X-OpenRouter-Categories": "cli-agent,personal-agent",
|
|
}
|
|
_DEFAULT_TIMEOUT_S = 120.0
|
|
_AIHUBMIX_TIMEOUT_S = 300.0
|
|
_AIHUBMIX_ASPECT_RATIO_SIZES = {
|
|
"1:1": "1024x1024",
|
|
"3:4": "1024x1536",
|
|
"9:16": "1024x1536",
|
|
"4:3": "1536x1024",
|
|
"16:9": "1536x1024",
|
|
}
|
|
|
|
|
|
class ImageGenerationError(RuntimeError):
|
|
"""Raised when the image generation provider cannot return images."""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GeneratedImageResponse:
|
|
"""Images and optional text returned by the provider."""
|
|
|
|
images: list[str]
|
|
content: str
|
|
raw: dict[str, Any]
|
|
|
|
|
|
def _provider_base_url(provider: str, api_base: str | None, fallback: str) -> str:
|
|
if api_base:
|
|
return api_base.rstrip("/")
|
|
spec = find_by_name(provider)
|
|
if spec and spec.default_api_base:
|
|
return spec.default_api_base.rstrip("/")
|
|
return fallback
|
|
|
|
|
|
def image_path_to_data_url(path: str | Path) -> str:
|
|
"""Convert a local image path to an image data URL."""
|
|
p = Path(path).expanduser()
|
|
raw = p.read_bytes()
|
|
mime = detect_image_mime(raw)
|
|
if mime is None:
|
|
raise ImageGenerationError(f"unsupported reference image: {p}")
|
|
encoded = base64.b64encode(raw).decode("ascii")
|
|
return f"data:{mime};base64,{encoded}"
|
|
|
|
|
|
def _b64_png_data_url(value: str) -> str:
|
|
return f"data:image/png;base64,{value}"
|
|
|
|
|
|
def _aihubmix_size(aspect_ratio: str | None, image_size: str | None) -> str:
|
|
"""Return an OpenAI Images API size string for AIHubMix.
|
|
|
|
The WebUI emits compact size hints like ``1K`` for OpenRouter. AIHubMix's
|
|
Images API expects OpenAI-style dimensions or ``auto``, so only pass
|
|
through explicit dimension strings and otherwise derive the closest
|
|
supported orientation from aspect ratio.
|
|
"""
|
|
if image_size and "x" in image_size.lower():
|
|
return image_size
|
|
if aspect_ratio in _AIHUBMIX_ASPECT_RATIO_SIZES:
|
|
return _AIHUBMIX_ASPECT_RATIO_SIZES[aspect_ratio]
|
|
return "auto"
|
|
|
|
|
|
def _aihubmix_model_path(model: str) -> str:
|
|
if "/" in model:
|
|
return model
|
|
if model.startswith(("gpt-image-", "dall-e-")):
|
|
return f"openai/{model}"
|
|
return model
|
|
|
|
|
|
async def _download_image_data_url(
|
|
client: httpx.AsyncClient,
|
|
url: str,
|
|
) -> str:
|
|
response = await client.get(url)
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:500]
|
|
raise ImageGenerationError(f"failed to download generated image: {detail}") from exc
|
|
raw = response.content
|
|
mime = detect_image_mime(raw)
|
|
if mime is None:
|
|
raise ImageGenerationError("generated image URL did not return a supported image")
|
|
encoded = base64.b64encode(raw).decode("ascii")
|
|
return f"data:{mime};base64,{encoded}"
|
|
|
|
|
|
class OpenRouterImageGenerationClient:
|
|
"""Small async client for OpenRouter Chat Completions image generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str | None,
|
|
api_base: str | None = None,
|
|
extra_headers: dict[str, str] | None = None,
|
|
extra_body: dict[str, Any] | None = None,
|
|
timeout: float = _DEFAULT_TIMEOUT_S,
|
|
client: httpx.AsyncClient | None = None,
|
|
) -> None:
|
|
self.api_key = api_key
|
|
self.api_base = _provider_base_url(
|
|
"openrouter",
|
|
api_base,
|
|
"https://openrouter.ai/api/v1",
|
|
)
|
|
self.extra_headers = extra_headers or {}
|
|
self.extra_body = extra_body or {}
|
|
self.timeout = timeout
|
|
self._client = client
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
if not self.api_key:
|
|
raise ImageGenerationError(
|
|
"OpenRouter API key is not configured. Set providers.openrouter.apiKey."
|
|
)
|
|
|
|
content: str | list[dict[str, Any]]
|
|
references = list(reference_images or [])
|
|
if references:
|
|
blocks: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
|
|
blocks.extend(
|
|
{"type": "image_url", "image_url": {"url": image_path_to_data_url(path)}}
|
|
for path in references
|
|
)
|
|
content = blocks
|
|
else:
|
|
content = prompt
|
|
|
|
body: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"modalities": ["image", "text"],
|
|
"stream": False,
|
|
}
|
|
image_config: dict[str, str] = {}
|
|
if aspect_ratio:
|
|
image_config["aspect_ratio"] = aspect_ratio
|
|
if image_size:
|
|
image_config["image_size"] = image_size
|
|
if image_config:
|
|
body["image_config"] = image_config
|
|
body.update(self.extra_body)
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
**_OPENROUTER_ATTRIBUTION_HEADERS,
|
|
**self.extra_headers,
|
|
}
|
|
url = f"{self.api_base}/chat/completions"
|
|
|
|
if self._client is not None:
|
|
response = await self._client.post(url, headers=headers, json=body)
|
|
else:
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
response = await client.post(url, headers=headers, json=body)
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:500]
|
|
raise ImageGenerationError(f"OpenRouter image generation failed: {detail}") from exc
|
|
|
|
data = response.json()
|
|
images: list[str] = []
|
|
text_parts: list[str] = []
|
|
for choice in data.get("choices") or []:
|
|
if not isinstance(choice, dict):
|
|
continue
|
|
message = choice.get("message") or {}
|
|
if isinstance(message.get("content"), str):
|
|
text_parts.append(message["content"])
|
|
for image in message.get("images") or []:
|
|
if not isinstance(image, dict):
|
|
continue
|
|
image_url = image.get("image_url") or image.get("imageUrl") or {}
|
|
url_value = image_url.get("url") if isinstance(image_url, dict) else None
|
|
if isinstance(url_value, str) and url_value.startswith("data:image/"):
|
|
images.append(url_value)
|
|
|
|
if not images:
|
|
provider_error = data.get("error") if isinstance(data, dict) else None
|
|
if provider_error:
|
|
raise ImageGenerationError(f"OpenRouter returned no images: {provider_error}")
|
|
raise ImageGenerationError("OpenRouter returned no images for this request")
|
|
|
|
return GeneratedImageResponse(
|
|
images=images,
|
|
content="\n".join(part for part in text_parts if part).strip(),
|
|
raw=data,
|
|
)
|
|
|
|
|
|
class AIHubMixImageGenerationClient:
|
|
"""Small async client for AIHubMix unified image generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str | None,
|
|
api_base: str | None = None,
|
|
extra_headers: dict[str, str] | None = None,
|
|
extra_body: dict[str, Any] | None = None,
|
|
timeout: float = _AIHUBMIX_TIMEOUT_S,
|
|
client: httpx.AsyncClient | None = None,
|
|
) -> None:
|
|
self.api_key = api_key
|
|
self.api_base = _provider_base_url(
|
|
"aihubmix",
|
|
api_base,
|
|
"https://aihubmix.com/v1",
|
|
)
|
|
self.extra_headers = extra_headers or {}
|
|
self.extra_body = extra_body or {}
|
|
self.timeout = timeout
|
|
self._client = client
|
|
|
|
async def generate(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str] | None = None,
|
|
aspect_ratio: str | None = None,
|
|
image_size: str | None = None,
|
|
) -> GeneratedImageResponse:
|
|
if not self.api_key:
|
|
raise ImageGenerationError(
|
|
"AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
|
|
)
|
|
|
|
refs = list(reference_images or [])
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
**self.extra_headers,
|
|
}
|
|
size = _aihubmix_size(aspect_ratio, image_size)
|
|
|
|
if self._client is not None:
|
|
return await self._generate_with_client(
|
|
self._client,
|
|
prompt=prompt,
|
|
model=model,
|
|
reference_images=refs,
|
|
size=size,
|
|
headers=headers,
|
|
)
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
return await self._generate_with_client(
|
|
client,
|
|
prompt=prompt,
|
|
model=model,
|
|
reference_images=refs,
|
|
size=size,
|
|
headers=headers,
|
|
)
|
|
|
|
async def _generate_with_client(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
reference_images: list[str],
|
|
size: str,
|
|
headers: dict[str, str],
|
|
) -> GeneratedImageResponse:
|
|
image_input: str | list[str] | None = None
|
|
if reference_images:
|
|
image_refs = [image_path_to_data_url(path) for path in reference_images]
|
|
image_input = image_refs[0] if len(image_refs) == 1 else image_refs
|
|
|
|
input_body: dict[str, Any] = {
|
|
"prompt": prompt,
|
|
"n": 1,
|
|
"size": size,
|
|
}
|
|
if image_input is not None:
|
|
input_body["image"] = image_input
|
|
input_body.update(self.extra_body)
|
|
|
|
body = {"input": input_body}
|
|
model_path = _aihubmix_model_path(model)
|
|
url = f"{self.api_base}/models/{model_path}/predictions"
|
|
try:
|
|
response = await client.post(
|
|
url,
|
|
headers={**headers, "Content-Type": "application/json"},
|
|
json=body,
|
|
)
|
|
except httpx.TimeoutException as exc:
|
|
raise ImageGenerationError("AIHubMix image generation timed out") from exc
|
|
except httpx.RequestError as exc:
|
|
raise ImageGenerationError(f"AIHubMix image generation request failed: {exc}") from exc
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
detail = response.text[:500]
|
|
raise ImageGenerationError(f"AIHubMix image generation failed: {detail}") from exc
|
|
|
|
payload = response.json()
|
|
images = await _aihubmix_images_from_payload(client, payload)
|
|
|
|
if not images:
|
|
provider_error = payload.get("error") if isinstance(payload, dict) else None
|
|
if provider_error:
|
|
raise ImageGenerationError(f"AIHubMix returned no images: {provider_error}")
|
|
raise ImageGenerationError("AIHubMix returned no images for this request")
|
|
|
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
|
|
|
|
|
async def _aihubmix_images_from_payload(
|
|
client: httpx.AsyncClient,
|
|
payload: dict[str, Any],
|
|
) -> list[str]:
|
|
images: list[str] = []
|
|
candidates: list[Any] = []
|
|
if "data" in payload:
|
|
candidates.append(payload["data"])
|
|
if "output" in payload:
|
|
candidates.append(payload["output"])
|
|
|
|
async def collect(value: Any) -> None:
|
|
if isinstance(value, list):
|
|
for item in value:
|
|
await collect(item)
|
|
return
|
|
if isinstance(value, str):
|
|
if value.startswith("data:image/"):
|
|
images.append(value)
|
|
elif value.startswith(("http://", "https://")):
|
|
images.append(await _download_image_data_url(client, value))
|
|
return
|
|
if not isinstance(value, dict):
|
|
return
|
|
|
|
b64_json = value.get("b64_json")
|
|
if isinstance(b64_json, str) and b64_json:
|
|
images.append(_b64_png_data_url(b64_json))
|
|
elif b64_json is not None:
|
|
await collect(b64_json)
|
|
|
|
bytes_base64 = value.get("bytesBase64") or value.get("bytes_base64") or value.get("base64")
|
|
if isinstance(bytes_base64, str) and bytes_base64:
|
|
images.append(_b64_png_data_url(bytes_base64))
|
|
|
|
image_url = value.get("image_url") or value.get("imageUrl")
|
|
if isinstance(image_url, dict):
|
|
await collect(image_url.get("url"))
|
|
elif image_url is not None:
|
|
await collect(image_url)
|
|
|
|
url_value = value.get("url")
|
|
if url_value is not None:
|
|
await collect(url_value)
|
|
|
|
for key in ("images", "image", "output"):
|
|
if key in value:
|
|
await collect(value[key])
|
|
|
|
for candidate in candidates:
|
|
await collect(candidate)
|
|
return images
|