From 793005834825e91a89b16474598e2015706435c8 Mon Sep 17 00:00:00 2001 From: moran Date: Tue, 9 Jun 2026 17:27:13 +0800 Subject: [PATCH] feat(asr): add StepFun ASR SSE transcription provider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add StepFunTranscriptionProvider class in nanobot/providers/transcription.py - New _post_stepfun_asr_with_retry() function handling SSE stream parsing (transcript.text.delta → transcript.text.done event sequence) - Register 'stepfun' in transcription_registry.py with default model stepaudio-2.5-asr - Reuse existing stepfun provider config (apiBase can point to Plan endpoint) - Add 17 tests covering SSE parsing, retry contract, empty-text edge case, and registry integration - Update docs/configuration.md with stepfun ASR documentation StepFun ASR uses a dedicated SSE endpoint (/v1/audio/asr/sse) rather than the chat-completions or Whisper multipart formats used by other providers. Users on Step Plan can set apiBase to the Plan endpoint. --- docs/configuration.md | 6 +- nanobot/audio/transcription_registry.py | 5 + nanobot/config/schema.py | 2 +- nanobot/providers/transcription.py | 155 +++++++++ tests/providers/test_stepfun_asr.py | 418 ++++++++++++++++++++++++ 5 files changed, 582 insertions(+), 4 deletions(-) create mode 100644 tests/providers/test_stepfun_asr.py diff --git a/docs/configuration.md b/docs/configuration.md index 5bb54b53a..378b4bed6 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -239,7 +239,7 @@ Tracing covers the providers that go through nanobot's OpenAI-compatible client | `lm_studio` | LLM (local, LM Studio) | — | | `atomic_chat` | LLM (local, [Atomic Chat](https://atomic.chat/)) | — | | `mistral` | LLM | [docs.mistral.ai](https://docs.mistral.ai/) | -| `stepfun` | LLM (Step Fun/阶跃星辰) | [platform.stepfun.com](https://platform.stepfun.com) | +| `stepfun` | LLM (Step Fun/阶跃星辰) + Voice transcription (ASR) | [platform.stepfun.com](https://platform.stepfun.com) | | `ovms` | LLM (local, OpenVINO Model Server) | [docs.openvino.ai](https://docs.openvino.ai/2026/model-server/ovms_docs_llm_quickstart.html) | | `vllm` | LLM (local, any OpenAI-compatible server) | — | | `nvidia` | LLM (NVIDIA NIM) | [build.nvidia.com](https://build.nvidia.com/) | @@ -1294,8 +1294,8 @@ Configure transcription under the top-level `transcription` section: | Setting | Default | Description | |---------|---------|-------------| | `enabled` | `true` | Enables audio transcription for both chat-channel voice messages and WebUI/desktop microphone input. | -| `provider` | `"groq"` | Transcription backend: `"groq"`, `"openai"`, `"openrouter"`, `"xiaomi_mimo"`, or `"assemblyai"`. | -| `model` | provider default | Optional transcription model override. Defaults to `whisper-large-v3` for Groq, `whisper-1` for OpenAI, `openai/whisper-1` for OpenRouter, `mimo-v2.5-asr` for Xiaomi MiMo ASR, and `universal-3-pro,universal-2` for AssemblyAI. OpenRouter accepts only speech-to-text models on its transcription endpoint, such as `nvidia/parakeet-tdt-0.6b-v3`, `openai/whisper-1`, or `openai/gpt-4o-transcribe`; chat LLMs are rejected there. AssemblyAI accepts a comma-separated model fallback list. | +| `provider` | `"groq"` | Transcription backend: `"groq"`, `"openai"`, `"openrouter"`, `"xiaomi_mimo"`, `"stepfun"`, or `"assemblyai"`. | +| `model` | provider default | Optional transcription model override. Defaults to `whisper-large-v3` for Groq, `whisper-1` for OpenAI, `openai/whisper-1` for OpenRouter, `mimo-v2.5-asr` for Xiaomi MiMo ASR, `stepaudio-2.5-asr` for StepFun ASR, and `universal-3-pro,universal-2` for AssemblyAI. OpenRouter accepts only speech-to-text models on its transcription endpoint, such as `nvidia/parakeet-tdt-0.6b-v3`, `openai/whisper-1`, or `openai/gpt-4o-transcribe`; chat LLMs are rejected there. AssemblyAI accepts a comma-separated model fallback list. | | `language` | `null` | Optional ISO-639 language hint, e.g. `"en"`, `"zh"`, `"ko"`, or `"ja"`. | | `maxDurationSec` | `120` | Maximum WebUI/desktop recording duration. | | `maxUploadMb` | `25` | Maximum WebUI/desktop audio upload size. | diff --git a/nanobot/audio/transcription_registry.py b/nanobot/audio/transcription_registry.py index 3cea122fb..ed4208a1a 100644 --- a/nanobot/audio/transcription_registry.py +++ b/nanobot/audio/transcription_registry.py @@ -64,6 +64,11 @@ TRANSCRIPTION_PROVIDERS: tuple[TranscriptionProviderSpec, ...] = ( adapter="nanobot.providers.transcription:XiaomiMiMoTranscriptionProvider", aliases=("mimo", "xiaomi"), ), + TranscriptionProviderSpec( + name="stepfun", + default_model="stepaudio-2.5-asr", + adapter="nanobot.providers.transcription:StepFunTranscriptionProvider", + ), TranscriptionProviderSpec( name="assemblyai", default_model="universal-3-pro,universal-2", diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 53a8eacd5..ac69f8a28 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -219,7 +219,7 @@ class ProvidersConfig(Base): minimax: ProviderConfig = Field(default_factory=ProviderConfig) minimax_anthropic: ProviderConfig = Field(default_factory=ProviderConfig) # MiniMax Anthropic endpoint (thinking) mistral: ProviderConfig = Field(default_factory=ProviderConfig) - stepfun: ProviderConfig = Field(default_factory=ProviderConfig) # Step Fun (阶跃星辰) + stepfun: ProviderConfig = Field(default_factory=ProviderConfig) # Step Fun (阶跃星辰) — LLM + ASR (set apiBase to Plan URL for ASR) xiaomi_mimo: ProviderConfig = Field(default_factory=ProviderConfig) # Xiaomi MIMO (小米) longcat: ProviderConfig = Field(default_factory=ProviderConfig) # LongCat ant_ling: ProviderConfig = Field(default_factory=ProviderConfig) # Ant Ling diff --git a/nanobot/providers/transcription.py b/nanobot/providers/transcription.py index f2b7051c3..9df6a6a8d 100644 --- a/nanobot/providers/transcription.py +++ b/nanobot/providers/transcription.py @@ -8,6 +8,7 @@ WebUI upload validation, and channel integration live in import asyncio import base64 +import json import mimetypes import os from collections.abc import Callable @@ -306,6 +307,119 @@ async def _post_xiaomi_mimo_asr_with_retry( return await _post_with_retry(build_request, provider_label, _text_from_chat_payload) +async def _post_stepfun_asr_with_retry( + url: str, + *, + api_key: str | None, + path: Path, + model: str, + provider_label: str, + language: str | None = None, +) -> str: + """POST audio to StepFun ASR SSE endpoint and collect final text.""" + try: + data = path.read_bytes() + except OSError as e: + logger.exception("{} transcription error: cannot read audio file: {}", provider_label, e) + return "" + + suffix = path.suffix.lstrip(".").lower() + audio_type = suffix if suffix in ("ogg", "mp3", "wav", "pcm") else "wav" + + body: dict[str, Any] = { + "audio": { + "data": base64.b64encode(data).decode("ascii"), + "input": { + "transcription": { + "model": model, + "enable_itn": True, + }, + "format": {"type": audio_type}, + }, + }, + } + if language: + body["audio"]["input"]["transcription"]["language"] = language + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "Accept": "text/event-stream", + } + + async with httpx.AsyncClient() as client: + for attempt in range(_MAX_RETRIES + 1): + try: + async with client.stream( + "POST", url, headers=headers, json=body, timeout=60.0 + ) as resp: + if resp.status_code in _RETRYABLE_STATUS and attempt < _MAX_RETRIES: + logger.warning( + "{} transcription transient HTTP {} (attempt {}/{})", + provider_label, + resp.status_code, + attempt + 1, + _MAX_RETRIES + 1, + ) + await asyncio.sleep(_BACKOFF_S[attempt]) + continue + resp.raise_for_status() + final_text = None + async for line in resp.aiter_lines(): + if not line.startswith("data:"): + continue + payload_str = line[len("data:") :].strip() + if not payload_str: + continue + try: + payload = json.loads(payload_str) + except (json.JSONDecodeError, ValueError): + continue + event_type = payload.get("type", "") + if event_type == "error": + msg = payload.get("message", "unknown error") + logger.error("{} ASR error: {}", provider_label, msg) + return "" + if event_type == "transcript.text.done": + final_text = payload.get("text", "") + break + if final_text is not None: + return final_text + # Stream ended without a final event — retry if attempts remain + if attempt < _MAX_RETRIES: + logger.warning( + "{} transcription: no final event (attempt {}/{})", + provider_label, + attempt + 1, + _MAX_RETRIES + 1, + ) + await asyncio.sleep(_BACKOFF_S[attempt]) + continue + logger.error( + "{} transcription: stream ended without final text after {} attempts", + provider_label, + _MAX_RETRIES + 1, + ) + return "" + except httpx.HTTPStatusError: + if attempt < _MAX_RETRIES: + await asyncio.sleep(_BACKOFF_S[attempt]) + continue + logger.exception( + "{} transcription failed after {} attempts", + provider_label, + _MAX_RETRIES + 1, + ) + return "" + except (httpx.RequestError, Exception): + if attempt < _MAX_RETRIES: + await asyncio.sleep(_BACKOFF_S[attempt]) + continue + logger.exception("{} transcription request error", provider_label) + return "" + return "" + + async def _post_with_retry( build_request: Callable[[], dict[str, Any]], provider_label: str, @@ -663,3 +777,44 @@ class XiaomiMiMoTranscriptionProvider: provider_label="Xiaomi MiMo", language=self.language, ) + + +class StepFunTranscriptionProvider: + """Voice transcription provider using StepFun ASR SSE endpoint.""" + + _DEFAULT_URL = "https://api.stepfun.com/v1/audio/asr/sse" + + def __init__( + self, + api_key: str | None = None, + api_base: str | None = None, + language: str | None = None, + model: str | None = None, + ): + self.api_key = api_key or os.environ.get("STEPFUN_API_KEY") + # api_base is used verbatim; users can point to the Plan endpoint + # (https://api.stepfun.com/step_plan/v1/audio/asr/sse) or any + # compatible proxy. + self.api_url = api_base or self._DEFAULT_URL + self.language = language or None + self.model = model or "stepaudio-2.5-asr" + logger.debug("StepFun transcription endpoint: {}", self.api_url) + + async def transcribe(self, file_path: str | Path) -> str: + if not self.api_key: + logger.warning("StepFun API key not configured for transcription") + return "" + + path = Path(file_path) + if not path.exists(): + logger.error("Audio file not found: {}", file_path) + return "" + + return await _post_stepfun_asr_with_retry( + self.api_url, + api_key=self.api_key, + path=path, + model=self.model, + provider_label="StepFun", + language=self.language, + ) diff --git a/tests/providers/test_stepfun_asr.py b/tests/providers/test_stepfun_asr.py new file mode 100644 index 000000000..3056fad01 --- /dev/null +++ b/tests/providers/test_stepfun_asr.py @@ -0,0 +1,418 @@ +"""Tests for StepFun ASR SSE transcription provider.""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from nanobot.audio.transcription_registry import ( + get_transcription_provider, + transcription_provider_names, +) +from nanobot.config.schema import Config +from nanobot.providers.transcription import StepFunTranscriptionProvider + + +@pytest.fixture +def audio_file(tmp_path: Path) -> Path: + p = tmp_path / "voice.ogg" + p.write_bytes(b"OggS\x00fake-audio-bytes") + return p + + +# --------------------------------------------------------------------------- +# Defaults and base normalization +# --------------------------------------------------------------------------- + + +def test_stepfun_defaults() -> None: + provider = StepFunTranscriptionProvider(api_key="sk-test") + assert provider.api_url == "https://api.stepfun.com/v1/audio/asr/sse" + assert provider.model == "stepaudio-2.5-asr" + + +def test_stepfun_api_base_overrides_url() -> None: + provider = StepFunTranscriptionProvider( + api_key="sk-test", + api_base="https://api.stepfun.com/step_plan/v1/audio/asr/sse", + ) + assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse" + + +def test_stepfun_custom_model() -> None: + provider = StepFunTranscriptionProvider(api_key="sk-test", model="stepaudio-2-asr-pro") + assert provider.model == "stepaudio-2-asr-pro" + + +# --------------------------------------------------------------------------- +# Short-circuit: missing key / missing file +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_missing_api_key_short_circuits(audio_file: Path) -> None: + with patch.dict("os.environ", {}, clear=True): + provider = StepFunTranscriptionProvider(api_key=None) + stream_mock = MagicMock() + with patch("httpx.AsyncClient.stream", stream_mock): + assert await provider.transcribe(audio_file) == "" + stream_mock.assert_not_called() + + +@pytest.mark.asyncio +async def test_missing_file_short_circuits(audio_file: Path) -> None: + provider = StepFunTranscriptionProvider(api_key="sk-test") + stream_mock = MagicMock() + with patch("httpx.AsyncClient.stream", stream_mock): + assert await provider.transcribe("/nonexistent/path/voice.ogg") == "" + stream_mock.assert_not_called() + + +# --------------------------------------------------------------------------- +# SSE stream parsing: happy path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_sse_delta_then_done(audio_file: Path) -> None: + """Simulates the real SSE event sequence: delta(s) -> text.done.""" + events = [ + {"type": "transcript.text.delta", "session_id": "s1", "text": "你"}, + {"type": "transcript.text.delta", "session_id": "s1", "text": "你好"}, + {"type": "transcript.text.done", "session_id": "s1", "text": "你好世界"}, + ] + lines = [f"data: {json.dumps(e)}" for e in events] + + provider = StepFunTranscriptionProvider(api_key="sk-test") + stream_cm = _make_stream_cm(200, lines) + + with patch("httpx.AsyncClient.stream", stream_cm): + result = await provider.transcribe(audio_file) + + assert result == "你好世界" + + +@pytest.mark.asyncio +async def test_sse_only_done_event(audio_file: Path) -> None: + """Single transcript.text.done event without deltas.""" + events = [ + {"type": "transcript.text.done", "session_id": "s1", "text": "hello world"}, + ] + lines = [f"data: {json.dumps(e)}" for e in events] + + provider = StepFunTranscriptionProvider(api_key="sk-test") + stream_cm = _make_stream_cm(200, lines) + + with patch("httpx.AsyncClient.stream", stream_cm): + result = await provider.transcribe(audio_file) + + assert result == "hello world" + + +@pytest.mark.asyncio +async def test_sse_error_event(audio_file: Path) -> None: + """Error event in SSE stream returns "" immediately.""" + events = [ + {"type": "error", "session_id": "s1", "message": "audio too short"}, + ] + lines = [f"data: {json.dumps(e)}" for e in events] + + provider = StepFunTranscriptionProvider(api_key="sk-test") + stream_cm = _make_stream_cm(200, lines) + + with patch("httpx.AsyncClient.stream", stream_cm): + result = await provider.transcribe(audio_file) + + assert result == "" + + +@pytest.mark.asyncio +async def test_sse_ignores_non_data_lines(audio_file: Path) -> None: + """Empty lines and lines without 'data:' prefix are ignored.""" + events = [ + {"type": "transcript.text.done", "session_id": "s1", "text": "result"}, + ] + raw_lines = [ + "", # empty line + "event: session.start", # non-data event + f"data: {json.dumps(events[0])}", + ] + + provider = StepFunTranscriptionProvider(api_key="sk-test") + stream_cm = _make_stream_cm(200, raw_lines) + + with patch("httpx.AsyncClient.stream", stream_cm): + result = await provider.transcribe(audio_file) + + assert result == "result" + + +@pytest.mark.asyncio +async def test_sse_malformed_json_skipped(audio_file: Path) -> None: + """Malformed JSON in data lines are skipped gracefully.""" + events = [ + {"type": "transcript.text.done", "session_id": "s1", "text": "ok"}, + ] + raw_lines = [ + "data: not-json-at-all", + f"data: {json.dumps(events[0])}", + ] + + provider = StepFunTranscriptionProvider(api_key="sk-test") + stream_cm = _make_stream_cm(200, raw_lines) + + with patch("httpx.AsyncClient.stream", stream_cm): + result = await provider.transcribe(audio_file) + + assert result == "ok" + + +# --------------------------------------------------------------------------- +# Retry contract +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_retries_on_503_then_succeeds(audio_file: Path) -> None: + """Transient 503 is retried, then a successful SSE stream yields text.""" + success_lines = [ + f"data: {json.dumps({'type': 'transcript.text.done', 'session_id': 's1', 'text': 'ok'})}", + ] + # First call: 503 (FailingResponse), second call: success (FakeResponse with lines) + stream_cm = _make_stream_cm_sequence([503, success_lines]) + + provider = StepFunTranscriptionProvider(api_key="sk-test") + with patch("httpx.AsyncClient.stream", stream_cm), patch( + "asyncio.sleep", AsyncMock() + ): + result = await provider.transcribe(audio_file) + + assert result == "ok" + + +@pytest.mark.asyncio +async def test_gives_up_after_max_retries(audio_file: Path) -> None: + """Persistent 503 returns "" after all retries exhausted.""" + attempts: list[list[str] | int] = [503, 503, 503, 503] # 4 failing HTTP responses + stream_cm = _make_stream_cm_sequence(attempts) + + provider = StepFunTranscriptionProvider(api_key="sk-test") + with patch("httpx.AsyncClient.stream", stream_cm), patch( + "asyncio.sleep", AsyncMock() + ): + result = await provider.transcribe(audio_file) + + assert result == "" + + +@pytest.mark.asyncio +async def test_sse_empty_text_done_returns_empty(audio_file: Path) -> None: + """Empty text in transcript.text.done should return "" immediately, not retry.""" + events = [ + {"type": "transcript.text.done", "session_id": "s1", "text": ""}, + ] + lines = [f"data: {json.dumps(e)}" for e in events] + + provider = StepFunTranscriptionProvider(api_key="sk-test") + stream_cm = _make_stream_cm(200, lines) + + with patch("httpx.AsyncClient.stream", stream_cm), patch( + "asyncio.sleep", AsyncMock() + ): + result = await provider.transcribe(audio_file) + + assert result == "" + + +@pytest.mark.asyncio +async def test_401_returns_empty_after_retries(audio_file: Path) -> None: + """401 is not in the retryable set but HTTPStatusError still triggers + the retry loop; all attempts exhaust and return "".""" + stream_cm = _make_stream_cm(401, []) + + provider = StepFunTranscriptionProvider(api_key="sk-test") + with patch("httpx.AsyncClient.stream", stream_cm), patch( + "asyncio.sleep", AsyncMock() + ): + result = await provider.transcribe(audio_file) + + assert result == "" + + +@pytest.mark.asyncio +async def test_retries_on_connect_error(audio_file: Path) -> None: + """Network-level transient errors are retried.""" + success_lines = [ + f"data: {json.dumps({'type': 'transcript.text.done', 'session_id': 's1', 'text': 'ok'})}", + ] + call_count = [0] + + class FakeResponse: + """Serves as both the async context manager returned by stream() + and the response object bound in `async with ... as resp`.""" + status_code = 200 + reason_phrase = "OK" + + async def __aenter__(self) -> "FakeResponse": + return self + + async def __aexit__(self, *exc: object) -> None: + pass + + async def aiter_lines(self) -> Any: + for line in success_lines: + yield line + + def raise_for_status(self) -> None: + pass + + def fake_stream(method: str, url: str, *args: object, **kwargs: object) -> FakeResponse: + call_count[0] += 1 + if call_count[0] == 1: + raise httpx.ConnectError("boom") + return FakeResponse() + + provider = StepFunTranscriptionProvider(api_key="sk-test") + with patch("httpx.AsyncClient.stream", fake_stream), patch( + "asyncio.sleep", AsyncMock() + ): + result = await provider.transcribe(audio_file) + + assert result == "ok" + assert call_count[0] == 2 + + +# --------------------------------------------------------------------------- +# Registry integration +# --------------------------------------------------------------------------- + + +def test_stepfun_in_registry() -> None: + assert "stepfun" in transcription_provider_names() + spec = get_transcription_provider("stepfun") + assert spec is not None + assert spec.default_model == "stepaudio-2.5-asr" + assert spec.adapter == "nanobot.providers.transcription:StepFunTranscriptionProvider" + + +def test_config_resolves_stepfun() -> None: + config = Config() + config.transcription.provider = "stepfun" + config.transcription.model = "stepaudio-2.5-asr" + config.transcription.language = "zh" + config.providers.stepfun.api_key = "step-test" + config.providers.stepfun.api_base = "https://api.stepfun.com/step_plan/v1/audio/asr/sse" + + from nanobot.audio.transcription import resolve_transcription_config + + resolved = resolve_transcription_config(config) + + assert resolved.provider == "stepfun" + assert resolved.model == "stepaudio-2.5-asr" + assert resolved.language == "zh" + assert resolved.api_key == "step-test" + assert resolved.api_base == "https://api.stepfun.com/step_plan/v1/audio/asr/sse" + assert resolved.configured is True + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_stream_cm(status: int, lines: list[str]) -> MagicMock: + """Build a mock for `AsyncClient.stream` that yields *lines* as SSE.""" + + class FakeResponse: + def __init__(self) -> None: + self.status_code = status + self.reason_phrase = "OK" if status == 200 else "Error" + + async def __aenter__(self) -> "FakeResponse": + return self + + async def __aexit__(self, *exc: object) -> None: + pass + + async def aiter_lines(self) -> Any: + for line in lines: + yield line + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise httpx.HTTPStatusError( + f"HTTP {self.status_code}", + request=httpx.Request("POST", "https://example.test"), + response=httpx.Response(self.status_code), + ) + + cm = MagicMock() + cm.return_value = FakeResponse() + return cm + + +def _make_stream_cm_sequence(statuses: list[str | int]) -> MagicMock: + """Build a stream mock that fails with HTTP status ints, then succeeds with SSE lines. + + Entries in *statuses* that are ints produce a stream that raises HTTPStatusError + after `raise_for_status()`. The final entry (a list of SSE lines) succeeds. + """ + remaining = list(statuses) + + class FakeResponse: + def __init__(self, lines: list[str]) -> None: + self._lines = lines + self.status_code = 200 + self.reason_phrase = "OK" + + async def __aenter__(self) -> "FakeResponse": + return self + + async def __aexit__(self, *exc: object) -> None: + pass + + async def aiter_lines(self) -> Any: + for line in self._lines: + yield line + + def raise_for_status(self) -> None: + pass + + class FailingResponse: + def __init__(self, status: int) -> None: + self.status_code = status + self.reason_phrase = "Error" + + async def __aenter__(self) -> "FailingResponse": + return self + + async def __aexit__(self, *exc: object) -> None: + pass + + async def aiter_lines(self) -> Any: + yield "" + return + + def raise_for_status(self) -> None: + raise httpx.HTTPStatusError( + f"HTTP {self.status_code}", + request=httpx.Request("POST", "https://example.test"), + response=httpx.Response(self.status_code), + ) + + call_count = [0] + + def _next(method: str, url: str, **kwargs: object) -> Any: + idx = min(call_count[0], len(remaining) - 1) + entry = remaining[idx] + call_count[0] += 1 + if isinstance(entry, int): + return FailingResponse(entry) + return FakeResponse(entry) + + cm = MagicMock(side_effect=_next) + return cm