mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-16 15:54:10 +00:00
feat(transcription): add SiliconFlow as transcription provider
- Register SiliconFlow in transcription registry with default model FunAudioLLM/SenseVoiceSmall and alias 'silicon' - Reuse existing OpenAITranscriptionProvider adapter (Whisper-compatible) - Add generic key/base resolution: fallback to registry env_key and default_api_base when provider config is absent - Add tests for registry entry, alias, adapter, default model, and config resolution with env var fallback
This commit is contained in:
parent
ddbd7ca39e
commit
9ed638ad70
@ -8,6 +8,7 @@ HTTP details; those live in ``nanobot.providers.transcription``.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -19,6 +20,7 @@ from nanobot.audio.transcription_registry import (
|
|||||||
get_transcription_provider,
|
get_transcription_provider,
|
||||||
resolve_transcription_provider,
|
resolve_transcription_provider,
|
||||||
)
|
)
|
||||||
|
from nanobot.providers.registry import find_by_name
|
||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
from nanobot.utils.media_decode import FileSizeExceeded, save_base64_data_url
|
from nanobot.utils.media_decode import FileSizeExceeded, save_base64_data_url
|
||||||
|
|
||||||
@ -74,6 +76,33 @@ def _provider_config(config: Any, provider: str) -> Any:
|
|||||||
return getattr(getattr(config, "providers", None), provider, None)
|
return getattr(getattr(config, "providers", None), provider, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_default_api_base(provider: str) -> str | None:
|
||||||
|
spec = find_by_name(provider)
|
||||||
|
return spec.default_api_base if spec else None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_transcription_api_key(provider: str, provider_cfg: Any) -> str:
|
||||||
|
api_key = getattr(provider_cfg, "api_key", None) if provider_cfg else None
|
||||||
|
if api_key:
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
spec = find_by_name(provider)
|
||||||
|
if provider == "siliconflow":
|
||||||
|
env_key = os.environ.get("SILICONFLOW_API_KEY")
|
||||||
|
if env_key:
|
||||||
|
return env_key
|
||||||
|
|
||||||
|
env_key = spec.env_key if spec else ""
|
||||||
|
return os.environ.get(env_key) if env_key else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_transcription_api_base(provider: str, provider_cfg: Any) -> str:
|
||||||
|
api_base = getattr(provider_cfg, "api_base", None) if provider_cfg else None
|
||||||
|
if api_base:
|
||||||
|
return api_base
|
||||||
|
return _provider_default_api_base(provider) or ""
|
||||||
|
|
||||||
|
|
||||||
def _extract_data_url_mime(url: str) -> str | None:
|
def _extract_data_url_mime(url: str) -> str | None:
|
||||||
header, _, _ = url.partition(",")
|
header, _, _ = url.partition(",")
|
||||||
if not header.startswith("data:") or ";base64" not in header:
|
if not header.startswith("data:") or ";base64" not in header:
|
||||||
@ -102,8 +131,8 @@ def resolve_transcription_config(config: Any) -> EffectiveTranscriptionConfig:
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
model=(getattr(top, "model", None) or default_model).strip(),
|
model=(getattr(top, "model", None) or default_model).strip(),
|
||||||
language=getattr(top, "language", None) or getattr(channels, "transcription_language", None),
|
language=getattr(top, "language", None) or getattr(channels, "transcription_language", None),
|
||||||
api_key=getattr(provider_cfg, "api_key", None) or "",
|
api_key=_resolve_transcription_api_key(provider, provider_cfg),
|
||||||
api_base=getattr(provider_cfg, "api_base", None) or "",
|
api_base=_resolve_transcription_api_base(provider, provider_cfg),
|
||||||
max_duration_sec=int(getattr(top, "max_duration_sec", 120)),
|
max_duration_sec=int(getattr(top, "max_duration_sec", 120)),
|
||||||
max_upload_mb=int(getattr(top, "max_upload_mb", 25)),
|
max_upload_mb=int(getattr(top, "max_upload_mb", 25)),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -74,6 +74,12 @@ TRANSCRIPTION_PROVIDERS: tuple[TranscriptionProviderSpec, ...] = (
|
|||||||
default_model="universal-3-pro,universal-2",
|
default_model="universal-3-pro,universal-2",
|
||||||
adapter="nanobot.providers.transcription:AssemblyAITranscriptionProvider",
|
adapter="nanobot.providers.transcription:AssemblyAITranscriptionProvider",
|
||||||
),
|
),
|
||||||
|
TranscriptionProviderSpec(
|
||||||
|
name="siliconflow",
|
||||||
|
default_model="FunAudioLLM/SenseVoiceSmall",
|
||||||
|
adapter="nanobot.providers.transcription:OpenAITranscriptionProvider",
|
||||||
|
aliases=("silicon",),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
_BY_NAME = {spec.name: spec for spec in TRANSCRIPTION_PROVIDERS}
|
_BY_NAME = {spec.name: spec for spec in TRANSCRIPTION_PROVIDERS}
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
@ -114,6 +115,48 @@ def test_resolver_supports_openrouter_transcription_provider() -> None:
|
|||||||
assert resolved.api_base == "https://openrouter.ai/api/v1"
|
assert resolved.api_base == "https://openrouter.ai/api/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_supports_siliconflow_transcription_provider() -> None:
|
||||||
|
config = Config()
|
||||||
|
config.transcription.provider = "siliconflow"
|
||||||
|
config.transcription.model = "TeleAI/TeleSpeechASR"
|
||||||
|
config.transcription.language = "zh"
|
||||||
|
config.providers.siliconflow.api_key = "sf-test"
|
||||||
|
config.providers.siliconflow.api_base = "https://api.siliconflow.cn/v1"
|
||||||
|
|
||||||
|
resolved = resolve_transcription_config(config)
|
||||||
|
|
||||||
|
assert resolved.provider == "siliconflow"
|
||||||
|
assert resolved.model == "TeleAI/TeleSpeechASR"
|
||||||
|
assert resolved.language == "zh"
|
||||||
|
assert resolved.api_key == "sf-test"
|
||||||
|
assert resolved.api_base == "https://api.siliconflow.cn/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_defaults_siliconflow_transcription_api_base() -> None:
|
||||||
|
config = Config()
|
||||||
|
config.transcription.provider = "siliconflow"
|
||||||
|
config.providers.siliconflow.api_key = "sf-test"
|
||||||
|
|
||||||
|
resolved = resolve_transcription_config(config)
|
||||||
|
|
||||||
|
assert resolved.provider == "siliconflow"
|
||||||
|
assert resolved.model == "FunAudioLLM/SenseVoiceSmall"
|
||||||
|
assert resolved.api_key == "sf-test"
|
||||||
|
assert resolved.api_base == "https://api.siliconflow.cn/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_supports_siliconflow_transcription_api_key_env() -> None:
|
||||||
|
config = Config()
|
||||||
|
config.transcription.provider = "siliconflow"
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"SILICONFLOW_API_KEY": "sf-env-key"}, clear=True):
|
||||||
|
resolved = resolve_transcription_config(config)
|
||||||
|
|
||||||
|
assert resolved.provider == "siliconflow"
|
||||||
|
assert resolved.api_key == "sf-env-key"
|
||||||
|
assert resolved.api_base == "https://api.siliconflow.cn/v1"
|
||||||
|
|
||||||
|
|
||||||
def test_resolver_supports_xiaomi_mimo_transcription_provider() -> None:
|
def test_resolver_supports_xiaomi_mimo_transcription_provider() -> None:
|
||||||
config = Config()
|
config = Config()
|
||||||
config.transcription.provider = "xiaomi_mimo"
|
config.transcription.provider = "xiaomi_mimo"
|
||||||
@ -146,6 +189,13 @@ def test_resolver_accepts_legacy_xiaomi_transcription_alias() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_transcription_registry_lists_providers_and_aliases() -> None:
|
def test_transcription_registry_lists_providers_and_aliases() -> None:
|
||||||
|
siliconflow = get_transcription_provider("siliconflow")
|
||||||
|
assert siliconflow is not None
|
||||||
|
assert siliconflow.adapter == "nanobot.providers.transcription:OpenAITranscriptionProvider"
|
||||||
|
assert siliconflow.load_adapter() is OpenAITranscriptionProvider
|
||||||
|
assert siliconflow.default_model == "FunAudioLLM/SenseVoiceSmall"
|
||||||
|
assert resolve_transcription_provider("silicon").name == "siliconflow"
|
||||||
|
|
||||||
assert "assemblyai" in transcription_provider_names()
|
assert "assemblyai" in transcription_provider_names()
|
||||||
assert get_transcription_provider("assemblyai").default_model == "universal-3-pro,universal-2"
|
assert get_transcription_provider("assemblyai").default_model == "universal-3-pro,universal-2"
|
||||||
assert resolve_transcription_provider("mimo").name == "xiaomi_mimo"
|
assert resolve_transcription_provider("mimo").name == "xiaomi_mimo"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user