fix(transcription): retry Whisper calls and guard malformed responses

A single transient failure between the agent and an OpenAI/Groq Whisper
endpoint currently vanishes as `return ""` in transcribe(). The voice
message arrives as the empty string and there is no way to tell real
silence apart from a failed upload. A malformed but successful response
body is even worse: the JSON-decode error escapes the helper unhandled.

Add a shared `_post_transcription_with_retry` used by both providers.

Retry behaviour:
  - exponential backoff 1s -> 2s -> 4s, up to 3 retries (4 attempts)
  - retryable HTTP statuses: 408, 429, 500, 502, 503, 504
  - retryable exceptions: TimeoutException, ConnectError, ReadError,
    WriteError, RemoteProtocolError

Non-transient failures short-circuit to "" on the first attempt --
retrying a misconfigured key or a broken upload only burns rate-limit
quota. Branches that short-circuit:
  - missing API key, missing audio file
  - file-read errors (PermissionError, OSError) on the audio path,
    preserving the nightly contract for direct provider callers
  - HTTP auth/4xx body issues via raise_for_status()
  - response.json() parse failures
  - non-dict JSON payloads

Sharing one helper means OpenAI and Groq cannot drift apart silently.

Thread `language` through the helper. The multipart files dict is rebuilt
inside the per-attempt loop, so when a caller sets self.language the
`language` field is sent on every attempt -- not just the first.

Tests cover:
  - every advertised retryable status and exception, parameterized
  - language present on attempts 1 and 2 of a 503->200 sequence
  - language absent when unset; present when set (both providers)
  - malformed JSON body and non-dict JSON body short-circuit to ""
  - PermissionError on file read short-circuits with no HTTP attempt
  - max-attempts give-up, exponential-backoff schedule, auth no-retry,
    missing-key / missing-file short-circuit

Test stub fix: the _StubResponse in tests/channels/test_channel_plugins.py
declared no status_code, which the new helper reads for retry classification.
Set status_code = 200 so the stub advertises the successful response that
those tests already simulate. Also moved the two transcription-provider
imports to the top of that file (previously placed mid-file) so the file
is ruff-clean (E402).
This commit is contained in:
mohamed-elkholy95 2026-04-25 17:45:36 -04:00 committed by chengyongru
parent e54fbfeb2a
commit 7ebf611be8
3 changed files with 428 additions and 43 deletions

View File

@ -1,11 +1,123 @@
"""Voice transcription providers (Groq and OpenAI Whisper).""" """Voice transcription providers (Groq and OpenAI Whisper)."""
import asyncio
import os import os
from pathlib import Path from pathlib import Path
import httpx import httpx
from loguru import logger from loguru import logger
# Up to 3 retries (4 attempts total) with exponential backoff on transient
# failures. Whisper endpoints occasionally return 502/503 under load, and
# mobile-network transcription callers hit sporadic connect/read errors.
# Without this, a voice message silently becomes the empty string.
_MAX_RETRIES = 3
_BACKOFF_S = (1.0, 2.0, 4.0)
_RETRYABLE_STATUS = {408, 429, 500, 502, 503, 504}
_RETRYABLE_EXCEPTIONS = (
httpx.TimeoutException,
httpx.ConnectError,
httpx.ReadError,
httpx.WriteError,
httpx.RemoteProtocolError,
)
async def _post_transcription_with_retry(
url: str,
*,
api_key: str,
path: Path,
model: str,
provider_label: str,
language: str | None = None,
) -> str:
"""POST an audio file for transcription, retrying on transient errors.
Retries on connect/read/timeout failures and on 408/429/5xx responses.
Other errors (including 4xx such as 401/403) return "" immediately the
caller's config is wrong and retrying only wastes quota.
When ``language`` is provided, it is forwarded as the ``language``
multipart field on every attempt (the dict is rebuilt per attempt so the
same field is present on retries).
"""
try:
data = path.read_bytes()
except OSError as e:
logger.error("{} transcription error: cannot read audio file: {}", provider_label, e)
return ""
headers = {"Authorization": f"Bearer {api_key}"}
async with httpx.AsyncClient() as client:
for attempt in range(_MAX_RETRIES + 1):
files = {
"file": (path.name, data),
"model": (None, model),
}
if language:
files["language"] = (None, language)
try:
response = await client.post(url, headers=headers, files=files, timeout=60.0)
except _RETRYABLE_EXCEPTIONS as e:
if attempt < _MAX_RETRIES:
logger.warning(
"{} transcription transient error (attempt {}/{}): {}",
provider_label,
attempt + 1,
_MAX_RETRIES + 1,
e,
)
await asyncio.sleep(_BACKOFF_S[attempt])
continue
logger.error(
"{} transcription error after {} attempts: {}",
provider_label,
_MAX_RETRIES + 1,
e,
)
return ""
except Exception as e:
logger.error("{} transcription error: {}", provider_label, e)
return ""
if response.status_code in _RETRYABLE_STATUS and attempt < _MAX_RETRIES:
logger.warning(
"{} transcription transient HTTP {} (attempt {}/{})",
provider_label,
response.status_code,
attempt + 1,
_MAX_RETRIES + 1,
)
await asyncio.sleep(_BACKOFF_S[attempt])
continue
try:
response.raise_for_status()
except Exception as e:
logger.error("{} transcription error: {}", provider_label, e)
return ""
try:
payload = response.json()
except Exception as e:
logger.error(
"{} transcription error: malformed response body: {}",
provider_label,
e,
)
return ""
if not isinstance(payload, dict):
logger.error(
"{} transcription error: unexpected response shape: {!r}",
provider_label,
type(payload).__name__,
)
return ""
return payload.get("text", "")
return ""
class OpenAITranscriptionProvider: class OpenAITranscriptionProvider:
"""Voice transcription provider using OpenAI's Whisper API.""" """Voice transcription provider using OpenAI's Whisper API."""
@ -32,21 +144,14 @@ class OpenAITranscriptionProvider:
if not path.exists(): if not path.exists():
logger.error("Audio file not found: {}", file_path) logger.error("Audio file not found: {}", file_path)
return "" return ""
try: return await _post_transcription_with_retry(
async with httpx.AsyncClient() as client: self.api_url,
with open(path, "rb") as f: api_key=self.api_key,
files = {"file": (path.name, f), "model": (None, "whisper-1")} path=path,
if self.language: model="whisper-1",
files["language"] = (None, self.language) provider_label="OpenAI",
headers = {"Authorization": f"Bearer {self.api_key}"} language=self.language,
response = await client.post( )
self.api_url, headers=headers, files=files, timeout=60.0,
)
response.raise_for_status()
return response.json().get("text", "")
except Exception as e:
logger.error("OpenAI transcription error: {}", e)
return ""
class GroqTranscriptionProvider: class GroqTranscriptionProvider:
@ -63,7 +168,11 @@ class GroqTranscriptionProvider:
language: str | None = None, language: str | None = None,
): ):
self.api_key = api_key or os.environ.get("GROQ_API_KEY") self.api_key = api_key or os.environ.get("GROQ_API_KEY")
self.api_url = api_base or os.environ.get("GROQ_BASE_URL") or "https://api.groq.com/openai/v1/audio/transcriptions" self.api_url = (
api_base
or os.environ.get("GROQ_BASE_URL")
or "https://api.groq.com/openai/v1/audio/transcriptions"
)
self.language = language or None self.language = language or None
async def transcribe(self, file_path: str | Path) -> str: async def transcribe(self, file_path: str | Path) -> str:
@ -85,30 +194,11 @@ class GroqTranscriptionProvider:
logger.error("Audio file not found: {}", file_path) logger.error("Audio file not found: {}", file_path)
return "" return ""
try: return await _post_transcription_with_retry(
async with httpx.AsyncClient() as client: self.api_url,
with open(path, "rb") as f: api_key=self.api_key,
files = { path=path,
"file": (path.name, f), model="whisper-large-v3",
"model": (None, "whisper-large-v3"), provider_label="Groq",
} language=self.language,
if self.language: )
files["language"] = (None, self.language)
headers = {
"Authorization": f"Bearer {self.api_key}",
}
response = await client.post(
self.api_url,
headers=headers,
files=files,
timeout=60.0
)
response.raise_for_status()
data = response.json()
return data.get("text", "")
except Exception as e:
logger.error("Groq transcription error: {}", e)
return ""

View File

@ -342,6 +342,8 @@ async def test_base_channel_passes_language_to_groq_transcription_provider():
class _StubResponse: class _StubResponse:
status_code = 200
def raise_for_status(self): def raise_for_status(self):
return None return None

View File

@ -0,0 +1,293 @@
"""Tests for transcription retry behavior on transient errors (B10)."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from nanobot.providers.transcription import GroqTranscriptionProvider, OpenAITranscriptionProvider
@pytest.fixture
def audio_file(tmp_path: Path) -> Path:
p = tmp_path / "voice.ogg"
p.write_bytes(b"OggS\x00fake-audio-bytes")
return p
def _response(status: int, payload: dict[str, object] | None = None) -> httpx.Response:
request = httpx.Request("POST", "https://example.test/audio/transcriptions")
return httpx.Response(status_code=status, json=payload or {}, request=request)
def _raw_response(status: int, content: bytes) -> httpx.Response:
"""Build a Response with a raw, possibly-malformed body (bypasses json= encoding)."""
request = httpx.Request("POST", "https://example.test/audio/transcriptions")
return httpx.Response(status_code=status, content=content, request=request)
# ---------------------------------------------------------------------------
# OpenAI provider — retry on transient HTTP + network errors
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_openai_retries_on_5xx_then_succeeds(audio_file: Path) -> None:
"""Transient 503 is retried; a subsequent 200 yields the text."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(side_effect=[_response(503), _response(200, {"text": "hello"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "hello"
assert post.await_count == 2
@pytest.mark.asyncio
async def test_openai_retries_on_429_then_succeeds(audio_file: Path) -> None:
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(side_effect=[_response(429), _response(200, {"text": "rate ok"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "rate ok"
assert post.await_count == 2
@pytest.mark.asyncio
async def test_openai_retries_on_connect_error(audio_file: Path) -> None:
"""Network-level transient errors are retried."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(side_effect=[httpx.ConnectError("boom"), _response(200, {"text": "ok"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "ok"
assert post.await_count == 2
@pytest.mark.asyncio
async def test_openai_does_not_retry_on_auth_error(audio_file: Path) -> None:
"""401 is the user's misconfiguration — retrying wastes time and rate-limit quota."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(return_value=_response(401, {"error": {"message": "bad key"}}))
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == ""
assert post.await_count == 1
@pytest.mark.asyncio
async def test_openai_gives_up_after_max_attempts(audio_file: Path) -> None:
"""Persistent 503 returns "" after the final retry — never hangs."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(return_value=_response(503))
sleep = AsyncMock()
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", sleep):
result = await provider.transcribe(audio_file)
assert result == ""
# 4 attempts total (initial + 3 retries) with 3 sleeps between them.
assert post.await_count == 4
assert sleep.await_count == 3
@pytest.mark.asyncio
async def test_openai_backoff_grows_exponentially(audio_file: Path) -> None:
"""Verify the backoff schedule is exponential (1s, 2s, 4s)."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(return_value=_response(503))
sleep = AsyncMock()
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", sleep):
await provider.transcribe(audio_file)
delays = [call.args[0] for call in sleep.await_args_list]
assert delays == [1.0, 2.0, 4.0]
# ---------------------------------------------------------------------------
# Groq provider — same semantics (both go through the shared helper)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_groq_retries_on_5xx_then_succeeds(audio_file: Path) -> None:
provider = GroqTranscriptionProvider(api_key="gsk-test")
post = AsyncMock(side_effect=[_response(502), _response(200, {"text": "groq ok"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "groq ok"
assert post.await_count == 2
@pytest.mark.asyncio
async def test_groq_does_not_retry_on_auth_error(audio_file: Path) -> None:
provider = GroqTranscriptionProvider(api_key="gsk-test")
post = AsyncMock(return_value=_response(403))
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == ""
assert post.await_count == 1
# ---------------------------------------------------------------------------
# Regression: missing file / missing key must still short-circuit
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_openai_missing_api_key_short_circuits(tmp_path: Path) -> None:
provider = OpenAITranscriptionProvider(api_key=None)
# Ensure env var doesn't accidentally satisfy it.
with patch.dict("os.environ", {}, clear=True):
provider = OpenAITranscriptionProvider(api_key=None)
post = AsyncMock()
with patch("httpx.AsyncClient.post", post):
assert await provider.transcribe(tmp_path / "voice.ogg") == ""
assert post.await_count == 0
@pytest.mark.asyncio
async def test_openai_missing_file_short_circuits() -> None:
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock()
with patch("httpx.AsyncClient.post", post):
assert await provider.transcribe("/nonexistent/path/voice.ogg") == ""
assert post.await_count == 0
@pytest.mark.asyncio
async def test_returns_empty_when_file_unreadable(audio_file: Path) -> None:
"""Existing file that cannot be read (PermissionError/OSError): "" with no HTTP attempt."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock()
with patch.object(Path, "read_bytes", side_effect=PermissionError("denied")), patch(
"httpx.AsyncClient.post", post
):
result = await provider.transcribe(audio_file)
assert result == ""
assert post.await_count == 0
# ---------------------------------------------------------------------------
# language: forwarded through the helper to the multipart body, on every attempt
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"provider_cls,language",
[(OpenAITranscriptionProvider, "en"), (GroqTranscriptionProvider, "ko")],
ids=["openai", "groq"],
)
@pytest.mark.asyncio
async def test_provider_forwards_language_in_multipart(
audio_file: Path, provider_cls: type, language: str
) -> None:
"""When ``language`` is set, the helper sends it as a multipart field."""
provider = provider_cls(api_key="k", language=language)
post = AsyncMock(return_value=_response(200, {"text": "ok"}))
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "ok"
assert post.await_count == 1
files = post.await_args_list[0].kwargs["files"]
assert files["language"] == (None, language)
@pytest.mark.parametrize(
"provider_cls",
[OpenAITranscriptionProvider, GroqTranscriptionProvider],
ids=["openai", "groq"],
)
@pytest.mark.asyncio
async def test_provider_omits_language_when_unset(
audio_file: Path, provider_cls: type
) -> None:
"""When ``language`` is None, no ``language`` field is sent."""
provider = provider_cls(api_key="k")
post = AsyncMock(return_value=_response(200, {"text": "ok"}))
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "ok"
assert post.await_count == 1
files = post.await_args_list[0].kwargs["files"]
assert "language" not in files
@pytest.mark.asyncio
async def test_language_survives_retry(audio_file: Path) -> None:
"""Regression: language must be present on every retry attempt, not just the first."""
provider = OpenAITranscriptionProvider(api_key="sk-test", language="ja")
post = AsyncMock(side_effect=[_response(503), _response(200, {"text": "konnichiwa"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "konnichiwa"
assert post.await_count == 2
for call in post.await_args_list:
assert call.kwargs["files"]["language"] == (None, "ja")
# ---------------------------------------------------------------------------
# Malformed / unexpected response bodies must short-circuit, not escape
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_returns_empty_on_malformed_json_body(audio_file: Path) -> None:
"""200 with invalid JSON: log and return "" immediately (no retry, no exception)."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(return_value=_raw_response(200, b"<html>not json</html>"))
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == ""
assert post.await_count == 1
@pytest.mark.asyncio
async def test_returns_empty_on_non_dict_json_body(audio_file: Path) -> None:
"""200 with a JSON array (not dict): no AttributeError leak; return "" immediately."""
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(return_value=_raw_response(200, b"[]"))
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == ""
assert post.await_count == 1
# ---------------------------------------------------------------------------
# Pin the full advertised retry contract: all retryable statuses + exceptions
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("status", [408, 429, 500, 502, 503, 504])
@pytest.mark.asyncio
async def test_retries_on_every_advertised_transient_status(
audio_file: Path, status: int
) -> None:
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(side_effect=[_response(status), _response(200, {"text": "ok"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "ok"
assert post.await_count == 2
@pytest.mark.parametrize(
"exc",
[
httpx.TimeoutException("t"),
httpx.ConnectError("c"),
httpx.ReadError("r"),
httpx.WriteError("w"),
httpx.RemoteProtocolError("p"),
],
ids=["timeout", "connect", "read", "write", "remote_protocol"],
)
@pytest.mark.asyncio
async def test_retries_on_every_advertised_transient_exception(
audio_file: Path, exc: Exception
) -> None:
provider = OpenAITranscriptionProvider(api_key="sk-test")
post = AsyncMock(side_effect=[exc, _response(200, {"text": "recovered"})])
with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()):
result = await provider.transcribe(audio_file)
assert result == "recovered"
assert post.await_count == 2