nanobot/tests/providers/test_stepfun_asr.py

428 lines
14 KiB
Python

"""Tests for StepFun ASR SSE transcription provider."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from nanobot.audio.transcription_registry import (
get_transcription_provider,
transcription_provider_names,
)
from nanobot.config.schema import Config
from nanobot.providers.transcription import StepFunTranscriptionProvider
@pytest.fixture
def audio_file(tmp_path: Path) -> Path:
p = tmp_path / "voice.ogg"
p.write_bytes(b"OggS\x00fake-audio-bytes")
return p
# ---------------------------------------------------------------------------
# Defaults and base normalization
# ---------------------------------------------------------------------------
def test_stepfun_defaults() -> None:
provider = StepFunTranscriptionProvider(api_key="sk-test")
assert provider.api_url == "https://api.stepfun.com/v1/audio/asr/sse"
assert provider.model == "stepaudio-2.5-asr"
def test_stepfun_api_base_overrides_url() -> None:
provider = StepFunTranscriptionProvider(
api_key="sk-test",
api_base="https://api.stepfun.com/step_plan/v1/audio/asr/sse",
)
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
def test_stepfun_api_base_appends_asr_path() -> None:
provider = StepFunTranscriptionProvider(
api_key="sk-test",
api_base="https://api.stepfun.com/step_plan/v1",
)
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
def test_stepfun_custom_model() -> None:
provider = StepFunTranscriptionProvider(api_key="sk-test", model="stepaudio-2-asr-pro")
assert provider.model == "stepaudio-2-asr-pro"
# ---------------------------------------------------------------------------
# Short-circuit: missing key / missing file
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_missing_api_key_short_circuits(audio_file: Path) -> None:
with patch.dict("os.environ", {}, clear=True):
provider = StepFunTranscriptionProvider(api_key=None)
stream_mock = MagicMock()
with patch("httpx.AsyncClient.stream", stream_mock):
assert await provider.transcribe(audio_file) == ""
stream_mock.assert_not_called()
@pytest.mark.asyncio
async def test_missing_file_short_circuits(audio_file: Path) -> None:
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_mock = MagicMock()
with patch("httpx.AsyncClient.stream", stream_mock):
assert await provider.transcribe("/nonexistent/path/voice.ogg") == ""
stream_mock.assert_not_called()
# ---------------------------------------------------------------------------
# SSE stream parsing: happy path
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_sse_delta_then_done(audio_file: Path) -> None:
"""Simulates the real SSE event sequence: delta(s) -> text.done."""
events = [
{"type": "transcript.text.delta", "session_id": "s1", "text": ""},
{"type": "transcript.text.delta", "session_id": "s1", "text": "你好"},
{"type": "transcript.text.done", "session_id": "s1", "text": "你好世界"},
]
lines = [f"data: {json.dumps(e)}" for e in events]
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_cm = _make_stream_cm(200, lines)
with patch("httpx.AsyncClient.stream", stream_cm):
result = await provider.transcribe(audio_file)
assert result == "你好世界"
@pytest.mark.asyncio
async def test_sse_only_done_event(audio_file: Path) -> None:
"""Single transcript.text.done event without deltas."""
events = [
{"type": "transcript.text.done", "session_id": "s1", "text": "hello world"},
]
lines = [f"data: {json.dumps(e)}" for e in events]
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_cm = _make_stream_cm(200, lines)
with patch("httpx.AsyncClient.stream", stream_cm):
result = await provider.transcribe(audio_file)
assert result == "hello world"
@pytest.mark.asyncio
async def test_sse_error_event(audio_file: Path) -> None:
"""Error event in SSE stream returns "" immediately."""
events = [
{"type": "error", "session_id": "s1", "message": "audio too short"},
]
lines = [f"data: {json.dumps(e)}" for e in events]
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_cm = _make_stream_cm(200, lines)
with patch("httpx.AsyncClient.stream", stream_cm):
result = await provider.transcribe(audio_file)
assert result == ""
@pytest.mark.asyncio
async def test_sse_ignores_non_data_lines(audio_file: Path) -> None:
"""Empty lines and lines without 'data:' prefix are ignored."""
events = [
{"type": "transcript.text.done", "session_id": "s1", "text": "result"},
]
raw_lines = [
"", # empty line
"event: session.start", # non-data event
f"data: {json.dumps(events[0])}",
]
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_cm = _make_stream_cm(200, raw_lines)
with patch("httpx.AsyncClient.stream", stream_cm):
result = await provider.transcribe(audio_file)
assert result == "result"
@pytest.mark.asyncio
async def test_sse_malformed_json_skipped(audio_file: Path) -> None:
"""Malformed JSON in data lines are skipped gracefully."""
events = [
{"type": "transcript.text.done", "session_id": "s1", "text": "ok"},
]
raw_lines = [
"data: not-json-at-all",
f"data: {json.dumps(events[0])}",
]
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_cm = _make_stream_cm(200, raw_lines)
with patch("httpx.AsyncClient.stream", stream_cm):
result = await provider.transcribe(audio_file)
assert result == "ok"
# ---------------------------------------------------------------------------
# Retry contract
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_retries_on_503_then_succeeds(audio_file: Path) -> None:
"""Transient 503 is retried, then a successful SSE stream yields text."""
success_lines = [
f"data: {json.dumps({'type': 'transcript.text.done', 'session_id': 's1', 'text': 'ok'})}",
]
# First call: 503 (FailingResponse), second call: success (FakeResponse with lines)
stream_cm = _make_stream_cm_sequence([503, success_lines])
provider = StepFunTranscriptionProvider(api_key="sk-test")
with patch("httpx.AsyncClient.stream", stream_cm), patch(
"asyncio.sleep", AsyncMock()
):
result = await provider.transcribe(audio_file)
assert result == "ok"
@pytest.mark.asyncio
async def test_gives_up_after_max_retries(audio_file: Path) -> None:
"""Persistent 503 returns "" after all retries exhausted."""
attempts: list[list[str] | int] = [503, 503, 503, 503] # 4 failing HTTP responses
stream_cm = _make_stream_cm_sequence(attempts)
provider = StepFunTranscriptionProvider(api_key="sk-test")
with patch("httpx.AsyncClient.stream", stream_cm), patch(
"asyncio.sleep", AsyncMock()
):
result = await provider.transcribe(audio_file)
assert result == ""
@pytest.mark.asyncio
async def test_sse_empty_text_done_returns_empty(audio_file: Path) -> None:
"""Empty text in transcript.text.done should return "" immediately, not retry."""
events = [
{"type": "transcript.text.done", "session_id": "s1", "text": ""},
]
lines = [f"data: {json.dumps(e)}" for e in events]
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_cm = _make_stream_cm(200, lines)
with patch("httpx.AsyncClient.stream", stream_cm), patch(
"asyncio.sleep", AsyncMock()
):
result = await provider.transcribe(audio_file)
assert result == ""
@pytest.mark.asyncio
async def test_401_returns_empty_without_retry(audio_file: Path) -> None:
"""401 is not retryable; bad credentials should fail immediately."""
stream_cm = _make_stream_cm(401, [])
sleep = AsyncMock()
provider = StepFunTranscriptionProvider(api_key="sk-test")
with patch("httpx.AsyncClient.stream", stream_cm), patch("asyncio.sleep", sleep):
result = await provider.transcribe(audio_file)
assert result == ""
assert stream_cm.call_count == 1
sleep.assert_not_awaited()
@pytest.mark.asyncio
async def test_retries_on_connect_error(audio_file: Path) -> None:
"""Network-level transient errors are retried."""
success_lines = [
f"data: {json.dumps({'type': 'transcript.text.done', 'session_id': 's1', 'text': 'ok'})}",
]
call_count = [0]
class FakeResponse:
"""Serves as both the async context manager returned by stream()
and the response object bound in `async with ... as resp`."""
status_code = 200
reason_phrase = "OK"
async def __aenter__(self) -> "FakeResponse":
return self
async def __aexit__(self, *exc: object) -> None:
pass
async def aiter_lines(self) -> Any:
for line in success_lines:
yield line
def raise_for_status(self) -> None:
pass
def fake_stream(method: str, url: str, *args: object, **kwargs: object) -> FakeResponse:
call_count[0] += 1
if call_count[0] == 1:
raise httpx.ConnectError("boom")
return FakeResponse()
provider = StepFunTranscriptionProvider(api_key="sk-test")
with patch("httpx.AsyncClient.stream", fake_stream), patch(
"asyncio.sleep", AsyncMock()
):
result = await provider.transcribe(audio_file)
assert result == "ok"
assert call_count[0] == 2
# ---------------------------------------------------------------------------
# Registry integration
# ---------------------------------------------------------------------------
def test_stepfun_in_registry() -> None:
assert "stepfun" in transcription_provider_names()
spec = get_transcription_provider("stepfun")
assert spec is not None
assert spec.default_model == "stepaudio-2.5-asr"
assert spec.adapter == "nanobot.providers.transcription:StepFunTranscriptionProvider"
def test_config_resolves_stepfun() -> None:
config = Config()
config.transcription.provider = "stepfun"
config.transcription.model = "stepaudio-2.5-asr"
config.transcription.language = "zh"
config.providers.stepfun.api_key = "step-test"
config.providers.stepfun.api_base = "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
from nanobot.audio.transcription import resolve_transcription_config
resolved = resolve_transcription_config(config)
assert resolved.provider == "stepfun"
assert resolved.model == "stepaudio-2.5-asr"
assert resolved.language == "zh"
assert resolved.api_key == "step-test"
assert resolved.api_base == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
assert resolved.configured is True
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_stream_cm(status: int, lines: list[str]) -> MagicMock:
"""Build a mock for `AsyncClient.stream` that yields *lines* as SSE."""
class FakeResponse:
def __init__(self) -> None:
self.status_code = status
self.reason_phrase = "OK" if status == 200 else "Error"
async def __aenter__(self) -> "FakeResponse":
return self
async def __aexit__(self, *exc: object) -> None:
pass
async def aiter_lines(self) -> Any:
for line in lines:
yield line
def raise_for_status(self) -> None:
if self.status_code >= 400:
raise httpx.HTTPStatusError(
f"HTTP {self.status_code}",
request=httpx.Request("POST", "https://example.test"),
response=httpx.Response(self.status_code),
)
cm = MagicMock()
cm.return_value = FakeResponse()
return cm
def _make_stream_cm_sequence(statuses: list[str | int]) -> MagicMock:
"""Build a stream mock that fails with HTTP status ints, then succeeds with SSE lines.
Entries in *statuses* that are ints produce a stream that raises HTTPStatusError
after `raise_for_status()`. The final entry (a list of SSE lines) succeeds.
"""
remaining = list(statuses)
class FakeResponse:
def __init__(self, lines: list[str]) -> None:
self._lines = lines
self.status_code = 200
self.reason_phrase = "OK"
async def __aenter__(self) -> "FakeResponse":
return self
async def __aexit__(self, *exc: object) -> None:
pass
async def aiter_lines(self) -> Any:
for line in self._lines:
yield line
def raise_for_status(self) -> None:
pass
class FailingResponse:
def __init__(self, status: int) -> None:
self.status_code = status
self.reason_phrase = "Error"
async def __aenter__(self) -> "FailingResponse":
return self
async def __aexit__(self, *exc: object) -> None:
pass
async def aiter_lines(self) -> Any:
yield ""
return
def raise_for_status(self) -> None:
raise httpx.HTTPStatusError(
f"HTTP {self.status_code}",
request=httpx.Request("POST", "https://example.test"),
response=httpx.Response(self.status_code),
)
call_count = [0]
def _next(method: str, url: str, **kwargs: object) -> Any:
idx = min(call_count[0], len(remaining) - 1)
entry = remaining[idx]
call_count[0] += 1
if isinstance(entry, int):
return FailingResponse(entry)
return FakeResponse(entry)
cm = MagicMock(side_effect=_next)
return cm