mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-04 16:55:54 +00:00
fix(transcription): harden language parameter validation and tests
- Add ISO-639 pattern validation (2-3 lowercase letters) to schema - Normalize empty language to None in provider constructors - Extract shared httpx mock stubs, parameterize provider tests - Add test for language=None omitting field from multipart body - Add test for Pydantic pattern validation rejecting invalid codes
This commit is contained in:
parent
123d69bfb7
commit
f6a417e77d
@ -29,7 +29,7 @@ class ChannelsConfig(Base):
|
|||||||
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
||||||
send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included)
|
send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included)
|
||||||
transcription_provider: str = "groq" # Voice transcription backend: "groq" or "openai"
|
transcription_provider: str = "groq" # Voice transcription backend: "groq" or "openai"
|
||||||
transcription_language: str | None = None # Optional ISO-639-1 hint for audio transcription
|
transcription_language: str | None = Field(default=None, pattern=r"^[a-z]{2,3}$") # Optional ISO-639-1 hint for audio transcription
|
||||||
|
|
||||||
|
|
||||||
class DreamConfig(Base):
|
class DreamConfig(Base):
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class OpenAITranscriptionProvider:
|
|||||||
or os.environ.get("OPENAI_TRANSCRIPTION_BASE_URL")
|
or os.environ.get("OPENAI_TRANSCRIPTION_BASE_URL")
|
||||||
or "https://api.openai.com/v1/audio/transcriptions"
|
or "https://api.openai.com/v1/audio/transcriptions"
|
||||||
)
|
)
|
||||||
self.language = language
|
self.language = language or None
|
||||||
|
|
||||||
async def transcribe(self, file_path: str | Path) -> str:
|
async def transcribe(self, file_path: str | Path) -> str:
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
@ -64,7 +64,7 @@ class GroqTranscriptionProvider:
|
|||||||
):
|
):
|
||||||
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
|
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
|
||||||
self.api_url = api_base or os.environ.get("GROQ_BASE_URL") or "https://api.groq.com/openai/v1/audio/transcriptions"
|
self.api_url = api_base or os.environ.get("GROQ_BASE_URL") or "https://api.groq.com/openai/v1/audio/transcriptions"
|
||||||
self.language = language
|
self.language = language or None
|
||||||
|
|
||||||
async def transcribe(self, file_path: str | Path) -> str:
|
async def transcribe(self, file_path: str | Path) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -334,21 +334,24 @@ async def test_base_channel_passes_language_to_groq_transcription_provider():
|
|||||||
assert captured["language"] == "ko"
|
assert captured["language"] == "ko"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
# ---------------------------------------------------------------------------
|
||||||
async def test_groq_transcription_provider_includes_language(tmp_path):
|
# Transcription provider HTTP tests
|
||||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
audio = tmp_path / "sample.wav"
|
from nanobot.providers.transcription import GroqTranscriptionProvider as _GroqProvider
|
||||||
audio.write_bytes(b"audio")
|
from nanobot.providers.transcription import OpenAITranscriptionProvider as _OpenAIProvider
|
||||||
captured: dict[str, object] = {}
|
|
||||||
|
|
||||||
class _Response:
|
|
||||||
def raise_for_status(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def json(self):
|
class _StubResponse:
|
||||||
return {"text": "hello"}
|
def raise_for_status(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return {"text": "hello"}
|
||||||
|
|
||||||
|
|
||||||
|
def _stub_async_client(captured: dict[str, object]):
|
||||||
|
"""Return an httpx.AsyncClient stub that records POST calls into *captured*."""
|
||||||
class _AsyncClient:
|
class _AsyncClient:
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
return self
|
return self
|
||||||
@ -357,19 +360,50 @@ async def test_groq_transcription_provider_includes_language(tmp_path):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def post(self, url, headers=None, files=None, timeout=None):
|
async def post(self, url, headers=None, files=None, timeout=None):
|
||||||
captured["url"] = url
|
|
||||||
captured["headers"] = headers
|
|
||||||
captured["files"] = files
|
captured["files"] = files
|
||||||
captured["timeout"] = timeout
|
return _StubResponse()
|
||||||
return _Response()
|
|
||||||
|
|
||||||
provider = GroqTranscriptionProvider(api_key="k", language="ko")
|
return _AsyncClient()
|
||||||
|
|
||||||
with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_AsyncClient()):
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls,language",
|
||||||
|
[(_GroqProvider, "ko"), (_OpenAIProvider, "en")],
|
||||||
|
ids=["groq", "openai"],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcription_provider_includes_language(tmp_path, provider_cls, language):
|
||||||
|
"""Provider must include the 'language' field in multipart body when set."""
|
||||||
|
audio = tmp_path / "sample.wav"
|
||||||
|
audio.write_bytes(b"audio")
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_stub_async_client(captured)):
|
||||||
|
provider = provider_cls(api_key="k", language=language)
|
||||||
result = await provider.transcribe(audio)
|
result = await provider.transcribe(audio)
|
||||||
|
|
||||||
assert result == "hello"
|
assert result == "hello"
|
||||||
assert captured["files"]["language"] == (None, "ko")
|
assert captured["files"]["language"] == (None, language)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls",
|
||||||
|
[_GroqProvider, _OpenAIProvider],
|
||||||
|
ids=["groq", "openai"],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcription_provider_omits_language_when_none(tmp_path, provider_cls):
|
||||||
|
"""When language is not set, the 'language' key must be absent from the multipart body."""
|
||||||
|
audio = tmp_path / "sample.wav"
|
||||||
|
audio.write_bytes(b"audio")
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_stub_async_client(captured)):
|
||||||
|
provider = provider_cls(api_key="k")
|
||||||
|
result = await provider.transcribe(audio)
|
||||||
|
|
||||||
|
assert result == "hello"
|
||||||
|
assert "language" not in captured["files"]
|
||||||
|
|
||||||
|
|
||||||
def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
||||||
@ -530,6 +564,24 @@ def test_channels_config_send_max_retries_upper_bound():
|
|||||||
ChannelsConfig(send_max_retries=11)
|
ChannelsConfig(send_max_retries=11)
|
||||||
|
|
||||||
|
|
||||||
|
def test_channels_config_transcription_language_pattern():
|
||||||
|
"""transcription_language must match ISO-639 format (2-3 lowercase letters) or be None."""
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
# Valid values
|
||||||
|
assert ChannelsConfig(transcription_language="en").transcription_language == "en"
|
||||||
|
assert ChannelsConfig(transcription_language="kor").transcription_language == "kor"
|
||||||
|
assert ChannelsConfig(transcription_language=None).transcription_language is None
|
||||||
|
|
||||||
|
# Invalid values
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ChannelsConfig(transcription_language="EN") # uppercase
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ChannelsConfig(transcription_language="english") # full word
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ChannelsConfig(transcription_language="en-US") # BCP 47 tag
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _send_with_retry
|
# _send_with_retry
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user