mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-24 02:22:52 +00:00
Merge PR #3929: Unify image provider HTTP handling and document Gemini image base URLs
Unify image provider HTTP handling and document Gemini image base URLs
This commit is contained in:
commit
782d761b81
@ -139,6 +139,11 @@ _IMAGE_GEN_PROVIDERS: dict[str, type[ImageGenerationProvider]] = {}
|
|||||||
|
|
||||||
|
|
||||||
def register_image_gen_provider(cls: type[ImageGenerationProvider]) -> None:
|
def register_image_gen_provider(cls: type[ImageGenerationProvider]) -> None:
|
||||||
|
"""Register an image provider at import time only.
|
||||||
|
|
||||||
|
The registry is populated by module side effects so provider discovery
|
||||||
|
stays lazy and consistent across the process.
|
||||||
|
"""
|
||||||
name = cls.provider_name
|
name = cls.provider_name
|
||||||
if not name:
|
if not name:
|
||||||
raise ValueError(f"{cls.__name__} must set provider_name")
|
raise ValueError(f"{cls.__name__} must set provider_name")
|
||||||
@ -229,7 +234,10 @@ class ImageGenerationProvider(ABC):
|
|||||||
*,
|
*,
|
||||||
headers: dict[str, str],
|
headers: dict[str, str],
|
||||||
body: dict[str, Any],
|
body: dict[str, Any],
|
||||||
|
client: httpx.AsyncClient | None = None,
|
||||||
) -> httpx.Response:
|
) -> httpx.Response:
|
||||||
|
if client is not None:
|
||||||
|
return await client.post(url, headers=headers, json=body)
|
||||||
if self._client is not None:
|
if self._client is not None:
|
||||||
return await self._client.post(url, headers=headers, json=body)
|
return await self._client.post(url, headers=headers, json=body)
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as c:
|
async with httpx.AsyncClient(timeout=self.timeout) as c:
|
||||||
@ -400,10 +408,11 @@ class AIHubMixImageGenerationClient(ImageGenerationProvider):
|
|||||||
model_path = _aihubmix_model_path(model)
|
model_path = _aihubmix_model_path(model)
|
||||||
url = f"{self.api_base}/models/{model_path}/predictions"
|
url = f"{self.api_base}/models/{model_path}/predictions"
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await self._http_post(
|
||||||
url,
|
url,
|
||||||
headers={**headers, "Content-Type": "application/json"},
|
headers={**headers, "Content-Type": "application/json"},
|
||||||
json=body,
|
body=body,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException as exc:
|
except httpx.TimeoutException as exc:
|
||||||
raise ImageGenerationError("AIHubMix image generation timed out") from exc
|
raise ImageGenerationError("AIHubMix image generation timed out") from exc
|
||||||
@ -585,9 +594,9 @@ class GeminiImageGenerationClient(ImageGenerationProvider):
|
|||||||
return "https://generativelanguage.googleapis.com/v1beta"
|
return "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
|
||||||
def _resolve_base_url(self, api_base: str | None) -> str:
|
def _resolve_base_url(self, api_base: str | None) -> str:
|
||||||
# The Gemini provider's registry default_api_base is the OpenAI-compat
|
# Gemini chat completions use the registry's OpenAI-compatible shim.
|
||||||
# shim (.../v1beta/openai/), which has no image endpoints.
|
# Image generation must hit the native Generative Language API, so we
|
||||||
# Skip the registry lookup and use the native API base directly.
|
# intentionally bypass the shared registry lookup here.
|
||||||
if api_base:
|
if api_base:
|
||||||
return api_base.rstrip("/")
|
return api_base.rstrip("/")
|
||||||
return self._default_base_url()
|
return self._default_base_url()
|
||||||
@ -849,22 +858,16 @@ class MiniMaxImageGenerationClient(ImageGenerationProvider):
|
|||||||
|
|
||||||
body.update(self.extra_body)
|
body.update(self.extra_body)
|
||||||
|
|
||||||
client = self._client or httpx.AsyncClient(timeout=self.timeout)
|
return await self._generate_with_client(body, headers)
|
||||||
try:
|
|
||||||
return await self._generate_with_client(client, body, headers)
|
|
||||||
finally:
|
|
||||||
if self._client is None:
|
|
||||||
await client.aclose()
|
|
||||||
|
|
||||||
async def _generate_with_client(
|
async def _generate_with_client(
|
||||||
self,
|
self,
|
||||||
client: httpx.AsyncClient,
|
|
||||||
body: dict[str, Any],
|
body: dict[str, Any],
|
||||||
headers: dict[str, str],
|
headers: dict[str, str],
|
||||||
) -> GeneratedImageResponse:
|
) -> GeneratedImageResponse:
|
||||||
url = f"{self.api_base}/image_generation"
|
url = f"{self.api_base}/image_generation"
|
||||||
try:
|
try:
|
||||||
response = await client.post(url, headers=headers, json=body)
|
response = await self._http_post(url, headers=headers, body=body)
|
||||||
except httpx.TimeoutException as exc:
|
except httpx.TimeoutException as exc:
|
||||||
raise ImageGenerationError("MiniMax image generation timed out") from exc
|
raise ImageGenerationError("MiniMax image generation timed out") from exc
|
||||||
except httpx.RequestError as exc:
|
except httpx.RequestError as exc:
|
||||||
|
|||||||
@ -410,6 +410,11 @@ async def test_gemini_requires_api_key() -> None:
|
|||||||
await client.generate(prompt="draw", model="imagen-4.0-generate-001")
|
await client.generate(prompt="draw", model="imagen-4.0-generate-001")
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemini_image_client_uses_native_api_base_by_default() -> None:
|
||||||
|
client = GeminiImageGenerationClient(api_key="AIza-test")
|
||||||
|
assert client.api_base == "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_no_images_raises() -> None:
|
async def test_gemini_no_images_raises() -> None:
|
||||||
fake = FakeClient(FakeResponse({"candidates": [{"content": {"parts": [{"text": "sorry"}]}}]}))
|
fake = FakeClient(FakeResponse({"candidates": [{"content": {"parts": [{"text": "sorry"}]}}]}))
|
||||||
@ -452,6 +457,17 @@ async def test_minimax_payload_and_response_with_reference_image(tmp_path: Path)
|
|||||||
assert body["subject_reference"][0]["image_file"].startswith("data:image/png;base64,")
|
assert body["subject_reference"][0]["image_file"].startswith("data:image/png;base64,")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_minimax_base64_response_uses_detected_mime() -> None:
|
||||||
|
raw_b64 = base64.b64encode(JPEG_BYTES).decode("ascii")
|
||||||
|
fake = FakeClient(FakeResponse({"data": {"image_base64": [raw_b64]}}))
|
||||||
|
client = MiniMaxImageGenerationClient(api_key="sk-mm-test", client=fake) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
response = await client.generate(prompt="draw", model="image-01")
|
||||||
|
|
||||||
|
assert response.images == [f"data:image/jpeg;base64,{raw_b64}"]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# StepFun (阶跃星辰)
|
# StepFun (阶跃星辰)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@ -458,6 +458,12 @@ def test_gemma_routes_to_gemini_provider() -> None:
|
|||||||
assert "gemma" in spec.keywords
|
assert "gemma" in spec.keywords
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemini_spec_keeps_openai_compat_base() -> None:
|
||||||
|
spec = find_by_name("gemini")
|
||||||
|
assert spec is not None
|
||||||
|
assert spec.default_api_base == "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||||
|
|
||||||
|
|
||||||
async def test_openrouter_sets_default_attribution_headers() -> None:
|
async def test_openrouter_sets_default_attribution_headers() -> None:
|
||||||
spec = find_by_name("openrouter")
|
spec = find_by_name("openrouter")
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_client_cls:
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_client_cls:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user