refactor(image-generation): introduce provider registry to eliminate manual wiring

Adds ImageGenerationProvider ABC with shared __init__, _http_post(), and
_require_images(). Introduces _IMAGE_GEN_PROVIDERS registry with
register/get/image_gen_provider_configs() helpers.

Four existing providers (OpenRouter, AIHubMix, Gemini, MiniMax) now inherit
from the base class and self-register. Adding a new provider only requires
writing one class + one registration line.

Eliminates if/else chains in the tool dispatch and hardcoded provider config
dicts in commands.py (3 sites) and nanobot.py (1 site). Fixes the agent CLI
command missing image_generation_provider_configs entirely.

Also simplifies test monkeypatch targets to patch the registry lookup.
This commit is contained in:
chengyongru 2026-05-18 17:20:54 +08:00 committed by Xubin Ren
parent 7367741ac1
commit c588d56a77
6 changed files with 188 additions and 200 deletions

View File

@ -17,11 +17,9 @@ from nanobot.agent.tools.schema import (
from nanobot.config.paths import get_media_dir from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base from nanobot.config.schema import Base
from nanobot.providers.image_generation import ( from nanobot.providers.image_generation import (
AIHubMixImageGenerationClient,
GeminiImageGenerationClient,
ImageGenerationError, ImageGenerationError,
MiniMaxImageGenerationClient, ImageGenerationProvider,
OpenRouterImageGenerationClient, get_image_gen_provider,
) )
from nanobot.utils.artifacts import ( from nanobot.utils.artifacts import (
ArtifactError, ArtifactError,
@ -119,37 +117,24 @@ class ImageGenerationTool(Tool):
def _provider_config(self) -> ProviderConfig | None: def _provider_config(self) -> ProviderConfig | None:
return self.provider_configs.get(self.config.provider) return self.provider_configs.get(self.config.provider)
def _provider_client( def _provider_client(self) -> ImageGenerationProvider | None:
self,
) -> OpenRouterImageGenerationClient | AIHubMixImageGenerationClient | MiniMaxImageGenerationClient | GeminiImageGenerationClient | None:
provider = self._provider_config() provider = self._provider_config()
cls = get_image_gen_provider(self.config.provider)
if cls is None:
return None
kwargs = { kwargs = {
"api_key": provider.api_key if provider else None, "api_key": provider.api_key if provider else None,
"api_base": provider.api_base if provider else None, "api_base": provider.api_base if provider else None,
"extra_headers": provider.extra_headers if provider else None, "extra_headers": provider.extra_headers if provider else None,
"extra_body": provider.extra_body if provider else None, "extra_body": provider.extra_body if provider else None,
} }
if self.config.provider == "openrouter": return cls(**kwargs)
return OpenRouterImageGenerationClient(**kwargs)
if self.config.provider == "aihubmix":
return AIHubMixImageGenerationClient(**kwargs)
if self.config.provider == "minimax":
return MiniMaxImageGenerationClient(**kwargs)
if self.config.provider == "gemini":
return GeminiImageGenerationClient(**kwargs)
return None
def _missing_api_key_error(self) -> str: def _missing_api_key_error(self) -> str:
provider = self.config.provider cls = get_image_gen_provider(self.config.provider)
if provider == "openrouter": if cls and cls.missing_key_message:
return "Error: OpenRouter API key is not configured. Set providers.openrouter.apiKey." return f"Error: {cls.missing_key_message}"
if provider == "aihubmix": return f"Error: {self.config.provider} API key is not configured."
return "Error: AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
if provider == "minimax":
return "Error: MiniMax API key is not configured. Set providers.minimax.apiKey."
if provider == "gemini":
return "Error: Gemini API key is not configured. Set providers.gemini.apiKey."
return f"Error: {provider} API key is not configured."
def _resolve_reference_image(self, value: str) -> str: def _resolve_reference_image(self, value: str) -> str:
raw_path = Path(value).expanduser() raw_path = Path(value).expanduser()

View File

@ -620,6 +620,7 @@ def serve(
from nanobot.api.server import create_app from nanobot.api.server import create_app
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.providers.image_generation import image_gen_provider_configs
from nanobot.session.manager import SessionManager from nanobot.session.manager import SessionManager
if verbose: if verbose:
@ -639,12 +640,7 @@ def serve(
agent_loop = AgentLoop.from_config( agent_loop = AgentLoop.from_config(
runtime_config, bus, runtime_config, bus,
session_manager=session_manager, session_manager=session_manager,
image_generation_provider_configs={ image_generation_provider_configs=image_gen_provider_configs(runtime_config),
"openrouter": runtime_config.providers.openrouter,
"aihubmix": runtime_config.providers.aihubmix,
"minimax": runtime_config.providers.minimax,
"gemini": runtime_config.providers.gemini,
},
) )
except ValueError as exc: except ValueError as exc:
console.print(f"[red]Error: {exc}[/red]") console.print(f"[red]Error: {exc}[/red]")
@ -724,6 +720,7 @@ def _run_gateway(
from nanobot.cron.types import CronJob from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService from nanobot.heartbeat.service import HeartbeatService
from nanobot.providers.factory import build_provider_snapshot, load_provider_snapshot from nanobot.providers.factory import build_provider_snapshot, load_provider_snapshot
from nanobot.providers.image_generation import image_gen_provider_configs
from nanobot.session.manager import SessionManager from nanobot.session.manager import SessionManager
port = port if port is not None else config.gateway.port port = port if port is not None else config.gateway.port
@ -754,12 +751,7 @@ def _run_gateway(
context_window_tokens=provider_snapshot.context_window_tokens, context_window_tokens=provider_snapshot.context_window_tokens,
cron_service=cron, cron_service=cron,
session_manager=session_manager, session_manager=session_manager,
image_generation_provider_configs={ image_generation_provider_configs=image_gen_provider_configs(config),
"openrouter": config.providers.openrouter,
"aihubmix": config.providers.aihubmix,
"minimax": config.providers.minimax,
"gemini": config.providers.gemini,
},
provider_snapshot_loader=load_provider_snapshot, provider_snapshot_loader=load_provider_snapshot,
runtime_model_publisher=lambda model, preset: publish_runtime_model_update( runtime_model_publisher=lambda model, preset: publish_runtime_model_update(
bus, bus,
@ -1126,6 +1118,7 @@ def agent(
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
from nanobot.providers.image_generation import image_gen_provider_configs
config = _load_runtime_config(config, workspace) config = _load_runtime_config(config, workspace)
sync_workspace_templates(config.workspace_path) sync_workspace_templates(config.workspace_path)
@ -1149,6 +1142,7 @@ def agent(
agent_loop = AgentLoop.from_config( agent_loop = AgentLoop.from_config(
config, bus, config, bus,
cron_service=cron, cron_service=cron,
image_generation_provider_configs=image_gen_provider_configs(config),
) )
except ValueError as exc: except ValueError as exc:
console.print(f"[red]Error: {exc}[/red]") console.print(f"[red]Error: {exc}[/red]")

View File

@ -8,6 +8,7 @@ from typing import Any
from nanobot.agent.hook import AgentHook, SDKCaptureHook from nanobot.agent.hook import AgentHook, SDKCaptureHook
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.providers.image_generation import image_gen_provider_configs
@dataclass(slots=True) @dataclass(slots=True)
@ -63,12 +64,7 @@ class Nanobot:
loop = AgentLoop.from_config( loop = AgentLoop.from_config(
config, config,
image_generation_provider_configs={ image_generation_provider_configs=image_gen_provider_configs(config),
"openrouter": config.providers.openrouter,
"aihubmix": config.providers.aihubmix,
"minimax": config.providers.minimax,
"gemini": config.providers.gemini,
},
) )
return cls(loop) return cls(loop)

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import base64 import base64
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -44,15 +45,6 @@ class GeneratedImageResponse:
raw: dict[str, Any] raw: dict[str, Any]
def _provider_base_url(provider: str, api_base: str | None, fallback: str) -> str:
if api_base:
return api_base.rstrip("/")
spec = find_by_name(provider)
if spec and spec.default_api_base:
return spec.default_api_base.rstrip("/")
return fallback
def _read_image_b64(path: str | Path) -> tuple[str, str]: def _read_image_b64(path: str | Path) -> tuple[str, str]:
"""Return ``(mime, base64)`` for the image at ``path``.""" """Return ``(mime, base64)`` for the image at ``path``."""
p = Path(path).expanduser() p = Path(path).expanduser()
@ -120,8 +112,44 @@ async def _download_image_data_url(
return f"data:{mime};base64,{encoded}" return f"data:{mime};base64,{encoded}"
class OpenRouterImageGenerationClient: # ---------------------------------------------------------------------------
"""Small async client for OpenRouter Chat Completions image generation.""" # Registry
# ---------------------------------------------------------------------------
_IMAGE_GEN_PROVIDERS: dict[str, type[ImageGenerationProvider]] = {}
def register_image_gen_provider(cls: type[ImageGenerationProvider]) -> None:
name = cls.provider_name
if not name:
raise ValueError(f"{cls.__name__} must set provider_name")
_IMAGE_GEN_PROVIDERS[name] = cls
def get_image_gen_provider(name: str) -> type[ImageGenerationProvider] | None:
return _IMAGE_GEN_PROVIDERS.get(name)
def image_gen_provider_configs(config: Any) -> dict[str, Any]:
providers_cfg = config.providers
return {
name: pc
for name in _IMAGE_GEN_PROVIDERS
if (pc := getattr(providers_cfg, name, None)) is not None
}
# ---------------------------------------------------------------------------
# Base class
# ---------------------------------------------------------------------------
class ImageGenerationProvider(ABC):
"""Base class for image generation provider clients."""
provider_name: str = ""
missing_key_message: str = ""
default_timeout: float = _DEFAULT_TIMEOUT_S
def __init__( def __init__(
self, self,
@ -130,20 +158,71 @@ class OpenRouterImageGenerationClient:
api_base: str | None = None, api_base: str | None = None,
extra_headers: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None,
extra_body: dict[str, Any] | None = None, extra_body: dict[str, Any] | None = None,
timeout: float = _DEFAULT_TIMEOUT_S, timeout: float | None = None,
client: httpx.AsyncClient | None = None, client: httpx.AsyncClient | None = None,
) -> None: ) -> None:
self.api_key = api_key self.api_key = api_key
self.api_base = _provider_base_url( self.api_base = self._resolve_base_url(api_base)
"openrouter",
api_base,
"https://openrouter.ai/api/v1",
)
self.extra_headers = extra_headers or {} self.extra_headers = extra_headers or {}
self.extra_body = extra_body or {} self.extra_body = extra_body or {}
self.timeout = timeout self.timeout = timeout if timeout is not None else self.default_timeout
self._client = client self._client = client
def _resolve_base_url(self, api_base: str | None) -> str:
if api_base:
return api_base.rstrip("/")
spec = find_by_name(self.provider_name)
if spec and spec.default_api_base:
return spec.default_api_base.rstrip("/")
return self._default_base_url()
def _default_base_url(self) -> str:
return ""
@abstractmethod
async def generate(
self,
*,
prompt: str,
model: str,
reference_images: list[str] | None = None,
aspect_ratio: str | None = None,
image_size: str | None = None,
) -> GeneratedImageResponse: ...
def _require_images(self, images: list[str], data: dict[str, Any]) -> None:
if images:
return
provider_error = data.get("error") if isinstance(data, dict) else None
label = self.provider_name
if provider_error:
raise ImageGenerationError(f"{label} returned no images: {provider_error}")
raise ImageGenerationError(f"{label} returned no images for this request")
async def _http_post(
self,
url: str,
*,
headers: dict[str, str],
body: dict[str, Any],
) -> httpx.Response:
if self._client is not None:
return await self._client.post(url, headers=headers, json=body)
async with httpx.AsyncClient(timeout=self.timeout) as c:
return await c.post(url, headers=headers, json=body)
class OpenRouterImageGenerationClient(ImageGenerationProvider):
"""Small async client for OpenRouter Chat Completions image generation."""
provider_name = "openrouter"
missing_key_message = (
"OpenRouter API key is not configured. Set providers.openrouter.apiKey."
)
def _default_base_url(self) -> str:
return "https://openrouter.ai/api/v1"
async def generate( async def generate(
self, self,
*, *,
@ -154,9 +233,7 @@ class OpenRouterImageGenerationClient:
image_size: str | None = None, image_size: str | None = None,
) -> GeneratedImageResponse: ) -> GeneratedImageResponse:
if not self.api_key: if not self.api_key:
raise ImageGenerationError( raise ImageGenerationError(self.missing_key_message)
"OpenRouter API key is not configured. Set providers.openrouter.apiKey."
)
content: str | list[dict[str, Any]] content: str | list[dict[str, Any]]
references = list(reference_images or []) references = list(reference_images or [])
@ -192,12 +269,7 @@ class OpenRouterImageGenerationClient:
**self.extra_headers, **self.extra_headers,
} }
url = f"{self.api_base}/chat/completions" url = f"{self.api_base}/chat/completions"
response = await self._http_post(url, headers=headers, body=body)
if self._client is not None:
response = await self._client.post(url, headers=headers, json=body)
else:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(url, headers=headers, json=body)
try: try:
response.raise_for_status() response.raise_for_status()
@ -222,11 +294,7 @@ class OpenRouterImageGenerationClient:
if isinstance(url_value, str) and url_value.startswith("data:image/"): if isinstance(url_value, str) and url_value.startswith("data:image/"):
images.append(url_value) images.append(url_value)
if not images: self._require_images(images, data)
provider_error = data.get("error") if isinstance(data, dict) else None
if provider_error:
raise ImageGenerationError(f"OpenRouter returned no images: {provider_error}")
raise ImageGenerationError("OpenRouter returned no images for this request")
return GeneratedImageResponse( return GeneratedImageResponse(
images=images, images=images,
@ -235,29 +303,17 @@ class OpenRouterImageGenerationClient:
) )
class AIHubMixImageGenerationClient: class AIHubMixImageGenerationClient(ImageGenerationProvider):
"""Small async client for AIHubMix unified image generation.""" """Small async client for AIHubMix unified image generation."""
def __init__( provider_name = "aihubmix"
self, missing_key_message = (
*, "AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
api_key: str | None, )
api_base: str | None = None, default_timeout = _AIHUBMIX_TIMEOUT_S
extra_headers: dict[str, str] | None = None,
extra_body: dict[str, Any] | None = None, def _default_base_url(self) -> str:
timeout: float = _AIHUBMIX_TIMEOUT_S, return "https://aihubmix.com/v1"
client: httpx.AsyncClient | None = None,
) -> None:
self.api_key = api_key
self.api_base = _provider_base_url(
"aihubmix",
api_base,
"https://aihubmix.com/v1",
)
self.extra_headers = extra_headers or {}
self.extra_body = extra_body or {}
self.timeout = timeout
self._client = client
async def generate( async def generate(
self, self,
@ -269,9 +325,7 @@ class AIHubMixImageGenerationClient:
image_size: str | None = None, image_size: str | None = None,
) -> GeneratedImageResponse: ) -> GeneratedImageResponse:
if not self.api_key: if not self.api_key:
raise ImageGenerationError( raise ImageGenerationError(self.missing_key_message)
"AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
)
refs = list(reference_images or []) refs = list(reference_images or [])
headers = { headers = {
@ -280,16 +334,8 @@ class AIHubMixImageGenerationClient:
} }
size = _aihubmix_size(aspect_ratio, image_size) size = _aihubmix_size(aspect_ratio, image_size)
if self._client is not None: client = self._client or httpx.AsyncClient(timeout=self.timeout)
return await self._generate_with_client( try:
self._client,
prompt=prompt,
model=model,
reference_images=refs,
size=size,
headers=headers,
)
async with httpx.AsyncClient(timeout=self.timeout) as client:
return await self._generate_with_client( return await self._generate_with_client(
client, client,
prompt=prompt, prompt=prompt,
@ -298,6 +344,9 @@ class AIHubMixImageGenerationClient:
size=size, size=size,
headers=headers, headers=headers,
) )
finally:
if self._client is None:
await client.aclose()
async def _generate_with_client( async def _generate_with_client(
self, self,
@ -346,11 +395,7 @@ class AIHubMixImageGenerationClient:
payload = response.json() payload = response.json()
images = await _aihubmix_images_from_payload(client, payload) images = await _aihubmix_images_from_payload(client, payload)
if not images: self._require_images(images, payload)
provider_error = payload.get("error") if isinstance(payload, dict) else None
if provider_error:
raise ImageGenerationError(f"AIHubMix returned no images: {provider_error}")
raise ImageGenerationError("AIHubMix returned no images for this request")
return GeneratedImageResponse(images=images, content="", raw=payload) return GeneratedImageResponse(images=images, content="", raw=payload)
@ -370,31 +415,25 @@ def _http_error_detail(response: httpx.Response) -> str:
return response.text[:500] or "<empty response body>" return response.text[:500] or "<empty response body>"
class GeminiImageGenerationClient: class GeminiImageGenerationClient(ImageGenerationProvider):
"""Async client for Gemini/Imagen image generation via the Generative Language API.""" """Async client for Gemini/Imagen image generation via the Generative Language API."""
def __init__( provider_name = "gemini"
self, missing_key_message = (
*, "Gemini API key is not configured. Set providers.gemini.apiKey."
api_key: str | None, )
api_base: str | None = None, default_timeout = _GEMINI_DEFAULT_TIMEOUT_S
extra_headers: dict[str, str] | None = None,
extra_body: dict[str, Any] | None = None, def _default_base_url(self) -> str:
timeout: float = _GEMINI_DEFAULT_TIMEOUT_S, return "https://generativelanguage.googleapis.com/v1beta"
client: httpx.AsyncClient | None = None,
) -> None: def _resolve_base_url(self, api_base: str | None) -> str:
self.api_key = api_key
# The Gemini provider's registry default_api_base is the OpenAI-compat # The Gemini provider's registry default_api_base is the OpenAI-compat
# shim (.../v1beta/openai/), which has no image endpoints. Image # shim (.../v1beta/openai/), which has no image endpoints.
# generation needs the native Generative Language API base, so we don't # Skip the registry lookup and use the native API base directly.
# use _provider_base_url() here. if api_base:
self.api_base = ( return api_base.rstrip("/")
api_base or "https://generativelanguage.googleapis.com/v1beta" return self._default_base_url()
).rstrip("/")
self.extra_headers = extra_headers or {}
self.extra_body = extra_body or {}
self.timeout = timeout
self._client = client
async def generate( async def generate(
self, self,
@ -406,9 +445,7 @@ class GeminiImageGenerationClient:
image_size: str | None = None, image_size: str | None = None,
) -> GeneratedImageResponse: ) -> GeneratedImageResponse:
if not self.api_key: if not self.api_key:
raise ImageGenerationError( raise ImageGenerationError(self.missing_key_message)
"Gemini API key is not configured. Set providers.gemini.apiKey."
)
if "imagen" in model.lower(): if "imagen" in model.lower():
if reference_images: if reference_images:
logger.warning( logger.warning(
@ -446,12 +483,7 @@ class GeminiImageGenerationClient:
"Content-Type": "application/json", "Content-Type": "application/json",
**self.extra_headers, **self.extra_headers,
} }
response = await self._http_post(url, headers=headers, body=body)
if self._client is not None:
response = await self._client.post(url, headers=headers, json=body)
else:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(url, headers=headers, json=body)
try: try:
response.raise_for_status() response.raise_for_status()
@ -472,11 +504,7 @@ class GeminiImageGenerationClient:
if isinstance(b64, str) and b64: if isinstance(b64, str) and b64:
images.append(f"data:{mime};base64,{b64}") images.append(f"data:{mime};base64,{b64}")
if not images: self._require_images(images, data)
provider_error = data.get("error") if isinstance(data, dict) else None
if provider_error:
raise ImageGenerationError(f"Gemini Imagen returned no images: {provider_error}")
raise ImageGenerationError("Gemini Imagen returned no images for this request")
return GeneratedImageResponse(images=images, content="", raw=data) return GeneratedImageResponse(images=images, content="", raw=data)
@ -504,12 +532,7 @@ class GeminiImageGenerationClient:
"Content-Type": "application/json", "Content-Type": "application/json",
**self.extra_headers, **self.extra_headers,
} }
response = await self._http_post(url, headers=headers, body=body)
if self._client is not None:
response = await self._client.post(url, headers=headers, json=body)
else:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(url, headers=headers, json=body)
try: try:
response.raise_for_status() response.raise_for_status()
@ -539,11 +562,7 @@ class GeminiImageGenerationClient:
if b64: if b64:
images.append(f"data:{mime};base64,{b64}") images.append(f"data:{mime};base64,{b64}")
if not images: self._require_images(images, data)
provider_error = data.get("error") if isinstance(data, dict) else None
if provider_error:
raise ImageGenerationError(f"Gemini returned no images: {provider_error}")
raise ImageGenerationError("Gemini returned no images for this request")
return GeneratedImageResponse( return GeneratedImageResponse(
images=images, images=images,
@ -620,29 +639,17 @@ _MINIMAX_ASPECT_RATIO_SIZES = {
} }
class MiniMaxImageGenerationClient: class MiniMaxImageGenerationClient(ImageGenerationProvider):
"""Async client for MiniMax image generation API.""" """Async client for MiniMax image generation API."""
def __init__( provider_name = "minimax"
self, missing_key_message = (
*, "MiniMax API key is not configured. Set providers.minimax.apiKey."
api_key: str | None, )
api_base: str | None = None, default_timeout = _MINIMAX_TIMEOUT_S
extra_headers: dict[str, str] | None = None,
extra_body: dict[str, Any] | None = None, def _default_base_url(self) -> str:
timeout: float = _MINIMAX_TIMEOUT_S, return "https://api.minimaxi.com/v1"
client: httpx.AsyncClient | None = None,
) -> None:
self.api_key = api_key
self.api_base = _provider_base_url(
"minimax",
api_base,
"https://api.minimaxi.com/v1",
)
self.extra_headers = extra_headers or {}
self.extra_body = extra_body or {}
self.timeout = timeout
self._client = client
def _resolve_aspect_ratio(self, aspect_ratio: str | None) -> str: def _resolve_aspect_ratio(self, aspect_ratio: str | None) -> str:
if aspect_ratio and aspect_ratio in _MINIMAX_ASPECT_RATIO_SIZES: if aspect_ratio and aspect_ratio in _MINIMAX_ASPECT_RATIO_SIZES:
@ -659,9 +666,7 @@ class MiniMaxImageGenerationClient:
image_size: str | None = None, image_size: str | None = None,
) -> GeneratedImageResponse: ) -> GeneratedImageResponse:
if not self.api_key: if not self.api_key:
raise ImageGenerationError( raise ImageGenerationError(self.missing_key_message)
"MiniMax API key is not configured. Set providers.minimax.apiKey."
)
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
@ -687,10 +692,12 @@ class MiniMaxImageGenerationClient:
body.update(self.extra_body) body.update(self.extra_body)
if self._client is not None: client = self._client or httpx.AsyncClient(timeout=self.timeout)
return await self._generate_with_client(self._client, body, headers) try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
return await self._generate_with_client(client, body, headers) return await self._generate_with_client(client, body, headers)
finally:
if self._client is None:
await client.aclose()
async def _generate_with_client( async def _generate_with_client(
self, self,
@ -715,11 +722,7 @@ class MiniMaxImageGenerationClient:
payload = response.json() payload = response.json()
images = _minimax_images_from_payload(payload) images = _minimax_images_from_payload(payload)
if not images: self._require_images(images, payload)
provider_error = payload.get("error") if isinstance(payload, dict) else None
if provider_error:
raise ImageGenerationError(f"MiniMax returned no images: {provider_error}")
raise ImageGenerationError("MiniMax returned no images for this request")
return GeneratedImageResponse(images=images, content="", raw=payload) return GeneratedImageResponse(images=images, content="", raw=payload)
@ -737,3 +740,13 @@ def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]:
if isinstance(b64, str) and b64: if isinstance(b64, str) and b64:
images.append(_b64_png_data_url(b64)) images.append(_b64_png_data_url(b64))
return images return images
# ---------------------------------------------------------------------------
# Provider registration
# ---------------------------------------------------------------------------
register_image_gen_provider(OpenRouterImageGenerationClient)
register_image_gen_provider(AIHubMixImageGenerationClient)
register_image_gen_provider(GeminiImageGenerationClient)
register_image_gen_provider(MiniMaxImageGenerationClient)

View File

@ -35,8 +35,8 @@ async def test_generated_image_media_is_attached_to_final_assistant_message(
) -> None: ) -> None:
set_config_path(tmp_path / "config.json") set_config_path(tmp_path / "config.json")
monkeypatch.setattr( monkeypatch.setattr(
"nanobot.agent.tools.image_generation.OpenRouterImageGenerationClient", "nanobot.agent.tools.image_generation.get_image_gen_provider",
FakeImageClient, lambda name: FakeImageClient if name == "openrouter" else None,
) )
provider = MagicMock() provider = MagicMock()
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"

View File

@ -44,8 +44,8 @@ async def test_generate_image_tool_stores_artifact_and_source_images(
set_config_path(tmp_path / "config.json") set_config_path(tmp_path / "config.json")
FakeImageClient.instances = [] FakeImageClient.instances = []
monkeypatch.setattr( monkeypatch.setattr(
"nanobot.agent.tools.image_generation.OpenRouterImageGenerationClient", "nanobot.agent.tools.image_generation.get_image_gen_provider",
FakeImageClient, lambda name: FakeImageClient if name == "openrouter" else None,
) )
ref = tmp_path / "ref.png" ref = tmp_path / "ref.png"
ref.write_bytes(PNG_BYTES) ref.write_bytes(PNG_BYTES)
@ -98,8 +98,8 @@ async def test_generate_image_tool_selects_aihubmix_provider(
set_config_path(tmp_path / "config.json") set_config_path(tmp_path / "config.json")
FakeImageClient.instances = [] FakeImageClient.instances = []
monkeypatch.setattr( monkeypatch.setattr(
"nanobot.agent.tools.image_generation.AIHubMixImageGenerationClient", "nanobot.agent.tools.image_generation.get_image_gen_provider",
FakeImageClient, lambda name: FakeImageClient if name == "aihubmix" else None,
) )
tool = ImageGenerationTool( tool = ImageGenerationTool(
workspace=tmp_path, workspace=tmp_path,