mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-24 18:42:35 +00:00
feat(providers): add OpenAI and OpenAI Codex image generation providers
Add two new image generation providers: - `openai` — uses the standalone OpenAI Images API (`/v1/images/generations`) with an API key. Supports DALL-E and gpt-image-* models, with automatic parameter adjustment (gpt-image models don't accept response_format or n). - `openai_codex` — uses the Codex Responses API with the `image_generation` tool, authenticated via OAuth subscription token. The same mechanism ChatGPT uses internally. Also remove the API key pre-check in ImageGenerationTool so providers that handle their own auth fallback (like Codex OAuth) can work without a configured key. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
b0d3069621
commit
3483141ed7
@ -130,12 +130,6 @@ class ImageGenerationTool(Tool):
|
||||
}
|
||||
return cls(**kwargs)
|
||||
|
||||
def _missing_api_key_error(self) -> str:
|
||||
cls = get_image_gen_provider(self.config.provider)
|
||||
if cls and cls.missing_key_message:
|
||||
return f"Error: {cls.missing_key_message}"
|
||||
return f"Error: {self.config.provider} API key is not configured."
|
||||
|
||||
def _resolve_reference_image(self, value: str) -> str:
|
||||
raw_path = Path(value).expanduser()
|
||||
path = raw_path if raw_path.is_absolute() else self.workspace / raw_path
|
||||
@ -173,9 +167,6 @@ class ImageGenerationTool(Tool):
|
||||
client = self._provider_client()
|
||||
if client is None:
|
||||
return f"Error: unsupported image generation provider '{self.config.provider}'"
|
||||
provider = self._provider_config()
|
||||
if not provider or not provider.api_key:
|
||||
return self._missing_api_key_error()
|
||||
|
||||
requested = count or 1
|
||||
if requested > self.config.max_images_per_turn:
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
from abc import ABC, abstractmethod
|
||||
@ -756,6 +757,359 @@ def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
||||
return images
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI image generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_OPENAI_ASPECT_RATIO_SIZES = {
|
||||
"1:1": "1024x1024",
|
||||
"16:9": "1792x1024",
|
||||
"9:16": "1024x1792",
|
||||
"3:4": "1024x1360",
|
||||
"4:3": "1360x1024",
|
||||
}
|
||||
|
||||
|
||||
class OpenAIImageGenerationClient(ImageGenerationProvider):
|
||||
"""OpenAI Images API using an API key (``providers.openai.apiKey``)."""
|
||||
|
||||
provider_name = "openai"
|
||||
missing_key_message = (
|
||||
"OpenAI API key is not configured. Set providers.openai.apiKey."
|
||||
)
|
||||
|
||||
def _default_base_url(self) -> str:
|
||||
return "https://api.openai.com/v1"
|
||||
|
||||
@staticmethod
|
||||
def _strip_model_prefix(model: str) -> str:
|
||||
"""Remove ``openai/`` prefix if present (OpenRouter convention)."""
|
||||
if model.startswith("openai/") or model.startswith("openai_codex/"):
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
reference_images: list[str] | None = None,
|
||||
aspect_ratio: str | None = None,
|
||||
image_size: str | None = None,
|
||||
) -> GeneratedImageResponse:
|
||||
if not self.api_key:
|
||||
raise ImageGenerationError(self.missing_key_message)
|
||||
|
||||
if reference_images:
|
||||
logger.warning(
|
||||
"DALL-E models do not support reference images; "
|
||||
"ignoring {} reference image(s) for {}",
|
||||
len(reference_images),
|
||||
model,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**self.extra_headers,
|
||||
}
|
||||
|
||||
clean_model = self._strip_model_prefix(model)
|
||||
body: dict[str, Any] = {
|
||||
"model": clean_model,
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
# gpt-image-* models don't support response_format or n
|
||||
if not clean_model.startswith("gpt-image"):
|
||||
body["response_format"] = "b64_json"
|
||||
body["n"] = 1
|
||||
|
||||
size = _openai_size(aspect_ratio, image_size)
|
||||
if size:
|
||||
body["size"] = size
|
||||
|
||||
body.update(self.extra_body)
|
||||
|
||||
logger.info("OpenAI Images API request: POST {}/images/generations body={}", self.api_base, body)
|
||||
|
||||
response = await self._http_post(
|
||||
f"{self.api_base}/images/generations",
|
||||
headers=headers,
|
||||
body=body,
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
detail = response.text[:1000]
|
||||
logger.error("OpenAI Images API error ({}): {}", response.status_code, detail)
|
||||
raise ImageGenerationError(
|
||||
f"OpenAI image generation failed (HTTP {response.status_code}): {detail}"
|
||||
) from exc
|
||||
|
||||
payload = response.json()
|
||||
logger.info("OpenAI Images API response ({}): {}", response.status_code,
|
||||
{k: v for k, v in payload.items() if k != "data"})
|
||||
|
||||
client = self._client
|
||||
owns_client = client is None
|
||||
if owns_client:
|
||||
client = httpx.AsyncClient(timeout=self.timeout)
|
||||
try:
|
||||
images = await _openai_images_from_payload(client, payload)
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.aclose()
|
||||
|
||||
self._require_images(images, payload)
|
||||
|
||||
return GeneratedImageResponse(images=images, content="", raw=payload)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI Codex image generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CodexImageGenerationClient(ImageGenerationProvider):
|
||||
"""OpenAI image generation via Codex subscription OAuth.
|
||||
|
||||
Uses the Codex Responses API with the ``image_generation`` tool
|
||||
(the same mechanism ChatGPT uses internally). No API key required —
|
||||
the Codex OAuth token from ``oauth_cli_kit`` is used instead.
|
||||
"""
|
||||
|
||||
provider_name = "openai_codex"
|
||||
missing_key_message = (
|
||||
"Codex OAuth token is unavailable. "
|
||||
"Log in with Codex subscription first."
|
||||
)
|
||||
|
||||
def _default_base_url(self) -> str:
|
||||
return "https://chatgpt.com/backend-api"
|
||||
|
||||
def _codex_model(self, model: str) -> str:
|
||||
"""Strip the ``openai-codex/`` prefix if present."""
|
||||
if model.startswith(("openai-codex/", "openai_codex/")):
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
reference_images: list[str] | None = None,
|
||||
aspect_ratio: str | None = None,
|
||||
image_size: str | None = None,
|
||||
) -> GeneratedImageResponse:
|
||||
try:
|
||||
from oauth_cli_kit import get_token as get_codex_token
|
||||
except ImportError:
|
||||
raise ImageGenerationError(self.missing_key_message)
|
||||
|
||||
try:
|
||||
token = await asyncio.to_thread(get_codex_token)
|
||||
except Exception as exc:
|
||||
raise ImageGenerationError(self.missing_key_message) from exc
|
||||
if not token or not token.access:
|
||||
raise ImageGenerationError(self.missing_key_message)
|
||||
|
||||
logger.info(
|
||||
"Using Codex OAuth token for image generation (account: {})",
|
||||
token.account_id,
|
||||
)
|
||||
|
||||
if reference_images:
|
||||
logger.warning(
|
||||
"Codex image generation does not support reference images; "
|
||||
"ignoring {} reference image(s)",
|
||||
len(reference_images),
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token.access}",
|
||||
"chatgpt-account-id": token.account_id,
|
||||
"OpenAI-Beta": "responses=experimental",
|
||||
"originator": "nanobot",
|
||||
"User-Agent": "nanobot (python)",
|
||||
"Content-Type": "application/json",
|
||||
**self.extra_headers,
|
||||
}
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"model": self._codex_model(model),
|
||||
"instructions": "Generate an image based on the user's request.",
|
||||
"input": [{"role": "user", "content": prompt}],
|
||||
"tools": [{"type": "image_generation"}],
|
||||
"tool_choice": "auto",
|
||||
"stream": True,
|
||||
"store": False,
|
||||
}
|
||||
body.update(self.extra_body)
|
||||
|
||||
logger.info("Codex Responses API request: POST {}/codex/responses body={}",
|
||||
self.api_base, {k: v for k, v in body.items() if k != "input"})
|
||||
|
||||
response = await self._http_post(
|
||||
f"{self.api_base}/codex/responses",
|
||||
headers=headers,
|
||||
body=body,
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
detail = response.text[:1000]
|
||||
logger.error("Codex Responses API error ({}): {}", response.status_code, detail)
|
||||
raise ImageGenerationError(
|
||||
f"Codex image generation failed (HTTP {response.status_code}): {detail}"
|
||||
) from exc
|
||||
|
||||
images, content_text = await _parse_codex_sse_images(response)
|
||||
|
||||
raw = {"status": "completed"}
|
||||
self._require_images(images, raw)
|
||||
|
||||
return GeneratedImageResponse(images=images, content=content_text, raw=raw)
|
||||
|
||||
|
||||
def _openai_size(
|
||||
aspect_ratio: str | None,
|
||||
image_size: str | None,
|
||||
) -> str:
|
||||
"""Resolve aspect ratio or image_size to an OpenAI Images API size string."""
|
||||
if image_size and "x" in image_size.lower():
|
||||
return image_size
|
||||
if aspect_ratio and aspect_ratio in _OPENAI_ASPECT_RATIO_SIZES:
|
||||
return _OPENAI_ASPECT_RATIO_SIZES[aspect_ratio]
|
||||
return "1024x1024"
|
||||
|
||||
|
||||
async def _openai_images_from_payload(
|
||||
client: httpx.AsyncClient,
|
||||
payload: dict[str, Any],
|
||||
) -> list[str]:
|
||||
"""Extract images from OpenAI Images API response.
|
||||
|
||||
Handles both ``b64_json`` (preferred) and ``url`` (downloaded) formats.
|
||||
"""
|
||||
images: list[str] = []
|
||||
for item in payload.get("data") or []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
b64 = item.get("b64_json")
|
||||
if isinstance(b64, str) and b64:
|
||||
images.append(_b64_image_data_url(b64))
|
||||
continue
|
||||
url = item.get("url")
|
||||
if isinstance(url, str) and url:
|
||||
images.append(await _download_image_data_url(client, url))
|
||||
return images
|
||||
|
||||
|
||||
def _codex_responses_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
||||
"""Extract images from Codex Responses API ``image_generation_call`` output."""
|
||||
images: list[str] = []
|
||||
for item in payload.get("output") or []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") != "image_generation_call":
|
||||
continue
|
||||
result = item.get("result")
|
||||
if isinstance(result, str):
|
||||
images.append(result if result.startswith("data:image/") else _b64_image_data_url(result))
|
||||
continue
|
||||
if isinstance(result, dict):
|
||||
image_url = result.get("image_url") or result.get("image") or ""
|
||||
if isinstance(image_url, str):
|
||||
images.append(image_url if image_url.startswith("data:image/") else _b64_image_data_url(image_url))
|
||||
return images
|
||||
|
||||
|
||||
async def _parse_codex_sse_images(
|
||||
response: httpx.Response,
|
||||
) -> tuple[list[str], str]:
|
||||
"""Parse a Codex Responses API SSE stream for image generation output.
|
||||
|
||||
Returns ``(images, content_text)``.
|
||||
"""
|
||||
import json as _json
|
||||
|
||||
images: list[str] = []
|
||||
text_parts: list[str] = []
|
||||
|
||||
buffer: list[str] = []
|
||||
async for line_bytes in response.aiter_lines():
|
||||
line = line_bytes.strip()
|
||||
if line == "":
|
||||
if buffer:
|
||||
data_lines = []
|
||||
for bl in buffer:
|
||||
if bl.startswith("data:"):
|
||||
data_lines.append(bl[5:].strip())
|
||||
buffer.clear()
|
||||
if data_lines:
|
||||
raw = "".join(data_lines)
|
||||
if raw == "[DONE]":
|
||||
break
|
||||
try:
|
||||
event = _json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
ev_type = event.get("type", "")
|
||||
if ev_type in ("error", "response.failed"):
|
||||
logger.error("Codex SSE failure: {}", raw[:2000])
|
||||
_collect_images_from_sse_event(event, images)
|
||||
_collect_text_from_sse_event(event, text_parts)
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
# flush remaining
|
||||
if buffer:
|
||||
data_lines = [bl[5:].strip() for bl in buffer if bl.startswith("data:")]
|
||||
raw = "".join(data_lines)
|
||||
if raw and raw != "[DONE]":
|
||||
try:
|
||||
event = _json.loads(raw)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
_collect_images_from_sse_event(event, images)
|
||||
_collect_text_from_sse_event(event, text_parts)
|
||||
|
||||
return images, "\n".join(text_parts).strip()
|
||||
|
||||
|
||||
def _collect_images_from_sse_event(event: dict[str, Any], images: list[str]) -> None:
|
||||
if event.get("type") != "response.output_item.done":
|
||||
return
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") != "image_generation_call":
|
||||
return
|
||||
result = item.get("result")
|
||||
if isinstance(result, str):
|
||||
if result.startswith("data:image/"):
|
||||
images.append(result)
|
||||
else:
|
||||
images.append(_b64_image_data_url(result))
|
||||
elif isinstance(result, dict):
|
||||
image_url = result.get("image_url") or result.get("image") or ""
|
||||
if isinstance(image_url, str):
|
||||
if image_url.startswith("data:image/"):
|
||||
images.append(image_url)
|
||||
else:
|
||||
images.append(_b64_image_data_url(image_url))
|
||||
|
||||
|
||||
def _collect_text_from_sse_event(event: dict[str, Any], text_parts: list[str]) -> None:
|
||||
if event.get("type") == "response.output_text.delta":
|
||||
delta = event.get("delta")
|
||||
if isinstance(delta, str) and delta:
|
||||
text_parts.append(delta)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StepFun (阶跃星辰) image generation
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -883,8 +1237,10 @@ def _stepfun_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
||||
# Provider registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
register_image_gen_provider(OpenRouterImageGenerationClient)
|
||||
register_image_gen_provider(AIHubMixImageGenerationClient)
|
||||
register_image_gen_provider(CodexImageGenerationClient)
|
||||
register_image_gen_provider(GeminiImageGenerationClient)
|
||||
register_image_gen_provider(MiniMaxImageGenerationClient)
|
||||
register_image_gen_provider(OpenAIImageGenerationClient)
|
||||
register_image_gen_provider(OpenRouterImageGenerationClient)
|
||||
register_image_gen_provider(StepFunImageGenerationClient)
|
||||
|
||||
@ -9,10 +9,12 @@ import pytest
|
||||
|
||||
from nanobot.providers.image_generation import (
|
||||
AIHubMixImageGenerationClient,
|
||||
CodexImageGenerationClient,
|
||||
GeminiImageGenerationClient,
|
||||
GeneratedImageResponse,
|
||||
ImageGenerationError,
|
||||
MiniMaxImageGenerationClient,
|
||||
OpenAIImageGenerationClient,
|
||||
OpenRouterImageGenerationClient,
|
||||
StepFunImageGenerationClient,
|
||||
)
|
||||
@ -36,12 +38,14 @@ class FakeResponse:
|
||||
payload: dict[str, Any],
|
||||
status_code: int = 200,
|
||||
content: bytes = b"",
|
||||
sse_lines: list[str] | None = None,
|
||||
) -> None:
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
self.text = str(payload)
|
||||
self.content = content
|
||||
self.request = httpx.Request("POST", "https://openrouter.ai/api/v1/chat/completions")
|
||||
self._sse_lines = sse_lines
|
||||
|
||||
def json(self) -> dict[str, Any]:
|
||||
return self._payload
|
||||
@ -51,6 +55,15 @@ class FakeResponse:
|
||||
response = httpx.Response(self.status_code, request=self.request, text=self.text)
|
||||
raise httpx.HTTPStatusError("failed", request=self.request, response=response)
|
||||
|
||||
async def aiter_lines(self):
|
||||
if self._sse_lines is not None:
|
||||
for line in self._sse_lines:
|
||||
yield line
|
||||
return
|
||||
# Fallback: treat response text as SSE lines
|
||||
for line in self.text.split("\n"):
|
||||
yield line
|
||||
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, response: FakeResponse) -> None:
|
||||
@ -515,3 +528,362 @@ async def test_stepfun_no_images_raises() -> None:
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="returned no images"):
|
||||
await client.generate(prompt="draw", model="step-image-edit-2")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_payload_and_response() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
api_base="https://api.openai.com/v1",
|
||||
extra_headers={"X-Test": "1"},
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(
|
||||
prompt="a cat on the moon",
|
||||
model="dall-e-3",
|
||||
aspect_ratio="16:9",
|
||||
)
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
call = fake.calls[0]
|
||||
assert call["url"] == "https://api.openai.com/v1/images/generations"
|
||||
assert call["headers"]["Authorization"] == "Bearer sk-openai-test"
|
||||
assert call["headers"]["X-Test"] == "1"
|
||||
body = call["json"]
|
||||
assert body["model"] == "dall-e-3"
|
||||
assert body["prompt"] == "a cat on the moon"
|
||||
assert body["response_format"] == "b64_json"
|
||||
assert body["n"] == 1
|
||||
assert body["size"] == "1792x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_b64_json_response_uses_detected_mime() -> None:
|
||||
raw_b64 = base64.b64encode(JPEG_BYTES).decode("ascii")
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": raw_b64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(prompt="draw", model="dall-e-3")
|
||||
|
||||
assert response.images == [f"data:image/jpeg;base64,{raw_b64}"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_url_download_fallback() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"url": "https://cdn.example/image.png"}]}))
|
||||
fake.get_response = FakeResponse({}, content=PNG_BYTES)
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(prompt="draw", model="dall-e-3")
|
||||
|
||||
assert response.images[0].startswith("data:image/png;base64,")
|
||||
assert fake.get_calls[0]["url"] == "https://cdn.example/image.png"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_multiple_images() -> None:
|
||||
fake = FakeClient(FakeResponse({
|
||||
"data": [
|
||||
{"b64_json": RAW_B64},
|
||||
{"b64_json": RAW_B64},
|
||||
]
|
||||
}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(prompt="draw", model="dall-e-3")
|
||||
|
||||
assert len(response.images) == 2
|
||||
assert response.images == [PNG_DATA_URL, PNG_DATA_URL]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_aspect_ratio_to_size() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="1:1")
|
||||
assert fake.calls[0]["json"]["size"] == "1024x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_default_size_when_no_aspect_ratio() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(prompt="draw", model="dall-e-3")
|
||||
|
||||
body = fake.calls[0]["json"]
|
||||
assert body["size"] == "1024x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_uses_explicit_image_size() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(
|
||||
prompt="draw",
|
||||
model="dall-e-3",
|
||||
aspect_ratio="16:9",
|
||||
image_size="1024x1024",
|
||||
)
|
||||
|
||||
body = fake.calls[0]["json"]
|
||||
assert body["size"] == "1024x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_requires_api_key() -> None:
|
||||
client = OpenAIImageGenerationClient(api_key=None)
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="API key"):
|
||||
await client.generate(prompt="draw", model="dall-e-3")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI Codex (Responses API)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_payload_and_response(monkeypatch) -> None:
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
@dataclass
|
||||
class FakeToken:
|
||||
account_id: str = "acct-123"
|
||||
access: str = "oauth-token"
|
||||
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||
|
||||
sse_lines = [
|
||||
'data: {"type":"response.output_item.added","item":{"id":"ig_1","type":"image_generation_call","status":"in_progress"}}',
|
||||
"",
|
||||
f'data: {{"type":"response.output_item.done","item":{{"id":"ig_1","type":"image_generation_call","result":"{PNG_DATA_URL}","status":"completed"}}}}',
|
||||
"",
|
||||
'data: [DONE]',
|
||||
"",
|
||||
]
|
||||
fake = FakeClient(FakeResponse({}, sse_lines=sse_lines))
|
||||
client = CodexImageGenerationClient(
|
||||
api_key=None,
|
||||
api_base="https://chatgpt.com/backend-api",
|
||||
extra_headers={"X-Test": "1"},
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(
|
||||
prompt="draw a cat",
|
||||
model="gpt-5.4",
|
||||
)
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
assert response.content == ""
|
||||
call = fake.calls[0]
|
||||
assert call["url"] == "https://chatgpt.com/backend-api/codex/responses"
|
||||
assert call["headers"]["Authorization"] == "Bearer oauth-token"
|
||||
assert call["headers"]["chatgpt-account-id"] == "acct-123"
|
||||
assert call["headers"]["OpenAI-Beta"] == "responses=experimental"
|
||||
assert call["headers"]["X-Test"] == "1"
|
||||
body = call["json"]
|
||||
assert body["model"] == "gpt-5.4"
|
||||
assert body["instructions"] == "Generate an image based on the user's request."
|
||||
assert body["input"] == [{"role": "user", "content": "draw a cat"}]
|
||||
assert body["tools"] == [{"type": "image_generation"}]
|
||||
assert body["tool_choice"] == "auto"
|
||||
assert body["store"] is False
|
||||
assert body["stream"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_strips_model_prefix(monkeypatch) -> None:
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
@dataclass
|
||||
class FakeToken:
|
||||
account_id: str = "acct-123"
|
||||
access: str = "oauth-token"
|
||||
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||
|
||||
fake = FakeClient(FakeResponse({}, sse_lines=[
|
||||
f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":"{PNG_DATA_URL}"}}}}',
|
||||
"",
|
||||
'data: [DONE]',
|
||||
"",
|
||||
]))
|
||||
client = CodexImageGenerationClient(
|
||||
api_key=None, client=fake # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(prompt="draw", model="openai-codex/gpt-5.4")
|
||||
|
||||
assert fake.calls[0]["json"]["model"] == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_requires_oauth(monkeypatch) -> None:
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
raise RuntimeError("no token")
|
||||
|
||||
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||
|
||||
client = CodexImageGenerationClient(api_key=None)
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="OAuth token"):
|
||||
await client.generate(prompt="draw", model="gpt-5.4")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_no_images_raises(monkeypatch) -> None:
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
@dataclass
|
||||
class FakeToken:
|
||||
account_id: str = "acct-123"
|
||||
access: str = "oauth-token"
|
||||
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||
|
||||
fake = FakeClient(FakeResponse({}, sse_lines=[
|
||||
'data: {"type":"response.completed","response":{"status":"completed"}}',
|
||||
"",
|
||||
'data: [DONE]',
|
||||
"",
|
||||
]))
|
||||
client = CodexImageGenerationClient(
|
||||
api_key=None, client=fake # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="returned no images"):
|
||||
await client.generate(prompt="draw", model="gpt-5.4")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_extracts_text_content(monkeypatch) -> None:
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
@dataclass
|
||||
class FakeToken:
|
||||
account_id: str = "acct-123"
|
||||
access: str = "oauth-token"
|
||||
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||
|
||||
fake = FakeClient(FakeResponse({}, sse_lines=[
|
||||
'data: {"type":"response.output_text.delta","delta":"Here "}',
|
||||
"",
|
||||
'data: {"type":"response.output_text.delta","delta":"is your cat image."}',
|
||||
"",
|
||||
f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":"{PNG_DATA_URL}"}}}}',
|
||||
"",
|
||||
'data: [DONE]',
|
||||
"",
|
||||
]))
|
||||
client = CodexImageGenerationClient(
|
||||
api_key=None, client=fake # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(prompt="draw a cat", model="gpt-5.4")
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
assert response.content == "Here is your cat image."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_json_result_format(monkeypatch) -> None:
|
||||
"""image_generation_call result can be a dict with image_url key."""
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
@dataclass
|
||||
class FakeToken:
|
||||
account_id: str = "acct-123"
|
||||
access: str = "oauth-token"
|
||||
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||
|
||||
fake = FakeClient(FakeResponse({}, sse_lines=[
|
||||
f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":{{"image_url":"{PNG_DATA_URL}"}}}}}}',
|
||||
"",
|
||||
'data: [DONE]',
|
||||
"",
|
||||
]))
|
||||
client = CodexImageGenerationClient(
|
||||
api_key=None, client=fake # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(prompt="draw", model="gpt-5.4")
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_no_images_raises() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": []}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="returned no images"):
|
||||
await client.generate(prompt="draw", model="dall-e-3")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user