mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 15:24:06 +00:00
feat(asr): add StepFun ASR SSE transcription provider
- Add StepFunTranscriptionProvider class in nanobot/providers/transcription.py - New _post_stepfun_asr_with_retry() function handling SSE stream parsing (transcript.text.delta → transcript.text.done event sequence) - Register 'stepfun' in transcription_registry.py with default model stepaudio-2.5-asr - Reuse existing stepfun provider config (apiBase can point to Plan endpoint) - Add 17 tests covering SSE parsing, retry contract, empty-text edge case, and registry integration - Update docs/configuration.md with stepfun ASR documentation StepFun ASR uses a dedicated SSE endpoint (/v1/audio/asr/sse) rather than the chat-completions or Whisper multipart formats used by other providers. Users on Step Plan can set apiBase to the Plan endpoint.
This commit is contained in:
parent
31bfec58d0
commit
7930058348
@ -239,7 +239,7 @@ Tracing covers the providers that go through nanobot's OpenAI-compatible client
|
|||||||
| `lm_studio` | LLM (local, LM Studio) | — |
|
| `lm_studio` | LLM (local, LM Studio) | — |
|
||||||
| `atomic_chat` | LLM (local, [Atomic Chat](https://atomic.chat/)) | — |
|
| `atomic_chat` | LLM (local, [Atomic Chat](https://atomic.chat/)) | — |
|
||||||
| `mistral` | LLM | [docs.mistral.ai](https://docs.mistral.ai/) |
|
| `mistral` | LLM | [docs.mistral.ai](https://docs.mistral.ai/) |
|
||||||
| `stepfun` | LLM (Step Fun/阶跃星辰) | [platform.stepfun.com](https://platform.stepfun.com) |
|
| `stepfun` | LLM (Step Fun/阶跃星辰) + Voice transcription (ASR) | [platform.stepfun.com](https://platform.stepfun.com) |
|
||||||
| `ovms` | LLM (local, OpenVINO Model Server) | [docs.openvino.ai](https://docs.openvino.ai/2026/model-server/ovms_docs_llm_quickstart.html) |
|
| `ovms` | LLM (local, OpenVINO Model Server) | [docs.openvino.ai](https://docs.openvino.ai/2026/model-server/ovms_docs_llm_quickstart.html) |
|
||||||
| `vllm` | LLM (local, any OpenAI-compatible server) | — |
|
| `vllm` | LLM (local, any OpenAI-compatible server) | — |
|
||||||
| `nvidia` | LLM (NVIDIA NIM) | [build.nvidia.com](https://build.nvidia.com/) |
|
| `nvidia` | LLM (NVIDIA NIM) | [build.nvidia.com](https://build.nvidia.com/) |
|
||||||
@ -1294,8 +1294,8 @@ Configure transcription under the top-level `transcription` section:
|
|||||||
| Setting | Default | Description |
|
| Setting | Default | Description |
|
||||||
|---------|---------|-------------|
|
|---------|---------|-------------|
|
||||||
| `enabled` | `true` | Enables audio transcription for both chat-channel voice messages and WebUI/desktop microphone input. |
|
| `enabled` | `true` | Enables audio transcription for both chat-channel voice messages and WebUI/desktop microphone input. |
|
||||||
| `provider` | `"groq"` | Transcription backend: `"groq"`, `"openai"`, `"openrouter"`, `"xiaomi_mimo"`, or `"assemblyai"`. |
|
| `provider` | `"groq"` | Transcription backend: `"groq"`, `"openai"`, `"openrouter"`, `"xiaomi_mimo"`, `"stepfun"`, or `"assemblyai"`. |
|
||||||
| `model` | provider default | Optional transcription model override. Defaults to `whisper-large-v3` for Groq, `whisper-1` for OpenAI, `openai/whisper-1` for OpenRouter, `mimo-v2.5-asr` for Xiaomi MiMo ASR, and `universal-3-pro,universal-2` for AssemblyAI. OpenRouter accepts only speech-to-text models on its transcription endpoint, such as `nvidia/parakeet-tdt-0.6b-v3`, `openai/whisper-1`, or `openai/gpt-4o-transcribe`; chat LLMs are rejected there. AssemblyAI accepts a comma-separated model fallback list. |
|
| `model` | provider default | Optional transcription model override. Defaults to `whisper-large-v3` for Groq, `whisper-1` for OpenAI, `openai/whisper-1` for OpenRouter, `mimo-v2.5-asr` for Xiaomi MiMo ASR, `stepaudio-2.5-asr` for StepFun ASR, and `universal-3-pro,universal-2` for AssemblyAI. OpenRouter accepts only speech-to-text models on its transcription endpoint, such as `nvidia/parakeet-tdt-0.6b-v3`, `openai/whisper-1`, or `openai/gpt-4o-transcribe`; chat LLMs are rejected there. AssemblyAI accepts a comma-separated model fallback list. |
|
||||||
| `language` | `null` | Optional ISO-639 language hint, e.g. `"en"`, `"zh"`, `"ko"`, or `"ja"`. |
|
| `language` | `null` | Optional ISO-639 language hint, e.g. `"en"`, `"zh"`, `"ko"`, or `"ja"`. |
|
||||||
| `maxDurationSec` | `120` | Maximum WebUI/desktop recording duration. |
|
| `maxDurationSec` | `120` | Maximum WebUI/desktop recording duration. |
|
||||||
| `maxUploadMb` | `25` | Maximum WebUI/desktop audio upload size. |
|
| `maxUploadMb` | `25` | Maximum WebUI/desktop audio upload size. |
|
||||||
|
|||||||
@ -64,6 +64,11 @@ TRANSCRIPTION_PROVIDERS: tuple[TranscriptionProviderSpec, ...] = (
|
|||||||
adapter="nanobot.providers.transcription:XiaomiMiMoTranscriptionProvider",
|
adapter="nanobot.providers.transcription:XiaomiMiMoTranscriptionProvider",
|
||||||
aliases=("mimo", "xiaomi"),
|
aliases=("mimo", "xiaomi"),
|
||||||
),
|
),
|
||||||
|
TranscriptionProviderSpec(
|
||||||
|
name="stepfun",
|
||||||
|
default_model="stepaudio-2.5-asr",
|
||||||
|
adapter="nanobot.providers.transcription:StepFunTranscriptionProvider",
|
||||||
|
),
|
||||||
TranscriptionProviderSpec(
|
TranscriptionProviderSpec(
|
||||||
name="assemblyai",
|
name="assemblyai",
|
||||||
default_model="universal-3-pro,universal-2",
|
default_model="universal-3-pro,universal-2",
|
||||||
|
|||||||
@ -219,7 +219,7 @@ class ProvidersConfig(Base):
|
|||||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
minimax_anthropic: ProviderConfig = Field(default_factory=ProviderConfig) # MiniMax Anthropic endpoint (thinking)
|
minimax_anthropic: ProviderConfig = Field(default_factory=ProviderConfig) # MiniMax Anthropic endpoint (thinking)
|
||||||
mistral: ProviderConfig = Field(default_factory=ProviderConfig)
|
mistral: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
stepfun: ProviderConfig = Field(default_factory=ProviderConfig) # Step Fun (阶跃星辰)
|
stepfun: ProviderConfig = Field(default_factory=ProviderConfig) # Step Fun (阶跃星辰) — LLM + ASR (set apiBase to Plan URL for ASR)
|
||||||
xiaomi_mimo: ProviderConfig = Field(default_factory=ProviderConfig) # Xiaomi MIMO (小米)
|
xiaomi_mimo: ProviderConfig = Field(default_factory=ProviderConfig) # Xiaomi MIMO (小米)
|
||||||
longcat: ProviderConfig = Field(default_factory=ProviderConfig) # LongCat
|
longcat: ProviderConfig = Field(default_factory=ProviderConfig) # LongCat
|
||||||
ant_ling: ProviderConfig = Field(default_factory=ProviderConfig) # Ant Ling
|
ant_ling: ProviderConfig = Field(default_factory=ProviderConfig) # Ant Ling
|
||||||
|
|||||||
@ -8,6 +8,7 @@ WebUI upload validation, and channel integration live in
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import json
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
@ -306,6 +307,119 @@ async def _post_xiaomi_mimo_asr_with_retry(
|
|||||||
return await _post_with_retry(build_request, provider_label, _text_from_chat_payload)
|
return await _post_with_retry(build_request, provider_label, _text_from_chat_payload)
|
||||||
|
|
||||||
|
|
||||||
|
async def _post_stepfun_asr_with_retry(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
api_key: str | None,
|
||||||
|
path: Path,
|
||||||
|
model: str,
|
||||||
|
provider_label: str,
|
||||||
|
language: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""POST audio to StepFun ASR SSE endpoint and collect final text."""
|
||||||
|
try:
|
||||||
|
data = path.read_bytes()
|
||||||
|
except OSError as e:
|
||||||
|
logger.exception("{} transcription error: cannot read audio file: {}", provider_label, e)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
suffix = path.suffix.lstrip(".").lower()
|
||||||
|
audio_type = suffix if suffix in ("ogg", "mp3", "wav", "pcm") else "wav"
|
||||||
|
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"audio": {
|
||||||
|
"data": base64.b64encode(data).decode("ascii"),
|
||||||
|
"input": {
|
||||||
|
"transcription": {
|
||||||
|
"model": model,
|
||||||
|
"enable_itn": True,
|
||||||
|
},
|
||||||
|
"format": {"type": audio_type},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if language:
|
||||||
|
body["audio"]["input"]["transcription"]["language"] = language
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "text/event-stream",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
for attempt in range(_MAX_RETRIES + 1):
|
||||||
|
try:
|
||||||
|
async with client.stream(
|
||||||
|
"POST", url, headers=headers, json=body, timeout=60.0
|
||||||
|
) as resp:
|
||||||
|
if resp.status_code in _RETRYABLE_STATUS and attempt < _MAX_RETRIES:
|
||||||
|
logger.warning(
|
||||||
|
"{} transcription transient HTTP {} (attempt {}/{})",
|
||||||
|
provider_label,
|
||||||
|
resp.status_code,
|
||||||
|
attempt + 1,
|
||||||
|
_MAX_RETRIES + 1,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(_BACKOFF_S[attempt])
|
||||||
|
continue
|
||||||
|
resp.raise_for_status()
|
||||||
|
final_text = None
|
||||||
|
async for line in resp.aiter_lines():
|
||||||
|
if not line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
payload_str = line[len("data:") :].strip()
|
||||||
|
if not payload_str:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
payload = json.loads(payload_str)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
continue
|
||||||
|
event_type = payload.get("type", "")
|
||||||
|
if event_type == "error":
|
||||||
|
msg = payload.get("message", "unknown error")
|
||||||
|
logger.error("{} ASR error: {}", provider_label, msg)
|
||||||
|
return ""
|
||||||
|
if event_type == "transcript.text.done":
|
||||||
|
final_text = payload.get("text", "")
|
||||||
|
break
|
||||||
|
if final_text is not None:
|
||||||
|
return final_text
|
||||||
|
# Stream ended without a final event — retry if attempts remain
|
||||||
|
if attempt < _MAX_RETRIES:
|
||||||
|
logger.warning(
|
||||||
|
"{} transcription: no final event (attempt {}/{})",
|
||||||
|
provider_label,
|
||||||
|
attempt + 1,
|
||||||
|
_MAX_RETRIES + 1,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(_BACKOFF_S[attempt])
|
||||||
|
continue
|
||||||
|
logger.error(
|
||||||
|
"{} transcription: stream ended without final text after {} attempts",
|
||||||
|
provider_label,
|
||||||
|
_MAX_RETRIES + 1,
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
if attempt < _MAX_RETRIES:
|
||||||
|
await asyncio.sleep(_BACKOFF_S[attempt])
|
||||||
|
continue
|
||||||
|
logger.exception(
|
||||||
|
"{} transcription failed after {} attempts",
|
||||||
|
provider_label,
|
||||||
|
_MAX_RETRIES + 1,
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
except (httpx.RequestError, Exception):
|
||||||
|
if attempt < _MAX_RETRIES:
|
||||||
|
await asyncio.sleep(_BACKOFF_S[attempt])
|
||||||
|
continue
|
||||||
|
logger.exception("{} transcription request error", provider_label)
|
||||||
|
return ""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
async def _post_with_retry(
|
async def _post_with_retry(
|
||||||
build_request: Callable[[], dict[str, Any]],
|
build_request: Callable[[], dict[str, Any]],
|
||||||
provider_label: str,
|
provider_label: str,
|
||||||
@ -663,3 +777,44 @@ class XiaomiMiMoTranscriptionProvider:
|
|||||||
provider_label="Xiaomi MiMo",
|
provider_label="Xiaomi MiMo",
|
||||||
language=self.language,
|
language=self.language,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StepFunTranscriptionProvider:
|
||||||
|
"""Voice transcription provider using StepFun ASR SSE endpoint."""
|
||||||
|
|
||||||
|
_DEFAULT_URL = "https://api.stepfun.com/v1/audio/asr/sse"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str | None = None,
|
||||||
|
api_base: str | None = None,
|
||||||
|
language: str | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
):
|
||||||
|
self.api_key = api_key or os.environ.get("STEPFUN_API_KEY")
|
||||||
|
# api_base is used verbatim; users can point to the Plan endpoint
|
||||||
|
# (https://api.stepfun.com/step_plan/v1/audio/asr/sse) or any
|
||||||
|
# compatible proxy.
|
||||||
|
self.api_url = api_base or self._DEFAULT_URL
|
||||||
|
self.language = language or None
|
||||||
|
self.model = model or "stepaudio-2.5-asr"
|
||||||
|
logger.debug("StepFun transcription endpoint: {}", self.api_url)
|
||||||
|
|
||||||
|
async def transcribe(self, file_path: str | Path) -> str:
|
||||||
|
if not self.api_key:
|
||||||
|
logger.warning("StepFun API key not configured for transcription")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
path = Path(file_path)
|
||||||
|
if not path.exists():
|
||||||
|
logger.error("Audio file not found: {}", file_path)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return await _post_stepfun_asr_with_retry(
|
||||||
|
self.api_url,
|
||||||
|
api_key=self.api_key,
|
||||||
|
path=path,
|
||||||
|
model=self.model,
|
||||||
|
provider_label="StepFun",
|
||||||
|
language=self.language,
|
||||||
|
)
|
||||||
|
|||||||
418
tests/providers/test_stepfun_asr.py
Normal file
418
tests/providers/test_stepfun_asr.py
Normal file
@ -0,0 +1,418 @@
|
|||||||
|
"""Tests for StepFun ASR SSE transcription provider."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.audio.transcription_registry import (
|
||||||
|
get_transcription_provider,
|
||||||
|
transcription_provider_names,
|
||||||
|
)
|
||||||
|
from nanobot.config.schema import Config
|
||||||
|
from nanobot.providers.transcription import StepFunTranscriptionProvider
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def audio_file(tmp_path: Path) -> Path:
|
||||||
|
p = tmp_path / "voice.ogg"
|
||||||
|
p.write_bytes(b"OggS\x00fake-audio-bytes")
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Defaults and base normalization
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_stepfun_defaults() -> None:
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||||
|
assert provider.api_url == "https://api.stepfun.com/v1/audio/asr/sse"
|
||||||
|
assert provider.model == "stepaudio-2.5-asr"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stepfun_api_base_overrides_url() -> None:
|
||||||
|
provider = StepFunTranscriptionProvider(
|
||||||
|
api_key="sk-test",
|
||||||
|
api_base="https://api.stepfun.com/step_plan/v1/audio/asr/sse",
|
||||||
|
)
|
||||||
|
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stepfun_custom_model() -> None:
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test", model="stepaudio-2-asr-pro")
|
||||||
|
assert provider.model == "stepaudio-2-asr-pro"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Short-circuit: missing key / missing file
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_api_key_short_circuits(audio_file: Path) -> None:
|
||||||
|
with patch.dict("os.environ", {}, clear=True):
|
||||||
|
provider = StepFunTranscriptionProvider(api_key=None)
|
||||||
|
stream_mock = MagicMock()
|
||||||
|
with patch("httpx.AsyncClient.stream", stream_mock):
|
||||||
|
assert await provider.transcribe(audio_file) == ""
|
||||||
|
stream_mock.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_file_short_circuits(audio_file: Path) -> None:
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||||
|
stream_mock = MagicMock()
|
||||||
|
with patch("httpx.AsyncClient.stream", stream_mock):
|
||||||
|
assert await provider.transcribe("/nonexistent/path/voice.ogg") == ""
|
||||||
|
stream_mock.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SSE stream parsing: happy path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_delta_then_done(audio_file: Path) -> None:
|
||||||
|
"""Simulates the real SSE event sequence: delta(s) -> text.done."""
|
||||||
|
events = [
|
||||||
|
{"type": "transcript.text.delta", "session_id": "s1", "text": "你"},
|
||||||
|
{"type": "transcript.text.delta", "session_id": "s1", "text": "你好"},
|
||||||
|
{"type": "transcript.text.done", "session_id": "s1", "text": "你好世界"},
|
||||||
|
]
|
||||||
|
lines = [f"data: {json.dumps(e)}" for e in events]
|
||||||
|
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||||
|
stream_cm = _make_stream_cm(200, lines)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient.stream", stream_cm):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
|
||||||
|
assert result == "你好世界"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_only_done_event(audio_file: Path) -> None:
|
||||||
|
"""Single transcript.text.done event without deltas."""
|
||||||
|
events = [
|
||||||
|
{"type": "transcript.text.done", "session_id": "s1", "text": "hello world"},
|
||||||
|
]
|
||||||
|
lines = [f"data: {json.dumps(e)}" for e in events]
|
||||||
|
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||||
|
stream_cm = _make_stream_cm(200, lines)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient.stream", stream_cm):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
|
||||||
|
assert result == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_error_event(audio_file: Path) -> None:
|
||||||
|
"""Error event in SSE stream returns "" immediately."""
|
||||||
|
events = [
|
||||||
|
{"type": "error", "session_id": "s1", "message": "audio too short"},
|
||||||
|
]
|
||||||
|
lines = [f"data: {json.dumps(e)}" for e in events]
|
||||||
|
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||||
|
stream_cm = _make_stream_cm(200, lines)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient.stream", stream_cm):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_ignores_non_data_lines(audio_file: Path) -> None:
|
||||||
|
"""Empty lines and lines without 'data:' prefix are ignored."""
|
||||||
|
events = [
|
||||||
|
{"type": "transcript.text.done", "session_id": "s1", "text": "result"},
|
||||||
|
]
|
||||||
|
raw_lines = [
|
||||||
|
"", # empty line
|
||||||
|
"event: session.start", # non-data event
|
||||||
|
f"data: {json.dumps(events[0])}",
|
||||||
|
]
|
||||||
|
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||||
|
stream_cm = _make_stream_cm(200, raw_lines)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient.stream", stream_cm):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
|
||||||
|
assert result == "result"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_malformed_json_skipped(audio_file: Path) -> None:
|
||||||
|
"""Malformed JSON in data lines are skipped gracefully."""
|
||||||
|
events = [
|
||||||
|
{"type": "transcript.text.done", "session_id": "s1", "text": "ok"},
|
||||||
|
]
|
||||||
|
raw_lines = [
|
||||||
|
"data: not-json-at-all",
|
||||||
|
f"data: {json.dumps(events[0])}",
|
||||||
|
]
|
||||||
|
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||||
|
stream_cm = _make_stream_cm(200, raw_lines)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient.stream", stream_cm):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Retry contract
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retries_on_503_then_succeeds(audio_file: Path) -> None:
|
||||||
|
"""Transient 503 is retried, then a successful SSE stream yields text."""
|
||||||
|
success_lines = [
|
||||||
|
f"data: {json.dumps({'type': 'transcript.text.done', 'session_id': 's1', 'text': 'ok'})}",
|
||||||
|
]
|
||||||
|
# First call: 503 (FailingResponse), second call: success (FakeResponse with lines)
|
||||||
|
stream_cm = _make_stream_cm_sequence([503, success_lines])
|
||||||
|
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||||
|
with patch("httpx.AsyncClient.stream", stream_cm), patch(
|
||||||
|
"asyncio.sleep", AsyncMock()
|
||||||
|
):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gives_up_after_max_retries(audio_file: Path) -> None:
|
||||||
|
"""Persistent 503 returns "" after all retries exhausted."""
|
||||||
|
attempts: list[list[str] | int] = [503, 503, 503, 503] # 4 failing HTTP responses
|
||||||
|
stream_cm = _make_stream_cm_sequence(attempts)
|
||||||
|
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||||
|
with patch("httpx.AsyncClient.stream", stream_cm), patch(
|
||||||
|
"asyncio.sleep", AsyncMock()
|
||||||
|
):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_empty_text_done_returns_empty(audio_file: Path) -> None:
|
||||||
|
"""Empty text in transcript.text.done should return "" immediately, not retry."""
|
||||||
|
events = [
|
||||||
|
{"type": "transcript.text.done", "session_id": "s1", "text": ""},
|
||||||
|
]
|
||||||
|
lines = [f"data: {json.dumps(e)}" for e in events]
|
||||||
|
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||||
|
stream_cm = _make_stream_cm(200, lines)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient.stream", stream_cm), patch(
|
||||||
|
"asyncio.sleep", AsyncMock()
|
||||||
|
):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_401_returns_empty_after_retries(audio_file: Path) -> None:
|
||||||
|
"""401 is not in the retryable set but HTTPStatusError still triggers
|
||||||
|
the retry loop; all attempts exhaust and return ""."""
|
||||||
|
stream_cm = _make_stream_cm(401, [])
|
||||||
|
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||||
|
with patch("httpx.AsyncClient.stream", stream_cm), patch(
|
||||||
|
"asyncio.sleep", AsyncMock()
|
||||||
|
):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retries_on_connect_error(audio_file: Path) -> None:
|
||||||
|
"""Network-level transient errors are retried."""
|
||||||
|
success_lines = [
|
||||||
|
f"data: {json.dumps({'type': 'transcript.text.done', 'session_id': 's1', 'text': 'ok'})}",
|
||||||
|
]
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
class FakeResponse:
|
||||||
|
"""Serves as both the async context manager returned by stream()
|
||||||
|
and the response object bound in `async with ... as resp`."""
|
||||||
|
status_code = 200
|
||||||
|
reason_phrase = "OK"
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "FakeResponse":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc: object) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def aiter_lines(self) -> Any:
|
||||||
|
for line in success_lines:
|
||||||
|
yield line
|
||||||
|
|
||||||
|
def raise_for_status(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fake_stream(method: str, url: str, *args: object, **kwargs: object) -> FakeResponse:
|
||||||
|
call_count[0] += 1
|
||||||
|
if call_count[0] == 1:
|
||||||
|
raise httpx.ConnectError("boom")
|
||||||
|
return FakeResponse()
|
||||||
|
|
||||||
|
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||||
|
with patch("httpx.AsyncClient.stream", fake_stream), patch(
|
||||||
|
"asyncio.sleep", AsyncMock()
|
||||||
|
):
|
||||||
|
result = await provider.transcribe(audio_file)
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
assert call_count[0] == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Registry integration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_stepfun_in_registry() -> None:
|
||||||
|
assert "stepfun" in transcription_provider_names()
|
||||||
|
spec = get_transcription_provider("stepfun")
|
||||||
|
assert spec is not None
|
||||||
|
assert spec.default_model == "stepaudio-2.5-asr"
|
||||||
|
assert spec.adapter == "nanobot.providers.transcription:StepFunTranscriptionProvider"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_resolves_stepfun() -> None:
|
||||||
|
config = Config()
|
||||||
|
config.transcription.provider = "stepfun"
|
||||||
|
config.transcription.model = "stepaudio-2.5-asr"
|
||||||
|
config.transcription.language = "zh"
|
||||||
|
config.providers.stepfun.api_key = "step-test"
|
||||||
|
config.providers.stepfun.api_base = "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
|
||||||
|
|
||||||
|
from nanobot.audio.transcription import resolve_transcription_config
|
||||||
|
|
||||||
|
resolved = resolve_transcription_config(config)
|
||||||
|
|
||||||
|
assert resolved.provider == "stepfun"
|
||||||
|
assert resolved.model == "stepaudio-2.5-asr"
|
||||||
|
assert resolved.language == "zh"
|
||||||
|
assert resolved.api_key == "step-test"
|
||||||
|
assert resolved.api_base == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
|
||||||
|
assert resolved.configured is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_stream_cm(status: int, lines: list[str]) -> MagicMock:
|
||||||
|
"""Build a mock for `AsyncClient.stream` that yields *lines* as SSE."""
|
||||||
|
|
||||||
|
class FakeResponse:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.status_code = status
|
||||||
|
self.reason_phrase = "OK" if status == 200 else "Error"
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "FakeResponse":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc: object) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def aiter_lines(self) -> Any:
|
||||||
|
for line in lines:
|
||||||
|
yield line
|
||||||
|
|
||||||
|
def raise_for_status(self) -> None:
|
||||||
|
if self.status_code >= 400:
|
||||||
|
raise httpx.HTTPStatusError(
|
||||||
|
f"HTTP {self.status_code}",
|
||||||
|
request=httpx.Request("POST", "https://example.test"),
|
||||||
|
response=httpx.Response(self.status_code),
|
||||||
|
)
|
||||||
|
|
||||||
|
cm = MagicMock()
|
||||||
|
cm.return_value = FakeResponse()
|
||||||
|
return cm
|
||||||
|
|
||||||
|
|
||||||
|
def _make_stream_cm_sequence(statuses: list[str | int]) -> MagicMock:
|
||||||
|
"""Build a stream mock that fails with HTTP status ints, then succeeds with SSE lines.
|
||||||
|
|
||||||
|
Entries in *statuses* that are ints produce a stream that raises HTTPStatusError
|
||||||
|
after `raise_for_status()`. The final entry (a list of SSE lines) succeeds.
|
||||||
|
"""
|
||||||
|
remaining = list(statuses)
|
||||||
|
|
||||||
|
class FakeResponse:
|
||||||
|
def __init__(self, lines: list[str]) -> None:
|
||||||
|
self._lines = lines
|
||||||
|
self.status_code = 200
|
||||||
|
self.reason_phrase = "OK"
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "FakeResponse":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc: object) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def aiter_lines(self) -> Any:
|
||||||
|
for line in self._lines:
|
||||||
|
yield line
|
||||||
|
|
||||||
|
def raise_for_status(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class FailingResponse:
|
||||||
|
def __init__(self, status: int) -> None:
|
||||||
|
self.status_code = status
|
||||||
|
self.reason_phrase = "Error"
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "FailingResponse":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc: object) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def aiter_lines(self) -> Any:
|
||||||
|
yield ""
|
||||||
|
return
|
||||||
|
|
||||||
|
def raise_for_status(self) -> None:
|
||||||
|
raise httpx.HTTPStatusError(
|
||||||
|
f"HTTP {self.status_code}",
|
||||||
|
request=httpx.Request("POST", "https://example.test"),
|
||||||
|
response=httpx.Response(self.status_code),
|
||||||
|
)
|
||||||
|
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
def _next(method: str, url: str, **kwargs: object) -> Any:
|
||||||
|
idx = min(call_count[0], len(remaining) - 1)
|
||||||
|
entry = remaining[idx]
|
||||||
|
call_count[0] += 1
|
||||||
|
if isinstance(entry, int):
|
||||||
|
return FailingResponse(entry)
|
||||||
|
return FakeResponse(entry)
|
||||||
|
|
||||||
|
cm = MagicMock(side_effect=_next)
|
||||||
|
return cm
|
||||||
Loading…
x
Reference in New Issue
Block a user