refactor(image-generation): introduce provider registry to eliminate manual wiring

Adds ImageGenerationProvider ABC with shared __init__, _http_post(), and
_require_images(). Introduces _IMAGE_GEN_PROVIDERS registry with
register/get/image_gen_provider_configs() helpers.

Four existing providers (OpenRouter, AIHubMix, Gemini, MiniMax) now inherit
from the base class and self-register. Adding a new provider only requires
writing one class + one registration line.

Eliminates if/else chains in the tool dispatch and hardcoded provider config
dicts in commands.py (3 sites) and nanobot.py (1 site). Fixes the agent CLI
command missing image_generation_provider_configs entirely.

Also simplifies test monkeypatch targets to patch the registry lookup.
This commit is contained in:
chengyongru 2026-05-18 17:20:54 +08:00
parent bb788cdb7d
commit 7aa5b9b17b
6 changed files with 188 additions and 200 deletions

View File

@ -17,11 +17,9 @@ from nanobot.agent.tools.schema import (
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base
from nanobot.providers.image_generation import (
AIHubMixImageGenerationClient,
GeminiImageGenerationClient,
ImageGenerationError,
MiniMaxImageGenerationClient,
OpenRouterImageGenerationClient,
ImageGenerationProvider,
get_image_gen_provider,
)
from nanobot.utils.artifacts import (
ArtifactError,
@ -119,37 +117,24 @@ class ImageGenerationTool(Tool):
def _provider_config(self) -> ProviderConfig | None:
return self.provider_configs.get(self.config.provider)
def _provider_client(
self,
) -> OpenRouterImageGenerationClient | AIHubMixImageGenerationClient | MiniMaxImageGenerationClient | GeminiImageGenerationClient | None:
def _provider_client(self) -> ImageGenerationProvider | None:
provider = self._provider_config()
cls = get_image_gen_provider(self.config.provider)
if cls is None:
return None
kwargs = {
"api_key": provider.api_key if provider else None,
"api_base": provider.api_base if provider else None,
"extra_headers": provider.extra_headers if provider else None,
"extra_body": provider.extra_body if provider else None,
}
if self.config.provider == "openrouter":
return OpenRouterImageGenerationClient(**kwargs)
if self.config.provider == "aihubmix":
return AIHubMixImageGenerationClient(**kwargs)
if self.config.provider == "minimax":
return MiniMaxImageGenerationClient(**kwargs)
if self.config.provider == "gemini":
return GeminiImageGenerationClient(**kwargs)
return None
return cls(**kwargs)
def _missing_api_key_error(self) -> str:
provider = self.config.provider
if provider == "openrouter":
return "Error: OpenRouter API key is not configured. Set providers.openrouter.apiKey."
if provider == "aihubmix":
return "Error: AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
if provider == "minimax":
return "Error: MiniMax API key is not configured. Set providers.minimax.apiKey."
if provider == "gemini":
return "Error: Gemini API key is not configured. Set providers.gemini.apiKey."
return f"Error: {provider} API key is not configured."
cls = get_image_gen_provider(self.config.provider)
if cls and cls.missing_key_message:
return f"Error: {cls.missing_key_message}"
return f"Error: {self.config.provider} API key is not configured."
def _resolve_reference_image(self, value: str) -> str:
raw_path = Path(value).expanduser()

View File

@ -620,6 +620,7 @@ def serve(
from nanobot.api.server import create_app
from nanobot.bus.queue import MessageBus
from nanobot.providers.image_generation import image_gen_provider_configs
from nanobot.session.manager import SessionManager
if verbose:
@ -639,12 +640,7 @@ def serve(
agent_loop = AgentLoop.from_config(
runtime_config, bus,
session_manager=session_manager,
image_generation_provider_configs={
"openrouter": runtime_config.providers.openrouter,
"aihubmix": runtime_config.providers.aihubmix,
"minimax": runtime_config.providers.minimax,
"gemini": runtime_config.providers.gemini,
},
image_generation_provider_configs=image_gen_provider_configs(runtime_config),
)
except ValueError as exc:
console.print(f"[red]Error: {exc}[/red]")
@ -724,6 +720,7 @@ def _run_gateway(
from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService
from nanobot.providers.factory import build_provider_snapshot, load_provider_snapshot
from nanobot.providers.image_generation import image_gen_provider_configs
from nanobot.session.manager import SessionManager
port = port if port is not None else config.gateway.port
@ -754,12 +751,7 @@ def _run_gateway(
context_window_tokens=provider_snapshot.context_window_tokens,
cron_service=cron,
session_manager=session_manager,
image_generation_provider_configs={
"openrouter": config.providers.openrouter,
"aihubmix": config.providers.aihubmix,
"minimax": config.providers.minimax,
"gemini": config.providers.gemini,
},
image_generation_provider_configs=image_gen_provider_configs(config),
provider_snapshot_loader=load_provider_snapshot,
runtime_model_publisher=lambda model, preset: publish_runtime_model_update(
bus,
@ -1126,6 +1118,7 @@ def agent(
from nanobot.bus.queue import MessageBus
from nanobot.cron.service import CronService
from nanobot.providers.image_generation import image_gen_provider_configs
config = _load_runtime_config(config, workspace)
sync_workspace_templates(config.workspace_path)
@ -1149,6 +1142,7 @@ def agent(
agent_loop = AgentLoop.from_config(
config, bus,
cron_service=cron,
image_generation_provider_configs=image_gen_provider_configs(config),
)
except ValueError as exc:
console.print(f"[red]Error: {exc}[/red]")

View File

@ -8,6 +8,7 @@ from typing import Any
from nanobot.agent.hook import AgentHook, SDKCaptureHook
from nanobot.agent.loop import AgentLoop
from nanobot.providers.image_generation import image_gen_provider_configs
@dataclass(slots=True)
@ -63,12 +64,7 @@ class Nanobot:
loop = AgentLoop.from_config(
config,
image_generation_provider_configs={
"openrouter": config.providers.openrouter,
"aihubmix": config.providers.aihubmix,
"minimax": config.providers.minimax,
"gemini": config.providers.gemini,
},
image_generation_provider_configs=image_gen_provider_configs(config),
)
return cls(loop)

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import base64
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@ -44,15 +45,6 @@ class GeneratedImageResponse:
raw: dict[str, Any]
def _provider_base_url(provider: str, api_base: str | None, fallback: str) -> str:
if api_base:
return api_base.rstrip("/")
spec = find_by_name(provider)
if spec and spec.default_api_base:
return spec.default_api_base.rstrip("/")
return fallback
def _read_image_b64(path: str | Path) -> tuple[str, str]:
"""Return ``(mime, base64)`` for the image at ``path``."""
p = Path(path).expanduser()
@ -120,8 +112,44 @@ async def _download_image_data_url(
return f"data:{mime};base64,{encoded}"
class OpenRouterImageGenerationClient:
"""Small async client for OpenRouter Chat Completions image generation."""
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
_IMAGE_GEN_PROVIDERS: dict[str, type[ImageGenerationProvider]] = {}
def register_image_gen_provider(cls: type[ImageGenerationProvider]) -> None:
name = cls.provider_name
if not name:
raise ValueError(f"{cls.__name__} must set provider_name")
_IMAGE_GEN_PROVIDERS[name] = cls
def get_image_gen_provider(name: str) -> type[ImageGenerationProvider] | None:
return _IMAGE_GEN_PROVIDERS.get(name)
def image_gen_provider_configs(config: Any) -> dict[str, Any]:
providers_cfg = config.providers
return {
name: pc
for name in _IMAGE_GEN_PROVIDERS
if (pc := getattr(providers_cfg, name, None)) is not None
}
# ---------------------------------------------------------------------------
# Base class
# ---------------------------------------------------------------------------
class ImageGenerationProvider(ABC):
"""Base class for image generation provider clients."""
provider_name: str = ""
missing_key_message: str = ""
default_timeout: float = _DEFAULT_TIMEOUT_S
def __init__(
self,
@ -130,20 +158,71 @@ class OpenRouterImageGenerationClient:
api_base: str | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict[str, Any] | None = None,
timeout: float = _DEFAULT_TIMEOUT_S,
timeout: float | None = None,
client: httpx.AsyncClient | None = None,
) -> None:
self.api_key = api_key
self.api_base = _provider_base_url(
"openrouter",
api_base,
"https://openrouter.ai/api/v1",
)
self.api_base = self._resolve_base_url(api_base)
self.extra_headers = extra_headers or {}
self.extra_body = extra_body or {}
self.timeout = timeout
self.timeout = timeout if timeout is not None else self.default_timeout
self._client = client
def _resolve_base_url(self, api_base: str | None) -> str:
if api_base:
return api_base.rstrip("/")
spec = find_by_name(self.provider_name)
if spec and spec.default_api_base:
return spec.default_api_base.rstrip("/")
return self._default_base_url()
def _default_base_url(self) -> str:
return ""
@abstractmethod
async def generate(
self,
*,
prompt: str,
model: str,
reference_images: list[str] | None = None,
aspect_ratio: str | None = None,
image_size: str | None = None,
) -> GeneratedImageResponse: ...
def _require_images(self, images: list[str], data: dict[str, Any]) -> None:
if images:
return
provider_error = data.get("error") if isinstance(data, dict) else None
label = self.provider_name
if provider_error:
raise ImageGenerationError(f"{label} returned no images: {provider_error}")
raise ImageGenerationError(f"{label} returned no images for this request")
async def _http_post(
self,
url: str,
*,
headers: dict[str, str],
body: dict[str, Any],
) -> httpx.Response:
if self._client is not None:
return await self._client.post(url, headers=headers, json=body)
async with httpx.AsyncClient(timeout=self.timeout) as c:
return await c.post(url, headers=headers, json=body)
class OpenRouterImageGenerationClient(ImageGenerationProvider):
"""Small async client for OpenRouter Chat Completions image generation."""
provider_name = "openrouter"
missing_key_message = (
"OpenRouter API key is not configured. Set providers.openrouter.apiKey."
)
def _default_base_url(self) -> str:
return "https://openrouter.ai/api/v1"
async def generate(
self,
*,
@ -154,9 +233,7 @@ class OpenRouterImageGenerationClient:
image_size: str | None = None,
) -> GeneratedImageResponse:
if not self.api_key:
raise ImageGenerationError(
"OpenRouter API key is not configured. Set providers.openrouter.apiKey."
)
raise ImageGenerationError(self.missing_key_message)
content: str | list[dict[str, Any]]
references = list(reference_images or [])
@ -192,12 +269,7 @@ class OpenRouterImageGenerationClient:
**self.extra_headers,
}
url = f"{self.api_base}/chat/completions"
if self._client is not None:
response = await self._client.post(url, headers=headers, json=body)
else:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(url, headers=headers, json=body)
response = await self._http_post(url, headers=headers, body=body)
try:
response.raise_for_status()
@ -222,11 +294,7 @@ class OpenRouterImageGenerationClient:
if isinstance(url_value, str) and url_value.startswith("data:image/"):
images.append(url_value)
if not images:
provider_error = data.get("error") if isinstance(data, dict) else None
if provider_error:
raise ImageGenerationError(f"OpenRouter returned no images: {provider_error}")
raise ImageGenerationError("OpenRouter returned no images for this request")
self._require_images(images, data)
return GeneratedImageResponse(
images=images,
@ -235,29 +303,17 @@ class OpenRouterImageGenerationClient:
)
class AIHubMixImageGenerationClient:
class AIHubMixImageGenerationClient(ImageGenerationProvider):
"""Small async client for AIHubMix unified image generation."""
def __init__(
self,
*,
api_key: str | None,
api_base: str | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict[str, Any] | None = None,
timeout: float = _AIHUBMIX_TIMEOUT_S,
client: httpx.AsyncClient | None = None,
) -> None:
self.api_key = api_key
self.api_base = _provider_base_url(
"aihubmix",
api_base,
"https://aihubmix.com/v1",
)
self.extra_headers = extra_headers or {}
self.extra_body = extra_body or {}
self.timeout = timeout
self._client = client
provider_name = "aihubmix"
missing_key_message = (
"AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
)
default_timeout = _AIHUBMIX_TIMEOUT_S
def _default_base_url(self) -> str:
return "https://aihubmix.com/v1"
async def generate(
self,
@ -269,9 +325,7 @@ class AIHubMixImageGenerationClient:
image_size: str | None = None,
) -> GeneratedImageResponse:
if not self.api_key:
raise ImageGenerationError(
"AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
)
raise ImageGenerationError(self.missing_key_message)
refs = list(reference_images or [])
headers = {
@ -280,16 +334,8 @@ class AIHubMixImageGenerationClient:
}
size = _aihubmix_size(aspect_ratio, image_size)
if self._client is not None:
return await self._generate_with_client(
self._client,
prompt=prompt,
model=model,
reference_images=refs,
size=size,
headers=headers,
)
async with httpx.AsyncClient(timeout=self.timeout) as client:
client = self._client or httpx.AsyncClient(timeout=self.timeout)
try:
return await self._generate_with_client(
client,
prompt=prompt,
@ -298,6 +344,9 @@ class AIHubMixImageGenerationClient:
size=size,
headers=headers,
)
finally:
if self._client is None:
await client.aclose()
async def _generate_with_client(
self,
@ -346,11 +395,7 @@ class AIHubMixImageGenerationClient:
payload = response.json()
images = await _aihubmix_images_from_payload(client, payload)
if not images:
provider_error = payload.get("error") if isinstance(payload, dict) else None
if provider_error:
raise ImageGenerationError(f"AIHubMix returned no images: {provider_error}")
raise ImageGenerationError("AIHubMix returned no images for this request")
self._require_images(images, payload)
return GeneratedImageResponse(images=images, content="", raw=payload)
@ -370,31 +415,25 @@ def _http_error_detail(response: httpx.Response) -> str:
return response.text[:500] or "<empty response body>"
class GeminiImageGenerationClient:
class GeminiImageGenerationClient(ImageGenerationProvider):
"""Async client for Gemini/Imagen image generation via the Generative Language API."""
def __init__(
self,
*,
api_key: str | None,
api_base: str | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict[str, Any] | None = None,
timeout: float = _GEMINI_DEFAULT_TIMEOUT_S,
client: httpx.AsyncClient | None = None,
) -> None:
self.api_key = api_key
provider_name = "gemini"
missing_key_message = (
"Gemini API key is not configured. Set providers.gemini.apiKey."
)
default_timeout = _GEMINI_DEFAULT_TIMEOUT_S
def _default_base_url(self) -> str:
return "https://generativelanguage.googleapis.com/v1beta"
def _resolve_base_url(self, api_base: str | None) -> str:
# The Gemini provider's registry default_api_base is the OpenAI-compat
# shim (.../v1beta/openai/), which has no image endpoints. Image
# generation needs the native Generative Language API base, so we don't
# use _provider_base_url() here.
self.api_base = (
api_base or "https://generativelanguage.googleapis.com/v1beta"
).rstrip("/")
self.extra_headers = extra_headers or {}
self.extra_body = extra_body or {}
self.timeout = timeout
self._client = client
# shim (.../v1beta/openai/), which has no image endpoints.
# Skip the registry lookup and use the native API base directly.
if api_base:
return api_base.rstrip("/")
return self._default_base_url()
async def generate(
self,
@ -406,9 +445,7 @@ class GeminiImageGenerationClient:
image_size: str | None = None,
) -> GeneratedImageResponse:
if not self.api_key:
raise ImageGenerationError(
"Gemini API key is not configured. Set providers.gemini.apiKey."
)
raise ImageGenerationError(self.missing_key_message)
if "imagen" in model.lower():
if reference_images:
logger.warning(
@ -446,12 +483,7 @@ class GeminiImageGenerationClient:
"Content-Type": "application/json",
**self.extra_headers,
}
if self._client is not None:
response = await self._client.post(url, headers=headers, json=body)
else:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(url, headers=headers, json=body)
response = await self._http_post(url, headers=headers, body=body)
try:
response.raise_for_status()
@ -472,11 +504,7 @@ class GeminiImageGenerationClient:
if isinstance(b64, str) and b64:
images.append(f"data:{mime};base64,{b64}")
if not images:
provider_error = data.get("error") if isinstance(data, dict) else None
if provider_error:
raise ImageGenerationError(f"Gemini Imagen returned no images: {provider_error}")
raise ImageGenerationError("Gemini Imagen returned no images for this request")
self._require_images(images, data)
return GeneratedImageResponse(images=images, content="", raw=data)
@ -504,12 +532,7 @@ class GeminiImageGenerationClient:
"Content-Type": "application/json",
**self.extra_headers,
}
if self._client is not None:
response = await self._client.post(url, headers=headers, json=body)
else:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(url, headers=headers, json=body)
response = await self._http_post(url, headers=headers, body=body)
try:
response.raise_for_status()
@ -539,11 +562,7 @@ class GeminiImageGenerationClient:
if b64:
images.append(f"data:{mime};base64,{b64}")
if not images:
provider_error = data.get("error") if isinstance(data, dict) else None
if provider_error:
raise ImageGenerationError(f"Gemini returned no images: {provider_error}")
raise ImageGenerationError("Gemini returned no images for this request")
self._require_images(images, data)
return GeneratedImageResponse(
images=images,
@ -620,29 +639,17 @@ _MINIMAX_ASPECT_RATIO_SIZES = {
}
class MiniMaxImageGenerationClient:
class MiniMaxImageGenerationClient(ImageGenerationProvider):
"""Async client for MiniMax image generation API."""
def __init__(
self,
*,
api_key: str | None,
api_base: str | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict[str, Any] | None = None,
timeout: float = _MINIMAX_TIMEOUT_S,
client: httpx.AsyncClient | None = None,
) -> None:
self.api_key = api_key
self.api_base = _provider_base_url(
"minimax",
api_base,
"https://api.minimaxi.com/v1",
)
self.extra_headers = extra_headers or {}
self.extra_body = extra_body or {}
self.timeout = timeout
self._client = client
provider_name = "minimax"
missing_key_message = (
"MiniMax API key is not configured. Set providers.minimax.apiKey."
)
default_timeout = _MINIMAX_TIMEOUT_S
def _default_base_url(self) -> str:
return "https://api.minimaxi.com/v1"
def _resolve_aspect_ratio(self, aspect_ratio: str | None) -> str:
if aspect_ratio and aspect_ratio in _MINIMAX_ASPECT_RATIO_SIZES:
@ -659,9 +666,7 @@ class MiniMaxImageGenerationClient:
image_size: str | None = None,
) -> GeneratedImageResponse:
if not self.api_key:
raise ImageGenerationError(
"MiniMax API key is not configured. Set providers.minimax.apiKey."
)
raise ImageGenerationError(self.missing_key_message)
headers = {
"Authorization": f"Bearer {self.api_key}",
@ -687,10 +692,12 @@ class MiniMaxImageGenerationClient:
body.update(self.extra_body)
if self._client is not None:
return await self._generate_with_client(self._client, body, headers)
async with httpx.AsyncClient(timeout=self.timeout) as client:
client = self._client or httpx.AsyncClient(timeout=self.timeout)
try:
return await self._generate_with_client(client, body, headers)
finally:
if self._client is None:
await client.aclose()
async def _generate_with_client(
self,
@ -715,11 +722,7 @@ class MiniMaxImageGenerationClient:
payload = response.json()
images = _minimax_images_from_payload(payload)
if not images:
provider_error = payload.get("error") if isinstance(payload, dict) else None
if provider_error:
raise ImageGenerationError(f"MiniMax returned no images: {provider_error}")
raise ImageGenerationError("MiniMax returned no images for this request")
self._require_images(images, payload)
return GeneratedImageResponse(images=images, content="", raw=payload)
@ -737,3 +740,13 @@ def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]:
if isinstance(b64, str) and b64:
images.append(_b64_png_data_url(b64))
return images
# ---------------------------------------------------------------------------
# Provider registration
# ---------------------------------------------------------------------------
register_image_gen_provider(OpenRouterImageGenerationClient)
register_image_gen_provider(AIHubMixImageGenerationClient)
register_image_gen_provider(GeminiImageGenerationClient)
register_image_gen_provider(MiniMaxImageGenerationClient)

View File

@ -35,8 +35,8 @@ async def test_generated_image_media_is_attached_to_final_assistant_message(
) -> None:
set_config_path(tmp_path / "config.json")
monkeypatch.setattr(
"nanobot.agent.tools.image_generation.OpenRouterImageGenerationClient",
FakeImageClient,
"nanobot.agent.tools.image_generation.get_image_gen_provider",
lambda name: FakeImageClient if name == "openrouter" else None,
)
provider = MagicMock()
provider.get_default_model.return_value = "test-model"

View File

@ -44,8 +44,8 @@ async def test_generate_image_tool_stores_artifact_and_source_images(
set_config_path(tmp_path / "config.json")
FakeImageClient.instances = []
monkeypatch.setattr(
"nanobot.agent.tools.image_generation.OpenRouterImageGenerationClient",
FakeImageClient,
"nanobot.agent.tools.image_generation.get_image_gen_provider",
lambda name: FakeImageClient if name == "openrouter" else None,
)
ref = tmp_path / "ref.png"
ref.write_bytes(PNG_BYTES)
@ -98,8 +98,8 @@ async def test_generate_image_tool_selects_aihubmix_provider(
set_config_path(tmp_path / "config.json")
FakeImageClient.instances = []
monkeypatch.setattr(
"nanobot.agent.tools.image_generation.AIHubMixImageGenerationClient",
FakeImageClient,
"nanobot.agent.tools.image_generation.get_image_gen_provider",
lambda name: FakeImageClient if name == "aihubmix" else None,
)
tool = ImageGenerationTool(
workspace=tmp_path,