mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 08:02:30 +00:00
refactor(image-generation): introduce provider registry to eliminate manual wiring
Adds ImageGenerationProvider ABC with shared __init__, _http_post(), and _require_images(). Introduces _IMAGE_GEN_PROVIDERS registry with register/get/image_gen_provider_configs() helpers. Four existing providers (OpenRouter, AIHubMix, Gemini, MiniMax) now inherit from the base class and self-register. Adding a new provider only requires writing one class + one registration line. Eliminates if/else chains in the tool dispatch and hardcoded provider config dicts in commands.py (3 sites) and nanobot.py (1 site). Fixes the agent CLI command missing image_generation_provider_configs entirely. Also simplifies test monkeypatch targets to patch the registry lookup.
This commit is contained in:
parent
bb788cdb7d
commit
7aa5b9b17b
@ -17,11 +17,9 @@ from nanobot.agent.tools.schema import (
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
from nanobot.providers.image_generation import (
|
||||
AIHubMixImageGenerationClient,
|
||||
GeminiImageGenerationClient,
|
||||
ImageGenerationError,
|
||||
MiniMaxImageGenerationClient,
|
||||
OpenRouterImageGenerationClient,
|
||||
ImageGenerationProvider,
|
||||
get_image_gen_provider,
|
||||
)
|
||||
from nanobot.utils.artifacts import (
|
||||
ArtifactError,
|
||||
@ -119,37 +117,24 @@ class ImageGenerationTool(Tool):
|
||||
def _provider_config(self) -> ProviderConfig | None:
|
||||
return self.provider_configs.get(self.config.provider)
|
||||
|
||||
def _provider_client(
|
||||
self,
|
||||
) -> OpenRouterImageGenerationClient | AIHubMixImageGenerationClient | MiniMaxImageGenerationClient | GeminiImageGenerationClient | None:
|
||||
def _provider_client(self) -> ImageGenerationProvider | None:
|
||||
provider = self._provider_config()
|
||||
cls = get_image_gen_provider(self.config.provider)
|
||||
if cls is None:
|
||||
return None
|
||||
kwargs = {
|
||||
"api_key": provider.api_key if provider else None,
|
||||
"api_base": provider.api_base if provider else None,
|
||||
"extra_headers": provider.extra_headers if provider else None,
|
||||
"extra_body": provider.extra_body if provider else None,
|
||||
}
|
||||
if self.config.provider == "openrouter":
|
||||
return OpenRouterImageGenerationClient(**kwargs)
|
||||
if self.config.provider == "aihubmix":
|
||||
return AIHubMixImageGenerationClient(**kwargs)
|
||||
if self.config.provider == "minimax":
|
||||
return MiniMaxImageGenerationClient(**kwargs)
|
||||
if self.config.provider == "gemini":
|
||||
return GeminiImageGenerationClient(**kwargs)
|
||||
return None
|
||||
return cls(**kwargs)
|
||||
|
||||
def _missing_api_key_error(self) -> str:
|
||||
provider = self.config.provider
|
||||
if provider == "openrouter":
|
||||
return "Error: OpenRouter API key is not configured. Set providers.openrouter.apiKey."
|
||||
if provider == "aihubmix":
|
||||
return "Error: AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
|
||||
if provider == "minimax":
|
||||
return "Error: MiniMax API key is not configured. Set providers.minimax.apiKey."
|
||||
if provider == "gemini":
|
||||
return "Error: Gemini API key is not configured. Set providers.gemini.apiKey."
|
||||
return f"Error: {provider} API key is not configured."
|
||||
cls = get_image_gen_provider(self.config.provider)
|
||||
if cls and cls.missing_key_message:
|
||||
return f"Error: {cls.missing_key_message}"
|
||||
return f"Error: {self.config.provider} API key is not configured."
|
||||
|
||||
def _resolve_reference_image(self, value: str) -> str:
|
||||
raw_path = Path(value).expanduser()
|
||||
|
||||
@ -620,6 +620,7 @@ def serve(
|
||||
|
||||
from nanobot.api.server import create_app
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.image_generation import image_gen_provider_configs
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
if verbose:
|
||||
@ -639,12 +640,7 @@ def serve(
|
||||
agent_loop = AgentLoop.from_config(
|
||||
runtime_config, bus,
|
||||
session_manager=session_manager,
|
||||
image_generation_provider_configs={
|
||||
"openrouter": runtime_config.providers.openrouter,
|
||||
"aihubmix": runtime_config.providers.aihubmix,
|
||||
"minimax": runtime_config.providers.minimax,
|
||||
"gemini": runtime_config.providers.gemini,
|
||||
},
|
||||
image_generation_provider_configs=image_gen_provider_configs(runtime_config),
|
||||
)
|
||||
except ValueError as exc:
|
||||
console.print(f"[red]Error: {exc}[/red]")
|
||||
@ -724,6 +720,7 @@ def _run_gateway(
|
||||
from nanobot.cron.types import CronJob
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
from nanobot.providers.factory import build_provider_snapshot, load_provider_snapshot
|
||||
from nanobot.providers.image_generation import image_gen_provider_configs
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
port = port if port is not None else config.gateway.port
|
||||
@ -754,12 +751,7 @@ def _run_gateway(
|
||||
context_window_tokens=provider_snapshot.context_window_tokens,
|
||||
cron_service=cron,
|
||||
session_manager=session_manager,
|
||||
image_generation_provider_configs={
|
||||
"openrouter": config.providers.openrouter,
|
||||
"aihubmix": config.providers.aihubmix,
|
||||
"minimax": config.providers.minimax,
|
||||
"gemini": config.providers.gemini,
|
||||
},
|
||||
image_generation_provider_configs=image_gen_provider_configs(config),
|
||||
provider_snapshot_loader=load_provider_snapshot,
|
||||
runtime_model_publisher=lambda model, preset: publish_runtime_model_update(
|
||||
bus,
|
||||
@ -1126,6 +1118,7 @@ def agent(
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.providers.image_generation import image_gen_provider_configs
|
||||
|
||||
config = _load_runtime_config(config, workspace)
|
||||
sync_workspace_templates(config.workspace_path)
|
||||
@ -1149,6 +1142,7 @@ def agent(
|
||||
agent_loop = AgentLoop.from_config(
|
||||
config, bus,
|
||||
cron_service=cron,
|
||||
image_generation_provider_configs=image_gen_provider_configs(config),
|
||||
)
|
||||
except ValueError as exc:
|
||||
console.print(f"[red]Error: {exc}[/red]")
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any
|
||||
|
||||
from nanobot.agent.hook import AgentHook, SDKCaptureHook
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.providers.image_generation import image_gen_provider_configs
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -63,12 +64,7 @@ class Nanobot:
|
||||
|
||||
loop = AgentLoop.from_config(
|
||||
config,
|
||||
image_generation_provider_configs={
|
||||
"openrouter": config.providers.openrouter,
|
||||
"aihubmix": config.providers.aihubmix,
|
||||
"minimax": config.providers.minimax,
|
||||
"gemini": config.providers.gemini,
|
||||
},
|
||||
image_generation_provider_configs=image_gen_provider_configs(config),
|
||||
)
|
||||
return cls(loop)
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@ -44,15 +45,6 @@ class GeneratedImageResponse:
|
||||
raw: dict[str, Any]
|
||||
|
||||
|
||||
def _provider_base_url(provider: str, api_base: str | None, fallback: str) -> str:
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
spec = find_by_name(provider)
|
||||
if spec and spec.default_api_base:
|
||||
return spec.default_api_base.rstrip("/")
|
||||
return fallback
|
||||
|
||||
|
||||
def _read_image_b64(path: str | Path) -> tuple[str, str]:
|
||||
"""Return ``(mime, base64)`` for the image at ``path``."""
|
||||
p = Path(path).expanduser()
|
||||
@ -120,8 +112,44 @@ async def _download_image_data_url(
|
||||
return f"data:{mime};base64,{encoded}"
|
||||
|
||||
|
||||
class OpenRouterImageGenerationClient:
|
||||
"""Small async client for OpenRouter Chat Completions image generation."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IMAGE_GEN_PROVIDERS: dict[str, type[ImageGenerationProvider]] = {}
|
||||
|
||||
|
||||
def register_image_gen_provider(cls: type[ImageGenerationProvider]) -> None:
|
||||
name = cls.provider_name
|
||||
if not name:
|
||||
raise ValueError(f"{cls.__name__} must set provider_name")
|
||||
_IMAGE_GEN_PROVIDERS[name] = cls
|
||||
|
||||
|
||||
def get_image_gen_provider(name: str) -> type[ImageGenerationProvider] | None:
|
||||
return _IMAGE_GEN_PROVIDERS.get(name)
|
||||
|
||||
|
||||
def image_gen_provider_configs(config: Any) -> dict[str, Any]:
|
||||
providers_cfg = config.providers
|
||||
return {
|
||||
name: pc
|
||||
for name in _IMAGE_GEN_PROVIDERS
|
||||
if (pc := getattr(providers_cfg, name, None)) is not None
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ImageGenerationProvider(ABC):
|
||||
"""Base class for image generation provider clients."""
|
||||
|
||||
provider_name: str = ""
|
||||
missing_key_message: str = ""
|
||||
default_timeout: float = _DEFAULT_TIMEOUT_S
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -130,20 +158,71 @@ class OpenRouterImageGenerationClient:
|
||||
api_base: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict[str, Any] | None = None,
|
||||
timeout: float = _DEFAULT_TIMEOUT_S,
|
||||
timeout: float | None = None,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.api_base = _provider_base_url(
|
||||
"openrouter",
|
||||
api_base,
|
||||
"https://openrouter.ai/api/v1",
|
||||
)
|
||||
self.api_base = self._resolve_base_url(api_base)
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.extra_body = extra_body or {}
|
||||
self.timeout = timeout
|
||||
self.timeout = timeout if timeout is not None else self.default_timeout
|
||||
self._client = client
|
||||
|
||||
def _resolve_base_url(self, api_base: str | None) -> str:
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
spec = find_by_name(self.provider_name)
|
||||
if spec and spec.default_api_base:
|
||||
return spec.default_api_base.rstrip("/")
|
||||
return self._default_base_url()
|
||||
|
||||
def _default_base_url(self) -> str:
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
reference_images: list[str] | None = None,
|
||||
aspect_ratio: str | None = None,
|
||||
image_size: str | None = None,
|
||||
) -> GeneratedImageResponse: ...
|
||||
|
||||
def _require_images(self, images: list[str], data: dict[str, Any]) -> None:
|
||||
if images:
|
||||
return
|
||||
provider_error = data.get("error") if isinstance(data, dict) else None
|
||||
label = self.provider_name
|
||||
if provider_error:
|
||||
raise ImageGenerationError(f"{label} returned no images: {provider_error}")
|
||||
raise ImageGenerationError(f"{label} returned no images for this request")
|
||||
|
||||
async def _http_post(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
headers: dict[str, str],
|
||||
body: dict[str, Any],
|
||||
) -> httpx.Response:
|
||||
if self._client is not None:
|
||||
return await self._client.post(url, headers=headers, json=body)
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as c:
|
||||
return await c.post(url, headers=headers, json=body)
|
||||
|
||||
|
||||
class OpenRouterImageGenerationClient(ImageGenerationProvider):
|
||||
"""Small async client for OpenRouter Chat Completions image generation."""
|
||||
|
||||
provider_name = "openrouter"
|
||||
missing_key_message = (
|
||||
"OpenRouter API key is not configured. Set providers.openrouter.apiKey."
|
||||
)
|
||||
|
||||
def _default_base_url(self) -> str:
|
||||
return "https://openrouter.ai/api/v1"
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
*,
|
||||
@ -154,9 +233,7 @@ class OpenRouterImageGenerationClient:
|
||||
image_size: str | None = None,
|
||||
) -> GeneratedImageResponse:
|
||||
if not self.api_key:
|
||||
raise ImageGenerationError(
|
||||
"OpenRouter API key is not configured. Set providers.openrouter.apiKey."
|
||||
)
|
||||
raise ImageGenerationError(self.missing_key_message)
|
||||
|
||||
content: str | list[dict[str, Any]]
|
||||
references = list(reference_images or [])
|
||||
@ -192,12 +269,7 @@ class OpenRouterImageGenerationClient:
|
||||
**self.extra_headers,
|
||||
}
|
||||
url = f"{self.api_base}/chat/completions"
|
||||
|
||||
if self._client is not None:
|
||||
response = await self._client.post(url, headers=headers, json=body)
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
response = await self._http_post(url, headers=headers, body=body)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
@ -222,11 +294,7 @@ class OpenRouterImageGenerationClient:
|
||||
if isinstance(url_value, str) and url_value.startswith("data:image/"):
|
||||
images.append(url_value)
|
||||
|
||||
if not images:
|
||||
provider_error = data.get("error") if isinstance(data, dict) else None
|
||||
if provider_error:
|
||||
raise ImageGenerationError(f"OpenRouter returned no images: {provider_error}")
|
||||
raise ImageGenerationError("OpenRouter returned no images for this request")
|
||||
self._require_images(images, data)
|
||||
|
||||
return GeneratedImageResponse(
|
||||
images=images,
|
||||
@ -235,29 +303,17 @@ class OpenRouterImageGenerationClient:
|
||||
)
|
||||
|
||||
|
||||
class AIHubMixImageGenerationClient:
|
||||
class AIHubMixImageGenerationClient(ImageGenerationProvider):
|
||||
"""Small async client for AIHubMix unified image generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict[str, Any] | None = None,
|
||||
timeout: float = _AIHUBMIX_TIMEOUT_S,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.api_base = _provider_base_url(
|
||||
"aihubmix",
|
||||
api_base,
|
||||
"https://aihubmix.com/v1",
|
||||
)
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.extra_body = extra_body or {}
|
||||
self.timeout = timeout
|
||||
self._client = client
|
||||
provider_name = "aihubmix"
|
||||
missing_key_message = (
|
||||
"AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
|
||||
)
|
||||
default_timeout = _AIHUBMIX_TIMEOUT_S
|
||||
|
||||
def _default_base_url(self) -> str:
|
||||
return "https://aihubmix.com/v1"
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
@ -269,9 +325,7 @@ class AIHubMixImageGenerationClient:
|
||||
image_size: str | None = None,
|
||||
) -> GeneratedImageResponse:
|
||||
if not self.api_key:
|
||||
raise ImageGenerationError(
|
||||
"AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
|
||||
)
|
||||
raise ImageGenerationError(self.missing_key_message)
|
||||
|
||||
refs = list(reference_images or [])
|
||||
headers = {
|
||||
@ -280,16 +334,8 @@ class AIHubMixImageGenerationClient:
|
||||
}
|
||||
size = _aihubmix_size(aspect_ratio, image_size)
|
||||
|
||||
if self._client is not None:
|
||||
return await self._generate_with_client(
|
||||
self._client,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
reference_images=refs,
|
||||
size=size,
|
||||
headers=headers,
|
||||
)
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
client = self._client or httpx.AsyncClient(timeout=self.timeout)
|
||||
try:
|
||||
return await self._generate_with_client(
|
||||
client,
|
||||
prompt=prompt,
|
||||
@ -298,6 +344,9 @@ class AIHubMixImageGenerationClient:
|
||||
size=size,
|
||||
headers=headers,
|
||||
)
|
||||
finally:
|
||||
if self._client is None:
|
||||
await client.aclose()
|
||||
|
||||
async def _generate_with_client(
|
||||
self,
|
||||
@ -346,11 +395,7 @@ class AIHubMixImageGenerationClient:
|
||||
payload = response.json()
|
||||
images = await _aihubmix_images_from_payload(client, payload)
|
||||
|
||||
if not images:
|
||||
provider_error = payload.get("error") if isinstance(payload, dict) else None
|
||||
if provider_error:
|
||||
raise ImageGenerationError(f"AIHubMix returned no images: {provider_error}")
|
||||
raise ImageGenerationError("AIHubMix returned no images for this request")
|
||||
self._require_images(images, payload)
|
||||
|
||||
return GeneratedImageResponse(images=images, content="", raw=payload)
|
||||
|
||||
@ -370,31 +415,25 @@ def _http_error_detail(response: httpx.Response) -> str:
|
||||
return response.text[:500] or "<empty response body>"
|
||||
|
||||
|
||||
class GeminiImageGenerationClient:
|
||||
class GeminiImageGenerationClient(ImageGenerationProvider):
|
||||
"""Async client for Gemini/Imagen image generation via the Generative Language API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict[str, Any] | None = None,
|
||||
timeout: float = _GEMINI_DEFAULT_TIMEOUT_S,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
provider_name = "gemini"
|
||||
missing_key_message = (
|
||||
"Gemini API key is not configured. Set providers.gemini.apiKey."
|
||||
)
|
||||
default_timeout = _GEMINI_DEFAULT_TIMEOUT_S
|
||||
|
||||
def _default_base_url(self) -> str:
|
||||
return "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
def _resolve_base_url(self, api_base: str | None) -> str:
|
||||
# The Gemini provider's registry default_api_base is the OpenAI-compat
|
||||
# shim (.../v1beta/openai/), which has no image endpoints. Image
|
||||
# generation needs the native Generative Language API base, so we don't
|
||||
# use _provider_base_url() here.
|
||||
self.api_base = (
|
||||
api_base or "https://generativelanguage.googleapis.com/v1beta"
|
||||
).rstrip("/")
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.extra_body = extra_body or {}
|
||||
self.timeout = timeout
|
||||
self._client = client
|
||||
# shim (.../v1beta/openai/), which has no image endpoints.
|
||||
# Skip the registry lookup and use the native API base directly.
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
return self._default_base_url()
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
@ -406,9 +445,7 @@ class GeminiImageGenerationClient:
|
||||
image_size: str | None = None,
|
||||
) -> GeneratedImageResponse:
|
||||
if not self.api_key:
|
||||
raise ImageGenerationError(
|
||||
"Gemini API key is not configured. Set providers.gemini.apiKey."
|
||||
)
|
||||
raise ImageGenerationError(self.missing_key_message)
|
||||
if "imagen" in model.lower():
|
||||
if reference_images:
|
||||
logger.warning(
|
||||
@ -446,12 +483,7 @@ class GeminiImageGenerationClient:
|
||||
"Content-Type": "application/json",
|
||||
**self.extra_headers,
|
||||
}
|
||||
|
||||
if self._client is not None:
|
||||
response = await self._client.post(url, headers=headers, json=body)
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
response = await self._http_post(url, headers=headers, body=body)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
@ -472,11 +504,7 @@ class GeminiImageGenerationClient:
|
||||
if isinstance(b64, str) and b64:
|
||||
images.append(f"data:{mime};base64,{b64}")
|
||||
|
||||
if not images:
|
||||
provider_error = data.get("error") if isinstance(data, dict) else None
|
||||
if provider_error:
|
||||
raise ImageGenerationError(f"Gemini Imagen returned no images: {provider_error}")
|
||||
raise ImageGenerationError("Gemini Imagen returned no images for this request")
|
||||
self._require_images(images, data)
|
||||
|
||||
return GeneratedImageResponse(images=images, content="", raw=data)
|
||||
|
||||
@ -504,12 +532,7 @@ class GeminiImageGenerationClient:
|
||||
"Content-Type": "application/json",
|
||||
**self.extra_headers,
|
||||
}
|
||||
|
||||
if self._client is not None:
|
||||
response = await self._client.post(url, headers=headers, json=body)
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
response = await self._http_post(url, headers=headers, body=body)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
@ -539,11 +562,7 @@ class GeminiImageGenerationClient:
|
||||
if b64:
|
||||
images.append(f"data:{mime};base64,{b64}")
|
||||
|
||||
if not images:
|
||||
provider_error = data.get("error") if isinstance(data, dict) else None
|
||||
if provider_error:
|
||||
raise ImageGenerationError(f"Gemini returned no images: {provider_error}")
|
||||
raise ImageGenerationError("Gemini returned no images for this request")
|
||||
self._require_images(images, data)
|
||||
|
||||
return GeneratedImageResponse(
|
||||
images=images,
|
||||
@ -620,29 +639,17 @@ _MINIMAX_ASPECT_RATIO_SIZES = {
|
||||
}
|
||||
|
||||
|
||||
class MiniMaxImageGenerationClient:
|
||||
class MiniMaxImageGenerationClient(ImageGenerationProvider):
|
||||
"""Async client for MiniMax image generation API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict[str, Any] | None = None,
|
||||
timeout: float = _MINIMAX_TIMEOUT_S,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.api_base = _provider_base_url(
|
||||
"minimax",
|
||||
api_base,
|
||||
"https://api.minimaxi.com/v1",
|
||||
)
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.extra_body = extra_body or {}
|
||||
self.timeout = timeout
|
||||
self._client = client
|
||||
provider_name = "minimax"
|
||||
missing_key_message = (
|
||||
"MiniMax API key is not configured. Set providers.minimax.apiKey."
|
||||
)
|
||||
default_timeout = _MINIMAX_TIMEOUT_S
|
||||
|
||||
def _default_base_url(self) -> str:
|
||||
return "https://api.minimaxi.com/v1"
|
||||
|
||||
def _resolve_aspect_ratio(self, aspect_ratio: str | None) -> str:
|
||||
if aspect_ratio and aspect_ratio in _MINIMAX_ASPECT_RATIO_SIZES:
|
||||
@ -659,9 +666,7 @@ class MiniMaxImageGenerationClient:
|
||||
image_size: str | None = None,
|
||||
) -> GeneratedImageResponse:
|
||||
if not self.api_key:
|
||||
raise ImageGenerationError(
|
||||
"MiniMax API key is not configured. Set providers.minimax.apiKey."
|
||||
)
|
||||
raise ImageGenerationError(self.missing_key_message)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
@ -687,10 +692,12 @@ class MiniMaxImageGenerationClient:
|
||||
|
||||
body.update(self.extra_body)
|
||||
|
||||
if self._client is not None:
|
||||
return await self._generate_with_client(self._client, body, headers)
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
client = self._client or httpx.AsyncClient(timeout=self.timeout)
|
||||
try:
|
||||
return await self._generate_with_client(client, body, headers)
|
||||
finally:
|
||||
if self._client is None:
|
||||
await client.aclose()
|
||||
|
||||
async def _generate_with_client(
|
||||
self,
|
||||
@ -715,11 +722,7 @@ class MiniMaxImageGenerationClient:
|
||||
payload = response.json()
|
||||
images = _minimax_images_from_payload(payload)
|
||||
|
||||
if not images:
|
||||
provider_error = payload.get("error") if isinstance(payload, dict) else None
|
||||
if provider_error:
|
||||
raise ImageGenerationError(f"MiniMax returned no images: {provider_error}")
|
||||
raise ImageGenerationError("MiniMax returned no images for this request")
|
||||
self._require_images(images, payload)
|
||||
|
||||
return GeneratedImageResponse(images=images, content="", raw=payload)
|
||||
|
||||
@ -737,3 +740,13 @@ def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
||||
if isinstance(b64, str) and b64:
|
||||
images.append(_b64_png_data_url(b64))
|
||||
return images
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
register_image_gen_provider(OpenRouterImageGenerationClient)
|
||||
register_image_gen_provider(AIHubMixImageGenerationClient)
|
||||
register_image_gen_provider(GeminiImageGenerationClient)
|
||||
register_image_gen_provider(MiniMaxImageGenerationClient)
|
||||
|
||||
@ -35,8 +35,8 @@ async def test_generated_image_media_is_attached_to_final_assistant_message(
|
||||
) -> None:
|
||||
set_config_path(tmp_path / "config.json")
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.tools.image_generation.OpenRouterImageGenerationClient",
|
||||
FakeImageClient,
|
||||
"nanobot.agent.tools.image_generation.get_image_gen_provider",
|
||||
lambda name: FakeImageClient if name == "openrouter" else None,
|
||||
)
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
@ -44,8 +44,8 @@ async def test_generate_image_tool_stores_artifact_and_source_images(
|
||||
set_config_path(tmp_path / "config.json")
|
||||
FakeImageClient.instances = []
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.tools.image_generation.OpenRouterImageGenerationClient",
|
||||
FakeImageClient,
|
||||
"nanobot.agent.tools.image_generation.get_image_gen_provider",
|
||||
lambda name: FakeImageClient if name == "openrouter" else None,
|
||||
)
|
||||
ref = tmp_path / "ref.png"
|
||||
ref.write_bytes(PNG_BYTES)
|
||||
@ -98,8 +98,8 @@ async def test_generate_image_tool_selects_aihubmix_provider(
|
||||
set_config_path(tmp_path / "config.json")
|
||||
FakeImageClient.instances = []
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.tools.image_generation.AIHubMixImageGenerationClient",
|
||||
FakeImageClient,
|
||||
"nanobot.agent.tools.image_generation.get_image_gen_provider",
|
||||
lambda name: FakeImageClient if name == "aihubmix" else None,
|
||||
)
|
||||
tool = ImageGenerationTool(
|
||||
workspace=tmp_path,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user