mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
feat(asr): add StepFun ASR SSE transcription provider
- Add StepFunTranscriptionProvider class in nanobot/providers/transcription.py - New _post_stepfun_asr_with_retry() function handling SSE stream parsing (transcript.text.delta → transcript.text.done event sequence) - Register 'stepfun' in transcription_registry.py with default model stepaudio-2.5-asr - Reuse existing stepfun provider config (apiBase can point to Plan endpoint) - Add 17 tests covering SSE parsing, retry contract, empty-text edge case, and registry integration - Update docs/configuration.md with stepfun ASR documentation StepFun ASR uses a dedicated SSE endpoint (/v1/audio/asr/sse) rather than the chat-completions or Whisper multipart formats used by other providers. Users on Step Plan can set apiBase to the Plan endpoint.
This commit is contained in:
parent
31bfec58d0
commit
7930058348
@ -239,7 +239,7 @@ Tracing covers the providers that go through nanobot's OpenAI-compatible client
|
||||
| `lm_studio` | LLM (local, LM Studio) | — |
|
||||
| `atomic_chat` | LLM (local, [Atomic Chat](https://atomic.chat/)) | — |
|
||||
| `mistral` | LLM | [docs.mistral.ai](https://docs.mistral.ai/) |
|
||||
| `stepfun` | LLM (Step Fun/阶跃星辰) | [platform.stepfun.com](https://platform.stepfun.com) |
|
||||
| `stepfun` | LLM (Step Fun/阶跃星辰) + Voice transcription (ASR) | [platform.stepfun.com](https://platform.stepfun.com) |
|
||||
| `ovms` | LLM (local, OpenVINO Model Server) | [docs.openvino.ai](https://docs.openvino.ai/2026/model-server/ovms_docs_llm_quickstart.html) |
|
||||
| `vllm` | LLM (local, any OpenAI-compatible server) | — |
|
||||
| `nvidia` | LLM (NVIDIA NIM) | [build.nvidia.com](https://build.nvidia.com/) |
|
||||
@ -1294,8 +1294,8 @@ Configure transcription under the top-level `transcription` section:
|
||||
| Setting | Default | Description |
|
||||
|---------|---------|-------------|
|
||||
| `enabled` | `true` | Enables audio transcription for both chat-channel voice messages and WebUI/desktop microphone input. |
|
||||
| `provider` | `"groq"` | Transcription backend: `"groq"`, `"openai"`, `"openrouter"`, `"xiaomi_mimo"`, or `"assemblyai"`. |
|
||||
| `model` | provider default | Optional transcription model override. Defaults to `whisper-large-v3` for Groq, `whisper-1` for OpenAI, `openai/whisper-1` for OpenRouter, `mimo-v2.5-asr` for Xiaomi MiMo ASR, and `universal-3-pro,universal-2` for AssemblyAI. OpenRouter accepts only speech-to-text models on its transcription endpoint, such as `nvidia/parakeet-tdt-0.6b-v3`, `openai/whisper-1`, or `openai/gpt-4o-transcribe`; chat LLMs are rejected there. AssemblyAI accepts a comma-separated model fallback list. |
|
||||
| `provider` | `"groq"` | Transcription backend: `"groq"`, `"openai"`, `"openrouter"`, `"xiaomi_mimo"`, `"stepfun"`, or `"assemblyai"`. |
|
||||
| `model` | provider default | Optional transcription model override. Defaults to `whisper-large-v3` for Groq, `whisper-1` for OpenAI, `openai/whisper-1` for OpenRouter, `mimo-v2.5-asr` for Xiaomi MiMo ASR, `stepaudio-2.5-asr` for StepFun ASR, and `universal-3-pro,universal-2` for AssemblyAI. OpenRouter accepts only speech-to-text models on its transcription endpoint, such as `nvidia/parakeet-tdt-0.6b-v3`, `openai/whisper-1`, or `openai/gpt-4o-transcribe`; chat LLMs are rejected there. AssemblyAI accepts a comma-separated model fallback list. |
|
||||
| `language` | `null` | Optional ISO-639 language hint, e.g. `"en"`, `"zh"`, `"ko"`, or `"ja"`. |
|
||||
| `maxDurationSec` | `120` | Maximum WebUI/desktop recording duration. |
|
||||
| `maxUploadMb` | `25` | Maximum WebUI/desktop audio upload size. |
|
||||
|
||||
@ -64,6 +64,11 @@ TRANSCRIPTION_PROVIDERS: tuple[TranscriptionProviderSpec, ...] = (
|
||||
adapter="nanobot.providers.transcription:XiaomiMiMoTranscriptionProvider",
|
||||
aliases=("mimo", "xiaomi"),
|
||||
),
|
||||
TranscriptionProviderSpec(
|
||||
name="stepfun",
|
||||
default_model="stepaudio-2.5-asr",
|
||||
adapter="nanobot.providers.transcription:StepFunTranscriptionProvider",
|
||||
),
|
||||
TranscriptionProviderSpec(
|
||||
name="assemblyai",
|
||||
default_model="universal-3-pro,universal-2",
|
||||
|
||||
@ -219,7 +219,7 @@ class ProvidersConfig(Base):
|
||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
minimax_anthropic: ProviderConfig = Field(default_factory=ProviderConfig) # MiniMax Anthropic endpoint (thinking)
|
||||
mistral: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
stepfun: ProviderConfig = Field(default_factory=ProviderConfig) # Step Fun (阶跃星辰)
|
||||
stepfun: ProviderConfig = Field(default_factory=ProviderConfig) # Step Fun (阶跃星辰) — LLM + ASR (set apiBase to Plan URL for ASR)
|
||||
xiaomi_mimo: ProviderConfig = Field(default_factory=ProviderConfig) # Xiaomi MIMO (小米)
|
||||
longcat: ProviderConfig = Field(default_factory=ProviderConfig) # LongCat
|
||||
ant_ling: ProviderConfig = Field(default_factory=ProviderConfig) # Ant Ling
|
||||
|
||||
@ -8,6 +8,7 @@ WebUI upload validation, and channel integration live in
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
@ -306,6 +307,119 @@ async def _post_xiaomi_mimo_asr_with_retry(
|
||||
return await _post_with_retry(build_request, provider_label, _text_from_chat_payload)
|
||||
|
||||
|
||||
async def _post_stepfun_asr_with_retry(
|
||||
url: str,
|
||||
*,
|
||||
api_key: str | None,
|
||||
path: Path,
|
||||
model: str,
|
||||
provider_label: str,
|
||||
language: str | None = None,
|
||||
) -> str:
|
||||
"""POST audio to StepFun ASR SSE endpoint and collect final text."""
|
||||
try:
|
||||
data = path.read_bytes()
|
||||
except OSError as e:
|
||||
logger.exception("{} transcription error: cannot read audio file: {}", provider_label, e)
|
||||
return ""
|
||||
|
||||
suffix = path.suffix.lstrip(".").lower()
|
||||
audio_type = suffix if suffix in ("ogg", "mp3", "wav", "pcm") else "wav"
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"audio": {
|
||||
"data": base64.b64encode(data).decode("ascii"),
|
||||
"input": {
|
||||
"transcription": {
|
||||
"model": model,
|
||||
"enable_itn": True,
|
||||
},
|
||||
"format": {"type": audio_type},
|
||||
},
|
||||
},
|
||||
}
|
||||
if language:
|
||||
body["audio"]["input"]["transcription"]["language"] = language
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "text/event-stream",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
for attempt in range(_MAX_RETRIES + 1):
|
||||
try:
|
||||
async with client.stream(
|
||||
"POST", url, headers=headers, json=body, timeout=60.0
|
||||
) as resp:
|
||||
if resp.status_code in _RETRYABLE_STATUS and attempt < _MAX_RETRIES:
|
||||
logger.warning(
|
||||
"{} transcription transient HTTP {} (attempt {}/{})",
|
||||
provider_label,
|
||||
resp.status_code,
|
||||
attempt + 1,
|
||||
_MAX_RETRIES + 1,
|
||||
)
|
||||
await asyncio.sleep(_BACKOFF_S[attempt])
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
final_text = None
|
||||
async for line in resp.aiter_lines():
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
payload_str = line[len("data:") :].strip()
|
||||
if not payload_str:
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(payload_str)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
event_type = payload.get("type", "")
|
||||
if event_type == "error":
|
||||
msg = payload.get("message", "unknown error")
|
||||
logger.error("{} ASR error: {}", provider_label, msg)
|
||||
return ""
|
||||
if event_type == "transcript.text.done":
|
||||
final_text = payload.get("text", "")
|
||||
break
|
||||
if final_text is not None:
|
||||
return final_text
|
||||
# Stream ended without a final event — retry if attempts remain
|
||||
if attempt < _MAX_RETRIES:
|
||||
logger.warning(
|
||||
"{} transcription: no final event (attempt {}/{})",
|
||||
provider_label,
|
||||
attempt + 1,
|
||||
_MAX_RETRIES + 1,
|
||||
)
|
||||
await asyncio.sleep(_BACKOFF_S[attempt])
|
||||
continue
|
||||
logger.error(
|
||||
"{} transcription: stream ended without final text after {} attempts",
|
||||
provider_label,
|
||||
_MAX_RETRIES + 1,
|
||||
)
|
||||
return ""
|
||||
except httpx.HTTPStatusError:
|
||||
if attempt < _MAX_RETRIES:
|
||||
await asyncio.sleep(_BACKOFF_S[attempt])
|
||||
continue
|
||||
logger.exception(
|
||||
"{} transcription failed after {} attempts",
|
||||
provider_label,
|
||||
_MAX_RETRIES + 1,
|
||||
)
|
||||
return ""
|
||||
except (httpx.RequestError, Exception):
|
||||
if attempt < _MAX_RETRIES:
|
||||
await asyncio.sleep(_BACKOFF_S[attempt])
|
||||
continue
|
||||
logger.exception("{} transcription request error", provider_label)
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
async def _post_with_retry(
|
||||
build_request: Callable[[], dict[str, Any]],
|
||||
provider_label: str,
|
||||
@ -663,3 +777,44 @@ class XiaomiMiMoTranscriptionProvider:
|
||||
provider_label="Xiaomi MiMo",
|
||||
language=self.language,
|
||||
)
|
||||
|
||||
|
||||
class StepFunTranscriptionProvider:
|
||||
"""Voice transcription provider using StepFun ASR SSE endpoint."""
|
||||
|
||||
_DEFAULT_URL = "https://api.stepfun.com/v1/audio/asr/sse"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
language: str | None = None,
|
||||
model: str | None = None,
|
||||
):
|
||||
self.api_key = api_key or os.environ.get("STEPFUN_API_KEY")
|
||||
# api_base is used verbatim; users can point to the Plan endpoint
|
||||
# (https://api.stepfun.com/step_plan/v1/audio/asr/sse) or any
|
||||
# compatible proxy.
|
||||
self.api_url = api_base or self._DEFAULT_URL
|
||||
self.language = language or None
|
||||
self.model = model or "stepaudio-2.5-asr"
|
||||
logger.debug("StepFun transcription endpoint: {}", self.api_url)
|
||||
|
||||
async def transcribe(self, file_path: str | Path) -> str:
|
||||
if not self.api_key:
|
||||
logger.warning("StepFun API key not configured for transcription")
|
||||
return ""
|
||||
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
logger.error("Audio file not found: {}", file_path)
|
||||
return ""
|
||||
|
||||
return await _post_stepfun_asr_with_retry(
|
||||
self.api_url,
|
||||
api_key=self.api_key,
|
||||
path=path,
|
||||
model=self.model,
|
||||
provider_label="StepFun",
|
||||
language=self.language,
|
||||
)
|
||||
|
||||
418
tests/providers/test_stepfun_asr.py
Normal file
418
tests/providers/test_stepfun_asr.py
Normal file
@ -0,0 +1,418 @@
|
||||
"""Tests for StepFun ASR SSE transcription provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from nanobot.audio.transcription_registry import (
|
||||
get_transcription_provider,
|
||||
transcription_provider_names,
|
||||
)
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.providers.transcription import StepFunTranscriptionProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_file(tmp_path: Path) -> Path:
|
||||
p = tmp_path / "voice.ogg"
|
||||
p.write_bytes(b"OggS\x00fake-audio-bytes")
|
||||
return p
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Defaults and base normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stepfun_defaults() -> None:
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
assert provider.api_url == "https://api.stepfun.com/v1/audio/asr/sse"
|
||||
assert provider.model == "stepaudio-2.5-asr"
|
||||
|
||||
|
||||
def test_stepfun_api_base_overrides_url() -> None:
|
||||
provider = StepFunTranscriptionProvider(
|
||||
api_key="sk-test",
|
||||
api_base="https://api.stepfun.com/step_plan/v1/audio/asr/sse",
|
||||
)
|
||||
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
|
||||
|
||||
|
||||
def test_stepfun_custom_model() -> None:
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test", model="stepaudio-2-asr-pro")
|
||||
assert provider.model == "stepaudio-2-asr-pro"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Short-circuit: missing key / missing file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_api_key_short_circuits(audio_file: Path) -> None:
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
provider = StepFunTranscriptionProvider(api_key=None)
|
||||
stream_mock = MagicMock()
|
||||
with patch("httpx.AsyncClient.stream", stream_mock):
|
||||
assert await provider.transcribe(audio_file) == ""
|
||||
stream_mock.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_file_short_circuits(audio_file: Path) -> None:
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
stream_mock = MagicMock()
|
||||
with patch("httpx.AsyncClient.stream", stream_mock):
|
||||
assert await provider.transcribe("/nonexistent/path/voice.ogg") == ""
|
||||
stream_mock.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSE stream parsing: happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_delta_then_done(audio_file: Path) -> None:
|
||||
"""Simulates the real SSE event sequence: delta(s) -> text.done."""
|
||||
events = [
|
||||
{"type": "transcript.text.delta", "session_id": "s1", "text": "你"},
|
||||
{"type": "transcript.text.delta", "session_id": "s1", "text": "你好"},
|
||||
{"type": "transcript.text.done", "session_id": "s1", "text": "你好世界"},
|
||||
]
|
||||
lines = [f"data: {json.dumps(e)}" for e in events]
|
||||
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
stream_cm = _make_stream_cm(200, lines)
|
||||
|
||||
with patch("httpx.AsyncClient.stream", stream_cm):
|
||||
result = await provider.transcribe(audio_file)
|
||||
|
||||
assert result == "你好世界"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_only_done_event(audio_file: Path) -> None:
|
||||
"""Single transcript.text.done event without deltas."""
|
||||
events = [
|
||||
{"type": "transcript.text.done", "session_id": "s1", "text": "hello world"},
|
||||
]
|
||||
lines = [f"data: {json.dumps(e)}" for e in events]
|
||||
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
stream_cm = _make_stream_cm(200, lines)
|
||||
|
||||
with patch("httpx.AsyncClient.stream", stream_cm):
|
||||
result = await provider.transcribe(audio_file)
|
||||
|
||||
assert result == "hello world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_error_event(audio_file: Path) -> None:
|
||||
"""Error event in SSE stream returns "" immediately."""
|
||||
events = [
|
||||
{"type": "error", "session_id": "s1", "message": "audio too short"},
|
||||
]
|
||||
lines = [f"data: {json.dumps(e)}" for e in events]
|
||||
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
stream_cm = _make_stream_cm(200, lines)
|
||||
|
||||
with patch("httpx.AsyncClient.stream", stream_cm):
|
||||
result = await provider.transcribe(audio_file)
|
||||
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_ignores_non_data_lines(audio_file: Path) -> None:
|
||||
"""Empty lines and lines without 'data:' prefix are ignored."""
|
||||
events = [
|
||||
{"type": "transcript.text.done", "session_id": "s1", "text": "result"},
|
||||
]
|
||||
raw_lines = [
|
||||
"", # empty line
|
||||
"event: session.start", # non-data event
|
||||
f"data: {json.dumps(events[0])}",
|
||||
]
|
||||
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
stream_cm = _make_stream_cm(200, raw_lines)
|
||||
|
||||
with patch("httpx.AsyncClient.stream", stream_cm):
|
||||
result = await provider.transcribe(audio_file)
|
||||
|
||||
assert result == "result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_malformed_json_skipped(audio_file: Path) -> None:
|
||||
"""Malformed JSON in data lines are skipped gracefully."""
|
||||
events = [
|
||||
{"type": "transcript.text.done", "session_id": "s1", "text": "ok"},
|
||||
]
|
||||
raw_lines = [
|
||||
"data: not-json-at-all",
|
||||
f"data: {json.dumps(events[0])}",
|
||||
]
|
||||
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
stream_cm = _make_stream_cm(200, raw_lines)
|
||||
|
||||
with patch("httpx.AsyncClient.stream", stream_cm):
|
||||
result = await provider.transcribe(audio_file)
|
||||
|
||||
assert result == "ok"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Retry contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_503_then_succeeds(audio_file: Path) -> None:
|
||||
"""Transient 503 is retried, then a successful SSE stream yields text."""
|
||||
success_lines = [
|
||||
f"data: {json.dumps({'type': 'transcript.text.done', 'session_id': 's1', 'text': 'ok'})}",
|
||||
]
|
||||
# First call: 503 (FailingResponse), second call: success (FakeResponse with lines)
|
||||
stream_cm = _make_stream_cm_sequence([503, success_lines])
|
||||
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
with patch("httpx.AsyncClient.stream", stream_cm), patch(
|
||||
"asyncio.sleep", AsyncMock()
|
||||
):
|
||||
result = await provider.transcribe(audio_file)
|
||||
|
||||
assert result == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gives_up_after_max_retries(audio_file: Path) -> None:
|
||||
"""Persistent 503 returns "" after all retries exhausted."""
|
||||
attempts: list[list[str] | int] = [503, 503, 503, 503] # 4 failing HTTP responses
|
||||
stream_cm = _make_stream_cm_sequence(attempts)
|
||||
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
with patch("httpx.AsyncClient.stream", stream_cm), patch(
|
||||
"asyncio.sleep", AsyncMock()
|
||||
):
|
||||
result = await provider.transcribe(audio_file)
|
||||
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_empty_text_done_returns_empty(audio_file: Path) -> None:
|
||||
"""Empty text in transcript.text.done should return "" immediately, not retry."""
|
||||
events = [
|
||||
{"type": "transcript.text.done", "session_id": "s1", "text": ""},
|
||||
]
|
||||
lines = [f"data: {json.dumps(e)}" for e in events]
|
||||
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
stream_cm = _make_stream_cm(200, lines)
|
||||
|
||||
with patch("httpx.AsyncClient.stream", stream_cm), patch(
|
||||
"asyncio.sleep", AsyncMock()
|
||||
):
|
||||
result = await provider.transcribe(audio_file)
|
||||
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_401_returns_empty_after_retries(audio_file: Path) -> None:
|
||||
"""401 is not in the retryable set but HTTPStatusError still triggers
|
||||
the retry loop; all attempts exhaust and return ""."""
|
||||
stream_cm = _make_stream_cm(401, [])
|
||||
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
with patch("httpx.AsyncClient.stream", stream_cm), patch(
|
||||
"asyncio.sleep", AsyncMock()
|
||||
):
|
||||
result = await provider.transcribe(audio_file)
|
||||
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_connect_error(audio_file: Path) -> None:
|
||||
"""Network-level transient errors are retried."""
|
||||
success_lines = [
|
||||
f"data: {json.dumps({'type': 'transcript.text.done', 'session_id': 's1', 'text': 'ok'})}",
|
||||
]
|
||||
call_count = [0]
|
||||
|
||||
class FakeResponse:
|
||||
"""Serves as both the async context manager returned by stream()
|
||||
and the response object bound in `async with ... as resp`."""
|
||||
status_code = 200
|
||||
reason_phrase = "OK"
|
||||
|
||||
async def __aenter__(self) -> "FakeResponse":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc: object) -> None:
|
||||
pass
|
||||
|
||||
async def aiter_lines(self) -> Any:
|
||||
for line in success_lines:
|
||||
yield line
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
pass
|
||||
|
||||
def fake_stream(method: str, url: str, *args: object, **kwargs: object) -> FakeResponse:
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
raise httpx.ConnectError("boom")
|
||||
return FakeResponse()
|
||||
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
with patch("httpx.AsyncClient.stream", fake_stream), patch(
|
||||
"asyncio.sleep", AsyncMock()
|
||||
):
|
||||
result = await provider.transcribe(audio_file)
|
||||
|
||||
assert result == "ok"
|
||||
assert call_count[0] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stepfun_in_registry() -> None:
|
||||
assert "stepfun" in transcription_provider_names()
|
||||
spec = get_transcription_provider("stepfun")
|
||||
assert spec is not None
|
||||
assert spec.default_model == "stepaudio-2.5-asr"
|
||||
assert spec.adapter == "nanobot.providers.transcription:StepFunTranscriptionProvider"
|
||||
|
||||
|
||||
def test_config_resolves_stepfun() -> None:
|
||||
config = Config()
|
||||
config.transcription.provider = "stepfun"
|
||||
config.transcription.model = "stepaudio-2.5-asr"
|
||||
config.transcription.language = "zh"
|
||||
config.providers.stepfun.api_key = "step-test"
|
||||
config.providers.stepfun.api_base = "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
|
||||
|
||||
from nanobot.audio.transcription import resolve_transcription_config
|
||||
|
||||
resolved = resolve_transcription_config(config)
|
||||
|
||||
assert resolved.provider == "stepfun"
|
||||
assert resolved.model == "stepaudio-2.5-asr"
|
||||
assert resolved.language == "zh"
|
||||
assert resolved.api_key == "step-test"
|
||||
assert resolved.api_base == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
|
||||
assert resolved.configured is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_stream_cm(status: int, lines: list[str]) -> MagicMock:
|
||||
"""Build a mock for `AsyncClient.stream` that yields *lines* as SSE."""
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self) -> None:
|
||||
self.status_code = status
|
||||
self.reason_phrase = "OK" if status == 200 else "Error"
|
||||
|
||||
async def __aenter__(self) -> "FakeResponse":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc: object) -> None:
|
||||
pass
|
||||
|
||||
async def aiter_lines(self) -> Any:
|
||||
for line in lines:
|
||||
yield line
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
if self.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"HTTP {self.status_code}",
|
||||
request=httpx.Request("POST", "https://example.test"),
|
||||
response=httpx.Response(self.status_code),
|
||||
)
|
||||
|
||||
cm = MagicMock()
|
||||
cm.return_value = FakeResponse()
|
||||
return cm
|
||||
|
||||
|
||||
def _make_stream_cm_sequence(statuses: list[str | int]) -> MagicMock:
|
||||
"""Build a stream mock that fails with HTTP status ints, then succeeds with SSE lines.
|
||||
|
||||
Entries in *statuses* that are ints produce a stream that raises HTTPStatusError
|
||||
after `raise_for_status()`. The final entry (a list of SSE lines) succeeds.
|
||||
"""
|
||||
remaining = list(statuses)
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, lines: list[str]) -> None:
|
||||
self._lines = lines
|
||||
self.status_code = 200
|
||||
self.reason_phrase = "OK"
|
||||
|
||||
async def __aenter__(self) -> "FakeResponse":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc: object) -> None:
|
||||
pass
|
||||
|
||||
async def aiter_lines(self) -> Any:
|
||||
for line in self._lines:
|
||||
yield line
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
pass
|
||||
|
||||
class FailingResponse:
|
||||
def __init__(self, status: int) -> None:
|
||||
self.status_code = status
|
||||
self.reason_phrase = "Error"
|
||||
|
||||
async def __aenter__(self) -> "FailingResponse":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc: object) -> None:
|
||||
pass
|
||||
|
||||
async def aiter_lines(self) -> Any:
|
||||
yield ""
|
||||
return
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"HTTP {self.status_code}",
|
||||
request=httpx.Request("POST", "https://example.test"),
|
||||
response=httpx.Response(self.status_code),
|
||||
)
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def _next(method: str, url: str, **kwargs: object) -> Any:
|
||||
idx = min(call_count[0], len(remaining) - 1)
|
||||
entry = remaining[idx]
|
||||
call_count[0] += 1
|
||||
if isinstance(entry, int):
|
||||
return FailingResponse(entry)
|
||||
return FakeResponse(entry)
|
||||
|
||||
cm = MagicMock(side_effect=_next)
|
||||
return cm
|
||||
Loading…
x
Reference in New Issue
Block a user