fix(asr): normalize StepFun transcription endpoint

This commit is contained in:
Xubin Ren 2026-06-10 15:06:37 +08:00
parent 7930058348
commit 62a35c21b8
2 changed files with 31 additions and 15 deletions

View File

@ -20,6 +20,7 @@ from loguru import logger
_CHAT_COMPLETIONS_PATH = "chat/completions"
_TRANSCRIPTIONS_PATH = "audio/transcriptions"
_STEPFUN_ASR_PATH = "audio/asr/sse"
_ASSEMBLYAI_DEFAULT_API_BASE = "https://api.assemblyai.com/v2"
_ASSEMBLYAI_POLL_ATTEMPTS = 60
_ASSEMBLYAI_POLL_INTERVAL_S = 2.0
@ -72,6 +73,13 @@ def _resolve_api_path(api_base: str | None, default_base: str, path: str) -> str
return f"{base}/{path.lstrip('/')}"
def _resolve_stepfun_asr_url(api_base: str | None) -> str:
base = (api_base or "https://api.stepfun.com/v1").rstrip("/")
if base.endswith(_STEPFUN_ASR_PATH):
return base
return f"{base}/{_STEPFUN_ASR_PATH}"
def _audio_mime_type(path: Path) -> str:
return (
_AUDIO_MIME_OVERRIDES.get(path.suffix.lower())
@ -401,14 +409,15 @@ async def _post_stepfun_asr_with_retry(
_MAX_RETRIES + 1,
)
return ""
except httpx.HTTPStatusError:
if attempt < _MAX_RETRIES:
except httpx.HTTPStatusError as e:
if e.response.status_code in _RETRYABLE_STATUS and attempt < _MAX_RETRIES:
await asyncio.sleep(_BACKOFF_S[attempt])
continue
logger.exception(
"{} transcription failed after {} attempts",
logger.error(
"{} transcription HTTP {}{}",
provider_label,
_MAX_RETRIES + 1,
e.response.status_code,
f" {e.response.reason_phrase}" if e.response.reason_phrase else "",
)
return ""
except (httpx.RequestError, Exception):
@ -792,10 +801,8 @@ class StepFunTranscriptionProvider:
model: str | None = None,
):
self.api_key = api_key or os.environ.get("STEPFUN_API_KEY")
# api_base is used verbatim; users can point to the Plan endpoint
# (https://api.stepfun.com/step_plan/v1/audio/asr/sse) or any
# compatible proxy.
self.api_url = api_base or self._DEFAULT_URL
# api_base accepts either a StepFun base URL or the full SSE endpoint.
self.api_url = _resolve_stepfun_asr_url(api_base)
self.language = language or None
self.model = model or "stepaudio-2.5-asr"
logger.debug("StepFun transcription endpoint: {}", self.api_url)

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
@ -43,6 +44,14 @@ def test_stepfun_api_base_overrides_url() -> None:
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
def test_stepfun_api_base_appends_asr_path() -> None:
provider = StepFunTranscriptionProvider(
api_key="sk-test",
api_base="https://api.stepfun.com/step_plan/v1",
)
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
def test_stepfun_custom_model() -> None:
provider = StepFunTranscriptionProvider(api_key="sk-test", model="stepaudio-2-asr-pro")
assert provider.model == "stepaudio-2-asr-pro"
@ -229,18 +238,18 @@ async def test_sse_empty_text_done_returns_empty(audio_file: Path) -> None:
@pytest.mark.asyncio
async def test_401_returns_empty_after_retries(audio_file: Path) -> None:
"""401 is not in the retryable set but HTTPStatusError still triggers
the retry loop; all attempts exhaust and return ""."""
async def test_401_returns_empty_without_retry(audio_file: Path) -> None:
"""401 is not retryable; bad credentials should fail immediately."""
stream_cm = _make_stream_cm(401, [])
sleep = AsyncMock()
provider = StepFunTranscriptionProvider(api_key="sk-test")
with patch("httpx.AsyncClient.stream", stream_cm), patch(
"asyncio.sleep", AsyncMock()
):
with patch("httpx.AsyncClient.stream", stream_cm), patch("asyncio.sleep", sleep):
result = await provider.transcribe(audio_file)
assert result == ""
assert stream_cm.call_count == 1
sleep.assert_not_awaited()
@pytest.mark.asyncio