diff --git a/nanobot/agent/tools/image_generation.py b/nanobot/agent/tools/image_generation.py index 3dec8eb92..f2f599ded 100644 --- a/nanobot/agent/tools/image_generation.py +++ b/nanobot/agent/tools/image_generation.py @@ -17,11 +17,9 @@ from nanobot.agent.tools.schema import ( from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base from nanobot.providers.image_generation import ( - AIHubMixImageGenerationClient, - GeminiImageGenerationClient, ImageGenerationError, - MiniMaxImageGenerationClient, - OpenRouterImageGenerationClient, + ImageGenerationProvider, + get_image_gen_provider, ) from nanobot.utils.artifacts import ( ArtifactError, @@ -119,37 +117,24 @@ class ImageGenerationTool(Tool): def _provider_config(self) -> ProviderConfig | None: return self.provider_configs.get(self.config.provider) - def _provider_client( - self, - ) -> OpenRouterImageGenerationClient | AIHubMixImageGenerationClient | MiniMaxImageGenerationClient | GeminiImageGenerationClient | None: + def _provider_client(self) -> ImageGenerationProvider | None: provider = self._provider_config() + cls = get_image_gen_provider(self.config.provider) + if cls is None: + return None kwargs = { "api_key": provider.api_key if provider else None, "api_base": provider.api_base if provider else None, "extra_headers": provider.extra_headers if provider else None, "extra_body": provider.extra_body if provider else None, } - if self.config.provider == "openrouter": - return OpenRouterImageGenerationClient(**kwargs) - if self.config.provider == "aihubmix": - return AIHubMixImageGenerationClient(**kwargs) - if self.config.provider == "minimax": - return MiniMaxImageGenerationClient(**kwargs) - if self.config.provider == "gemini": - return GeminiImageGenerationClient(**kwargs) - return None + return cls(**kwargs) def _missing_api_key_error(self) -> str: - provider = self.config.provider - if provider == "openrouter": - return "Error: OpenRouter API key is not configured. Set providers.openrouter.apiKey." - if provider == "aihubmix": - return "Error: AIHubMix API key is not configured. Set providers.aihubmix.apiKey." - if provider == "minimax": - return "Error: MiniMax API key is not configured. Set providers.minimax.apiKey." - if provider == "gemini": - return "Error: Gemini API key is not configured. Set providers.gemini.apiKey." - return f"Error: {provider} API key is not configured." + cls = get_image_gen_provider(self.config.provider) + if cls and cls.missing_key_message: + return f"Error: {cls.missing_key_message}" + return f"Error: {self.config.provider} API key is not configured." def _resolve_reference_image(self, value: str) -> str: raw_path = Path(value).expanduser() diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index cedc03bd0..f7bf043a4 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -620,6 +620,7 @@ def serve( from nanobot.api.server import create_app from nanobot.bus.queue import MessageBus + from nanobot.providers.image_generation import image_gen_provider_configs from nanobot.session.manager import SessionManager if verbose: @@ -639,12 +640,7 @@ def serve( agent_loop = AgentLoop.from_config( runtime_config, bus, session_manager=session_manager, - image_generation_provider_configs={ - "openrouter": runtime_config.providers.openrouter, - "aihubmix": runtime_config.providers.aihubmix, - "minimax": runtime_config.providers.minimax, - "gemini": runtime_config.providers.gemini, - }, + image_generation_provider_configs=image_gen_provider_configs(runtime_config), ) except ValueError as exc: console.print(f"[red]Error: {exc}[/red]") @@ -724,6 +720,7 @@ def _run_gateway( from nanobot.cron.types import CronJob from nanobot.heartbeat.service import HeartbeatService from nanobot.providers.factory import build_provider_snapshot, load_provider_snapshot + from nanobot.providers.image_generation import image_gen_provider_configs from nanobot.session.manager import SessionManager port = port if port is not None else config.gateway.port @@ -754,12 +751,7 @@ def _run_gateway( context_window_tokens=provider_snapshot.context_window_tokens, cron_service=cron, session_manager=session_manager, - image_generation_provider_configs={ - "openrouter": config.providers.openrouter, - "aihubmix": config.providers.aihubmix, - "minimax": config.providers.minimax, - "gemini": config.providers.gemini, - }, + image_generation_provider_configs=image_gen_provider_configs(config), provider_snapshot_loader=load_provider_snapshot, runtime_model_publisher=lambda model, preset: publish_runtime_model_update( bus, @@ -1126,6 +1118,7 @@ def agent( from nanobot.bus.queue import MessageBus from nanobot.cron.service import CronService + from nanobot.providers.image_generation import image_gen_provider_configs config = _load_runtime_config(config, workspace) sync_workspace_templates(config.workspace_path) @@ -1149,6 +1142,7 @@ def agent( agent_loop = AgentLoop.from_config( config, bus, cron_service=cron, + image_generation_provider_configs=image_gen_provider_configs(config), ) except ValueError as exc: console.print(f"[red]Error: {exc}[/red]") diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index 527f81b16..95185ba47 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -8,6 +8,7 @@ from typing import Any from nanobot.agent.hook import AgentHook, SDKCaptureHook from nanobot.agent.loop import AgentLoop +from nanobot.providers.image_generation import image_gen_provider_configs @dataclass(slots=True) @@ -63,12 +64,7 @@ class Nanobot: loop = AgentLoop.from_config( config, - image_generation_provider_configs={ - "openrouter": config.providers.openrouter, - "aihubmix": config.providers.aihubmix, - "minimax": config.providers.minimax, - "gemini": config.providers.gemini, - }, + image_generation_provider_configs=image_gen_provider_configs(config), ) return cls(loop) diff --git a/nanobot/providers/image_generation.py b/nanobot/providers/image_generation.py index 1b0c5189d..070623798 100644 --- a/nanobot/providers/image_generation.py +++ b/nanobot/providers/image_generation.py @@ -3,6 +3,7 @@ from __future__ import annotations import base64 +from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path from typing import Any @@ -44,15 +45,6 @@ class GeneratedImageResponse: raw: dict[str, Any] -def _provider_base_url(provider: str, api_base: str | None, fallback: str) -> str: - if api_base: - return api_base.rstrip("/") - spec = find_by_name(provider) - if spec and spec.default_api_base: - return spec.default_api_base.rstrip("/") - return fallback - - def _read_image_b64(path: str | Path) -> tuple[str, str]: """Return ``(mime, base64)`` for the image at ``path``.""" p = Path(path).expanduser() @@ -120,8 +112,44 @@ async def _download_image_data_url( return f"data:{mime};base64,{encoded}" -class OpenRouterImageGenerationClient: - """Small async client for OpenRouter Chat Completions image generation.""" +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + +_IMAGE_GEN_PROVIDERS: dict[str, type[ImageGenerationProvider]] = {} + + +def register_image_gen_provider(cls: type[ImageGenerationProvider]) -> None: + name = cls.provider_name + if not name: + raise ValueError(f"{cls.__name__} must set provider_name") + _IMAGE_GEN_PROVIDERS[name] = cls + + +def get_image_gen_provider(name: str) -> type[ImageGenerationProvider] | None: + return _IMAGE_GEN_PROVIDERS.get(name) + + +def image_gen_provider_configs(config: Any) -> dict[str, Any]: + providers_cfg = config.providers + return { + name: pc + for name in _IMAGE_GEN_PROVIDERS + if (pc := getattr(providers_cfg, name, None)) is not None + } + + +# --------------------------------------------------------------------------- +# Base class +# --------------------------------------------------------------------------- + + +class ImageGenerationProvider(ABC): + """Base class for image generation provider clients.""" + + provider_name: str = "" + missing_key_message: str = "" + default_timeout: float = _DEFAULT_TIMEOUT_S def __init__( self, @@ -130,20 +158,71 @@ class OpenRouterImageGenerationClient: api_base: str | None = None, extra_headers: dict[str, str] | None = None, extra_body: dict[str, Any] | None = None, - timeout: float = _DEFAULT_TIMEOUT_S, + timeout: float | None = None, client: httpx.AsyncClient | None = None, ) -> None: self.api_key = api_key - self.api_base = _provider_base_url( - "openrouter", - api_base, - "https://openrouter.ai/api/v1", - ) + self.api_base = self._resolve_base_url(api_base) self.extra_headers = extra_headers or {} self.extra_body = extra_body or {} - self.timeout = timeout + self.timeout = timeout if timeout is not None else self.default_timeout self._client = client + def _resolve_base_url(self, api_base: str | None) -> str: + if api_base: + return api_base.rstrip("/") + spec = find_by_name(self.provider_name) + if spec and spec.default_api_base: + return spec.default_api_base.rstrip("/") + return self._default_base_url() + + def _default_base_url(self) -> str: + return "" + + @abstractmethod + async def generate( + self, + *, + prompt: str, + model: str, + reference_images: list[str] | None = None, + aspect_ratio: str | None = None, + image_size: str | None = None, + ) -> GeneratedImageResponse: ... + + def _require_images(self, images: list[str], data: dict[str, Any]) -> None: + if images: + return + provider_error = data.get("error") if isinstance(data, dict) else None + label = self.provider_name + if provider_error: + raise ImageGenerationError(f"{label} returned no images: {provider_error}") + raise ImageGenerationError(f"{label} returned no images for this request") + + async def _http_post( + self, + url: str, + *, + headers: dict[str, str], + body: dict[str, Any], + ) -> httpx.Response: + if self._client is not None: + return await self._client.post(url, headers=headers, json=body) + async with httpx.AsyncClient(timeout=self.timeout) as c: + return await c.post(url, headers=headers, json=body) + + +class OpenRouterImageGenerationClient(ImageGenerationProvider): + """Small async client for OpenRouter Chat Completions image generation.""" + + provider_name = "openrouter" + missing_key_message = ( + "OpenRouter API key is not configured. Set providers.openrouter.apiKey." + ) + + def _default_base_url(self) -> str: + return "https://openrouter.ai/api/v1" + async def generate( self, *, @@ -154,9 +233,7 @@ class OpenRouterImageGenerationClient: image_size: str | None = None, ) -> GeneratedImageResponse: if not self.api_key: - raise ImageGenerationError( - "OpenRouter API key is not configured. Set providers.openrouter.apiKey." - ) + raise ImageGenerationError(self.missing_key_message) content: str | list[dict[str, Any]] references = list(reference_images or []) @@ -192,12 +269,7 @@ class OpenRouterImageGenerationClient: **self.extra_headers, } url = f"{self.api_base}/chat/completions" - - if self._client is not None: - response = await self._client.post(url, headers=headers, json=body) - else: - async with httpx.AsyncClient(timeout=self.timeout) as client: - response = await client.post(url, headers=headers, json=body) + response = await self._http_post(url, headers=headers, body=body) try: response.raise_for_status() @@ -222,11 +294,7 @@ class OpenRouterImageGenerationClient: if isinstance(url_value, str) and url_value.startswith("data:image/"): images.append(url_value) - if not images: - provider_error = data.get("error") if isinstance(data, dict) else None - if provider_error: - raise ImageGenerationError(f"OpenRouter returned no images: {provider_error}") - raise ImageGenerationError("OpenRouter returned no images for this request") + self._require_images(images, data) return GeneratedImageResponse( images=images, @@ -235,29 +303,17 @@ class OpenRouterImageGenerationClient: ) -class AIHubMixImageGenerationClient: +class AIHubMixImageGenerationClient(ImageGenerationProvider): """Small async client for AIHubMix unified image generation.""" - def __init__( - self, - *, - api_key: str | None, - api_base: str | None = None, - extra_headers: dict[str, str] | None = None, - extra_body: dict[str, Any] | None = None, - timeout: float = _AIHUBMIX_TIMEOUT_S, - client: httpx.AsyncClient | None = None, - ) -> None: - self.api_key = api_key - self.api_base = _provider_base_url( - "aihubmix", - api_base, - "https://aihubmix.com/v1", - ) - self.extra_headers = extra_headers or {} - self.extra_body = extra_body or {} - self.timeout = timeout - self._client = client + provider_name = "aihubmix" + missing_key_message = ( + "AIHubMix API key is not configured. Set providers.aihubmix.apiKey." + ) + default_timeout = _AIHUBMIX_TIMEOUT_S + + def _default_base_url(self) -> str: + return "https://aihubmix.com/v1" async def generate( self, @@ -269,9 +325,7 @@ class AIHubMixImageGenerationClient: image_size: str | None = None, ) -> GeneratedImageResponse: if not self.api_key: - raise ImageGenerationError( - "AIHubMix API key is not configured. Set providers.aihubmix.apiKey." - ) + raise ImageGenerationError(self.missing_key_message) refs = list(reference_images or []) headers = { @@ -280,16 +334,8 @@ class AIHubMixImageGenerationClient: } size = _aihubmix_size(aspect_ratio, image_size) - if self._client is not None: - return await self._generate_with_client( - self._client, - prompt=prompt, - model=model, - reference_images=refs, - size=size, - headers=headers, - ) - async with httpx.AsyncClient(timeout=self.timeout) as client: + client = self._client or httpx.AsyncClient(timeout=self.timeout) + try: return await self._generate_with_client( client, prompt=prompt, @@ -298,6 +344,9 @@ class AIHubMixImageGenerationClient: size=size, headers=headers, ) + finally: + if self._client is None: + await client.aclose() async def _generate_with_client( self, @@ -346,11 +395,7 @@ class AIHubMixImageGenerationClient: payload = response.json() images = await _aihubmix_images_from_payload(client, payload) - if not images: - provider_error = payload.get("error") if isinstance(payload, dict) else None - if provider_error: - raise ImageGenerationError(f"AIHubMix returned no images: {provider_error}") - raise ImageGenerationError("AIHubMix returned no images for this request") + self._require_images(images, payload) return GeneratedImageResponse(images=images, content="", raw=payload) @@ -370,31 +415,25 @@ def _http_error_detail(response: httpx.Response) -> str: return response.text[:500] or "" -class GeminiImageGenerationClient: +class GeminiImageGenerationClient(ImageGenerationProvider): """Async client for Gemini/Imagen image generation via the Generative Language API.""" - def __init__( - self, - *, - api_key: str | None, - api_base: str | None = None, - extra_headers: dict[str, str] | None = None, - extra_body: dict[str, Any] | None = None, - timeout: float = _GEMINI_DEFAULT_TIMEOUT_S, - client: httpx.AsyncClient | None = None, - ) -> None: - self.api_key = api_key + provider_name = "gemini" + missing_key_message = ( + "Gemini API key is not configured. Set providers.gemini.apiKey." + ) + default_timeout = _GEMINI_DEFAULT_TIMEOUT_S + + def _default_base_url(self) -> str: + return "https://generativelanguage.googleapis.com/v1beta" + + def _resolve_base_url(self, api_base: str | None) -> str: # The Gemini provider's registry default_api_base is the OpenAI-compat - # shim (.../v1beta/openai/), which has no image endpoints. Image - # generation needs the native Generative Language API base, so we don't - # use _provider_base_url() here. - self.api_base = ( - api_base or "https://generativelanguage.googleapis.com/v1beta" - ).rstrip("/") - self.extra_headers = extra_headers or {} - self.extra_body = extra_body or {} - self.timeout = timeout - self._client = client + # shim (.../v1beta/openai/), which has no image endpoints. + # Skip the registry lookup and use the native API base directly. + if api_base: + return api_base.rstrip("/") + return self._default_base_url() async def generate( self, @@ -406,9 +445,7 @@ class GeminiImageGenerationClient: image_size: str | None = None, ) -> GeneratedImageResponse: if not self.api_key: - raise ImageGenerationError( - "Gemini API key is not configured. Set providers.gemini.apiKey." - ) + raise ImageGenerationError(self.missing_key_message) if "imagen" in model.lower(): if reference_images: logger.warning( @@ -446,12 +483,7 @@ class GeminiImageGenerationClient: "Content-Type": "application/json", **self.extra_headers, } - - if self._client is not None: - response = await self._client.post(url, headers=headers, json=body) - else: - async with httpx.AsyncClient(timeout=self.timeout) as client: - response = await client.post(url, headers=headers, json=body) + response = await self._http_post(url, headers=headers, body=body) try: response.raise_for_status() @@ -472,11 +504,7 @@ class GeminiImageGenerationClient: if isinstance(b64, str) and b64: images.append(f"data:{mime};base64,{b64}") - if not images: - provider_error = data.get("error") if isinstance(data, dict) else None - if provider_error: - raise ImageGenerationError(f"Gemini Imagen returned no images: {provider_error}") - raise ImageGenerationError("Gemini Imagen returned no images for this request") + self._require_images(images, data) return GeneratedImageResponse(images=images, content="", raw=data) @@ -504,12 +532,7 @@ class GeminiImageGenerationClient: "Content-Type": "application/json", **self.extra_headers, } - - if self._client is not None: - response = await self._client.post(url, headers=headers, json=body) - else: - async with httpx.AsyncClient(timeout=self.timeout) as client: - response = await client.post(url, headers=headers, json=body) + response = await self._http_post(url, headers=headers, body=body) try: response.raise_for_status() @@ -539,11 +562,7 @@ class GeminiImageGenerationClient: if b64: images.append(f"data:{mime};base64,{b64}") - if not images: - provider_error = data.get("error") if isinstance(data, dict) else None - if provider_error: - raise ImageGenerationError(f"Gemini returned no images: {provider_error}") - raise ImageGenerationError("Gemini returned no images for this request") + self._require_images(images, data) return GeneratedImageResponse( images=images, @@ -620,29 +639,17 @@ _MINIMAX_ASPECT_RATIO_SIZES = { } -class MiniMaxImageGenerationClient: +class MiniMaxImageGenerationClient(ImageGenerationProvider): """Async client for MiniMax image generation API.""" - def __init__( - self, - *, - api_key: str | None, - api_base: str | None = None, - extra_headers: dict[str, str] | None = None, - extra_body: dict[str, Any] | None = None, - timeout: float = _MINIMAX_TIMEOUT_S, - client: httpx.AsyncClient | None = None, - ) -> None: - self.api_key = api_key - self.api_base = _provider_base_url( - "minimax", - api_base, - "https://api.minimaxi.com/v1", - ) - self.extra_headers = extra_headers or {} - self.extra_body = extra_body or {} - self.timeout = timeout - self._client = client + provider_name = "minimax" + missing_key_message = ( + "MiniMax API key is not configured. Set providers.minimax.apiKey." + ) + default_timeout = _MINIMAX_TIMEOUT_S + + def _default_base_url(self) -> str: + return "https://api.minimaxi.com/v1" def _resolve_aspect_ratio(self, aspect_ratio: str | None) -> str: if aspect_ratio and aspect_ratio in _MINIMAX_ASPECT_RATIO_SIZES: @@ -659,9 +666,7 @@ class MiniMaxImageGenerationClient: image_size: str | None = None, ) -> GeneratedImageResponse: if not self.api_key: - raise ImageGenerationError( - "MiniMax API key is not configured. Set providers.minimax.apiKey." - ) + raise ImageGenerationError(self.missing_key_message) headers = { "Authorization": f"Bearer {self.api_key}", @@ -687,10 +692,12 @@ class MiniMaxImageGenerationClient: body.update(self.extra_body) - if self._client is not None: - return await self._generate_with_client(self._client, body, headers) - async with httpx.AsyncClient(timeout=self.timeout) as client: + client = self._client or httpx.AsyncClient(timeout=self.timeout) + try: return await self._generate_with_client(client, body, headers) + finally: + if self._client is None: + await client.aclose() async def _generate_with_client( self, @@ -715,11 +722,7 @@ class MiniMaxImageGenerationClient: payload = response.json() images = _minimax_images_from_payload(payload) - if not images: - provider_error = payload.get("error") if isinstance(payload, dict) else None - if provider_error: - raise ImageGenerationError(f"MiniMax returned no images: {provider_error}") - raise ImageGenerationError("MiniMax returned no images for this request") + self._require_images(images, payload) return GeneratedImageResponse(images=images, content="", raw=payload) @@ -737,3 +740,13 @@ def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]: if isinstance(b64, str) and b64: images.append(_b64_png_data_url(b64)) return images + + +# --------------------------------------------------------------------------- +# Provider registration +# --------------------------------------------------------------------------- + +register_image_gen_provider(OpenRouterImageGenerationClient) +register_image_gen_provider(AIHubMixImageGenerationClient) +register_image_gen_provider(GeminiImageGenerationClient) +register_image_gen_provider(MiniMaxImageGenerationClient) diff --git a/tests/agent/test_loop_image_generation_media.py b/tests/agent/test_loop_image_generation_media.py index 6c10ecb1c..73904be93 100644 --- a/tests/agent/test_loop_image_generation_media.py +++ b/tests/agent/test_loop_image_generation_media.py @@ -35,8 +35,8 @@ async def test_generated_image_media_is_attached_to_final_assistant_message( ) -> None: set_config_path(tmp_path / "config.json") monkeypatch.setattr( - "nanobot.agent.tools.image_generation.OpenRouterImageGenerationClient", - FakeImageClient, + "nanobot.agent.tools.image_generation.get_image_gen_provider", + lambda name: FakeImageClient if name == "openrouter" else None, ) provider = MagicMock() provider.get_default_model.return_value = "test-model" diff --git a/tests/tools/test_image_generation_tool.py b/tests/tools/test_image_generation_tool.py index 2afdbdff2..92ed8a339 100644 --- a/tests/tools/test_image_generation_tool.py +++ b/tests/tools/test_image_generation_tool.py @@ -44,8 +44,8 @@ async def test_generate_image_tool_stores_artifact_and_source_images( set_config_path(tmp_path / "config.json") FakeImageClient.instances = [] monkeypatch.setattr( - "nanobot.agent.tools.image_generation.OpenRouterImageGenerationClient", - FakeImageClient, + "nanobot.agent.tools.image_generation.get_image_gen_provider", + lambda name: FakeImageClient if name == "openrouter" else None, ) ref = tmp_path / "ref.png" ref.write_bytes(PNG_BYTES) @@ -98,8 +98,8 @@ async def test_generate_image_tool_selects_aihubmix_provider( set_config_path(tmp_path / "config.json") FakeImageClient.instances = [] monkeypatch.setattr( - "nanobot.agent.tools.image_generation.AIHubMixImageGenerationClient", - FakeImageClient, + "nanobot.agent.tools.image_generation.get_image_gen_provider", + lambda name: FakeImageClient if name == "aihubmix" else None, ) tool = ImageGenerationTool( workspace=tmp_path,