mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 15:24:06 +00:00
fix(asr): normalize StepFun transcription endpoint
This commit is contained in:
parent
7930058348
commit
62a35c21b8
@ -20,6 +20,7 @@ from loguru import logger
|
||||
|
||||
_CHAT_COMPLETIONS_PATH = "chat/completions"
|
||||
_TRANSCRIPTIONS_PATH = "audio/transcriptions"
|
||||
_STEPFUN_ASR_PATH = "audio/asr/sse"
|
||||
_ASSEMBLYAI_DEFAULT_API_BASE = "https://api.assemblyai.com/v2"
|
||||
_ASSEMBLYAI_POLL_ATTEMPTS = 60
|
||||
_ASSEMBLYAI_POLL_INTERVAL_S = 2.0
|
||||
@ -72,6 +73,13 @@ def _resolve_api_path(api_base: str | None, default_base: str, path: str) -> str
|
||||
return f"{base}/{path.lstrip('/')}"
|
||||
|
||||
|
||||
def _resolve_stepfun_asr_url(api_base: str | None) -> str:
|
||||
base = (api_base or "https://api.stepfun.com/v1").rstrip("/")
|
||||
if base.endswith(_STEPFUN_ASR_PATH):
|
||||
return base
|
||||
return f"{base}/{_STEPFUN_ASR_PATH}"
|
||||
|
||||
|
||||
def _audio_mime_type(path: Path) -> str:
|
||||
return (
|
||||
_AUDIO_MIME_OVERRIDES.get(path.suffix.lower())
|
||||
@ -401,14 +409,15 @@ async def _post_stepfun_asr_with_retry(
|
||||
_MAX_RETRIES + 1,
|
||||
)
|
||||
return ""
|
||||
except httpx.HTTPStatusError:
|
||||
if attempt < _MAX_RETRIES:
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code in _RETRYABLE_STATUS and attempt < _MAX_RETRIES:
|
||||
await asyncio.sleep(_BACKOFF_S[attempt])
|
||||
continue
|
||||
logger.exception(
|
||||
"{} transcription failed after {} attempts",
|
||||
logger.error(
|
||||
"{} transcription HTTP {}{}",
|
||||
provider_label,
|
||||
_MAX_RETRIES + 1,
|
||||
e.response.status_code,
|
||||
f" {e.response.reason_phrase}" if e.response.reason_phrase else "",
|
||||
)
|
||||
return ""
|
||||
except (httpx.RequestError, Exception):
|
||||
@ -792,10 +801,8 @@ class StepFunTranscriptionProvider:
|
||||
model: str | None = None,
|
||||
):
|
||||
self.api_key = api_key or os.environ.get("STEPFUN_API_KEY")
|
||||
# api_base is used verbatim; users can point to the Plan endpoint
|
||||
# (https://api.stepfun.com/step_plan/v1/audio/asr/sse) or any
|
||||
# compatible proxy.
|
||||
self.api_url = api_base or self._DEFAULT_URL
|
||||
# api_base accepts either a StepFun base URL or the full SSE endpoint.
|
||||
self.api_url = _resolve_stepfun_asr_url(api_base)
|
||||
self.language = language or None
|
||||
self.model = model or "stepaudio-2.5-asr"
|
||||
logger.debug("StepFun transcription endpoint: {}", self.api_url)
|
||||
|
||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
@ -43,6 +44,14 @@ def test_stepfun_api_base_overrides_url() -> None:
|
||||
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
|
||||
|
||||
|
||||
def test_stepfun_api_base_appends_asr_path() -> None:
|
||||
provider = StepFunTranscriptionProvider(
|
||||
api_key="sk-test",
|
||||
api_base="https://api.stepfun.com/step_plan/v1",
|
||||
)
|
||||
assert provider.api_url == "https://api.stepfun.com/step_plan/v1/audio/asr/sse"
|
||||
|
||||
|
||||
def test_stepfun_custom_model() -> None:
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test", model="stepaudio-2-asr-pro")
|
||||
assert provider.model == "stepaudio-2-asr-pro"
|
||||
@ -229,18 +238,18 @@ async def test_sse_empty_text_done_returns_empty(audio_file: Path) -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_401_returns_empty_after_retries(audio_file: Path) -> None:
|
||||
"""401 is not in the retryable set but HTTPStatusError still triggers
|
||||
the retry loop; all attempts exhaust and return ""."""
|
||||
async def test_401_returns_empty_without_retry(audio_file: Path) -> None:
|
||||
"""401 is not retryable; bad credentials should fail immediately."""
|
||||
stream_cm = _make_stream_cm(401, [])
|
||||
sleep = AsyncMock()
|
||||
|
||||
provider = StepFunTranscriptionProvider(api_key="sk-test")
|
||||
with patch("httpx.AsyncClient.stream", stream_cm), patch(
|
||||
"asyncio.sleep", AsyncMock()
|
||||
):
|
||||
with patch("httpx.AsyncClient.stream", stream_cm), patch("asyncio.sleep", sleep):
|
||||
result = await provider.transcribe(audio_file)
|
||||
|
||||
assert result == ""
|
||||
assert stream_cm.call_count == 1
|
||||
sleep.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user