feat(asr): add StepFun ASR SSE transcription provider

- Add StepFunTranscriptionProvider class in nanobot/providers/transcription.py
- New _post_stepfun_asr_with_retry() function handling SSE stream parsing
  (transcript.text.delta → transcript.text.done event sequence)
- Register 'stepfun' in transcription_registry.py with default model stepaudio-2.5-asr
- Reuse existing stepfun provider config (apiBase can point to Plan endpoint)
- Add 17 tests covering SSE parsing, retry contract, empty-text edge case, and registry integration
- Update docs/configuration.md with stepfun ASR documentation

StepFun ASR uses a dedicated SSE endpoint (/v1/audio/asr/sse) rather
than the chat-completions or Whisper multipart formats used by other
providers. Users on Step Plan can set apiBase to the Plan endpoint.
This commit is contained in:
moran 2026-06-09 17:27:13 +08:00 committed by Xubin Ren
parent 31bfec58d0
commit 7930058348
5 changed files with 582 additions and 4 deletions

View File

@ -239,7 +239,7 @@ Tracing covers the providers that go through nanobot's OpenAI-compatible client
| `lm_studio` | LLM (local, LM Studio) | — |
| `atomic_chat` | LLM (local, [Atomic Chat](https://atomic.chat/)) | — |
| `mistral` | LLM | [docs.mistral.ai](https://docs.mistral.ai/) |
| `stepfun` | LLM (Step Fun/阶跃星辰) | [platform.stepfun.com](https://platform.stepfun.com) |
| `stepfun` | LLM (Step Fun/阶跃星辰) + Voice transcription (ASR) | [platform.stepfun.com](https://platform.stepfun.com) |
| `ovms` | LLM (local, OpenVINO Model Server) | [docs.openvino.ai](https://docs.openvino.ai/2026/model-server/ovms_docs_llm_quickstart.html) |
| `vllm` | LLM (local, any OpenAI-compatible server) | — |
| `nvidia` | LLM (NVIDIA NIM) | [build.nvidia.com](https://build.nvidia.com/) |
@ -1294,8 +1294,8 @@ Configure transcription under the top-level `transcription` section:
| Setting | Default | Description |
|---------|---------|-------------|
| `enabled` | `true` | Enables audio transcription for both chat-channel voice messages and WebUI/desktop microphone input. |
| `provider` | `"groq"` | Transcription backend: `"groq"`, `"openai"`, `"openrouter"`, `"xiaomi_mimo"`, or `"assemblyai"`. |
| `model` | provider default | Optional transcription model override. Defaults to `whisper-large-v3` for Groq, `whisper-1` for OpenAI, `openai/whisper-1` for OpenRouter, `mimo-v2.5-asr` for Xiaomi MiMo ASR, and `universal-3-pro,universal-2` for AssemblyAI. OpenRouter accepts only speech-to-text models on its transcription endpoint, such as `nvidia/parakeet-tdt-0.6b-v3`, `openai/whisper-1`, or `openai/gpt-4o-transcribe`; chat LLMs are rejected there. AssemblyAI accepts a comma-separated model fallback list. |
| `provider` | `"groq"` | Transcription backend: `"groq"`, `"openai"`, `"openrouter"`, `"xiaomi_mimo"`, `"stepfun"`, or `"assemblyai"`. |
| `model` | provider default | Optional transcription model override. Defaults to `whisper-large-v3` for Groq, `whisper-1` for OpenAI, `openai/whisper-1` for OpenRouter, `mimo-v2.5-asr` for Xiaomi MiMo ASR, `stepaudio-2.5-asr` for StepFun ASR, and `universal-3-pro,universal-2` for AssemblyAI. OpenRouter accepts only speech-to-text models on its transcription endpoint, such as `nvidia/parakeet-tdt-0.6b-v3`, `openai/whisper-1`, or `openai/gpt-4o-transcribe`; chat LLMs are rejected there. AssemblyAI accepts a comma-separated model fallback list. |
| `language` | `null` | Optional ISO-639 language hint, e.g. `"en"`, `"zh"`, `"ko"`, or `"ja"`. |
| `maxDurationSec` | `120` | Maximum WebUI/desktop recording duration. |
| `maxUploadMb` | `25` | Maximum WebUI/desktop audio upload size. |

View File

@ -64,6 +64,11 @@ TRANSCRIPTION_PROVIDERS: tuple[TranscriptionProviderSpec, ...] = (
adapter="nanobot.providers.transcription:XiaomiMiMoTranscriptionProvider",
aliases=("mimo", "xiaomi"),
),
TranscriptionProviderSpec(
name="stepfun",
default_model="stepaudio-2.5-asr",
adapter="nanobot.providers.transcription:StepFunTranscriptionProvider",
),
TranscriptionProviderSpec(
name="assemblyai",
default_model="universal-3-pro,universal-2",

View File

@ -219,7 +219,7 @@ class ProvidersConfig(Base):
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
minimax_anthropic: ProviderConfig = Field(default_factory=ProviderConfig) # MiniMax Anthropic endpoint (thinking)
mistral: ProviderConfig = Field(default_factory=ProviderConfig)
stepfun: ProviderConfig = Field(default_factory=ProviderConfig) # Step Fun (阶跃星辰)
stepfun: ProviderConfig = Field(default_factory=ProviderConfig) # Step Fun (阶跃星辰) — LLM + ASR (set apiBase to Plan URL for ASR)
xiaomi_mimo: ProviderConfig = Field(default_factory=ProviderConfig) # Xiaomi MIMO (小米)
longcat: ProviderConfig = Field(default_factory=ProviderConfig) # LongCat
ant_ling: ProviderConfig = Field(default_factory=ProviderConfig) # Ant Ling

View File

@ -8,6 +8,7 @@ WebUI upload validation, and channel integration live in
import asyncio
import base64
import json
import mimetypes
import os
from collections.abc import Callable
@ -306,6 +307,119 @@ async def _post_xiaomi_mimo_asr_with_retry(
return await _post_with_retry(build_request, provider_label, _text_from_chat_payload)
async def _post_stepfun_asr_with_retry(
url: str,
*,
api_key: str | None,
path: Path,
model: str,
provider_label: str,
language: str | None = None,
) -> str:
"""POST audio to StepFun ASR SSE endpoint and collect final text."""
try:
data = path.read_bytes()
except OSError as e:
logger.exception("{} transcription error: cannot read audio file: {}", provider_label, e)
return ""
suffix = path.suffix.lstrip(".").lower()
audio_type = suffix if suffix in ("ogg", "mp3", "wav", "pcm") else "wav"
body: dict[str, Any] = {
"audio": {
"data": base64.b64encode(data).decode("ascii"),
"input": {
"transcription": {
"model": model,
"enable_itn": True,
},
"format": {"type": audio_type},
},
},
}
if language:
body["audio"]["input"]["transcription"]["language"] = language
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"Accept": "text/event-stream",
}
async with httpx.AsyncClient() as client:
for attempt in range(_MAX_RETRIES + 1):
try:
async with client.stream(
"POST", url, headers=headers, json=body, timeout=60.0
) as resp:
if resp.status_code in _RETRYABLE_STATUS and attempt < _MAX_RETRIES:
logger.warning(
"{} transcription transient HTTP {} (attempt {}/{})",
provider_label,
resp.status_code,
attempt + 1,
_MAX_RETRIES + 1,
)
await asyncio.sleep(_BACKOFF_S[attempt])
continue
resp.raise_for_status()
final_text = None
async for line in resp.aiter_lines():
if not line.startswith("data:"):
continue
payload_str = line[len("data:") :].strip()
if not payload_str:
continue
try:
payload = json.loads(payload_str)
except (json.JSONDecodeError, ValueError):
continue
event_type = payload.get("type", "")
if event_type == "error":
msg = payload.get("message", "unknown error")
logger.error("{} ASR error: {}", provider_label, msg)
return ""
if event_type == "transcript.text.done":
final_text = payload.get("text", "")
break
if final_text is not None:
return final_text
# Stream ended without a final event — retry if attempts remain
if attempt < _MAX_RETRIES:
logger.warning(
"{} transcription: no final event (attempt {}/{})",
provider_label,
attempt + 1,
_MAX_RETRIES + 1,
)
await asyncio.sleep(_BACKOFF_S[attempt])
continue
logger.error(
"{} transcription: stream ended without final text after {} attempts",
provider_label,
_MAX_RETRIES + 1,
)
return ""
except httpx.HTTPStatusError:
if attempt < _MAX_RETRIES:
await asyncio.sleep(_BACKOFF_S[attempt])
continue
logger.exception(
"{} transcription failed after {} attempts",
provider_label,
_MAX_RETRIES + 1,
)
return ""
except (httpx.RequestError, Exception):
if attempt < _MAX_RETRIES:
await asyncio.sleep(_BACKOFF_S[attempt])
continue
logger.exception("{} transcription request error", provider_label)
return ""
return ""
async def _post_with_retry(
build_request: Callable[[], dict[str, Any]],
provider_label: str,
@ -663,3 +777,44 @@ class XiaomiMiMoTranscriptionProvider:
provider_label="Xiaomi MiMo",
language=self.language,
)
class StepFunTranscriptionProvider:
"""Voice transcription provider using StepFun ASR SSE endpoint."""
_DEFAULT_URL = "https://api.stepfun.com/v1/audio/asr/sse"
def __init__(
self,
api_key: str | None = None,
api_base: str | None = None,
language: str | None = None,
model: str | None = None,
):
self.api_key = api_key or os.environ.get("STEPFUN_API_KEY")
# api_base is used verbatim; users can point to the Plan endpoint
# (https://api.stepfun.com/step_plan/v1/audio/asr/sse) or any
# compatible proxy.
self.api_url = api_base or self._DEFAULT_URL
self.language = language or None
self.model = model or "stepaudio-2.5-asr"
logger.debug("StepFun transcription endpoint: {}", self.api_url)
async def transcribe(self, file_path: str | Path) -> str:
if not self.api_key:
logger.warning("StepFun API key not configured for transcription")
return ""
path = Path(file_path)
if not path.exists():
logger.error("Audio file not found: {}", file_path)
return ""
return await _post_stepfun_asr_with_retry(
self.api_url,
api_key=self.api_key,
path=path,
model=self.model,
provider_label="StepFun",
language=self.language,
)

View File

@ -0,0 +1,418 @@
"""Tests for StepFun ASR SSE transcription provider."""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from nanobot.audio.transcription_registry import (
get_transcription_provider,
transcription_provider_names,
)
from nanobot.config.schema import Config
from nanobot.providers.transcription import StepFunTranscriptionProvider
@pytest.fixture
def audio_file(tmp_path: Path) -> Path:
p = tmp_path / "voice.ogg"
p.write_bytes(b"OggS\x00fake-audio-bytes")
return p
# ---------------------------------------------------------------------------
# Defaults and base normalization
# ---------------------------------------------------------------------------
def test_stepfun_defaults() -> None:
provider = StepFunTranscriptionProvider(api_key="sk-test")
assert provider.api_url == "https://api.stepfun.com/v1/audio/asr/sse"
assert provider.model == "stepaudio-2.5-asr"
def test_stepfun_api_base_overrides_url() -> None:
provider = StepFunTranscriptionProvider(
api_key="sk-test",
api_base="https://api.stepfun.com/step_plan/v1/audio/asr/sse",
)
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
def test_stepfun_custom_model() -> None:
provider = StepFunTranscriptionProvider(api_key="sk-test", model="stepaudio-2-asr-pro")
assert provider.model == "stepaudio-2-asr-pro"
# ---------------------------------------------------------------------------
# Short-circuit: missing key / missing file
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_missing_api_key_short_circuits(audio_file: Path) -> None:
with patch.dict("os.environ", {}, clear=True):
provider = StepFunTranscriptionProvider(api_key=None)
stream_mock = MagicMock()
with patch("httpx.AsyncClient.stream", stream_mock):
assert await provider.transcribe(audio_file) == ""
stream_mock.assert_not_called()
@pytest.mark.asyncio
async def test_missing_file_short_circuits(audio_file: Path) -> None:
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_mock = MagicMock()
with patch("httpx.AsyncClient.stream", stream_mock):
assert await provider.transcribe("/nonexistent/path/voice.ogg") == ""
stream_mock.assert_not_called()
# ---------------------------------------------------------------------------
# SSE stream parsing: happy path
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_sse_delta_then_done(audio_file: Path) -> None:
"""Simulates the real SSE event sequence: delta(s) -> text.done."""
events = [
{"type": "transcript.text.delta", "session_id": "s1", "text": ""},
{"type": "transcript.text.delta", "session_id": "s1", "text": "你好"},
{"type": "transcript.text.done", "session_id": "s1", "text": "你好世界"},
]
lines = [f"data: {json.dumps(e)}" for e in events]
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_cm = _make_stream_cm(200, lines)
with patch("httpx.AsyncClient.stream", stream_cm):
result = await provider.transcribe(audio_file)
assert result == "你好世界"
@pytest.mark.asyncio
async def test_sse_only_done_event(audio_file: Path) -> None:
"""Single transcript.text.done event without deltas."""
events = [
{"type": "transcript.text.done", "session_id": "s1", "text": "hello world"},
]
lines = [f"data: {json.dumps(e)}" for e in events]
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_cm = _make_stream_cm(200, lines)
with patch("httpx.AsyncClient.stream", stream_cm):
result = await provider.transcribe(audio_file)
assert result == "hello world"
@pytest.mark.asyncio
async def test_sse_error_event(audio_file: Path) -> None:
"""Error event in SSE stream returns "" immediately."""
events = [
{"type": "error", "session_id": "s1", "message": "audio too short"},
]
lines = [f"data: {json.dumps(e)}" for e in events]
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_cm = _make_stream_cm(200, lines)
with patch("httpx.AsyncClient.stream", stream_cm):
result = await provider.transcribe(audio_file)
assert result == ""
@pytest.mark.asyncio
async def test_sse_ignores_non_data_lines(audio_file: Path) -> None:
"""Empty lines and lines without 'data:' prefix are ignored."""
events = [
{"type": "transcript.text.done", "session_id": "s1", "text": "result"},
]
raw_lines = [
"", # empty line
"event: session.start", # non-data event
f"data: {json.dumps(events[0])}",
]
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_cm = _make_stream_cm(200, raw_lines)
with patch("httpx.AsyncClient.stream", stream_cm):
result = await provider.transcribe(audio_file)
assert result == "result"
@pytest.mark.asyncio
async def test_sse_malformed_json_skipped(audio_file: Path) -> None:
"""Malformed JSON in data lines are skipped gracefully."""
events = [
{"type": "transcript.text.done", "session_id": "s1", "text": "ok"},
]
raw_lines = [
"data: not-json-at-all",
f"data: {json.dumps(events[0])}",
]
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_cm = _make_stream_cm(200, raw_lines)
with patch("httpx.AsyncClient.stream", stream_cm):
result = await provider.transcribe(audio_file)
assert result == "ok"
# ---------------------------------------------------------------------------
# Retry contract
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_retries_on_503_then_succeeds(audio_file: Path) -> None:
"""Transient 503 is retried, then a successful SSE stream yields text."""
success_lines = [
f"data: {json.dumps({'type': 'transcript.text.done', 'session_id': 's1', 'text': 'ok'})}",
]
# First call: 503 (FailingResponse), second call: success (FakeResponse with lines)
stream_cm = _make_stream_cm_sequence([503, success_lines])
provider = StepFunTranscriptionProvider(api_key="sk-test")
with patch("httpx.AsyncClient.stream", stream_cm), patch(
"asyncio.sleep", AsyncMock()
):
result = await provider.transcribe(audio_file)
assert result == "ok"
@pytest.mark.asyncio
async def test_gives_up_after_max_retries(audio_file: Path) -> None:
"""Persistent 503 returns "" after all retries exhausted."""
attempts: list[list[str] | int] = [503, 503, 503, 503] # 4 failing HTTP responses
stream_cm = _make_stream_cm_sequence(attempts)
provider = StepFunTranscriptionProvider(api_key="sk-test")
with patch("httpx.AsyncClient.stream", stream_cm), patch(
"asyncio.sleep", AsyncMock()
):
result = await provider.transcribe(audio_file)
assert result == ""
@pytest.mark.asyncio
async def test_sse_empty_text_done_returns_empty(audio_file: Path) -> None:
"""Empty text in transcript.text.done should return "" immediately, not retry."""
events = [
{"type": "transcript.text.done", "session_id": "s1", "text": ""},
]
lines = [f"data: {json.dumps(e)}" for e in events]
provider = StepFunTranscriptionProvider(api_key="sk-test")
stream_cm = _make_stream_cm(200, lines)
with patch("httpx.AsyncClient.stream", stream_cm), patch(
"asyncio.sleep", AsyncMock()
):
result = await provider.transcribe(audio_file)
assert result == ""
@pytest.mark.asyncio
async def test_401_returns_empty_after_retries(audio_file: Path) -> None:
"""401 is not in the retryable set but HTTPStatusError still triggers
the retry loop; all attempts exhaust and return ""."""
stream_cm = _make_stream_cm(401, [])
provider = StepFunTranscriptionProvider(api_key="sk-test")
with patch("httpx.AsyncClient.stream", stream_cm), patch(
"asyncio.sleep", AsyncMock()
):
result = await provider.transcribe(audio_file)
assert result == ""
@pytest.mark.asyncio
async def test_retries_on_connect_error(audio_file: Path) -> None:
"""Network-level transient errors are retried."""
success_lines = [
f"data: {json.dumps({'type': 'transcript.text.done', 'session_id': 's1', 'text': 'ok'})}",
]
call_count = [0]
class FakeResponse:
"""Serves as both the async context manager returned by stream()
and the response object bound in `async with ... as resp`."""
status_code = 200
reason_phrase = "OK"
async def __aenter__(self) -> "FakeResponse":
return self
async def __aexit__(self, *exc: object) -> None:
pass
async def aiter_lines(self) -> Any:
for line in success_lines:
yield line
def raise_for_status(self) -> None:
pass
def fake_stream(method: str, url: str, *args: object, **kwargs: object) -> FakeResponse:
call_count[0] += 1
if call_count[0] == 1:
raise httpx.ConnectError("boom")
return FakeResponse()
provider = StepFunTranscriptionProvider(api_key="sk-test")
with patch("httpx.AsyncClient.stream", fake_stream), patch(
"asyncio.sleep", AsyncMock()
):
result = await provider.transcribe(audio_file)
assert result == "ok"
assert call_count[0] == 2
# ---------------------------------------------------------------------------
# Registry integration
# ---------------------------------------------------------------------------
def test_stepfun_in_registry() -> None:
assert "stepfun" in transcription_provider_names()
spec = get_transcription_provider("stepfun")
assert spec is not None
assert spec.default_model == "stepaudio-2.5-asr"
assert spec.adapter == "nanobot.providers.transcription:StepFunTranscriptionProvider"
def test_config_resolves_stepfun() -> None:
config = Config()
config.transcription.provider = "stepfun"
config.transcription.model = "stepaudio-2.5-asr"
config.transcription.language = "zh"
config.providers.stepfun.api_key = "step-test"
config.providers.stepfun.api_base = "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
from nanobot.audio.transcription import resolve_transcription_config
resolved = resolve_transcription_config(config)
assert resolved.provider == "stepfun"
assert resolved.model == "stepaudio-2.5-asr"
assert resolved.language == "zh"
assert resolved.api_key == "step-test"
assert resolved.api_base == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
assert resolved.configured is True
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_stream_cm(status: int, lines: list[str]) -> MagicMock:
"""Build a mock for `AsyncClient.stream` that yields *lines* as SSE."""
class FakeResponse:
def __init__(self) -> None:
self.status_code = status
self.reason_phrase = "OK" if status == 200 else "Error"
async def __aenter__(self) -> "FakeResponse":
return self
async def __aexit__(self, *exc: object) -> None:
pass
async def aiter_lines(self) -> Any:
for line in lines:
yield line
def raise_for_status(self) -> None:
if self.status_code >= 400:
raise httpx.HTTPStatusError(
f"HTTP {self.status_code}",
request=httpx.Request("POST", "https://example.test"),
response=httpx.Response(self.status_code),
)
cm = MagicMock()
cm.return_value = FakeResponse()
return cm
def _make_stream_cm_sequence(statuses: list[str | int]) -> MagicMock:
"""Build a stream mock that fails with HTTP status ints, then succeeds with SSE lines.
Entries in *statuses* that are ints produce a stream that raises HTTPStatusError
after `raise_for_status()`. The final entry (a list of SSE lines) succeeds.
"""
remaining = list(statuses)
class FakeResponse:
def __init__(self, lines: list[str]) -> None:
self._lines = lines
self.status_code = 200
self.reason_phrase = "OK"
async def __aenter__(self) -> "FakeResponse":
return self
async def __aexit__(self, *exc: object) -> None:
pass
async def aiter_lines(self) -> Any:
for line in self._lines:
yield line
def raise_for_status(self) -> None:
pass
class FailingResponse:
def __init__(self, status: int) -> None:
self.status_code = status
self.reason_phrase = "Error"
async def __aenter__(self) -> "FailingResponse":
return self
async def __aexit__(self, *exc: object) -> None:
pass
async def aiter_lines(self) -> Any:
yield ""
return
def raise_for_status(self) -> None:
raise httpx.HTTPStatusError(
f"HTTP {self.status_code}",
request=httpx.Request("POST", "https://example.test"),
response=httpx.Response(self.status_code),
)
call_count = [0]
def _next(method: str, url: str, **kwargs: object) -> Any:
idx = min(call_count[0], len(remaining) - 1)
entry = remaining[idx]
call_count[0] += 1
if isinstance(entry, int):
return FailingResponse(entry)
return FakeResponse(entry)
cm = MagicMock(side_effect=_next)
return cm