fix(transcription): harden language parameter validation and tests

- Add ISO-639 pattern validation (2-3 lowercase letters) to schema
- Normalize empty language to None in provider constructors
- Extract shared httpx mock stubs, parameterize provider tests
- Add test for language=None omitting field from multipart body
- Add test for Pydantic pattern validation rejecting invalid codes
This commit is contained in:
chengyongru 2026-04-22 10:47:49 +08:00 committed by Xubin Ren
parent 123d69bfb7
commit f6a417e77d
3 changed files with 73 additions and 21 deletions

View File

@ -29,7 +29,7 @@ class ChannelsConfig(Base):
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…")) send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included) send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included)
transcription_provider: str = "groq" # Voice transcription backend: "groq" or "openai" transcription_provider: str = "groq" # Voice transcription backend: "groq" or "openai"
transcription_language: str | None = None # Optional ISO-639-1 hint for audio transcription transcription_language: str | None = Field(default=None, pattern=r"^[a-z]{2,3}$") # Optional ISO-639-1 hint for audio transcription
class DreamConfig(Base): class DreamConfig(Base):

View File

@ -22,7 +22,7 @@ class OpenAITranscriptionProvider:
or os.environ.get("OPENAI_TRANSCRIPTION_BASE_URL") or os.environ.get("OPENAI_TRANSCRIPTION_BASE_URL")
or "https://api.openai.com/v1/audio/transcriptions" or "https://api.openai.com/v1/audio/transcriptions"
) )
self.language = language self.language = language or None
async def transcribe(self, file_path: str | Path) -> str: async def transcribe(self, file_path: str | Path) -> str:
if not self.api_key: if not self.api_key:
@ -64,7 +64,7 @@ class GroqTranscriptionProvider:
): ):
self.api_key = api_key or os.environ.get("GROQ_API_KEY") self.api_key = api_key or os.environ.get("GROQ_API_KEY")
self.api_url = api_base or os.environ.get("GROQ_BASE_URL") or "https://api.groq.com/openai/v1/audio/transcriptions" self.api_url = api_base or os.environ.get("GROQ_BASE_URL") or "https://api.groq.com/openai/v1/audio/transcriptions"
self.language = language self.language = language or None
async def transcribe(self, file_path: str | Path) -> str: async def transcribe(self, file_path: str | Path) -> str:
""" """

View File

@ -334,21 +334,24 @@ async def test_base_channel_passes_language_to_groq_transcription_provider():
assert captured["language"] == "ko" assert captured["language"] == "ko"
@pytest.mark.asyncio # ---------------------------------------------------------------------------
async def test_groq_transcription_provider_includes_language(tmp_path): # Transcription provider HTTP tests
from nanobot.providers.transcription import GroqTranscriptionProvider # ---------------------------------------------------------------------------
audio = tmp_path / "sample.wav" from nanobot.providers.transcription import GroqTranscriptionProvider as _GroqProvider
audio.write_bytes(b"audio") from nanobot.providers.transcription import OpenAITranscriptionProvider as _OpenAIProvider
captured: dict[str, object] = {}
class _Response:
def raise_for_status(self):
return None
def json(self): class _StubResponse:
return {"text": "hello"} def raise_for_status(self):
return None
def json(self):
return {"text": "hello"}
def _stub_async_client(captured: dict[str, object]):
"""Return an httpx.AsyncClient stub that records POST calls into *captured*."""
class _AsyncClient: class _AsyncClient:
async def __aenter__(self): async def __aenter__(self):
return self return self
@ -357,19 +360,50 @@ async def test_groq_transcription_provider_includes_language(tmp_path):
return False return False
async def post(self, url, headers=None, files=None, timeout=None): async def post(self, url, headers=None, files=None, timeout=None):
captured["url"] = url
captured["headers"] = headers
captured["files"] = files captured["files"] = files
captured["timeout"] = timeout return _StubResponse()
return _Response()
provider = GroqTranscriptionProvider(api_key="k", language="ko") return _AsyncClient()
with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_AsyncClient()):
@pytest.mark.parametrize(
"provider_cls,language",
[(_GroqProvider, "ko"), (_OpenAIProvider, "en")],
ids=["groq", "openai"],
)
@pytest.mark.asyncio
async def test_transcription_provider_includes_language(tmp_path, provider_cls, language):
"""Provider must include the 'language' field in multipart body when set."""
audio = tmp_path / "sample.wav"
audio.write_bytes(b"audio")
captured: dict[str, object] = {}
with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_stub_async_client(captured)):
provider = provider_cls(api_key="k", language=language)
result = await provider.transcribe(audio) result = await provider.transcribe(audio)
assert result == "hello" assert result == "hello"
assert captured["files"]["language"] == (None, "ko") assert captured["files"]["language"] == (None, language)
@pytest.mark.parametrize(
"provider_cls",
[_GroqProvider, _OpenAIProvider],
ids=["groq", "openai"],
)
@pytest.mark.asyncio
async def test_transcription_provider_omits_language_when_none(tmp_path, provider_cls):
"""When language is not set, the 'language' key must be absent from the multipart body."""
audio = tmp_path / "sample.wav"
audio.write_bytes(b"audio")
captured: dict[str, object] = {}
with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_stub_async_client(captured)):
provider = provider_cls(api_key="k")
result = await provider.transcribe(audio)
assert result == "hello"
assert "language" not in captured["files"]
def test_channels_login_uses_discovered_plugin_class(monkeypatch): def test_channels_login_uses_discovered_plugin_class(monkeypatch):
@ -530,6 +564,24 @@ def test_channels_config_send_max_retries_upper_bound():
ChannelsConfig(send_max_retries=11) ChannelsConfig(send_max_retries=11)
def test_channels_config_transcription_language_pattern():
"""transcription_language must match ISO-639 format (2-3 lowercase letters) or be None."""
from pydantic import ValidationError
# Valid values
assert ChannelsConfig(transcription_language="en").transcription_language == "en"
assert ChannelsConfig(transcription_language="kor").transcription_language == "kor"
assert ChannelsConfig(transcription_language=None).transcription_language is None
# Invalid values
with pytest.raises(ValidationError):
ChannelsConfig(transcription_language="EN") # uppercase
with pytest.raises(ValidationError):
ChannelsConfig(transcription_language="english") # full word
with pytest.raises(ValidationError):
ChannelsConfig(transcription_language="en-US") # BCP 47 tag
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# _send_with_retry # _send_with_retry
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------