From 84e8aed6b122b07b3ef212840c01ff7121a3b6f5 Mon Sep 17 00:00:00 2001 From: mohamed-elkholy95 Date: Sat, 25 Apr 2026 17:45:36 -0400 Subject: [PATCH] fix(transcription): retry Whisper calls and guard malformed responses A single transient failure between the agent and an OpenAI/Groq Whisper endpoint currently vanishes as `return ""` in transcribe(). The voice message arrives as the empty string and there is no way to tell real silence apart from a failed upload. A malformed but successful response body is even worse: the JSON-decode error escapes the helper unhandled. Add a shared `_post_transcription_with_retry` used by both providers. Retry behaviour: - exponential backoff 1s -> 2s -> 4s, up to 3 retries (4 attempts) - retryable HTTP statuses: 408, 429, 500, 502, 503, 504 - retryable exceptions: TimeoutException, ConnectError, ReadError, WriteError, RemoteProtocolError Non-transient failures short-circuit to "" on the first attempt -- retrying a misconfigured key or a broken upload only burns rate-limit quota. Branches that short-circuit: - missing API key, missing audio file - file-read errors (PermissionError, OSError) on the audio path, preserving the nightly contract for direct provider callers - HTTP auth/4xx body issues via raise_for_status() - response.json() parse failures - non-dict JSON payloads Sharing one helper means OpenAI and Groq cannot drift apart silently. Thread `language` through the helper. The multipart files dict is rebuilt inside the per-attempt loop, so when a caller sets self.language the `language` field is sent on every attempt -- not just the first. Tests cover: - every advertised retryable status and exception, parameterized - language present on attempts 1 and 2 of a 503->200 sequence - language absent when unset; present when set (both providers) - malformed JSON body and non-dict JSON body short-circuit to "" - PermissionError on file read short-circuits with no HTTP attempt - max-attempts give-up, exponential-backoff schedule, auth no-retry, missing-key / missing-file short-circuit Test stub fix: the _StubResponse in tests/channels/test_channel_plugins.py declared no status_code, which the new helper reads for retry classification. Set status_code = 200 so the stub advertises the successful response that those tests already simulate. Also moved the two transcription-provider imports to the top of that file (previously placed mid-file) so the file is ruff-clean (E402). --- nanobot/providers/transcription.py | 176 +++++++++++---- tests/channels/test_channel_plugins.py | 2 + tests/providers/test_transcription.py | 293 +++++++++++++++++++++++++ 3 files changed, 428 insertions(+), 43 deletions(-) create mode 100644 tests/providers/test_transcription.py diff --git a/nanobot/providers/transcription.py b/nanobot/providers/transcription.py index 10fcafd6d..25e09dab7 100644 --- a/nanobot/providers/transcription.py +++ b/nanobot/providers/transcription.py @@ -1,11 +1,123 @@ """Voice transcription providers (Groq and OpenAI Whisper).""" +import asyncio import os from pathlib import Path import httpx from loguru import logger +# Up to 3 retries (4 attempts total) with exponential backoff on transient +# failures. Whisper endpoints occasionally return 502/503 under load, and +# mobile-network transcription callers hit sporadic connect/read errors. +# Without this, a voice message silently becomes the empty string. +_MAX_RETRIES = 3 +_BACKOFF_S = (1.0, 2.0, 4.0) +_RETRYABLE_STATUS = {408, 429, 500, 502, 503, 504} +_RETRYABLE_EXCEPTIONS = ( + httpx.TimeoutException, + httpx.ConnectError, + httpx.ReadError, + httpx.WriteError, + httpx.RemoteProtocolError, +) + + +async def _post_transcription_with_retry( + url: str, + *, + api_key: str, + path: Path, + model: str, + provider_label: str, + language: str | None = None, +) -> str: + """POST an audio file for transcription, retrying on transient errors. + + Retries on connect/read/timeout failures and on 408/429/5xx responses. + Other errors (including 4xx such as 401/403) return "" immediately — the + caller's config is wrong and retrying only wastes quota. + + When ``language`` is provided, it is forwarded as the ``language`` + multipart field on every attempt (the dict is rebuilt per attempt so the + same field is present on retries). + """ + try: + data = path.read_bytes() + except OSError as e: + logger.error("{} transcription error: cannot read audio file: {}", provider_label, e) + return "" + headers = {"Authorization": f"Bearer {api_key}"} + + async with httpx.AsyncClient() as client: + for attempt in range(_MAX_RETRIES + 1): + files = { + "file": (path.name, data), + "model": (None, model), + } + if language: + files["language"] = (None, language) + try: + response = await client.post(url, headers=headers, files=files, timeout=60.0) + except _RETRYABLE_EXCEPTIONS as e: + if attempt < _MAX_RETRIES: + logger.warning( + "{} transcription transient error (attempt {}/{}): {}", + provider_label, + attempt + 1, + _MAX_RETRIES + 1, + e, + ) + await asyncio.sleep(_BACKOFF_S[attempt]) + continue + logger.error( + "{} transcription error after {} attempts: {}", + provider_label, + _MAX_RETRIES + 1, + e, + ) + return "" + except Exception as e: + logger.error("{} transcription error: {}", provider_label, e) + return "" + + if response.status_code in _RETRYABLE_STATUS and attempt < _MAX_RETRIES: + logger.warning( + "{} transcription transient HTTP {} (attempt {}/{})", + provider_label, + response.status_code, + attempt + 1, + _MAX_RETRIES + 1, + ) + await asyncio.sleep(_BACKOFF_S[attempt]) + continue + + try: + response.raise_for_status() + except Exception as e: + logger.error("{} transcription error: {}", provider_label, e) + return "" + + try: + payload = response.json() + except Exception as e: + logger.error( + "{} transcription error: malformed response body: {}", + provider_label, + e, + ) + return "" + if not isinstance(payload, dict): + logger.error( + "{} transcription error: unexpected response shape: {!r}", + provider_label, + type(payload).__name__, + ) + return "" + return payload.get("text", "") + + return "" + class OpenAITranscriptionProvider: """Voice transcription provider using OpenAI's Whisper API.""" @@ -32,21 +144,14 @@ class OpenAITranscriptionProvider: if not path.exists(): logger.error("Audio file not found: {}", file_path) return "" - try: - async with httpx.AsyncClient() as client: - with open(path, "rb") as f: - files = {"file": (path.name, f), "model": (None, "whisper-1")} - if self.language: - files["language"] = (None, self.language) - headers = {"Authorization": f"Bearer {self.api_key}"} - response = await client.post( - self.api_url, headers=headers, files=files, timeout=60.0, - ) - response.raise_for_status() - return response.json().get("text", "") - except Exception as e: - logger.error("OpenAI transcription error: {}", e) - return "" + return await _post_transcription_with_retry( + self.api_url, + api_key=self.api_key, + path=path, + model="whisper-1", + provider_label="OpenAI", + language=self.language, + ) class GroqTranscriptionProvider: @@ -63,7 +168,11 @@ class GroqTranscriptionProvider: language: str | None = None, ): self.api_key = api_key or os.environ.get("GROQ_API_KEY") - self.api_url = api_base or os.environ.get("GROQ_BASE_URL") or "https://api.groq.com/openai/v1/audio/transcriptions" + self.api_url = ( + api_base + or os.environ.get("GROQ_BASE_URL") + or "https://api.groq.com/openai/v1/audio/transcriptions" + ) self.language = language or None async def transcribe(self, file_path: str | Path) -> str: @@ -85,30 +194,11 @@ class GroqTranscriptionProvider: logger.error("Audio file not found: {}", file_path) return "" - try: - async with httpx.AsyncClient() as client: - with open(path, "rb") as f: - files = { - "file": (path.name, f), - "model": (None, "whisper-large-v3"), - } - if self.language: - files["language"] = (None, self.language) - headers = { - "Authorization": f"Bearer {self.api_key}", - } - - response = await client.post( - self.api_url, - headers=headers, - files=files, - timeout=60.0 - ) - - response.raise_for_status() - data = response.json() - return data.get("text", "") - - except Exception as e: - logger.error("Groq transcription error: {}", e) - return "" + return await _post_transcription_with_retry( + self.api_url, + api_key=self.api_key, + path=path, + model="whisper-large-v3", + provider_label="Groq", + language=self.language, + ) diff --git a/tests/channels/test_channel_plugins.py b/tests/channels/test_channel_plugins.py index 378cdd059..a32d96e1a 100644 --- a/tests/channels/test_channel_plugins.py +++ b/tests/channels/test_channel_plugins.py @@ -342,6 +342,8 @@ async def test_base_channel_passes_language_to_groq_transcription_provider(): class _StubResponse: + status_code = 200 + def raise_for_status(self): return None diff --git a/tests/providers/test_transcription.py b/tests/providers/test_transcription.py new file mode 100644 index 000000000..288290a92 --- /dev/null +++ b/tests/providers/test_transcription.py @@ -0,0 +1,293 @@ +"""Tests for transcription retry behavior on transient errors (B10).""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from nanobot.providers.transcription import GroqTranscriptionProvider, OpenAITranscriptionProvider + + +@pytest.fixture +def audio_file(tmp_path: Path) -> Path: + p = tmp_path / "voice.ogg" + p.write_bytes(b"OggS\x00fake-audio-bytes") + return p + + +def _response(status: int, payload: dict[str, object] | None = None) -> httpx.Response: + request = httpx.Request("POST", "https://example.test/audio/transcriptions") + return httpx.Response(status_code=status, json=payload or {}, request=request) + + +def _raw_response(status: int, content: bytes) -> httpx.Response: + """Build a Response with a raw, possibly-malformed body (bypasses json= encoding).""" + request = httpx.Request("POST", "https://example.test/audio/transcriptions") + return httpx.Response(status_code=status, content=content, request=request) + + +# --------------------------------------------------------------------------- +# OpenAI provider — retry on transient HTTP + network errors +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_openai_retries_on_5xx_then_succeeds(audio_file: Path) -> None: + """Transient 503 is retried; a subsequent 200 yields the text.""" + provider = OpenAITranscriptionProvider(api_key="sk-test") + post = AsyncMock(side_effect=[_response(503), _response(200, {"text": "hello"})]) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "hello" + assert post.await_count == 2 + + +@pytest.mark.asyncio +async def test_openai_retries_on_429_then_succeeds(audio_file: Path) -> None: + provider = OpenAITranscriptionProvider(api_key="sk-test") + post = AsyncMock(side_effect=[_response(429), _response(200, {"text": "rate ok"})]) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "rate ok" + assert post.await_count == 2 + + +@pytest.mark.asyncio +async def test_openai_retries_on_connect_error(audio_file: Path) -> None: + """Network-level transient errors are retried.""" + provider = OpenAITranscriptionProvider(api_key="sk-test") + post = AsyncMock(side_effect=[httpx.ConnectError("boom"), _response(200, {"text": "ok"})]) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "ok" + assert post.await_count == 2 + + +@pytest.mark.asyncio +async def test_openai_does_not_retry_on_auth_error(audio_file: Path) -> None: + """401 is the user's misconfiguration — retrying wastes time and rate-limit quota.""" + provider = OpenAITranscriptionProvider(api_key="sk-test") + post = AsyncMock(return_value=_response(401, {"error": {"message": "bad key"}})) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "" + assert post.await_count == 1 + + +@pytest.mark.asyncio +async def test_openai_gives_up_after_max_attempts(audio_file: Path) -> None: + """Persistent 503 returns "" after the final retry — never hangs.""" + provider = OpenAITranscriptionProvider(api_key="sk-test") + post = AsyncMock(return_value=_response(503)) + sleep = AsyncMock() + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", sleep): + result = await provider.transcribe(audio_file) + assert result == "" + # 4 attempts total (initial + 3 retries) with 3 sleeps between them. + assert post.await_count == 4 + assert sleep.await_count == 3 + + +@pytest.mark.asyncio +async def test_openai_backoff_grows_exponentially(audio_file: Path) -> None: + """Verify the backoff schedule is exponential (1s, 2s, 4s).""" + provider = OpenAITranscriptionProvider(api_key="sk-test") + post = AsyncMock(return_value=_response(503)) + sleep = AsyncMock() + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", sleep): + await provider.transcribe(audio_file) + delays = [call.args[0] for call in sleep.await_args_list] + assert delays == [1.0, 2.0, 4.0] + + +# --------------------------------------------------------------------------- +# Groq provider — same semantics (both go through the shared helper) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_groq_retries_on_5xx_then_succeeds(audio_file: Path) -> None: + provider = GroqTranscriptionProvider(api_key="gsk-test") + post = AsyncMock(side_effect=[_response(502), _response(200, {"text": "groq ok"})]) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "groq ok" + assert post.await_count == 2 + + +@pytest.mark.asyncio +async def test_groq_does_not_retry_on_auth_error(audio_file: Path) -> None: + provider = GroqTranscriptionProvider(api_key="gsk-test") + post = AsyncMock(return_value=_response(403)) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "" + assert post.await_count == 1 + + +# --------------------------------------------------------------------------- +# Regression: missing file / missing key must still short-circuit +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_openai_missing_api_key_short_circuits(tmp_path: Path) -> None: + provider = OpenAITranscriptionProvider(api_key=None) + # Ensure env var doesn't accidentally satisfy it. + with patch.dict("os.environ", {}, clear=True): + provider = OpenAITranscriptionProvider(api_key=None) + post = AsyncMock() + with patch("httpx.AsyncClient.post", post): + assert await provider.transcribe(tmp_path / "voice.ogg") == "" + assert post.await_count == 0 + + +@pytest.mark.asyncio +async def test_openai_missing_file_short_circuits() -> None: + provider = OpenAITranscriptionProvider(api_key="sk-test") + post = AsyncMock() + with patch("httpx.AsyncClient.post", post): + assert await provider.transcribe("/nonexistent/path/voice.ogg") == "" + assert post.await_count == 0 + + +@pytest.mark.asyncio +async def test_returns_empty_when_file_unreadable(audio_file: Path) -> None: + """Existing file that cannot be read (PermissionError/OSError): "" with no HTTP attempt.""" + provider = OpenAITranscriptionProvider(api_key="sk-test") + post = AsyncMock() + with patch.object(Path, "read_bytes", side_effect=PermissionError("denied")), patch( + "httpx.AsyncClient.post", post + ): + result = await provider.transcribe(audio_file) + assert result == "" + assert post.await_count == 0 + + +# --------------------------------------------------------------------------- +# language: forwarded through the helper to the multipart body, on every attempt +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "provider_cls,language", + [(OpenAITranscriptionProvider, "en"), (GroqTranscriptionProvider, "ko")], + ids=["openai", "groq"], +) +@pytest.mark.asyncio +async def test_provider_forwards_language_in_multipart( + audio_file: Path, provider_cls: type, language: str +) -> None: + """When ``language`` is set, the helper sends it as a multipart field.""" + provider = provider_cls(api_key="k", language=language) + post = AsyncMock(return_value=_response(200, {"text": "ok"})) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "ok" + assert post.await_count == 1 + files = post.await_args_list[0].kwargs["files"] + assert files["language"] == (None, language) + + +@pytest.mark.parametrize( + "provider_cls", + [OpenAITranscriptionProvider, GroqTranscriptionProvider], + ids=["openai", "groq"], +) +@pytest.mark.asyncio +async def test_provider_omits_language_when_unset( + audio_file: Path, provider_cls: type +) -> None: + """When ``language`` is None, no ``language`` field is sent.""" + provider = provider_cls(api_key="k") + post = AsyncMock(return_value=_response(200, {"text": "ok"})) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "ok" + assert post.await_count == 1 + files = post.await_args_list[0].kwargs["files"] + assert "language" not in files + + +@pytest.mark.asyncio +async def test_language_survives_retry(audio_file: Path) -> None: + """Regression: language must be present on every retry attempt, not just the first.""" + provider = OpenAITranscriptionProvider(api_key="sk-test", language="ja") + post = AsyncMock(side_effect=[_response(503), _response(200, {"text": "konnichiwa"})]) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "konnichiwa" + assert post.await_count == 2 + for call in post.await_args_list: + assert call.kwargs["files"]["language"] == (None, "ja") + + +# --------------------------------------------------------------------------- +# Malformed / unexpected response bodies must short-circuit, not escape +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_returns_empty_on_malformed_json_body(audio_file: Path) -> None: + """200 with invalid JSON: log and return "" immediately (no retry, no exception).""" + provider = OpenAITranscriptionProvider(api_key="sk-test") + post = AsyncMock(return_value=_raw_response(200, b"not json")) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "" + assert post.await_count == 1 + + +@pytest.mark.asyncio +async def test_returns_empty_on_non_dict_json_body(audio_file: Path) -> None: + """200 with a JSON array (not dict): no AttributeError leak; return "" immediately.""" + provider = OpenAITranscriptionProvider(api_key="sk-test") + post = AsyncMock(return_value=_raw_response(200, b"[]")) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "" + assert post.await_count == 1 + + +# --------------------------------------------------------------------------- +# Pin the full advertised retry contract: all retryable statuses + exceptions +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("status", [408, 429, 500, 502, 503, 504]) +@pytest.mark.asyncio +async def test_retries_on_every_advertised_transient_status( + audio_file: Path, status: int +) -> None: + provider = OpenAITranscriptionProvider(api_key="sk-test") + post = AsyncMock(side_effect=[_response(status), _response(200, {"text": "ok"})]) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "ok" + assert post.await_count == 2 + + +@pytest.mark.parametrize( + "exc", + [ + httpx.TimeoutException("t"), + httpx.ConnectError("c"), + httpx.ReadError("r"), + httpx.WriteError("w"), + httpx.RemoteProtocolError("p"), + ], + ids=["timeout", "connect", "read", "write", "remote_protocol"], +) +@pytest.mark.asyncio +async def test_retries_on_every_advertised_transient_exception( + audio_file: Path, exc: Exception +) -> None: + provider = OpenAITranscriptionProvider(api_key="sk-test") + post = AsyncMock(side_effect=[exc, _response(200, {"text": "recovered"})]) + with patch("httpx.AsyncClient.post", post), patch("asyncio.sleep", AsyncMock()): + result = await provider.transcribe(audio_file) + assert result == "recovered" + assert post.await_count == 2