diff --git a/nanobot/audio/transcription.py b/nanobot/audio/transcription.py index fa46dbb23..3f942d925 100644 --- a/nanobot/audio/transcription.py +++ b/nanobot/audio/transcription.py @@ -8,6 +8,7 @@ HTTP details; those live in ``nanobot.providers.transcription``. from __future__ import annotations +import os from contextlib import suppress from dataclasses import dataclass, field from pathlib import Path @@ -19,6 +20,7 @@ from nanobot.audio.transcription_registry import ( get_transcription_provider, resolve_transcription_provider, ) +from nanobot.providers.registry import find_by_name from nanobot.config.paths import get_media_dir from nanobot.utils.media_decode import FileSizeExceeded, save_base64_data_url @@ -74,6 +76,33 @@ def _provider_config(config: Any, provider: str) -> Any: return getattr(getattr(config, "providers", None), provider, None) +def _provider_default_api_base(provider: str) -> str | None: + spec = find_by_name(provider) + return spec.default_api_base if spec else None + + +def _resolve_transcription_api_key(provider: str, provider_cfg: Any) -> str: + api_key = getattr(provider_cfg, "api_key", None) if provider_cfg else None + if api_key: + return api_key + + spec = find_by_name(provider) + if provider == "siliconflow": + env_key = os.environ.get("SILICONFLOW_API_KEY") + if env_key: + return env_key + + env_key = spec.env_key if spec else "" + return os.environ.get(env_key) if env_key else "" + + +def _resolve_transcription_api_base(provider: str, provider_cfg: Any) -> str: + api_base = getattr(provider_cfg, "api_base", None) if provider_cfg else None + if api_base: + return api_base + return _provider_default_api_base(provider) or "" + + def _extract_data_url_mime(url: str) -> str | None: header, _, _ = url.partition(",") if not header.startswith("data:") or ";base64" not in header: @@ -102,8 +131,8 @@ def resolve_transcription_config(config: Any) -> EffectiveTranscriptionConfig: provider=provider, model=(getattr(top, "model", None) or default_model).strip(), language=getattr(top, "language", None) or getattr(channels, "transcription_language", None), - api_key=getattr(provider_cfg, "api_key", None) or "", - api_base=getattr(provider_cfg, "api_base", None) or "", + api_key=_resolve_transcription_api_key(provider, provider_cfg), + api_base=_resolve_transcription_api_base(provider, provider_cfg), max_duration_sec=int(getattr(top, "max_duration_sec", 120)), max_upload_mb=int(getattr(top, "max_upload_mb", 25)), ) diff --git a/nanobot/audio/transcription_registry.py b/nanobot/audio/transcription_registry.py index ed4208a1a..a044abd60 100644 --- a/nanobot/audio/transcription_registry.py +++ b/nanobot/audio/transcription_registry.py @@ -74,6 +74,12 @@ TRANSCRIPTION_PROVIDERS: tuple[TranscriptionProviderSpec, ...] = ( default_model="universal-3-pro,universal-2", adapter="nanobot.providers.transcription:AssemblyAITranscriptionProvider", ), + TranscriptionProviderSpec( + name="siliconflow", + default_model="FunAudioLLM/SenseVoiceSmall", + adapter="nanobot.providers.transcription:OpenAITranscriptionProvider", + aliases=("silicon",), + ), ) _BY_NAME = {spec.name: spec for spec in TRANSCRIPTION_PROVIDERS} diff --git a/tests/providers/test_transcription.py b/tests/providers/test_transcription.py index dadf59440..c0acae59a 100644 --- a/tests/providers/test_transcription.py +++ b/tests/providers/test_transcription.py @@ -3,6 +3,7 @@ from __future__ import annotations import base64 +import os from pathlib import Path from unittest.mock import AsyncMock, patch @@ -114,6 +115,48 @@ def test_resolver_supports_openrouter_transcription_provider() -> None: assert resolved.api_base == "https://openrouter.ai/api/v1" +def test_resolver_supports_siliconflow_transcription_provider() -> None: + config = Config() + config.transcription.provider = "siliconflow" + config.transcription.model = "TeleAI/TeleSpeechASR" + config.transcription.language = "zh" + config.providers.siliconflow.api_key = "sf-test" + config.providers.siliconflow.api_base = "https://api.siliconflow.cn/v1" + + resolved = resolve_transcription_config(config) + + assert resolved.provider == "siliconflow" + assert resolved.model == "TeleAI/TeleSpeechASR" + assert resolved.language == "zh" + assert resolved.api_key == "sf-test" + assert resolved.api_base == "https://api.siliconflow.cn/v1" + + +def test_resolver_defaults_siliconflow_transcription_api_base() -> None: + config = Config() + config.transcription.provider = "siliconflow" + config.providers.siliconflow.api_key = "sf-test" + + resolved = resolve_transcription_config(config) + + assert resolved.provider == "siliconflow" + assert resolved.model == "FunAudioLLM/SenseVoiceSmall" + assert resolved.api_key == "sf-test" + assert resolved.api_base == "https://api.siliconflow.cn/v1" + + +def test_resolver_supports_siliconflow_transcription_api_key_env() -> None: + config = Config() + config.transcription.provider = "siliconflow" + + with patch.dict(os.environ, {"SILICONFLOW_API_KEY": "sf-env-key"}, clear=True): + resolved = resolve_transcription_config(config) + + assert resolved.provider == "siliconflow" + assert resolved.api_key == "sf-env-key" + assert resolved.api_base == "https://api.siliconflow.cn/v1" + + def test_resolver_supports_xiaomi_mimo_transcription_provider() -> None: config = Config() config.transcription.provider = "xiaomi_mimo" @@ -146,6 +189,13 @@ def test_resolver_accepts_legacy_xiaomi_transcription_alias() -> None: def test_transcription_registry_lists_providers_and_aliases() -> None: + siliconflow = get_transcription_provider("siliconflow") + assert siliconflow is not None + assert siliconflow.adapter == "nanobot.providers.transcription:OpenAITranscriptionProvider" + assert siliconflow.load_adapter() is OpenAITranscriptionProvider + assert siliconflow.default_model == "FunAudioLLM/SenseVoiceSmall" + assert resolve_transcription_provider("silicon").name == "siliconflow" + assert "assemblyai" in transcription_provider_names() assert get_transcription_provider("assemblyai").default_model == "universal-3-pro,universal-2" assert resolve_transcription_provider("mimo").name == "xiaomi_mimo"