mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
refactor(image-generation): introduce provider registry to eliminate manual wiring
Adds ImageGenerationProvider ABC with shared __init__, _http_post(), and _require_images(). Introduces _IMAGE_GEN_PROVIDERS registry with register/get/image_gen_provider_configs() helpers. Four existing providers (OpenRouter, AIHubMix, Gemini, MiniMax) now inherit from the base class and self-register. Adding a new provider only requires writing one class + one registration line. Eliminates if/else chains in the tool dispatch and hardcoded provider config dicts in commands.py (3 sites) and nanobot.py (1 site). Fixes the agent CLI command missing image_generation_provider_configs entirely. Also simplifies test monkeypatch targets to patch the registry lookup.
This commit is contained in:
parent
7367741ac1
commit
c588d56a77
@ -17,11 +17,9 @@ from nanobot.agent.tools.schema import (
|
|||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
from nanobot.config.schema import Base
|
from nanobot.config.schema import Base
|
||||||
from nanobot.providers.image_generation import (
|
from nanobot.providers.image_generation import (
|
||||||
AIHubMixImageGenerationClient,
|
|
||||||
GeminiImageGenerationClient,
|
|
||||||
ImageGenerationError,
|
ImageGenerationError,
|
||||||
MiniMaxImageGenerationClient,
|
ImageGenerationProvider,
|
||||||
OpenRouterImageGenerationClient,
|
get_image_gen_provider,
|
||||||
)
|
)
|
||||||
from nanobot.utils.artifacts import (
|
from nanobot.utils.artifacts import (
|
||||||
ArtifactError,
|
ArtifactError,
|
||||||
@ -119,37 +117,24 @@ class ImageGenerationTool(Tool):
|
|||||||
def _provider_config(self) -> ProviderConfig | None:
|
def _provider_config(self) -> ProviderConfig | None:
|
||||||
return self.provider_configs.get(self.config.provider)
|
return self.provider_configs.get(self.config.provider)
|
||||||
|
|
||||||
def _provider_client(
|
def _provider_client(self) -> ImageGenerationProvider | None:
|
||||||
self,
|
|
||||||
) -> OpenRouterImageGenerationClient | AIHubMixImageGenerationClient | MiniMaxImageGenerationClient | GeminiImageGenerationClient | None:
|
|
||||||
provider = self._provider_config()
|
provider = self._provider_config()
|
||||||
|
cls = get_image_gen_provider(self.config.provider)
|
||||||
|
if cls is None:
|
||||||
|
return None
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"api_key": provider.api_key if provider else None,
|
"api_key": provider.api_key if provider else None,
|
||||||
"api_base": provider.api_base if provider else None,
|
"api_base": provider.api_base if provider else None,
|
||||||
"extra_headers": provider.extra_headers if provider else None,
|
"extra_headers": provider.extra_headers if provider else None,
|
||||||
"extra_body": provider.extra_body if provider else None,
|
"extra_body": provider.extra_body if provider else None,
|
||||||
}
|
}
|
||||||
if self.config.provider == "openrouter":
|
return cls(**kwargs)
|
||||||
return OpenRouterImageGenerationClient(**kwargs)
|
|
||||||
if self.config.provider == "aihubmix":
|
|
||||||
return AIHubMixImageGenerationClient(**kwargs)
|
|
||||||
if self.config.provider == "minimax":
|
|
||||||
return MiniMaxImageGenerationClient(**kwargs)
|
|
||||||
if self.config.provider == "gemini":
|
|
||||||
return GeminiImageGenerationClient(**kwargs)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _missing_api_key_error(self) -> str:
|
def _missing_api_key_error(self) -> str:
|
||||||
provider = self.config.provider
|
cls = get_image_gen_provider(self.config.provider)
|
||||||
if provider == "openrouter":
|
if cls and cls.missing_key_message:
|
||||||
return "Error: OpenRouter API key is not configured. Set providers.openrouter.apiKey."
|
return f"Error: {cls.missing_key_message}"
|
||||||
if provider == "aihubmix":
|
return f"Error: {self.config.provider} API key is not configured."
|
||||||
return "Error: AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
|
|
||||||
if provider == "minimax":
|
|
||||||
return "Error: MiniMax API key is not configured. Set providers.minimax.apiKey."
|
|
||||||
if provider == "gemini":
|
|
||||||
return "Error: Gemini API key is not configured. Set providers.gemini.apiKey."
|
|
||||||
return f"Error: {provider} API key is not configured."
|
|
||||||
|
|
||||||
def _resolve_reference_image(self, value: str) -> str:
|
def _resolve_reference_image(self, value: str) -> str:
|
||||||
raw_path = Path(value).expanduser()
|
raw_path = Path(value).expanduser()
|
||||||
|
|||||||
@ -620,6 +620,7 @@ def serve(
|
|||||||
|
|
||||||
from nanobot.api.server import create_app
|
from nanobot.api.server import create_app
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.image_generation import image_gen_provider_configs
|
||||||
from nanobot.session.manager import SessionManager
|
from nanobot.session.manager import SessionManager
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
@ -639,12 +640,7 @@ def serve(
|
|||||||
agent_loop = AgentLoop.from_config(
|
agent_loop = AgentLoop.from_config(
|
||||||
runtime_config, bus,
|
runtime_config, bus,
|
||||||
session_manager=session_manager,
|
session_manager=session_manager,
|
||||||
image_generation_provider_configs={
|
image_generation_provider_configs=image_gen_provider_configs(runtime_config),
|
||||||
"openrouter": runtime_config.providers.openrouter,
|
|
||||||
"aihubmix": runtime_config.providers.aihubmix,
|
|
||||||
"minimax": runtime_config.providers.minimax,
|
|
||||||
"gemini": runtime_config.providers.gemini,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
console.print(f"[red]Error: {exc}[/red]")
|
console.print(f"[red]Error: {exc}[/red]")
|
||||||
@ -724,6 +720,7 @@ def _run_gateway(
|
|||||||
from nanobot.cron.types import CronJob
|
from nanobot.cron.types import CronJob
|
||||||
from nanobot.heartbeat.service import HeartbeatService
|
from nanobot.heartbeat.service import HeartbeatService
|
||||||
from nanobot.providers.factory import build_provider_snapshot, load_provider_snapshot
|
from nanobot.providers.factory import build_provider_snapshot, load_provider_snapshot
|
||||||
|
from nanobot.providers.image_generation import image_gen_provider_configs
|
||||||
from nanobot.session.manager import SessionManager
|
from nanobot.session.manager import SessionManager
|
||||||
|
|
||||||
port = port if port is not None else config.gateway.port
|
port = port if port is not None else config.gateway.port
|
||||||
@ -754,12 +751,7 @@ def _run_gateway(
|
|||||||
context_window_tokens=provider_snapshot.context_window_tokens,
|
context_window_tokens=provider_snapshot.context_window_tokens,
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
session_manager=session_manager,
|
session_manager=session_manager,
|
||||||
image_generation_provider_configs={
|
image_generation_provider_configs=image_gen_provider_configs(config),
|
||||||
"openrouter": config.providers.openrouter,
|
|
||||||
"aihubmix": config.providers.aihubmix,
|
|
||||||
"minimax": config.providers.minimax,
|
|
||||||
"gemini": config.providers.gemini,
|
|
||||||
},
|
|
||||||
provider_snapshot_loader=load_provider_snapshot,
|
provider_snapshot_loader=load_provider_snapshot,
|
||||||
runtime_model_publisher=lambda model, preset: publish_runtime_model_update(
|
runtime_model_publisher=lambda model, preset: publish_runtime_model_update(
|
||||||
bus,
|
bus,
|
||||||
@ -1126,6 +1118,7 @@ def agent(
|
|||||||
|
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
|
from nanobot.providers.image_generation import image_gen_provider_configs
|
||||||
|
|
||||||
config = _load_runtime_config(config, workspace)
|
config = _load_runtime_config(config, workspace)
|
||||||
sync_workspace_templates(config.workspace_path)
|
sync_workspace_templates(config.workspace_path)
|
||||||
@ -1149,6 +1142,7 @@ def agent(
|
|||||||
agent_loop = AgentLoop.from_config(
|
agent_loop = AgentLoop.from_config(
|
||||||
config, bus,
|
config, bus,
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
|
image_generation_provider_configs=image_gen_provider_configs(config),
|
||||||
)
|
)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
console.print(f"[red]Error: {exc}[/red]")
|
console.print(f"[red]Error: {exc}[/red]")
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from typing import Any
|
|||||||
|
|
||||||
from nanobot.agent.hook import AgentHook, SDKCaptureHook
|
from nanobot.agent.hook import AgentHook, SDKCaptureHook
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.providers.image_generation import image_gen_provider_configs
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@ -63,12 +64,7 @@ class Nanobot:
|
|||||||
|
|
||||||
loop = AgentLoop.from_config(
|
loop = AgentLoop.from_config(
|
||||||
config,
|
config,
|
||||||
image_generation_provider_configs={
|
image_generation_provider_configs=image_gen_provider_configs(config),
|
||||||
"openrouter": config.providers.openrouter,
|
|
||||||
"aihubmix": config.providers.aihubmix,
|
|
||||||
"minimax": config.providers.minimax,
|
|
||||||
"gemini": config.providers.gemini,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return cls(loop)
|
return cls(loop)
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -44,15 +45,6 @@ class GeneratedImageResponse:
|
|||||||
raw: dict[str, Any]
|
raw: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
def _provider_base_url(provider: str, api_base: str | None, fallback: str) -> str:
|
|
||||||
if api_base:
|
|
||||||
return api_base.rstrip("/")
|
|
||||||
spec = find_by_name(provider)
|
|
||||||
if spec and spec.default_api_base:
|
|
||||||
return spec.default_api_base.rstrip("/")
|
|
||||||
return fallback
|
|
||||||
|
|
||||||
|
|
||||||
def _read_image_b64(path: str | Path) -> tuple[str, str]:
|
def _read_image_b64(path: str | Path) -> tuple[str, str]:
|
||||||
"""Return ``(mime, base64)`` for the image at ``path``."""
|
"""Return ``(mime, base64)`` for the image at ``path``."""
|
||||||
p = Path(path).expanduser()
|
p = Path(path).expanduser()
|
||||||
@ -120,8 +112,44 @@ async def _download_image_data_url(
|
|||||||
return f"data:{mime};base64,{encoded}"
|
return f"data:{mime};base64,{encoded}"
|
||||||
|
|
||||||
|
|
||||||
class OpenRouterImageGenerationClient:
|
# ---------------------------------------------------------------------------
|
||||||
"""Small async client for OpenRouter Chat Completions image generation."""
|
# Registry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_IMAGE_GEN_PROVIDERS: dict[str, type[ImageGenerationProvider]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_image_gen_provider(cls: type[ImageGenerationProvider]) -> None:
|
||||||
|
name = cls.provider_name
|
||||||
|
if not name:
|
||||||
|
raise ValueError(f"{cls.__name__} must set provider_name")
|
||||||
|
_IMAGE_GEN_PROVIDERS[name] = cls
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_gen_provider(name: str) -> type[ImageGenerationProvider] | None:
|
||||||
|
return _IMAGE_GEN_PROVIDERS.get(name)
|
||||||
|
|
||||||
|
|
||||||
|
def image_gen_provider_configs(config: Any) -> dict[str, Any]:
|
||||||
|
providers_cfg = config.providers
|
||||||
|
return {
|
||||||
|
name: pc
|
||||||
|
for name in _IMAGE_GEN_PROVIDERS
|
||||||
|
if (pc := getattr(providers_cfg, name, None)) is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Base class
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGenerationProvider(ABC):
|
||||||
|
"""Base class for image generation provider clients."""
|
||||||
|
|
||||||
|
provider_name: str = ""
|
||||||
|
missing_key_message: str = ""
|
||||||
|
default_timeout: float = _DEFAULT_TIMEOUT_S
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -130,20 +158,71 @@ class OpenRouterImageGenerationClient:
|
|||||||
api_base: str | None = None,
|
api_base: str | None = None,
|
||||||
extra_headers: dict[str, str] | None = None,
|
extra_headers: dict[str, str] | None = None,
|
||||||
extra_body: dict[str, Any] | None = None,
|
extra_body: dict[str, Any] | None = None,
|
||||||
timeout: float = _DEFAULT_TIMEOUT_S,
|
timeout: float | None = None,
|
||||||
client: httpx.AsyncClient | None = None,
|
client: httpx.AsyncClient | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.api_base = _provider_base_url(
|
self.api_base = self._resolve_base_url(api_base)
|
||||||
"openrouter",
|
|
||||||
api_base,
|
|
||||||
"https://openrouter.ai/api/v1",
|
|
||||||
)
|
|
||||||
self.extra_headers = extra_headers or {}
|
self.extra_headers = extra_headers or {}
|
||||||
self.extra_body = extra_body or {}
|
self.extra_body = extra_body or {}
|
||||||
self.timeout = timeout
|
self.timeout = timeout if timeout is not None else self.default_timeout
|
||||||
self._client = client
|
self._client = client
|
||||||
|
|
||||||
|
def _resolve_base_url(self, api_base: str | None) -> str:
|
||||||
|
if api_base:
|
||||||
|
return api_base.rstrip("/")
|
||||||
|
spec = find_by_name(self.provider_name)
|
||||||
|
if spec and spec.default_api_base:
|
||||||
|
return spec.default_api_base.rstrip("/")
|
||||||
|
return self._default_base_url()
|
||||||
|
|
||||||
|
def _default_base_url(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
reference_images: list[str] | None = None,
|
||||||
|
aspect_ratio: str | None = None,
|
||||||
|
image_size: str | None = None,
|
||||||
|
) -> GeneratedImageResponse: ...
|
||||||
|
|
||||||
|
def _require_images(self, images: list[str], data: dict[str, Any]) -> None:
|
||||||
|
if images:
|
||||||
|
return
|
||||||
|
provider_error = data.get("error") if isinstance(data, dict) else None
|
||||||
|
label = self.provider_name
|
||||||
|
if provider_error:
|
||||||
|
raise ImageGenerationError(f"{label} returned no images: {provider_error}")
|
||||||
|
raise ImageGenerationError(f"{label} returned no images for this request")
|
||||||
|
|
||||||
|
async def _http_post(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
headers: dict[str, str],
|
||||||
|
body: dict[str, Any],
|
||||||
|
) -> httpx.Response:
|
||||||
|
if self._client is not None:
|
||||||
|
return await self._client.post(url, headers=headers, json=body)
|
||||||
|
async with httpx.AsyncClient(timeout=self.timeout) as c:
|
||||||
|
return await c.post(url, headers=headers, json=body)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterImageGenerationClient(ImageGenerationProvider):
|
||||||
|
"""Small async client for OpenRouter Chat Completions image generation."""
|
||||||
|
|
||||||
|
provider_name = "openrouter"
|
||||||
|
missing_key_message = (
|
||||||
|
"OpenRouter API key is not configured. Set providers.openrouter.apiKey."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _default_base_url(self) -> str:
|
||||||
|
return "https://openrouter.ai/api/v1"
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -154,9 +233,7 @@ class OpenRouterImageGenerationClient:
|
|||||||
image_size: str | None = None,
|
image_size: str | None = None,
|
||||||
) -> GeneratedImageResponse:
|
) -> GeneratedImageResponse:
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ImageGenerationError(
|
raise ImageGenerationError(self.missing_key_message)
|
||||||
"OpenRouter API key is not configured. Set providers.openrouter.apiKey."
|
|
||||||
)
|
|
||||||
|
|
||||||
content: str | list[dict[str, Any]]
|
content: str | list[dict[str, Any]]
|
||||||
references = list(reference_images or [])
|
references = list(reference_images or [])
|
||||||
@ -192,12 +269,7 @@ class OpenRouterImageGenerationClient:
|
|||||||
**self.extra_headers,
|
**self.extra_headers,
|
||||||
}
|
}
|
||||||
url = f"{self.api_base}/chat/completions"
|
url = f"{self.api_base}/chat/completions"
|
||||||
|
response = await self._http_post(url, headers=headers, body=body)
|
||||||
if self._client is not None:
|
|
||||||
response = await self._client.post(url, headers=headers, json=body)
|
|
||||||
else:
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
response = await client.post(url, headers=headers, json=body)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -222,11 +294,7 @@ class OpenRouterImageGenerationClient:
|
|||||||
if isinstance(url_value, str) and url_value.startswith("data:image/"):
|
if isinstance(url_value, str) and url_value.startswith("data:image/"):
|
||||||
images.append(url_value)
|
images.append(url_value)
|
||||||
|
|
||||||
if not images:
|
self._require_images(images, data)
|
||||||
provider_error = data.get("error") if isinstance(data, dict) else None
|
|
||||||
if provider_error:
|
|
||||||
raise ImageGenerationError(f"OpenRouter returned no images: {provider_error}")
|
|
||||||
raise ImageGenerationError("OpenRouter returned no images for this request")
|
|
||||||
|
|
||||||
return GeneratedImageResponse(
|
return GeneratedImageResponse(
|
||||||
images=images,
|
images=images,
|
||||||
@ -235,29 +303,17 @@ class OpenRouterImageGenerationClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AIHubMixImageGenerationClient:
|
class AIHubMixImageGenerationClient(ImageGenerationProvider):
|
||||||
"""Small async client for AIHubMix unified image generation."""
|
"""Small async client for AIHubMix unified image generation."""
|
||||||
|
|
||||||
def __init__(
|
provider_name = "aihubmix"
|
||||||
self,
|
missing_key_message = (
|
||||||
*,
|
"AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
|
||||||
api_key: str | None,
|
)
|
||||||
api_base: str | None = None,
|
default_timeout = _AIHUBMIX_TIMEOUT_S
|
||||||
extra_headers: dict[str, str] | None = None,
|
|
||||||
extra_body: dict[str, Any] | None = None,
|
def _default_base_url(self) -> str:
|
||||||
timeout: float = _AIHUBMIX_TIMEOUT_S,
|
return "https://aihubmix.com/v1"
|
||||||
client: httpx.AsyncClient | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
self.api_base = _provider_base_url(
|
|
||||||
"aihubmix",
|
|
||||||
api_base,
|
|
||||||
"https://aihubmix.com/v1",
|
|
||||||
)
|
|
||||||
self.extra_headers = extra_headers or {}
|
|
||||||
self.extra_body = extra_body or {}
|
|
||||||
self.timeout = timeout
|
|
||||||
self._client = client
|
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
@ -269,9 +325,7 @@ class AIHubMixImageGenerationClient:
|
|||||||
image_size: str | None = None,
|
image_size: str | None = None,
|
||||||
) -> GeneratedImageResponse:
|
) -> GeneratedImageResponse:
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ImageGenerationError(
|
raise ImageGenerationError(self.missing_key_message)
|
||||||
"AIHubMix API key is not configured. Set providers.aihubmix.apiKey."
|
|
||||||
)
|
|
||||||
|
|
||||||
refs = list(reference_images or [])
|
refs = list(reference_images or [])
|
||||||
headers = {
|
headers = {
|
||||||
@ -280,16 +334,8 @@ class AIHubMixImageGenerationClient:
|
|||||||
}
|
}
|
||||||
size = _aihubmix_size(aspect_ratio, image_size)
|
size = _aihubmix_size(aspect_ratio, image_size)
|
||||||
|
|
||||||
if self._client is not None:
|
client = self._client or httpx.AsyncClient(timeout=self.timeout)
|
||||||
return await self._generate_with_client(
|
try:
|
||||||
self._client,
|
|
||||||
prompt=prompt,
|
|
||||||
model=model,
|
|
||||||
reference_images=refs,
|
|
||||||
size=size,
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
return await self._generate_with_client(
|
return await self._generate_with_client(
|
||||||
client,
|
client,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -298,6 +344,9 @@ class AIHubMixImageGenerationClient:
|
|||||||
size=size,
|
size=size,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
if self._client is None:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
async def _generate_with_client(
|
async def _generate_with_client(
|
||||||
self,
|
self,
|
||||||
@ -346,11 +395,7 @@ class AIHubMixImageGenerationClient:
|
|||||||
payload = response.json()
|
payload = response.json()
|
||||||
images = await _aihubmix_images_from_payload(client, payload)
|
images = await _aihubmix_images_from_payload(client, payload)
|
||||||
|
|
||||||
if not images:
|
self._require_images(images, payload)
|
||||||
provider_error = payload.get("error") if isinstance(payload, dict) else None
|
|
||||||
if provider_error:
|
|
||||||
raise ImageGenerationError(f"AIHubMix returned no images: {provider_error}")
|
|
||||||
raise ImageGenerationError("AIHubMix returned no images for this request")
|
|
||||||
|
|
||||||
return GeneratedImageResponse(images=images, content="", raw=payload)
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
||||||
|
|
||||||
@ -370,31 +415,25 @@ def _http_error_detail(response: httpx.Response) -> str:
|
|||||||
return response.text[:500] or "<empty response body>"
|
return response.text[:500] or "<empty response body>"
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageGenerationClient:
|
class GeminiImageGenerationClient(ImageGenerationProvider):
|
||||||
"""Async client for Gemini/Imagen image generation via the Generative Language API."""
|
"""Async client for Gemini/Imagen image generation via the Generative Language API."""
|
||||||
|
|
||||||
def __init__(
|
provider_name = "gemini"
|
||||||
self,
|
missing_key_message = (
|
||||||
*,
|
"Gemini API key is not configured. Set providers.gemini.apiKey."
|
||||||
api_key: str | None,
|
)
|
||||||
api_base: str | None = None,
|
default_timeout = _GEMINI_DEFAULT_TIMEOUT_S
|
||||||
extra_headers: dict[str, str] | None = None,
|
|
||||||
extra_body: dict[str, Any] | None = None,
|
def _default_base_url(self) -> str:
|
||||||
timeout: float = _GEMINI_DEFAULT_TIMEOUT_S,
|
return "https://generativelanguage.googleapis.com/v1beta"
|
||||||
client: httpx.AsyncClient | None = None,
|
|
||||||
) -> None:
|
def _resolve_base_url(self, api_base: str | None) -> str:
|
||||||
self.api_key = api_key
|
|
||||||
# The Gemini provider's registry default_api_base is the OpenAI-compat
|
# The Gemini provider's registry default_api_base is the OpenAI-compat
|
||||||
# shim (.../v1beta/openai/), which has no image endpoints. Image
|
# shim (.../v1beta/openai/), which has no image endpoints.
|
||||||
# generation needs the native Generative Language API base, so we don't
|
# Skip the registry lookup and use the native API base directly.
|
||||||
# use _provider_base_url() here.
|
if api_base:
|
||||||
self.api_base = (
|
return api_base.rstrip("/")
|
||||||
api_base or "https://generativelanguage.googleapis.com/v1beta"
|
return self._default_base_url()
|
||||||
).rstrip("/")
|
|
||||||
self.extra_headers = extra_headers or {}
|
|
||||||
self.extra_body = extra_body or {}
|
|
||||||
self.timeout = timeout
|
|
||||||
self._client = client
|
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
@ -406,9 +445,7 @@ class GeminiImageGenerationClient:
|
|||||||
image_size: str | None = None,
|
image_size: str | None = None,
|
||||||
) -> GeneratedImageResponse:
|
) -> GeneratedImageResponse:
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ImageGenerationError(
|
raise ImageGenerationError(self.missing_key_message)
|
||||||
"Gemini API key is not configured. Set providers.gemini.apiKey."
|
|
||||||
)
|
|
||||||
if "imagen" in model.lower():
|
if "imagen" in model.lower():
|
||||||
if reference_images:
|
if reference_images:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -446,12 +483,7 @@ class GeminiImageGenerationClient:
|
|||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**self.extra_headers,
|
**self.extra_headers,
|
||||||
}
|
}
|
||||||
|
response = await self._http_post(url, headers=headers, body=body)
|
||||||
if self._client is not None:
|
|
||||||
response = await self._client.post(url, headers=headers, json=body)
|
|
||||||
else:
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
response = await client.post(url, headers=headers, json=body)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -472,11 +504,7 @@ class GeminiImageGenerationClient:
|
|||||||
if isinstance(b64, str) and b64:
|
if isinstance(b64, str) and b64:
|
||||||
images.append(f"data:{mime};base64,{b64}")
|
images.append(f"data:{mime};base64,{b64}")
|
||||||
|
|
||||||
if not images:
|
self._require_images(images, data)
|
||||||
provider_error = data.get("error") if isinstance(data, dict) else None
|
|
||||||
if provider_error:
|
|
||||||
raise ImageGenerationError(f"Gemini Imagen returned no images: {provider_error}")
|
|
||||||
raise ImageGenerationError("Gemini Imagen returned no images for this request")
|
|
||||||
|
|
||||||
return GeneratedImageResponse(images=images, content="", raw=data)
|
return GeneratedImageResponse(images=images, content="", raw=data)
|
||||||
|
|
||||||
@ -504,12 +532,7 @@ class GeminiImageGenerationClient:
|
|||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**self.extra_headers,
|
**self.extra_headers,
|
||||||
}
|
}
|
||||||
|
response = await self._http_post(url, headers=headers, body=body)
|
||||||
if self._client is not None:
|
|
||||||
response = await self._client.post(url, headers=headers, json=body)
|
|
||||||
else:
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
response = await client.post(url, headers=headers, json=body)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -539,11 +562,7 @@ class GeminiImageGenerationClient:
|
|||||||
if b64:
|
if b64:
|
||||||
images.append(f"data:{mime};base64,{b64}")
|
images.append(f"data:{mime};base64,{b64}")
|
||||||
|
|
||||||
if not images:
|
self._require_images(images, data)
|
||||||
provider_error = data.get("error") if isinstance(data, dict) else None
|
|
||||||
if provider_error:
|
|
||||||
raise ImageGenerationError(f"Gemini returned no images: {provider_error}")
|
|
||||||
raise ImageGenerationError("Gemini returned no images for this request")
|
|
||||||
|
|
||||||
return GeneratedImageResponse(
|
return GeneratedImageResponse(
|
||||||
images=images,
|
images=images,
|
||||||
@ -620,29 +639,17 @@ _MINIMAX_ASPECT_RATIO_SIZES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxImageGenerationClient:
|
class MiniMaxImageGenerationClient(ImageGenerationProvider):
|
||||||
"""Async client for MiniMax image generation API."""
|
"""Async client for MiniMax image generation API."""
|
||||||
|
|
||||||
def __init__(
|
provider_name = "minimax"
|
||||||
self,
|
missing_key_message = (
|
||||||
*,
|
"MiniMax API key is not configured. Set providers.minimax.apiKey."
|
||||||
api_key: str | None,
|
)
|
||||||
api_base: str | None = None,
|
default_timeout = _MINIMAX_TIMEOUT_S
|
||||||
extra_headers: dict[str, str] | None = None,
|
|
||||||
extra_body: dict[str, Any] | None = None,
|
def _default_base_url(self) -> str:
|
||||||
timeout: float = _MINIMAX_TIMEOUT_S,
|
return "https://api.minimaxi.com/v1"
|
||||||
client: httpx.AsyncClient | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
self.api_base = _provider_base_url(
|
|
||||||
"minimax",
|
|
||||||
api_base,
|
|
||||||
"https://api.minimaxi.com/v1",
|
|
||||||
)
|
|
||||||
self.extra_headers = extra_headers or {}
|
|
||||||
self.extra_body = extra_body or {}
|
|
||||||
self.timeout = timeout
|
|
||||||
self._client = client
|
|
||||||
|
|
||||||
def _resolve_aspect_ratio(self, aspect_ratio: str | None) -> str:
|
def _resolve_aspect_ratio(self, aspect_ratio: str | None) -> str:
|
||||||
if aspect_ratio and aspect_ratio in _MINIMAX_ASPECT_RATIO_SIZES:
|
if aspect_ratio and aspect_ratio in _MINIMAX_ASPECT_RATIO_SIZES:
|
||||||
@ -659,9 +666,7 @@ class MiniMaxImageGenerationClient:
|
|||||||
image_size: str | None = None,
|
image_size: str | None = None,
|
||||||
) -> GeneratedImageResponse:
|
) -> GeneratedImageResponse:
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ImageGenerationError(
|
raise ImageGenerationError(self.missing_key_message)
|
||||||
"MiniMax API key is not configured. Set providers.minimax.apiKey."
|
|
||||||
)
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
@ -687,10 +692,12 @@ class MiniMaxImageGenerationClient:
|
|||||||
|
|
||||||
body.update(self.extra_body)
|
body.update(self.extra_body)
|
||||||
|
|
||||||
if self._client is not None:
|
client = self._client or httpx.AsyncClient(timeout=self.timeout)
|
||||||
return await self._generate_with_client(self._client, body, headers)
|
try:
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
return await self._generate_with_client(client, body, headers)
|
return await self._generate_with_client(client, body, headers)
|
||||||
|
finally:
|
||||||
|
if self._client is None:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
async def _generate_with_client(
|
async def _generate_with_client(
|
||||||
self,
|
self,
|
||||||
@ -715,11 +722,7 @@ class MiniMaxImageGenerationClient:
|
|||||||
payload = response.json()
|
payload = response.json()
|
||||||
images = _minimax_images_from_payload(payload)
|
images = _minimax_images_from_payload(payload)
|
||||||
|
|
||||||
if not images:
|
self._require_images(images, payload)
|
||||||
provider_error = payload.get("error") if isinstance(payload, dict) else None
|
|
||||||
if provider_error:
|
|
||||||
raise ImageGenerationError(f"MiniMax returned no images: {provider_error}")
|
|
||||||
raise ImageGenerationError("MiniMax returned no images for this request")
|
|
||||||
|
|
||||||
return GeneratedImageResponse(images=images, content="", raw=payload)
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
||||||
|
|
||||||
@ -737,3 +740,13 @@ def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
|||||||
if isinstance(b64, str) and b64:
|
if isinstance(b64, str) and b64:
|
||||||
images.append(_b64_png_data_url(b64))
|
images.append(_b64_png_data_url(b64))
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Provider registration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
register_image_gen_provider(OpenRouterImageGenerationClient)
|
||||||
|
register_image_gen_provider(AIHubMixImageGenerationClient)
|
||||||
|
register_image_gen_provider(GeminiImageGenerationClient)
|
||||||
|
register_image_gen_provider(MiniMaxImageGenerationClient)
|
||||||
|
|||||||
@ -35,8 +35,8 @@ async def test_generated_image_media_is_attached_to_final_assistant_message(
|
|||||||
) -> None:
|
) -> None:
|
||||||
set_config_path(tmp_path / "config.json")
|
set_config_path(tmp_path / "config.json")
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.agent.tools.image_generation.OpenRouterImageGenerationClient",
|
"nanobot.agent.tools.image_generation.get_image_gen_provider",
|
||||||
FakeImageClient,
|
lambda name: FakeImageClient if name == "openrouter" else None,
|
||||||
)
|
)
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
|||||||
@ -44,8 +44,8 @@ async def test_generate_image_tool_stores_artifact_and_source_images(
|
|||||||
set_config_path(tmp_path / "config.json")
|
set_config_path(tmp_path / "config.json")
|
||||||
FakeImageClient.instances = []
|
FakeImageClient.instances = []
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.agent.tools.image_generation.OpenRouterImageGenerationClient",
|
"nanobot.agent.tools.image_generation.get_image_gen_provider",
|
||||||
FakeImageClient,
|
lambda name: FakeImageClient if name == "openrouter" else None,
|
||||||
)
|
)
|
||||||
ref = tmp_path / "ref.png"
|
ref = tmp_path / "ref.png"
|
||||||
ref.write_bytes(PNG_BYTES)
|
ref.write_bytes(PNG_BYTES)
|
||||||
@ -98,8 +98,8 @@ async def test_generate_image_tool_selects_aihubmix_provider(
|
|||||||
set_config_path(tmp_path / "config.json")
|
set_config_path(tmp_path / "config.json")
|
||||||
FakeImageClient.instances = []
|
FakeImageClient.instances = []
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.agent.tools.image_generation.AIHubMixImageGenerationClient",
|
"nanobot.agent.tools.image_generation.get_image_gen_provider",
|
||||||
FakeImageClient,
|
lambda name: FakeImageClient if name == "aihubmix" else None,
|
||||||
)
|
)
|
||||||
tool = ImageGenerationTool(
|
tool = ImageGenerationTool(
|
||||||
workspace=tmp_path,
|
workspace=tmp_path,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user