nanobot/tests/agent/test_loop_image_generation_media.py
chengyongru 7aa5b9b17b refactor(image-generation): introduce provider registry to eliminate manual wiring
Adds ImageGenerationProvider ABC with shared __init__, _http_post(), and
_require_images(). Introduces _IMAGE_GEN_PROVIDERS registry with
register/get/image_gen_provider_configs() helpers.

Four existing providers (OpenRouter, AIHubMix, Gemini, MiniMax) now inherit
from the base class and self-register. Adding a new provider only requires
writing one class + one registration line.

Eliminates if/else chains in the tool dispatch and hardcoded provider config
dicts in commands.py (3 sites) and nanobot.py (1 site). Fixes the agent CLI
command missing image_generation_provider_configs entirely.

Also simplifies test monkeypatch targets to patch the registry lookup.
2026-05-18 17:20:54 +08:00

90 lines
2.9 KiB
Python

from __future__ import annotations
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.config.loader import set_config_path
from nanobot.config.schema import ImageGenerationToolConfig, ProviderConfig, ToolsConfig
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.providers.image_generation import GeneratedImageResponse
PNG_DATA_URL = (
"data:image/png;base64,"
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+/p9sAAAAASUVORK5CYII="
)
class FakeImageClient:
def __init__(self, **kwargs: Any) -> None:
pass
async def generate(self, **kwargs: Any) -> GeneratedImageResponse:
return GeneratedImageResponse(images=[PNG_DATA_URL], content="", raw={})
@pytest.mark.asyncio
async def test_generated_image_media_is_attached_to_final_assistant_message(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
set_config_path(tmp_path / "config.json")
monkeypatch.setattr(
"nanobot.agent.tools.image_generation.get_image_gen_provider",
lambda name: FakeImageClient if name == "openrouter" else None,
)
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.generation.max_tokens = 4096
provider.chat_with_retry = AsyncMock(
side_effect=[
LLMResponse(
content="",
finish_reason="tool_calls",
tool_calls=[
ToolCallRequest(
id="call_img",
name="generate_image",
arguments={"prompt": "draw a tiny icon"},
)
],
),
LLMResponse(content="Done", finish_reason="stop"),
]
)
provider.chat_stream_with_retry = AsyncMock()
loop = AgentLoop(
bus=MessageBus(),
provider=provider,
workspace=tmp_path,
model="test-model",
tools_config=ToolsConfig(
image_generation=ImageGenerationToolConfig(enabled=True),
),
image_generation_provider_config=ProviderConfig(api_key="sk-or-test"),
)
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
result = await loop._process_message(
InboundMessage(
channel="websocket",
sender_id="user",
chat_id="chat-image",
content="draw an icon",
)
)
assert result is not None
assert result.content == "Done"
assert len(result.media) == 1
assert Path(result.media[0]).is_file()
session = loop.sessions.get_or_create("websocket:chat-image")
assert session.messages[-1]["role"] == "assistant"
assert session.messages[-1]["media"] == result.media