mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-24 10:32:45 +00:00
refactor image provider HTTP handling
This commit is contained in:
parent
e6587a8d8e
commit
72f999f8f7
@ -219,7 +219,10 @@ class ImageGenerationProvider(ABC):
|
|||||||
*,
|
*,
|
||||||
headers: dict[str, str],
|
headers: dict[str, str],
|
||||||
body: dict[str, Any],
|
body: dict[str, Any],
|
||||||
|
client: httpx.AsyncClient | None = None,
|
||||||
) -> httpx.Response:
|
) -> httpx.Response:
|
||||||
|
if client is not None:
|
||||||
|
return await client.post(url, headers=headers, json=body)
|
||||||
if self._client is not None:
|
if self._client is not None:
|
||||||
return await self._client.post(url, headers=headers, json=body)
|
return await self._client.post(url, headers=headers, json=body)
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as c:
|
async with httpx.AsyncClient(timeout=self.timeout) as c:
|
||||||
@ -390,10 +393,11 @@ class AIHubMixImageGenerationClient(ImageGenerationProvider):
|
|||||||
model_path = _aihubmix_model_path(model)
|
model_path = _aihubmix_model_path(model)
|
||||||
url = f"{self.api_base}/models/{model_path}/predictions"
|
url = f"{self.api_base}/models/{model_path}/predictions"
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await self._http_post(
|
||||||
url,
|
url,
|
||||||
headers={**headers, "Content-Type": "application/json"},
|
headers={**headers, "Content-Type": "application/json"},
|
||||||
json=body,
|
body=body,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException as exc:
|
except httpx.TimeoutException as exc:
|
||||||
raise ImageGenerationError("AIHubMix image generation timed out") from exc
|
raise ImageGenerationError("AIHubMix image generation timed out") from exc
|
||||||
@ -706,22 +710,16 @@ class MiniMaxImageGenerationClient(ImageGenerationProvider):
|
|||||||
|
|
||||||
body.update(self.extra_body)
|
body.update(self.extra_body)
|
||||||
|
|
||||||
client = self._client or httpx.AsyncClient(timeout=self.timeout)
|
return await self._generate_with_client(body, headers)
|
||||||
try:
|
|
||||||
return await self._generate_with_client(client, body, headers)
|
|
||||||
finally:
|
|
||||||
if self._client is None:
|
|
||||||
await client.aclose()
|
|
||||||
|
|
||||||
async def _generate_with_client(
|
async def _generate_with_client(
|
||||||
self,
|
self,
|
||||||
client: httpx.AsyncClient,
|
|
||||||
body: dict[str, Any],
|
body: dict[str, Any],
|
||||||
headers: dict[str, str],
|
headers: dict[str, str],
|
||||||
) -> GeneratedImageResponse:
|
) -> GeneratedImageResponse:
|
||||||
url = f"{self.api_base}/image_generation"
|
url = f"{self.api_base}/image_generation"
|
||||||
try:
|
try:
|
||||||
response = await client.post(url, headers=headers, json=body)
|
response = await self._http_post(url, headers=headers, body=body)
|
||||||
except httpx.TimeoutException as exc:
|
except httpx.TimeoutException as exc:
|
||||||
raise ImageGenerationError("MiniMax image generation timed out") from exc
|
raise ImageGenerationError("MiniMax image generation timed out") from exc
|
||||||
except httpx.RequestError as exc:
|
except httpx.RequestError as exc:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user