Merge PR #3646: fix(transcription): retry Whisper calls on transient failures

fix(transcription): retry Whisper calls on transient failures
This commit is contained in:
Xubin Ren 2026-05-06 21:52:33 +08:00 committed by GitHub
commit d8fd4c80bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 425 additions and 43 deletions

View File

@ -1,11 +1,121 @@
"""Voice transcription providers (Groq and OpenAI Whisper).""" """Voice transcription providers (Groq and OpenAI Whisper)."""
import asyncio
import os import os
from pathlib import Path from pathlib import Path
import httpx import httpx
from loguru import logger from loguru import logger
# Up to 3 retries (4 attempts total) with exponential backoff on transient
# failures. Whisper endpoints occasionally return 502/503 under load, and
# mobile-network transcription callers hit sporadic connect/read errors.
# Without this, a voice message silently becomes the empty string.
_MAX_RETRIES = 3
_BACKOFF_S = (1.0, 2.0, 4.0)
_RETRYABLE_STATUS = {408, 429, 500, 502, 503, 504}
_RETRYABLE_EXCEPTIONS = (
httpx.TimeoutException,
httpx.ConnectError,
httpx.ReadError,
httpx.WriteError,
httpx.RemoteProtocolError,
)
async def _post_transcription_with_retry(
url: str,
*,
api_key: str | None,
path: Path,
model: str,
provider_label: str,
language: str | None = None,
) -> str:
"""POST an audio file for transcription, retrying on transient errors.
Retries on connect/read/timeout failures and on 408/429/5xx responses.
Other errors (including 4xx such as 401/403) return "" immediately the
caller's config is wrong and retrying only wastes quota.
When ``language`` is provided, it is forwarded as the ``language``
multipart field on every attempt (the dict is rebuilt per attempt so the
same field is present on retries).
"""
try:
data = path.read_bytes()
except OSError as e:
logger.error("{} transcription error: cannot read audio file: {}", provider_label, e)
return ""
headers = {"Authorization": f"Bearer {api_key}"}
async with httpx.AsyncClient() as client:
for attempt in range(_MAX_RETRIES + 1):
files = {
"file": (path.name, data),
"model": (None, model),
}
if language:
files["language"] = (None, language)
try:
response = await client.post(url, headers=headers, files=files, timeout=60.0)
except _RETRYABLE_EXCEPTIONS as e:
if attempt < _MAX_RETRIES:
logger.warning(
"{} transcription transient error (attempt {}/{}): {}",
provider_label,
attempt + 1,
_MAX_RETRIES + 1,
e,
)
await asyncio.sleep(_BACKOFF_S[attempt])
continue
logger.error(
"{} transcription error after {} attempts: {}",
provider_label,
_MAX_RETRIES + 1,
e,
)
return ""
except Exception as e:
logger.error("{} transcription error: {}", provider_label, e)
return ""
if response.status_code in _RETRYABLE_STATUS and attempt < _MAX_RETRIES:
logger.warning(
"{} transcription transient HTTP {} (attempt {}/{})",
provider_label,
response.status_code,
attempt + 1,
_MAX_RETRIES + 1,
)
await asyncio.sleep(_BACKOFF_S[attempt])
continue
try:
response.raise_for_status()
except Exception as e:
logger.error("{} transcription error: {}", provider_label, e)
return ""
try:
payload = response.json()
except Exception as e:
logger.error(
"{} transcription error: malformed response body: {}",
provider_label,
e,
)
return ""
if not isinstance(payload, dict):
logger.error(
"{} transcription error: unexpected response shape: {!r}",
provider_label,
type(payload).__name__,
)
return ""
return payload.get("text", "")
class OpenAITranscriptionProvider: class OpenAITranscriptionProvider:
"""Voice transcription provider using OpenAI's Whisper API.""" """Voice transcription provider using OpenAI's Whisper API."""
@ -32,21 +142,14 @@ class OpenAITranscriptionProvider:
if not path.exists(): if not path.exists():
logger.error("Audio file not found: {}", file_path) logger.error("Audio file not found: {}", file_path)
return "" return ""
try: return await _post_transcription_with_retry(
async with httpx.AsyncClient() as client: self.api_url,
with open(path, "rb") as f: api_key=self.api_key,
files = {"file": (path.name, f), "model": (None, "whisper-1")} path=path,
if self.language: model="whisper-1",
files["language"] = (None, self.language) provider_label="OpenAI",
headers = {"Authorization": f"Bearer {self.api_key}"} language=self.language,
response = await client.post( )
self.api_url, headers=headers, files=files, timeout=60.0,
)
response.raise_for_status()
return response.json().get("text", "")
except Exception:
logger.exception("OpenAI transcription error")
return ""
class GroqTranscriptionProvider: class GroqTranscriptionProvider:
@ -63,7 +166,11 @@ class GroqTranscriptionProvider:
language: str | None = None, language: str | None = None,
): ):
self.api_key = api_key or os.environ.get("GROQ_API_KEY") self.api_key = api_key or os.environ.get("GROQ_API_KEY")
self.api_url = api_base or os.environ.get("GROQ_BASE_URL") or "https://api.groq.com/openai/v1/audio/transcriptions" self.api_url = (
api_base
or os.environ.get("GROQ_BASE_URL")
or "https://api.groq.com/openai/v1/audio/transcriptions"
)
self.language = language or None self.language = language or None
async def transcribe(self, file_path: str | Path) -> str: async def transcribe(self, file_path: str | Path) -> str:
@ -85,30 +192,11 @@ class GroqTranscriptionProvider:
logger.error("Audio file not found: {}", file_path) logger.error("Audio file not found: {}", file_path)
return "" return ""
try: return await _post_transcription_with_retry(
async with httpx.AsyncClient() as client: self.api_url,
with open(path, "rb") as f: api_key=self.api_key,
files = { path=path,
"file": (path.name, f), model="whisper-large-v3",
"model": (None, "whisper-large-v3"), provider_label="Groq",
} language=self.language,
if self.language: )
files["language"] = (None, self.language)
headers = {
"Authorization": f"Bearer {self.api_key}",
}
response = await client.post(
self.api_url,
headers=headers,
files=files,
timeout=60.0
)
response.raise_for_status()
data = response.json()
return data.get("text", "")
except Exception:
logger.exception("Groq transcription error")
return ""

View File

@ -342,6 +342,8 @@ async def test_base_channel_passes_language_to_groq_transcription_provider():
class _StubResponse: class _StubResponse:
status_code = 200
def raise_for_status(self): def raise_for_status(self):
return None return None

View File

@ -0,0 +1,292 @@
"""Tests for transcription retry behavior on transient errors (B10)."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from nanobot.providers.transcription import GroqTranscriptionProvider, OpenAITranscriptionProvider
@pytest.fixture
def audio_file(tmp_path: Path) -> Path:
p = tmp_path / "voice.ogg"
p.write_bytes(b"OggS\x00fake-audio-bytes")
return p
def _response(status: int, payload: dict[str, object] | None = None) -> httpx.Response:
request = httpx.Request("POST", "https://example.test/audio/transcriptions")
return httpx.Response(status_code=status, json=payload or {}, request=request)
def _raw_response(status: int, content: bytes) -> httpx.Response:
"""Build a Response with a raw, possibly-malformed body (bypasses json= encoding)."""
request = httpx.Request("POST", "https://example.test/audio/transcriptions")
return httpx.Response(status_code=status, content=content, request=request)
# ---------------------------------------------------------------------------
# OpenAI provider — retry on transient HTTP + network errors
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_openai_retries_on_5xx_then_succeeds(audio_file: Path) -> None:
"""Transient 503 is retried; a subsequent 200 yields the text."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(side_effect=[_response(503), _response(200, {"text": "hello"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "hello"
assert post.await_count == 2
@pytest.mark.asyncio
async def test_openai_retries_on_429_then_succeeds(audio_file: Path) -> None:
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(side_effect=[_response(429), _response(200, {"text": "rate ok"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "rate ok"
assert post.await_count == 2
@pytest.mark.asyncio
async def test_openai_retries_on_connect_error(audio_file: Path) -> None:
"""Network-level transient errors are retried."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(side_effect=[httpx.ConnectError("boom"), _response(200, {"text": "ok"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "ok"
assert post.await_count == 2
@pytest.mark.asyncio
async def test_openai_does_not_retry_on_auth_error(audio_file: Path) -> None:
"""401 is the user's misconfiguration — retrying wastes time and rate-limit quota."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(return_value=_response(401, {"error": {"message": "bad key"}}))
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == ""
assert post.await_count == 1
@pytest.mark.asyncio
async def test_openai_gives_up_after_max_attempts(audio_file: Path) -> None:
"""Persistent 503 returns "" after the final retry — never hangs."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(return_value=_response(503))
sleep = AsyncMock()
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", sleep):
result = await provider.transcribe(audio_file)
assert result == ""
# 4 attempts total (initial + 3 retries) with 3 sleeps between them.
assert post.await_count == 4
assert sleep.await_count == 3
@pytest.mark.asyncio
async def test_openai_backoff_grows_exponentially(audio_file: Path) -> None:
"""Verify the backoff schedule is exponential (1s, 2s, 4s)."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(return_value=_response(503))
sleep = AsyncMock()
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", sleep):
await provider.transcribe(audio_file)
delays = [call.args[0] for call in sleep.await_args_list]
assert delays == [1.0, 2.0, 4.0]
# ---------------------------------------------------------------------------
# Groq provider — same semantics (both go through the shared helper)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_groq_retries_on_5xx_then_succeeds(audio_file: Path) -> None:
provider = GroqTranscriptionProvider(api_key="gsk-test")
post = AsyncMock(side_effect=[_response(502), _response(200, {"text": "groq ok"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "groq ok"
assert post.await_count == 2
@pytest.mark.asyncio
async def test_groq_does_not_retry_on_auth_error(audio_file: Path) -> None:
provider = GroqTranscriptionProvider(api_key="gsk-test")
post = AsyncMock(return_value=_response(403))
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == ""
assert post.await_count == 1
# ---------------------------------------------------------------------------
# Regression: missing file / missing key must still short-circuit
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_openai_missing_api_key_short_circuits(audio_file: Path) -> None:
"""Missing API key short-circuits before any HTTP call, even when the file exists."""
with patch.dict("os.environ", {}, clear=True):
provider = OpenAITranscriptionProvider(api_key=None)
post = AsyncMock()
with patch("httpx.AsyncClient.post", post):
assert await provider.transcribe(audio_file) == ""
assert post.await_count == 0
@pytest.mark.asyncio
async def test_openai_missing_file_short_circuits() -> None:
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock()
with patch("httpx.AsyncClient.post", post):
assert await provider.transcribe("/nonexistent/path/voice.ogg") == ""
assert post.await_count == 0
@pytest.mark.asyncio
async def test_returns_empty_when_file_unreadable(audio_file: Path) -> None:
"""Existing file that cannot be read (PermissionError/OSError): "" with no HTTP attempt."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock()
with patch("pathlib.Path.read_bytes", side_effect=PermissionError("denied")), patch(
"httpx.AsyncClient.post", post
):
result = await provider.transcribe(audio_file)
assert result == ""
assert post.await_count == 0
# ---------------------------------------------------------------------------
# language: forwarded through the helper to the multipart body, on every attempt
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"provider_cls,language",
[(OpenAITranscriptionProvider, "en"), (GroqTranscriptionProvider, "ko")],
ids=["openai", "groq"],
)
@pytest.mark.asyncio
async def test_provider_forwards_language_in_multipart(
audio_file: Path, provider_cls: type, language: str
) -> None:
"""When ``language`` is set, the helper sends it as a multipart field."""
provider = provider_cls(api_key="k", language=language)
post = AsyncMock(return_value=_response(200, {"text": "ok"}))
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "ok"
assert post.await_count == 1
files = post.await_args_list[0].kwargs["files"]
assert files["language"] == (None, language)
@pytest.mark.parametrize(
"provider_cls",
[OpenAITranscriptionProvider, GroqTranscriptionProvider],
ids=["openai", "groq"],
)
@pytest.mark.asyncio
async def test_provider_omits_language_when_unset(
audio_file: Path, provider_cls: type
) -> None:
"""When ``language`` is None, no ``language`` field is sent."""
provider = provider_cls(api_key="k")
post = AsyncMock(return_value=_response(200, {"text": "ok"}))
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "ok"
assert post.await_count == 1
files = post.await_args_list[0].kwargs["files"]
assert "language" not in files
@pytest.mark.asyncio
async def test_language_survives_retry(audio_file: Path) -> None:
"""Regression: language must be present on every retry attempt, not just the first."""
provider = OpenAITranscriptionProvider(api_key="sk-test", language="ja")
post = AsyncMock(side_effect=[_response(503), _response(200, {"text": "konnichiwa"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "konnichiwa"
assert post.await_count == 2
for call in post.await_args_list:
assert call.kwargs["files"]["language"] == (None, "ja")
# ---------------------------------------------------------------------------
# Malformed / unexpected response bodies must short-circuit, not escape
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_returns_empty_on_malformed_json_body(audio_file: Path) -> None:
"""200 with invalid JSON: log and return "" immediately (no retry, no exception)."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(return_value=_raw_response(200, b"<html>not json</html>"))
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == ""
assert post.await_count == 1
@pytest.mark.asyncio
async def test_returns_empty_on_non_dict_json_body(audio_file: Path) -> None:
"""200 with a JSON array (not dict): no AttributeError leak; return "" immediately."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(return_value=_raw_response(200, b"[]"))
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == ""
assert post.await_count == 1
# ---------------------------------------------------------------------------
# Pin the full advertised retry contract: all retryable statuses + exceptions
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("status", [408, 429, 500, 502, 503, 504])
@pytest.mark.asyncio
async def test_retries_on_every_advertised_transient_status(
audio_file: Path, status: int
) -> None:
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(side_effect=[_response(status), _response(200, {"text": "ok"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "ok"
assert post.await_count == 2
@pytest.mark.parametrize(
"exc",
[
httpx.TimeoutException("t"),
httpx.ConnectError("c"),
httpx.ReadError("r"),
httpx.WriteError("w"),
httpx.RemoteProtocolError("p"),
],
ids=["timeout", "connect", "read", "write", "remote_protocol"],
)
@pytest.mark.asyncio
async def test_retries_on_every_advertised_transient_exception(
audio_file: Path, exc: Exception
) -> None:
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(side_effect=[exc, _response(200, {"text": "recovered"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "recovered"
assert post.await_count == 2