fix(asr): normalize StepFun transcription endpoint

This commit is contained in:
Xubin Ren 2026-06-10 15:06:37 +08:00
parent 7930058348
commit 62a35c21b8
2 changed files with 31 additions and 15 deletions

View File

@ -20,6 +20,7 @@ from loguru import logger
_CHAT_COMPLETIONS_PATH = "chat/completions" _CHAT_COMPLETIONS_PATH = "chat/completions"
_TRANSCRIPTIONS_PATH = "audio/transcriptions" _TRANSCRIPTIONS_PATH = "audio/transcriptions"
_STEPFUN_ASR_PATH = "audio/asr/sse"
_ASSEMBLYAI_DEFAULT_API_BASE = "https://api.assemblyai.com/v2" _ASSEMBLYAI_DEFAULT_API_BASE = "https://api.assemblyai.com/v2"
_ASSEMBLYAI_POLL_ATTEMPTS = 60 _ASSEMBLYAI_POLL_ATTEMPTS = 60
_ASSEMBLYAI_POLL_INTERVAL_S = 2.0 _ASSEMBLYAI_POLL_INTERVAL_S = 2.0
@ -72,6 +73,13 @@ def _resolve_api_path(api_base: str | None, default_base: str, path: str) -> str
return f"{base}/{path.lstrip('/')}" return f"{base}/{path.lstrip('/')}"
def _resolve_stepfun_asr_url(api_base: str | None) -> str:
base = (api_base or "https://api.stepfun.com/v1").rstrip("/")
if base.endswith(_STEPFUN_ASR_PATH):
return base
return f"{base}/{_STEPFUN_ASR_PATH}"
def _audio_mime_type(path: Path) -> str: def _audio_mime_type(path: Path) -> str:
return ( return (
_AUDIO_MIME_OVERRIDES.get(path.suffix.lower()) _AUDIO_MIME_OVERRIDES.get(path.suffix.lower())
@ -401,14 +409,15 @@ async def _post_stepfun_asr_with_retry(
_MAX_RETRIES + 1, _MAX_RETRIES + 1,
) )
return "" return ""
except httpx.HTTPStatusError: except httpx.HTTPStatusError as e:
if attempt < _MAX_RETRIES: if e.response.status_code in _RETRYABLE_STATUS and attempt < _MAX_RETRIES:
await asyncio.sleep(_BACKOFF_S[attempt]) await asyncio.sleep(_BACKOFF_S[attempt])
continue continue
logger.exception( logger.error(
"{} transcription failed after {} attempts", "{} transcription HTTP {}{}",
provider_label, provider_label,
_MAX_RETRIES + 1, e.response.status_code,
f" {e.response.reason_phrase}" if e.response.reason_phrase else "",
) )
return "" return ""
except (httpx.RequestError, Exception): except (httpx.RequestError, Exception):
@ -792,10 +801,8 @@ class StepFunTranscriptionProvider:
model: str | None = None, model: str | None = None,
): ):
self.api_key = api_key or os.environ.get("STEPFUN_API_KEY") self.api_key = api_key or os.environ.get("STEPFUN_API_KEY")
# api_base is used verbatim; users can point to the Plan endpoint # api_base accepts either a StepFun base URL or the full SSE endpoint.
# (https://api.stepfun.com/step_plan/v1/audio/asr/sse) or any self.api_url = _resolve_stepfun_asr_url(api_base)
# compatible proxy.
self.api_url = api_base or self._DEFAULT_URL
self.language = language or None self.language = language or None
self.model = model or "stepaudio-2.5-asr" self.model = model or "stepaudio-2.5-asr"
logger.debug("StepFun transcription endpoint: {}", self.api_url) logger.debug("StepFun transcription endpoint: {}", self.api_url)

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import json import json
from pathlib import Path from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import httpx import httpx
@ -43,6 +44,14 @@ def test_stepfun_api_base_overrides_url() -> None:
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse" assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
def test_stepfun_api_base_appends_asr_path() -> None:
provider = StepFunTranscriptionProvider(
api_key="sk-test",
api_base="https://api.stepfun.com/step_plan/v1",
)
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
def test_stepfun_custom_model() -> None: def test_stepfun_custom_model() -> None:
provider = StepFunTranscriptionProvider(api_key="sk-test", model="stepaudio-2-asr-pro") provider = StepFunTranscriptionProvider(api_key="sk-test", model="stepaudio-2-asr-pro")
assert provider.model == "stepaudio-2-asr-pro" assert provider.model == "stepaudio-2-asr-pro"
@ -229,18 +238,18 @@ async def test_sse_empty_text_done_returns_empty(audio_file: Path) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_401_returns_empty_after_retries(audio_file: Path) -> None: async def test_401_returns_empty_without_retry(audio_file: Path) -> None:
"""401 is not in the retryable set but HTTPStatusError still triggers """401 is not retryable; bad credentials should fail immediately."""
the retry loop; all attempts exhaust and return ""."""
stream_cm = _make_stream_cm(401, []) stream_cm = _make_stream_cm(401, [])
sleep = AsyncMock()
provider = StepFunTranscriptionProvider(api_key="sk-test") provider = StepFunTranscriptionProvider(api_key="sk-test")
with patch("httpx.AsyncClient.stream", stream_cm), patch( with patch("httpx.AsyncClient.stream", stream_cm), patch("asyncio.sleep", sleep):
"asyncio.sleep", AsyncMock()
):
result = await provider.transcribe(audio_file) result = await provider.transcribe(audio_file)
assert result == "" assert result == ""
assert stream_cm.call_count == 1
sleep.assert_not_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio