From f6a417e77d9b2a3f5958fb756f1e3b2a36f3d371 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Wed, 22 Apr 2026 10:47:49 +0800 Subject: [PATCH] fix(transcription): harden language parameter validation and tests - Add ISO-639 pattern validation (2-3 lowercase letters) to schema - Normalize empty language to None in provider constructors - Extract shared httpx mock stubs, parameterize provider tests - Add test for language=None omitting field from multipart body - Add test for Pydantic pattern validation rejecting invalid codes --- nanobot/config/schema.py | 2 +- nanobot/providers/transcription.py | 4 +- tests/channels/test_channel_plugins.py | 88 ++++++++++++++++++++------ 3 files changed, 73 insertions(+), 21 deletions(-) diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 0c2b1b2ac..cca8f210f 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -29,7 +29,7 @@ class ChannelsConfig(Base): send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…")) send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included) transcription_provider: str = "groq" # Voice transcription backend: "groq" or "openai" - transcription_language: str | None = None # Optional ISO-639-1 hint for audio transcription + transcription_language: str | None = Field(default=None, pattern=r"^[a-z]{2,3}$") # Optional ISO-639-1 hint for audio transcription class DreamConfig(Base): diff --git a/nanobot/providers/transcription.py b/nanobot/providers/transcription.py index 969990166..10fcafd6d 100644 --- a/nanobot/providers/transcription.py +++ b/nanobot/providers/transcription.py @@ -22,7 +22,7 @@ class OpenAITranscriptionProvider: or os.environ.get("OPENAI_TRANSCRIPTION_BASE_URL") or "https://api.openai.com/v1/audio/transcriptions" ) - self.language = language + self.language = language or None async def transcribe(self, file_path: str | Path) -> str: if not self.api_key: @@ -64,7 +64,7 @@ class GroqTranscriptionProvider: ): self.api_key = api_key or os.environ.get("GROQ_API_KEY") self.api_url = api_base or os.environ.get("GROQ_BASE_URL") or "https://api.groq.com/openai/v1/audio/transcriptions" - self.language = language + self.language = language or None async def transcribe(self, file_path: str | Path) -> str: """ diff --git a/tests/channels/test_channel_plugins.py b/tests/channels/test_channel_plugins.py index 6abe21f7a..10f045bf8 100644 --- a/tests/channels/test_channel_plugins.py +++ b/tests/channels/test_channel_plugins.py @@ -334,21 +334,24 @@ async def test_base_channel_passes_language_to_groq_transcription_provider(): assert captured["language"] == "ko" -@pytest.mark.asyncio -async def test_groq_transcription_provider_includes_language(tmp_path): - from nanobot.providers.transcription import GroqTranscriptionProvider +# --------------------------------------------------------------------------- +# Transcription provider HTTP tests +# --------------------------------------------------------------------------- - audio = tmp_path / "sample.wav" - audio.write_bytes(b"audio") - captured: dict[str, object] = {} +from nanobot.providers.transcription import GroqTranscriptionProvider as _GroqProvider +from nanobot.providers.transcription import OpenAITranscriptionProvider as _OpenAIProvider - class _Response: - def raise_for_status(self): - return None - def json(self): - return {"text": "hello"} +class _StubResponse: + def raise_for_status(self): + return None + def json(self): + return {"text": "hello"} + + +def _stub_async_client(captured: dict[str, object]): + """Return an httpx.AsyncClient stub that records POST calls into *captured*.""" class _AsyncClient: async def __aenter__(self): return self @@ -357,19 +360,50 @@ async def test_groq_transcription_provider_includes_language(tmp_path): return False async def post(self, url, headers=None, files=None, timeout=None): - captured["url"] = url - captured["headers"] = headers captured["files"] = files - captured["timeout"] = timeout - return _Response() + return _StubResponse() - provider = GroqTranscriptionProvider(api_key="k", language="ko") + return _AsyncClient() - with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_AsyncClient()): + +@pytest.mark.parametrize( + "provider_cls,language", + [(_GroqProvider, "ko"), (_OpenAIProvider, "en")], + ids=["groq", "openai"], +) +@pytest.mark.asyncio +async def test_transcription_provider_includes_language(tmp_path, provider_cls, language): + """Provider must include the 'language' field in multipart body when set.""" + audio = tmp_path / "sample.wav" + audio.write_bytes(b"audio") + captured: dict[str, object] = {} + + with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_stub_async_client(captured)): + provider = provider_cls(api_key="k", language=language) result = await provider.transcribe(audio) assert result == "hello" - assert captured["files"]["language"] == (None, "ko") + assert captured["files"]["language"] == (None, language) + + +@pytest.mark.parametrize( + "provider_cls", + [_GroqProvider, _OpenAIProvider], + ids=["groq", "openai"], +) +@pytest.mark.asyncio +async def test_transcription_provider_omits_language_when_none(tmp_path, provider_cls): + """When language is not set, the 'language' key must be absent from the multipart body.""" + audio = tmp_path / "sample.wav" + audio.write_bytes(b"audio") + captured: dict[str, object] = {} + + with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_stub_async_client(captured)): + provider = provider_cls(api_key="k") + result = await provider.transcribe(audio) + + assert result == "hello" + assert "language" not in captured["files"] def test_channels_login_uses_discovered_plugin_class(monkeypatch): @@ -530,6 +564,24 @@ def test_channels_config_send_max_retries_upper_bound(): ChannelsConfig(send_max_retries=11) +def test_channels_config_transcription_language_pattern(): + """transcription_language must match ISO-639 format (2-3 lowercase letters) or be None.""" + from pydantic import ValidationError + + # Valid values + assert ChannelsConfig(transcription_language="en").transcription_language == "en" + assert ChannelsConfig(transcription_language="kor").transcription_language == "kor" + assert ChannelsConfig(transcription_language=None).transcription_language is None + + # Invalid values + with pytest.raises(ValidationError): + ChannelsConfig(transcription_language="EN") # uppercase + with pytest.raises(ValidationError): + ChannelsConfig(transcription_language="english") # full word + with pytest.raises(ValidationError): + ChannelsConfig(transcription_language="en-US") # BCP 47 tag + + # --------------------------------------------------------------------------- # _send_with_retry # ---------------------------------------------------------------------------