feat(transcription): add SiliconFlow as transcription provider

- Register SiliconFlow in transcription registry with default model
  FunAudioLLM/SenseVoiceSmall and alias 'silicon'
- Reuse existing OpenAITranscriptionProvider adapter (Whisper-compatible)
- Add generic key/base resolution: fallback to registry env_key and
  default_api_base when provider config is absent
- Add tests for registry entry, alias, adapter, default model, and
  config resolution with env var fallback
This commit is contained in:
moran 2026-06-10 22:16:53 +08:00 committed by Xubin Ren
parent ddbd7ca39e
commit 9ed638ad70
3 changed files with 87 additions and 2 deletions

View File

@ -8,6 +8,7 @@ HTTP details; those live in ``nanobot.providers.transcription``.
from __future__ import annotations from __future__ import annotations
import os
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@ -19,6 +20,7 @@ from nanobot.audio.transcription_registry import (
get_transcription_provider, get_transcription_provider,
resolve_transcription_provider, resolve_transcription_provider,
) )
from nanobot.providers.registry import find_by_name
from nanobot.config.paths import get_media_dir from nanobot.config.paths import get_media_dir
from nanobot.utils.media_decode import FileSizeExceeded, save_base64_data_url from nanobot.utils.media_decode import FileSizeExceeded, save_base64_data_url
@ -74,6 +76,33 @@ def _provider_config(config: Any, provider: str) -> Any:
return getattr(getattr(config, "providers", None), provider, None) return getattr(getattr(config, "providers", None), provider, None)
def _provider_default_api_base(provider: str) -> str | None:
spec = find_by_name(provider)
return spec.default_api_base if spec else None
def _resolve_transcription_api_key(provider: str, provider_cfg: Any) -> str:
api_key = getattr(provider_cfg, "api_key", None) if provider_cfg else None
if api_key:
return api_key
spec = find_by_name(provider)
if provider == "siliconflow":
env_key = os.environ.get("SILICONFLOW_API_KEY")
if env_key:
return env_key
env_key = spec.env_key if spec else ""
return os.environ.get(env_key) if env_key else ""
def _resolve_transcription_api_base(provider: str, provider_cfg: Any) -> str:
api_base = getattr(provider_cfg, "api_base", None) if provider_cfg else None
if api_base:
return api_base
return _provider_default_api_base(provider) or ""
def _extract_data_url_mime(url: str) -> str | None: def _extract_data_url_mime(url: str) -> str | None:
header, _, _ = url.partition(",") header, _, _ = url.partition(",")
if not header.startswith("data:") or ";base64" not in header: if not header.startswith("data:") or ";base64" not in header:
@ -102,8 +131,8 @@ def resolve_transcription_config(config: Any) -> EffectiveTranscriptionConfig:
provider=provider, provider=provider,
model=(getattr(top, "model", None) or default_model).strip(), model=(getattr(top, "model", None) or default_model).strip(),
language=getattr(top, "language", None) or getattr(channels, "transcription_language", None), language=getattr(top, "language", None) or getattr(channels, "transcription_language", None),
api_key=getattr(provider_cfg, "api_key", None) or "", api_key=_resolve_transcription_api_key(provider, provider_cfg),
api_base=getattr(provider_cfg, "api_base", None) or "", api_base=_resolve_transcription_api_base(provider, provider_cfg),
max_duration_sec=int(getattr(top, "max_duration_sec", 120)), max_duration_sec=int(getattr(top, "max_duration_sec", 120)),
max_upload_mb=int(getattr(top, "max_upload_mb", 25)), max_upload_mb=int(getattr(top, "max_upload_mb", 25)),
) )

View File

@ -74,6 +74,12 @@ TRANSCRIPTION_PROVIDERS: tuple[TranscriptionProviderSpec, ...] = (
default_model="universal-3-pro,universal-2", default_model="universal-3-pro,universal-2",
adapter="nanobot.providers.transcription:AssemblyAITranscriptionProvider", adapter="nanobot.providers.transcription:AssemblyAITranscriptionProvider",
), ),
TranscriptionProviderSpec(
name="siliconflow",
default_model="FunAudioLLM/SenseVoiceSmall",
adapter="nanobot.providers.transcription:OpenAITranscriptionProvider",
aliases=("silicon",),
),
) )
_BY_NAME = {spec.name: spec for spec in TRANSCRIPTION_PROVIDERS} _BY_NAME = {spec.name: spec for spec in TRANSCRIPTION_PROVIDERS}

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import base64 import base64
import os
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
@ -114,6 +115,48 @@ def test_resolver_supports_openrouter_transcription_provider() -> None:
assert resolved.api_base == "https://openrouter.ai/api/v1" assert resolved.api_base == "https://openrouter.ai/api/v1"
def test_resolver_supports_siliconflow_transcription_provider() -> None:
config = Config()
config.transcription.provider = "siliconflow"
config.transcription.model = "TeleAI/TeleSpeechASR"
config.transcription.language = "zh"
config.providers.siliconflow.api_key = "sf-test"
config.providers.siliconflow.api_base = "https://api.siliconflow.cn/v1"
resolved = resolve_transcription_config(config)
assert resolved.provider == "siliconflow"
assert resolved.model == "TeleAI/TeleSpeechASR"
assert resolved.language == "zh"
assert resolved.api_key == "sf-test"
assert resolved.api_base == "https://api.siliconflow.cn/v1"
def test_resolver_defaults_siliconflow_transcription_api_base() -> None:
config = Config()
config.transcription.provider = "siliconflow"
config.providers.siliconflow.api_key = "sf-test"
resolved = resolve_transcription_config(config)
assert resolved.provider == "siliconflow"
assert resolved.model == "FunAudioLLM/SenseVoiceSmall"
assert resolved.api_key == "sf-test"
assert resolved.api_base == "https://api.siliconflow.cn/v1"
def test_resolver_supports_siliconflow_transcription_api_key_env() -> None:
config = Config()
config.transcription.provider = "siliconflow"
with patch.dict(os.environ, {"SILICONFLOW_API_KEY": "sf-env-key"}, clear=True):
resolved = resolve_transcription_config(config)
assert resolved.provider == "siliconflow"
assert resolved.api_key == "sf-env-key"
assert resolved.api_base == "https://api.siliconflow.cn/v1"
def test_resolver_supports_xiaomi_mimo_transcription_provider() -> None: def test_resolver_supports_xiaomi_mimo_transcription_provider() -> None:
config = Config() config = Config()
config.transcription.provider = "xiaomi_mimo" config.transcription.provider = "xiaomi_mimo"
@ -146,6 +189,13 @@ def test_resolver_accepts_legacy_xiaomi_transcription_alias() -> None:
def test_transcription_registry_lists_providers_and_aliases() -> None: def test_transcription_registry_lists_providers_and_aliases() -> None:
siliconflow = get_transcription_provider("siliconflow")
assert siliconflow is not None
assert siliconflow.adapter == "nanobot.providers.transcription:OpenAITranscriptionProvider"
assert siliconflow.load_adapter() is OpenAITranscriptionProvider
assert siliconflow.default_model == "FunAudioLLM/SenseVoiceSmall"
assert resolve_transcription_provider("silicon").name == "siliconflow"
assert "assemblyai" in transcription_provider_names() assert "assemblyai" in transcription_provider_names()
assert get_transcription_provider("assemblyai").default_model == "universal-3-pro,universal-2" assert get_transcription_provider("assemblyai").default_model == "universal-3-pro,universal-2"
assert resolve_transcription_provider("mimo").name == "xiaomi_mimo" assert resolve_transcription_provider("mimo").name == "xiaomi_mimo"