mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
fix(transcription): retry Whisper calls and guard malformed responses
A single transient failure between the agent and an OpenAI/Groq Whisper
endpoint currently vanishes as `return ""` in transcribe(). The voice
message arrives as the empty string and there is no way to tell real
silence apart from a failed upload. A malformed but successful response
body is even worse: the JSON-decode error escapes the helper unhandled.
Add a shared `_post_transcription_with_retry` used by both providers.
Retry behaviour:
- exponential backoff 1s -> 2s -> 4s, up to 3 retries (4 attempts)
- retryable HTTP statuses: 408, 429, 500, 502, 503, 504
- retryable exceptions: TimeoutException, ConnectError, ReadError,
WriteError, RemoteProtocolError
Non-transient failures short-circuit to "" on the first attempt --
retrying a misconfigured key or a broken upload only burns rate-limit
quota. Branches that short-circuit:
- missing API key, missing audio file
- file-read errors (PermissionError, OSError) on the audio path,
preserving the nightly contract for direct provider callers
- HTTP auth/4xx body issues via raise_for_status()
- response.json() parse failures
- non-dict JSON payloads
Sharing one helper means OpenAI and Groq cannot drift apart silently.
Thread `language` through the helper. The multipart files dict is rebuilt
inside the per-attempt loop, so when a caller sets self.language the
`language` field is sent on every attempt -- not just the first.
Tests cover:
- every advertised retryable status and exception, parameterized
- language present on attempts 1 and 2 of a 503->200 sequence
- language absent when unset; present when set (both providers)
- malformed JSON body and non-dict JSON body short-circuit to ""
- PermissionError on file read short-circuits with no HTTP attempt
- max-attempts give-up, exponential-backoff schedule, auth no-retry,
missing-key / missing-file short-circuit
Test stub fix: the _StubResponse in tests/channels/test_channel_plugins.py
declared no status_code, which the new helper reads for retry classification.
Set status_code = 200 so the stub advertises the successful response that
those tests already simulate. Also moved the two transcription-provider
imports to the top of that file (previously placed mid-file) so the file
is ruff-clean (E402).
This commit is contained in:
parent
e54fbfeb2a
commit
7ebf611be8
@ -1,11 +1,123 @@
|
|||||||
"""Voice transcription providers (Groq and OpenAI Whisper)."""
|
"""Voice transcription providers (Groq and OpenAI Whisper)."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
# Up to 3 retries (4 attempts total) with exponential backoff on transient
|
||||||
|
# failures. Whisper endpoints occasionally return 502/503 under load, and
|
||||||
|
# mobile-network transcription callers hit sporadic connect/read errors.
|
||||||
|
# Without this, a voice message silently becomes the empty string.
|
||||||
|
_MAX_RETRIES = 3
|
||||||
|
_BACKOFF_S = (1.0, 2.0, 4.0)
|
||||||
|
_RETRYABLE_STATUS = {408, 429, 500, 502, 503, 504}
|
||||||
|
_RETRYABLE_EXCEPTIONS = (
|
||||||
|
httpx.TimeoutException,
|
||||||
|
httpx.ConnectError,
|
||||||
|
httpx.ReadError,
|
||||||
|
httpx.WriteError,
|
||||||
|
httpx.RemoteProtocolError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _post_transcription_with_retry(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
api_key: str,
|
||||||
|
path: Path,
|
||||||
|
model: str,
|
||||||
|
provider_label: str,
|
||||||
|
language: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""POST an audio file for transcription, retrying on transient errors.
|
||||||
|
|
||||||
|
Retries on connect/read/timeout failures and on 408/429/5xx responses.
|
||||||
|
Other errors (including 4xx such as 401/403) return "" immediately — the
|
||||||
|
caller's config is wrong and retrying only wastes quota.
|
||||||
|
|
||||||
|
When ``language`` is provided, it is forwarded as the ``language``
|
||||||
|
multipart field on every attempt (the dict is rebuilt per attempt so the
|
||||||
|
same field is present on retries).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = path.read_bytes()
|
||||||
|
except OSError as e:
|
||||||
|
logger.error("{} transcription error: cannot read audio file: {}", provider_label, e)
|
||||||
|
return ""
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
for attempt in range(_MAX_RETRIES + 1):
|
||||||
|
files = {
|
||||||
|
"file": (path.name, data),
|
||||||
|
"model": (None, model),
|
||||||
|
}
|
||||||
|
if language:
|
||||||
|
files["language"] = (None, language)
|
||||||
|
try:
|
||||||
|
response = await client.post(url, headers=headers, files=files, timeout=60.0)
|
||||||
|
except _RETRYABLE_EXCEPTIONS as e:
|
||||||
|
if attempt < _MAX_RETRIES:
|
||||||
|
logger.warning(
|
||||||
|
"{} transcription transient error (attempt {}/{}): {}",
|
||||||
|
provider_label,
|
||||||
|
attempt + 1,
|
||||||
|
_MAX_RETRIES + 1,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(_BACKOFF_S[attempt])
|
||||||
|
continue
|
||||||
|
logger.error(
|
||||||
|
"{} transcription error after {} attempts: {}",
|
||||||
|
provider_label,
|
||||||
|
_MAX_RETRIES + 1,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("{} transcription error: {}", provider_label, e)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if response.status_code in _RETRYABLE_STATUS and attempt < _MAX_RETRIES:
|
||||||
|
logger.warning(
|
||||||
|
"{} transcription transient HTTP {} (attempt {}/{})",
|
||||||
|
provider_label,
|
||||||
|
response.status_code,
|
||||||
|
attempt + 1,
|
||||||
|
_MAX_RETRIES + 1,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(_BACKOFF_S[attempt])
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("{} transcription error: {}", provider_label, e)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = response.json()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"{} transcription error: malformed response body: {}",
|
||||||
|
provider_label,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
logger.error(
|
||||||
|
"{} transcription error: unexpected response shape: {!r}",
|
||||||
|
provider_label,
|
||||||
|
type(payload).__name__,
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
return payload.get("text", "")
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class OpenAITranscriptionProvider:
|
class OpenAITranscriptionProvider:
|
||||||
"""Voice transcription provider using OpenAI's Whisper API."""
|
"""Voice transcription provider using OpenAI's Whisper API."""
|
||||||
@ -32,21 +144,14 @@ class OpenAITranscriptionProvider:
|
|||||||
if not path.exists():
|
if not path.exists():
|
||||||
logger.error("Audio file not found: {}", file_path)
|
logger.error("Audio file not found: {}", file_path)
|
||||||
return ""
|
return ""
|
||||||
try:
|
return await _post_transcription_with_retry(
|
||||||
async with httpx.AsyncClient() as client:
|
self.api_url,
|
||||||
with open(path, "rb") as f:
|
api_key=self.api_key,
|
||||||
files = {"file": (path.name, f), "model": (None, "whisper-1")}
|
path=path,
|
||||||
if self.language:
|
model="whisper-1",
|
||||||
files["language"] = (None, self.language)
|
provider_label="OpenAI",
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
language=self.language,
|
||||||
response = await client.post(
|
)
|
||||||
self.api_url, headers=headers, files=files, timeout=60.0,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json().get("text", "")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("OpenAI transcription error: {}", e)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
class GroqTranscriptionProvider:
|
class GroqTranscriptionProvider:
|
||||||
@ -63,7 +168,11 @@ class GroqTranscriptionProvider:
|
|||||||
language: str | None = None,
|
language: str | None = None,
|
||||||
):
|
):
|
||||||
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
|
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
|
||||||
self.api_url = api_base or os.environ.get("GROQ_BASE_URL") or "https://api.groq.com/openai/v1/audio/transcriptions"
|
self.api_url = (
|
||||||
|
api_base
|
||||||
|
or os.environ.get("GROQ_BASE_URL")
|
||||||
|
or "https://api.groq.com/openai/v1/audio/transcriptions"
|
||||||
|
)
|
||||||
self.language = language or None
|
self.language = language or None
|
||||||
|
|
||||||
async def transcribe(self, file_path: str | Path) -> str:
|
async def transcribe(self, file_path: str | Path) -> str:
|
||||||
@ -85,30 +194,11 @@ class GroqTranscriptionProvider:
|
|||||||
logger.error("Audio file not found: {}", file_path)
|
logger.error("Audio file not found: {}", file_path)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
try:
|
return await _post_transcription_with_retry(
|
||||||
async with httpx.AsyncClient() as client:
|
self.api_url,
|
||||||
with open(path, "rb") as f:
|
api_key=self.api_key,
|
||||||
files = {
|
path=path,
|
||||||
"file": (path.name, f),
|
model="whisper-large-v3",
|
||||||
"model": (None, "whisper-large-v3"),
|
provider_label="Groq",
|
||||||
}
|
language=self.language,
|
||||||
if self.language:
|
)
|
||||||
files["language"] = (None, self.language)
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
self.api_url,
|
|
||||||
headers=headers,
|
|
||||||
files=files,
|
|
||||||
timeout=60.0
|
|
||||||
)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
return data.get("text", "")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Groq transcription error: {}", e)
|
|
||||||
return ""
|
|
||||||
|
|||||||
@ -342,6 +342,8 @@ async def test_base_channel_passes_language_to_groq_transcription_provider():
|
|||||||
|
|
||||||
|
|
||||||
class _StubResponse:
|
class _StubResponse:
|
||||||
|
status_code = 200
|
||||||
|
|
||||||
def raise_for_status(self):
|
def raise_for_status(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
293
tests/providers/test_transcription.py
Normal file
293
tests/providers/test_transcription.py
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
"""Tests for transcription retry behavior on transient errors (B10)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.providers.transcription import GroqTranscriptionProvider, OpenAITranscriptionProvider
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def audio_file(tmp_path: Path) -> Path:
|
||||||
|
p = tmp_path / "voice.ogg"
|
||||||
|
p.write_bytes(b"OggS\x00fake-audio-bytes")
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def _response(status: int, payload: dict[str, object] | None = None) -> httpx.Response:
|
||||||
|
request = httpx.Request("POST", "https://example.test/audio/transcriptions")
|
||||||
|
return httpx.Response(status_code=status, json=payload or {}, request=request)
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_response(status: int, content: bytes) -> httpx.Response:
|
||||||
|
"""Build a Response with a raw, possibly-malformed body (bypasses json= encoding)."""
|
||||||
|
request = httpx.Request("POST", "https://example.test/audio/transcriptions")
|
||||||
|
return httpx.Response(status_code=status, content=content, request=request)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# OpenAI provider — retry on transient HTTP + network errors
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_retries_on_5xx_then_succeeds(audio_file: Path) -> None:
|
||||||
|
"""Transient 503 is retried; a subsequent 200 yields the text."""
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test")
|
||||||
|
post = AsyncMock(side_effect=[_response(503), _response(200, {"text": "hello"})])
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == "hello"
|
||||||
|
assert post.await_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_retries_on_429_then_succeeds(audio_file: Path) -> None:
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test")
|
||||||
|
post = AsyncMock(side_effect=[_response(429), _response(200, {"text": "rate ok"})])
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == "rate ok"
|
||||||
|
assert post.await_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_retries_on_connect_error(audio_file: Path) -> None:
|
||||||
|
"""Network-level transient errors are retried."""
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test")
|
||||||
|
post = AsyncMock(side_effect=[httpx.ConnectError("boom"), _response(200, {"text": "ok"})])
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == "ok"
|
||||||
|
assert post.await_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_does_not_retry_on_auth_error(audio_file: Path) -> None:
|
||||||
|
"""401 is the user's misconfiguration — retrying wastes time and rate-limit quota."""
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test")
|
||||||
|
post = AsyncMock(return_value=_response(401, {"error": {"message": "bad key"}}))
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == ""
|
||||||
|
assert post.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_gives_up_after_max_attempts(audio_file: Path) -> None:
|
||||||
|
"""Persistent 503 returns "" after the final retry — never hangs."""
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test")
|
||||||
|
post = AsyncMock(return_value=_response(503))
|
||||||
|
sleep = AsyncMock()
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", sleep):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == ""
|
||||||
|
# 4 attempts total (initial + 3 retries) with 3 sleeps between them.
|
||||||
|
assert post.await_count == 4
|
||||||
|
assert sleep.await_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_backoff_grows_exponentially(audio_file: Path) -> None:
|
||||||
|
"""Verify the backoff schedule is exponential (1s, 2s, 4s)."""
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test")
|
||||||
|
post = AsyncMock(return_value=_response(503))
|
||||||
|
sleep = AsyncMock()
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", sleep):
|
||||||
|
await provider.transcribe(audio_file)
|
||||||
|
delays = [call.args[0] for call in sleep.await_args_list]
|
||||||
|
assert delays == [1.0, 2.0, 4.0]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Groq provider — same semantics (both go through the shared helper)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_groq_retries_on_5xx_then_succeeds(audio_file: Path) -> None:
|
||||||
|
provider = GroqTranscriptionProvider(api_key="gsk-test")
|
||||||
|
post = AsyncMock(side_effect=[_response(502), _response(200, {"text": "groq ok"})])
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == "groq ok"
|
||||||
|
assert post.await_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_groq_does_not_retry_on_auth_error(audio_file: Path) -> None:
|
||||||
|
provider = GroqTranscriptionProvider(api_key="gsk-test")
|
||||||
|
post = AsyncMock(return_value=_response(403))
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == ""
|
||||||
|
assert post.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Regression: missing file / missing key must still short-circuit
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_missing_api_key_short_circuits(tmp_path: Path) -> None:
|
||||||
|
provider = OpenAITranscriptionProvider(api_key=None)
|
||||||
|
# Ensure env var doesn't accidentally satisfy it.
|
||||||
|
with patch.dict("os.environ", {}, clear=True):
|
||||||
|
provider = OpenAITranscriptionProvider(api_key=None)
|
||||||
|
post = AsyncMock()
|
||||||
|
with patch("httpx.AsyncClient.post", post):
|
||||||
|
assert await provider.transcribe(tmp_path / "voice.ogg") == ""
|
||||||
|
assert post.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_missing_file_short_circuits() -> None:
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test")
|
||||||
|
post = AsyncMock()
|
||||||
|
with patch("httpx.AsyncClient.post", post):
|
||||||
|
assert await provider.transcribe("/nonexistent/path/voice.ogg") == ""
|
||||||
|
assert post.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_empty_when_file_unreadable(audio_file: Path) -> None:
|
||||||
|
"""Existing file that cannot be read (PermissionError/OSError): "" with no HTTP attempt."""
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test")
|
||||||
|
post = AsyncMock()
|
||||||
|
with patch.object(Path, "read_bytes", side_effect=PermissionError("denied")), patch(
|
||||||
|
"httpx.AsyncClient.post", post
|
||||||
|
):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == ""
|
||||||
|
assert post.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# language: forwarded through the helper to the multipart body, on every attempt
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls,language",
|
||||||
|
[(OpenAITranscriptionProvider, "en"), (GroqTranscriptionProvider, "ko")],
|
||||||
|
ids=["openai", "groq"],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_forwards_language_in_multipart(
|
||||||
|
audio_file: Path, provider_cls: type, language: str
|
||||||
|
) -> None:
|
||||||
|
"""When ``language`` is set, the helper sends it as a multipart field."""
|
||||||
|
provider = provider_cls(api_key="k", language=language)
|
||||||
|
post = AsyncMock(return_value=_response(200, {"text": "ok"}))
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == "ok"
|
||||||
|
assert post.await_count == 1
|
||||||
|
files = post.await_args_list[0].kwargs["files"]
|
||||||
|
assert files["language"] == (None, language)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls",
|
||||||
|
[OpenAITranscriptionProvider, GroqTranscriptionProvider],
|
||||||
|
ids=["openai", "groq"],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_omits_language_when_unset(
|
||||||
|
audio_file: Path, provider_cls: type
|
||||||
|
) -> None:
|
||||||
|
"""When ``language`` is None, no ``language`` field is sent."""
|
||||||
|
provider = provider_cls(api_key="k")
|
||||||
|
post = AsyncMock(return_value=_response(200, {"text": "ok"}))
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == "ok"
|
||||||
|
assert post.await_count == 1
|
||||||
|
files = post.await_args_list[0].kwargs["files"]
|
||||||
|
assert "language" not in files
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_language_survives_retry(audio_file: Path) -> None:
|
||||||
|
"""Regression: language must be present on every retry attempt, not just the first."""
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test", language="ja")
|
||||||
|
post = AsyncMock(side_effect=[_response(503), _response(200, {"text": "konnichiwa"})])
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == "konnichiwa"
|
||||||
|
assert post.await_count == 2
|
||||||
|
for call in post.await_args_list:
|
||||||
|
assert call.kwargs["files"]["language"] == (None, "ja")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Malformed / unexpected response bodies must short-circuit, not escape
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_empty_on_malformed_json_body(audio_file: Path) -> None:
|
||||||
|
"""200 with invalid JSON: log and return "" immediately (no retry, no exception)."""
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test")
|
||||||
|
post = AsyncMock(return_value=_raw_response(200, b"<html>not json</html>"))
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == ""
|
||||||
|
assert post.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_empty_on_non_dict_json_body(audio_file: Path) -> None:
|
||||||
|
"""200 with a JSON array (not dict): no AttributeError leak; return "" immediately."""
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test")
|
||||||
|
post = AsyncMock(return_value=_raw_response(200, b"[]"))
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == ""
|
||||||
|
assert post.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pin the full advertised retry contract: all retryable statuses + exceptions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("status", [408, 429, 500, 502, 503, 504])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retries_on_every_advertised_transient_status(
|
||||||
|
audio_file: Path, status: int
|
||||||
|
) -> None:
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test")
|
||||||
|
post = AsyncMock(side_effect=[_response(status), _response(200, {"text": "ok"})])
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == "ok"
|
||||||
|
assert post.await_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"exc",
|
||||||
|
[
|
||||||
|
httpx.TimeoutException("t"),
|
||||||
|
httpx.ConnectError("c"),
|
||||||
|
httpx.ReadError("r"),
|
||||||
|
httpx.WriteError("w"),
|
||||||
|
httpx.RemoteProtocolError("p"),
|
||||||
|
],
|
||||||
|
ids=["timeout", "connect", "read", "write", "remote_protocol"],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retries_on_every_advertised_transient_exception(
|
||||||
|
audio_file: Path, exc: Exception
|
||||||
|
) -> None:
|
||||||
|
provider = OpenAITranscriptionProvider(api_key="sk-test")
|
||||||
|
post = AsyncMock(side_effect=[exc, _response(200, {"text": "recovered"})])
|
||||||
|
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
assert result == "recovered"
|
||||||
|
assert post.await_count == 2
|
||||||
Loading…
x
Reference in New Issue
Block a user