mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
428 lines
14 KiB
Python
428 lines
14 KiB
Python
"""Tests for StepFun ASR SSE transcription provider."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from nanobot.audio.transcription_registry import (
|
|
get_transcription_provider,
|
|
transcription_provider_names,
|
|
)
|
|
from nanobot.config.schema import Config
|
|
from nanobot.providers.transcription import StepFunTranscriptionProvider
|
|
|
|
|
|
@pytest.fixture
|
|
def audio_file(tmp_path: Path) -> Path:
|
|
p = tmp_path / "voice.ogg"
|
|
p.write_bytes(b"OggS\x00fake-audio-bytes")
|
|
return p
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Defaults and base normalization
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_stepfun_defaults() -> None:
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
|
assert provider.api_url == "https://api.stepfun.com/v1/audio/asr/sse"
|
|
assert provider.model == "stepaudio-2.5-asr"
|
|
|
|
|
|
def test_stepfun_api_base_overrides_url() -> None:
|
|
provider = StepFunTranscriptionProvider(
|
|
api_key="sk-test",
|
|
api_base="https://api.stepfun.com/step_plan/v1/audio/asr/sse",
|
|
)
|
|
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
|
|
|
|
|
|
def test_stepfun_api_base_appends_asr_path() -> None:
|
|
provider = StepFunTranscriptionProvider(
|
|
api_key="sk-test",
|
|
api_base="https://api.stepfun.com/step_plan/v1",
|
|
)
|
|
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
|
|
|
|
|
|
def test_stepfun_custom_model() -> None:
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test", model="stepaudio-2-asr-pro")
|
|
assert provider.model == "stepaudio-2-asr-pro"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Short-circuit: missing key / missing file
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_api_key_short_circuits(audio_file: Path) -> None:
|
|
with patch.dict("os.environ", {}, clear=True):
|
|
provider = StepFunTranscriptionProvider(api_key=None)
|
|
stream_mock = MagicMock()
|
|
with patch("httpx.AsyncClient.stream", stream_mock):
|
|
assert await provider.transcribe(audio_file) == ""
|
|
stream_mock.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_file_short_circuits(audio_file: Path) -> None:
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
|
stream_mock = MagicMock()
|
|
with patch("httpx.AsyncClient.stream", stream_mock):
|
|
assert await provider.transcribe("/nonexistent/path/voice.ogg") == ""
|
|
stream_mock.assert_not_called()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SSE stream parsing: happy path
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sse_delta_then_done(audio_file: Path) -> None:
|
|
"""Simulates the real SSE event sequence: delta(s) -> text.done."""
|
|
events = [
|
|
{"type": "transcript.text.delta", "session_id": "s1", "text": "你"},
|
|
{"type": "transcript.text.delta", "session_id": "s1", "text": "你好"},
|
|
{"type": "transcript.text.done", "session_id": "s1", "text": "你好世界"},
|
|
]
|
|
lines = [f"data: {json.dumps(e)}" for e in events]
|
|
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
|
stream_cm = _make_stream_cm(200, lines)
|
|
|
|
with patch("httpx.AsyncClient.stream", stream_cm):
|
|
result = await provider.transcribe(audio_file)
|
|
|
|
assert result == "你好世界"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sse_only_done_event(audio_file: Path) -> None:
|
|
"""Single transcript.text.done event without deltas."""
|
|
events = [
|
|
{"type": "transcript.text.done", "session_id": "s1", "text": "hello world"},
|
|
]
|
|
lines = [f"data: {json.dumps(e)}" for e in events]
|
|
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
|
stream_cm = _make_stream_cm(200, lines)
|
|
|
|
with patch("httpx.AsyncClient.stream", stream_cm):
|
|
result = await provider.transcribe(audio_file)
|
|
|
|
assert result == "hello world"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sse_error_event(audio_file: Path) -> None:
|
|
"""Error event in SSE stream returns "" immediately."""
|
|
events = [
|
|
{"type": "error", "session_id": "s1", "message": "audio too short"},
|
|
]
|
|
lines = [f"data: {json.dumps(e)}" for e in events]
|
|
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
|
stream_cm = _make_stream_cm(200, lines)
|
|
|
|
with patch("httpx.AsyncClient.stream", stream_cm):
|
|
result = await provider.transcribe(audio_file)
|
|
|
|
assert result == ""
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sse_ignores_non_data_lines(audio_file: Path) -> None:
|
|
"""Empty lines and lines without 'data:' prefix are ignored."""
|
|
events = [
|
|
{"type": "transcript.text.done", "session_id": "s1", "text": "result"},
|
|
]
|
|
raw_lines = [
|
|
"", # empty line
|
|
"event: session.start", # non-data event
|
|
f"data: {json.dumps(events[0])}",
|
|
]
|
|
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
|
stream_cm = _make_stream_cm(200, raw_lines)
|
|
|
|
with patch("httpx.AsyncClient.stream", stream_cm):
|
|
result = await provider.transcribe(audio_file)
|
|
|
|
assert result == "result"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sse_malformed_json_skipped(audio_file: Path) -> None:
|
|
"""Malformed JSON in data lines are skipped gracefully."""
|
|
events = [
|
|
{"type": "transcript.text.done", "session_id": "s1", "text": "ok"},
|
|
]
|
|
raw_lines = [
|
|
"data: not-json-at-all",
|
|
f"data: {json.dumps(events[0])}",
|
|
]
|
|
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
|
stream_cm = _make_stream_cm(200, raw_lines)
|
|
|
|
with patch("httpx.AsyncClient.stream", stream_cm):
|
|
result = await provider.transcribe(audio_file)
|
|
|
|
assert result == "ok"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Retry contract
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retries_on_503_then_succeeds(audio_file: Path) -> None:
|
|
"""Transient 503 is retried, then a successful SSE stream yields text."""
|
|
success_lines = [
|
|
f"data: {json.dumps({'type': 'transcript.text.done', 'session_id': 's1', 'text': 'ok'})}",
|
|
]
|
|
# First call: 503 (FailingResponse), second call: success (FakeResponse with lines)
|
|
stream_cm = _make_stream_cm_sequence([503, success_lines])
|
|
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
|
with patch("httpx.AsyncClient.stream", stream_cm), patch(
|
|
"asyncio.sleep", AsyncMock()
|
|
):
|
|
result = await provider.transcribe(audio_file)
|
|
|
|
assert result == "ok"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_gives_up_after_max_retries(audio_file: Path) -> None:
|
|
"""Persistent 503 returns "" after all retries exhausted."""
|
|
attempts: list[list[str] | int] = [503, 503, 503, 503] # 4 failing HTTP responses
|
|
stream_cm = _make_stream_cm_sequence(attempts)
|
|
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
|
with patch("httpx.AsyncClient.stream", stream_cm), patch(
|
|
"asyncio.sleep", AsyncMock()
|
|
):
|
|
result = await provider.transcribe(audio_file)
|
|
|
|
assert result == ""
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sse_empty_text_done_returns_empty(audio_file: Path) -> None:
|
|
"""Empty text in transcript.text.done should return "" immediately, not retry."""
|
|
events = [
|
|
{"type": "transcript.text.done", "session_id": "s1", "text": ""},
|
|
]
|
|
lines = [f"data: {json.dumps(e)}" for e in events]
|
|
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
|
stream_cm = _make_stream_cm(200, lines)
|
|
|
|
with patch("httpx.AsyncClient.stream", stream_cm), patch(
|
|
"asyncio.sleep", AsyncMock()
|
|
):
|
|
result = await provider.transcribe(audio_file)
|
|
|
|
assert result == ""
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_401_returns_empty_without_retry(audio_file: Path) -> None:
|
|
"""401 is not retryable; bad credentials should fail immediately."""
|
|
stream_cm = _make_stream_cm(401, [])
|
|
sleep = AsyncMock()
|
|
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
|
with patch("httpx.AsyncClient.stream", stream_cm), patch("asyncio.sleep", sleep):
|
|
result = await provider.transcribe(audio_file)
|
|
|
|
assert result == ""
|
|
assert stream_cm.call_count == 1
|
|
sleep.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retries_on_connect_error(audio_file: Path) -> None:
|
|
"""Network-level transient errors are retried."""
|
|
success_lines = [
|
|
f"data: {json.dumps({'type': 'transcript.text.done', 'session_id': 's1', 'text': 'ok'})}",
|
|
]
|
|
call_count = [0]
|
|
|
|
class FakeResponse:
|
|
"""Serves as both the async context manager returned by stream()
|
|
and the response object bound in `async with ... as resp`."""
|
|
status_code = 200
|
|
reason_phrase = "OK"
|
|
|
|
async def __aenter__(self) -> "FakeResponse":
|
|
return self
|
|
|
|
async def __aexit__(self, *exc: object) -> None:
|
|
pass
|
|
|
|
async def aiter_lines(self) -> Any:
|
|
for line in success_lines:
|
|
yield line
|
|
|
|
def raise_for_status(self) -> None:
|
|
pass
|
|
|
|
def fake_stream(method: str, url: str, *args: object, **kwargs: object) -> FakeResponse:
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
raise httpx.ConnectError("boom")
|
|
return FakeResponse()
|
|
|
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
|
with patch("httpx.AsyncClient.stream", fake_stream), patch(
|
|
"asyncio.sleep", AsyncMock()
|
|
):
|
|
result = await provider.transcribe(audio_file)
|
|
|
|
assert result == "ok"
|
|
assert call_count[0] == 2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Registry integration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_stepfun_in_registry() -> None:
|
|
assert "stepfun" in transcription_provider_names()
|
|
spec = get_transcription_provider("stepfun")
|
|
assert spec is not None
|
|
assert spec.default_model == "stepaudio-2.5-asr"
|
|
assert spec.adapter == "nanobot.providers.transcription:StepFunTranscriptionProvider"
|
|
|
|
|
|
def test_config_resolves_stepfun() -> None:
|
|
config = Config()
|
|
config.transcription.provider = "stepfun"
|
|
config.transcription.model = "stepaudio-2.5-asr"
|
|
config.transcription.language = "zh"
|
|
config.providers.stepfun.api_key = "step-test"
|
|
config.providers.stepfun.api_base = "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
|
|
|
|
from nanobot.audio.transcription import resolve_transcription_config
|
|
|
|
resolved = resolve_transcription_config(config)
|
|
|
|
assert resolved.provider == "stepfun"
|
|
assert resolved.model == "stepaudio-2.5-asr"
|
|
assert resolved.language == "zh"
|
|
assert resolved.api_key == "step-test"
|
|
assert resolved.api_base == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
|
|
assert resolved.configured is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_stream_cm(status: int, lines: list[str]) -> MagicMock:
|
|
"""Build a mock for `AsyncClient.stream` that yields *lines* as SSE."""
|
|
|
|
class FakeResponse:
|
|
def __init__(self) -> None:
|
|
self.status_code = status
|
|
self.reason_phrase = "OK" if status == 200 else "Error"
|
|
|
|
async def __aenter__(self) -> "FakeResponse":
|
|
return self
|
|
|
|
async def __aexit__(self, *exc: object) -> None:
|
|
pass
|
|
|
|
async def aiter_lines(self) -> Any:
|
|
for line in lines:
|
|
yield line
|
|
|
|
def raise_for_status(self) -> None:
|
|
if self.status_code >= 400:
|
|
raise httpx.HTTPStatusError(
|
|
f"HTTP {self.status_code}",
|
|
request=httpx.Request("POST", "https://example.test"),
|
|
response=httpx.Response(self.status_code),
|
|
)
|
|
|
|
cm = MagicMock()
|
|
cm.return_value = FakeResponse()
|
|
return cm
|
|
|
|
|
|
def _make_stream_cm_sequence(statuses: list[str | int]) -> MagicMock:
|
|
"""Build a stream mock that fails with HTTP status ints, then succeeds with SSE lines.
|
|
|
|
Entries in *statuses* that are ints produce a stream that raises HTTPStatusError
|
|
after `raise_for_status()`. The final entry (a list of SSE lines) succeeds.
|
|
"""
|
|
remaining = list(statuses)
|
|
|
|
class FakeResponse:
|
|
def __init__(self, lines: list[str]) -> None:
|
|
self._lines = lines
|
|
self.status_code = 200
|
|
self.reason_phrase = "OK"
|
|
|
|
async def __aenter__(self) -> "FakeResponse":
|
|
return self
|
|
|
|
async def __aexit__(self, *exc: object) -> None:
|
|
pass
|
|
|
|
async def aiter_lines(self) -> Any:
|
|
for line in self._lines:
|
|
yield line
|
|
|
|
def raise_for_status(self) -> None:
|
|
pass
|
|
|
|
class FailingResponse:
|
|
def __init__(self, status: int) -> None:
|
|
self.status_code = status
|
|
self.reason_phrase = "Error"
|
|
|
|
async def __aenter__(self) -> "FailingResponse":
|
|
return self
|
|
|
|
async def __aexit__(self, *exc: object) -> None:
|
|
pass
|
|
|
|
async def aiter_lines(self) -> Any:
|
|
yield ""
|
|
return
|
|
|
|
def raise_for_status(self) -> None:
|
|
raise httpx.HTTPStatusError(
|
|
f"HTTP {self.status_code}",
|
|
request=httpx.Request("POST", "https://example.test"),
|
|
response=httpx.Response(self.status_code),
|
|
)
|
|
|
|
call_count = [0]
|
|
|
|
def _next(method: str, url: str, **kwargs: object) -> Any:
|
|
idx = min(call_count[0], len(remaining) - 1)
|
|
entry = remaining[idx]
|
|
call_count[0] += 1
|
|
if isinstance(entry, int):
|
|
return FailingResponse(entry)
|
|
return FakeResponse(entry)
|
|
|
|
cm = MagicMock(side_effect=_next)
|
|
return cm
|