mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
Adds ImageGenerationProvider ABC with shared __init__, _http_post(), and _require_images(). Introduces _IMAGE_GEN_PROVIDERS registry with register/get/image_gen_provider_configs() helpers. Four existing providers (OpenRouter, AIHubMix, Gemini, MiniMax) now inherit from the base class and self-register. Adding a new provider only requires writing one class + one registration line. Eliminates if/else chains in the tool dispatch and hardcoded provider config dicts in commands.py (3 sites) and nanobot.py (1 site). Fixes the agent CLI command missing image_generation_provider_configs entirely. Also simplifies test monkeypatch targets to patch the registry lookup.
155 lines
5.1 KiB
Python
155 lines
5.1 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from nanobot.agent.tools.image_generation import ImageGenerationTool
|
|
from nanobot.config.loader import set_config_path
|
|
from nanobot.config.schema import ImageGenerationToolConfig, ProviderConfig
|
|
from nanobot.providers.image_generation import GeneratedImageResponse
|
|
|
|
PNG_BYTES = (
|
|
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01"
|
|
b"\x00\x00\x00\x01\x08\x04\x00\x00\x00\xb5\x1c\x0c\x02"
|
|
b"\x00\x00\x00\x0bIDATx\xdacd\xfc\xff\x1f\x00\x03\x03"
|
|
b"\x02\x00\xef\xbf\xa7\xdb\x00\x00\x00\x00IEND\xaeB`\x82"
|
|
)
|
|
PNG_DATA_URL = (
|
|
"data:image/png;base64,"
|
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+/p9sAAAAASUVORK5CYII="
|
|
)
|
|
|
|
|
|
class FakeImageClient:
|
|
instances: list["FakeImageClient"] = []
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
self.kwargs = kwargs
|
|
self.calls: list[dict[str, Any]] = []
|
|
self.instances.append(self)
|
|
|
|
async def generate(self, **kwargs: Any) -> GeneratedImageResponse:
|
|
self.calls.append(kwargs)
|
|
return GeneratedImageResponse(images=[PNG_DATA_URL], content="", raw={})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_image_tool_stores_artifact_and_source_images(
|
|
tmp_path: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
set_config_path(tmp_path / "config.json")
|
|
FakeImageClient.instances = []
|
|
monkeypatch.setattr(
|
|
"nanobot.agent.tools.image_generation.get_image_gen_provider",
|
|
lambda name: FakeImageClient if name == "openrouter" else None,
|
|
)
|
|
ref = tmp_path / "ref.png"
|
|
ref.write_bytes(PNG_BYTES)
|
|
tool = ImageGenerationTool(
|
|
workspace=tmp_path,
|
|
config=ImageGenerationToolConfig(enabled=True, max_images_per_turn=2),
|
|
provider_config=ProviderConfig(api_key="sk-or-test"),
|
|
)
|
|
|
|
result = await tool.execute(
|
|
prompt="make this blue",
|
|
reference_images=["ref.png"],
|
|
aspect_ratio="16:9",
|
|
image_size="2K",
|
|
count=2,
|
|
)
|
|
|
|
payload = json.loads(result)
|
|
artifacts = payload["artifacts"]
|
|
assert len(artifacts) == 2
|
|
assert Path(artifacts[0]["path"]).is_file()
|
|
assert artifacts[0]["source_images"] == [str(ref.resolve())]
|
|
assert artifacts[0]["model"] == "openai/gpt-5.4-image-2"
|
|
|
|
fake = FakeImageClient.instances[0]
|
|
assert fake.kwargs["api_key"] == "sk-or-test"
|
|
assert len(fake.calls) == 2
|
|
assert fake.calls[0]["aspect_ratio"] == "16:9"
|
|
assert fake.calls[0]["image_size"] == "2K"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_image_tool_reports_missing_key(tmp_path: Path) -> None:
|
|
tool = ImageGenerationTool(
|
|
workspace=tmp_path,
|
|
config=ImageGenerationToolConfig(enabled=True),
|
|
provider_config=ProviderConfig(),
|
|
)
|
|
|
|
result = await tool.execute(prompt="draw")
|
|
|
|
assert result.startswith("Error: OpenRouter API key is not configured")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_image_tool_selects_aihubmix_provider(
|
|
tmp_path: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
set_config_path(tmp_path / "config.json")
|
|
FakeImageClient.instances = []
|
|
monkeypatch.setattr(
|
|
"nanobot.agent.tools.image_generation.get_image_gen_provider",
|
|
lambda name: FakeImageClient if name == "aihubmix" else None,
|
|
)
|
|
tool = ImageGenerationTool(
|
|
workspace=tmp_path,
|
|
config=ImageGenerationToolConfig(
|
|
enabled=True,
|
|
provider="aihubmix",
|
|
model="gpt-image-2-free",
|
|
),
|
|
provider_configs={
|
|
"openrouter": ProviderConfig(api_key="sk-or-test"),
|
|
"aihubmix": ProviderConfig(api_key="sk-ahm-test", extra_body={"quality": "low"}),
|
|
},
|
|
)
|
|
|
|
result = await tool.execute(prompt="draw a poster", aspect_ratio="3:4")
|
|
|
|
payload = json.loads(result)
|
|
assert len(payload["artifacts"]) == 1
|
|
fake = FakeImageClient.instances[0]
|
|
assert fake.kwargs["api_key"] == "sk-ahm-test"
|
|
assert fake.kwargs["extra_body"] == {"quality": "low"}
|
|
assert fake.calls[0]["model"] == "gpt-image-2-free"
|
|
assert fake.calls[0]["aspect_ratio"] == "3:4"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_image_tool_reports_missing_aihubmix_key(tmp_path: Path) -> None:
|
|
tool = ImageGenerationTool(
|
|
workspace=tmp_path,
|
|
config=ImageGenerationToolConfig(enabled=True, provider="aihubmix"),
|
|
provider_configs={"aihubmix": ProviderConfig()},
|
|
)
|
|
|
|
result = await tool.execute(prompt="draw")
|
|
|
|
assert result.startswith("Error: AIHubMix API key is not configured")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_image_tool_rejects_reference_outside_workspace(tmp_path: Path) -> None:
|
|
set_config_path(tmp_path / "config.json")
|
|
outside = tmp_path.parent / "outside.png"
|
|
outside.write_bytes(PNG_BYTES)
|
|
tool = ImageGenerationTool(
|
|
workspace=tmp_path,
|
|
config=ImageGenerationToolConfig(enabled=True),
|
|
provider_config=ProviderConfig(api_key="sk-or-test"),
|
|
)
|
|
|
|
result = await tool.execute(prompt="edit", reference_images=[str(outside)])
|
|
|
|
assert "reference_images must be inside the workspace" in result
|