mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-04 08:45:54 +00:00
fix(transcription): harden language parameter validation and tests
- Add ISO-639 pattern validation (2-3 lowercase letters) to schema - Normalize empty language to None in provider constructors - Extract shared httpx mock stubs, parameterize provider tests - Add test for language=None omitting field from multipart body - Add test for Pydantic pattern validation rejecting invalid codes
This commit is contained in:
parent
123d69bfb7
commit
f6a417e77d
@ -29,7 +29,7 @@ class ChannelsConfig(Base):
|
||||
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
||||
send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included)
|
||||
transcription_provider: str = "groq" # Voice transcription backend: "groq" or "openai"
|
||||
transcription_language: str | None = None # Optional ISO-639-1 hint for audio transcription
|
||||
transcription_language: str | None = Field(default=None, pattern=r"^[a-z]{2,3}$") # Optional ISO-639-1 hint for audio transcription
|
||||
|
||||
|
||||
class DreamConfig(Base):
|
||||
|
||||
@ -22,7 +22,7 @@ class OpenAITranscriptionProvider:
|
||||
or os.environ.get("OPENAI_TRANSCRIPTION_BASE_URL")
|
||||
or "https://api.openai.com/v1/audio/transcriptions"
|
||||
)
|
||||
self.language = language
|
||||
self.language = language or None
|
||||
|
||||
async def transcribe(self, file_path: str | Path) -> str:
|
||||
if not self.api_key:
|
||||
@ -64,7 +64,7 @@ class GroqTranscriptionProvider:
|
||||
):
|
||||
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
|
||||
self.api_url = api_base or os.environ.get("GROQ_BASE_URL") or "https://api.groq.com/openai/v1/audio/transcriptions"
|
||||
self.language = language
|
||||
self.language = language or None
|
||||
|
||||
async def transcribe(self, file_path: str | Path) -> str:
|
||||
"""
|
||||
|
||||
@ -334,21 +334,24 @@ async def test_base_channel_passes_language_to_groq_transcription_provider():
|
||||
assert captured["language"] == "ko"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_groq_transcription_provider_includes_language(tmp_path):
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcription provider HTTP tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
audio = tmp_path / "sample.wav"
|
||||
audio.write_bytes(b"audio")
|
||||
captured: dict[str, object] = {}
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider as _GroqProvider
|
||||
from nanobot.providers.transcription import OpenAITranscriptionProvider as _OpenAIProvider
|
||||
|
||||
class _Response:
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
def json(self):
|
||||
return {"text": "hello"}
|
||||
class _StubResponse:
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
def json(self):
|
||||
return {"text": "hello"}
|
||||
|
||||
|
||||
def _stub_async_client(captured: dict[str, object]):
|
||||
"""Return an httpx.AsyncClient stub that records POST calls into *captured*."""
|
||||
class _AsyncClient:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
@ -357,19 +360,50 @@ async def test_groq_transcription_provider_includes_language(tmp_path):
|
||||
return False
|
||||
|
||||
async def post(self, url, headers=None, files=None, timeout=None):
|
||||
captured["url"] = url
|
||||
captured["headers"] = headers
|
||||
captured["files"] = files
|
||||
captured["timeout"] = timeout
|
||||
return _Response()
|
||||
return _StubResponse()
|
||||
|
||||
provider = GroqTranscriptionProvider(api_key="k", language="ko")
|
||||
return _AsyncClient()
|
||||
|
||||
with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_AsyncClient()):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls,language",
|
||||
[(_GroqProvider, "ko"), (_OpenAIProvider, "en")],
|
||||
ids=["groq", "openai"],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription_provider_includes_language(tmp_path, provider_cls, language):
|
||||
"""Provider must include the 'language' field in multipart body when set."""
|
||||
audio = tmp_path / "sample.wav"
|
||||
audio.write_bytes(b"audio")
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_stub_async_client(captured)):
|
||||
provider = provider_cls(api_key="k", language=language)
|
||||
result = await provider.transcribe(audio)
|
||||
|
||||
assert result == "hello"
|
||||
assert captured["files"]["language"] == (None, "ko")
|
||||
assert captured["files"]["language"] == (None, language)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
[_GroqProvider, _OpenAIProvider],
|
||||
ids=["groq", "openai"],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription_provider_omits_language_when_none(tmp_path, provider_cls):
|
||||
"""When language is not set, the 'language' key must be absent from the multipart body."""
|
||||
audio = tmp_path / "sample.wav"
|
||||
audio.write_bytes(b"audio")
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_stub_async_client(captured)):
|
||||
provider = provider_cls(api_key="k")
|
||||
result = await provider.transcribe(audio)
|
||||
|
||||
assert result == "hello"
|
||||
assert "language" not in captured["files"]
|
||||
|
||||
|
||||
def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
||||
@ -530,6 +564,24 @@ def test_channels_config_send_max_retries_upper_bound():
|
||||
ChannelsConfig(send_max_retries=11)
|
||||
|
||||
|
||||
def test_channels_config_transcription_language_pattern():
|
||||
"""transcription_language must match ISO-639 format (2-3 lowercase letters) or be None."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
# Valid values
|
||||
assert ChannelsConfig(transcription_language="en").transcription_language == "en"
|
||||
assert ChannelsConfig(transcription_language="kor").transcription_language == "kor"
|
||||
assert ChannelsConfig(transcription_language=None).transcription_language is None
|
||||
|
||||
# Invalid values
|
||||
with pytest.raises(ValidationError):
|
||||
ChannelsConfig(transcription_language="EN") # uppercase
|
||||
with pytest.raises(ValidationError):
|
||||
ChannelsConfig(transcription_language="english") # full word
|
||||
with pytest.raises(ValidationError):
|
||||
ChannelsConfig(transcription_language="en-US") # BCP 47 tag
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user