mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-27 13:25:52 +00:00
feat: auto-fallback to other transcription provider on failure
When the primary transcription provider fails (bad key, API error, etc.), automatically try the other provider if its API key is available. Made-with: Cursor
This commit is contained in:
parent
35dde8a30e
commit
3bf1fa5225
@ -24,6 +24,7 @@ class BaseChannel(ABC):
|
|||||||
display_name: str = "Base"
|
display_name: str = "Base"
|
||||||
transcription_provider: str = "groq"
|
transcription_provider: str = "groq"
|
||||||
transcription_api_key: str = ""
|
transcription_api_key: str = ""
|
||||||
|
_transcription_fallback_key: str = ""
|
||||||
|
|
||||||
def __init__(self, config: Any, bus: MessageBus):
|
def __init__(self, config: Any, bus: MessageBus):
|
||||||
"""
|
"""
|
||||||
@ -38,19 +39,30 @@ class BaseChannel(ABC):
|
|||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
async def transcribe_audio(self, file_path: str | Path) -> str:
|
async def transcribe_audio(self, file_path: str | Path) -> str:
|
||||||
"""Transcribe an audio file via Whisper (OpenAI or Groq). Returns empty string on failure."""
|
"""Transcribe an audio file via Whisper. Falls back to the other provider on failure."""
|
||||||
if not self.transcription_api_key:
|
if not self.transcription_api_key:
|
||||||
return ""
|
return ""
|
||||||
|
result = await self._try_transcribe(self.transcription_provider, self.transcription_api_key, file_path)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
fallback = "groq" if self.transcription_provider == "openai" else "openai"
|
||||||
|
if self._transcription_fallback_key:
|
||||||
|
logger.info("{}: trying {} fallback for transcription", self.name, fallback)
|
||||||
|
return await self._try_transcribe(fallback, self._transcription_fallback_key, file_path)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def _try_transcribe(self, provider: str, api_key: str, file_path: str | Path) -> str:
|
||||||
|
"""Attempt transcription with a single provider. Returns empty string on failure."""
|
||||||
try:
|
try:
|
||||||
if self.transcription_provider == "openai":
|
if provider == "openai":
|
||||||
from nanobot.providers.transcription import OpenAITranscriptionProvider
|
from nanobot.providers.transcription import OpenAITranscriptionProvider
|
||||||
provider = OpenAITranscriptionProvider(api_key=self.transcription_api_key)
|
p = OpenAITranscriptionProvider(api_key=api_key)
|
||||||
else:
|
else:
|
||||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||||
provider = GroqTranscriptionProvider(api_key=self.transcription_api_key)
|
p = GroqTranscriptionProvider(api_key=api_key)
|
||||||
return await provider.transcribe(file_path)
|
return await p.transcribe(file_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("{}: audio transcription failed: {}", self.name, e)
|
logger.warning("{}: {} transcription failed: {}", self.name, provider, e)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def login(self, force: bool = False) -> bool:
|
async def login(self, force: bool = False) -> bool:
|
||||||
|
|||||||
@ -41,6 +41,8 @@ class ChannelManager:
|
|||||||
|
|
||||||
transcription_provider = self.config.channels.transcription_provider
|
transcription_provider = self.config.channels.transcription_provider
|
||||||
transcription_key = self._resolve_transcription_key(transcription_provider)
|
transcription_key = self._resolve_transcription_key(transcription_provider)
|
||||||
|
fallback_provider = "groq" if transcription_provider == "openai" else "openai"
|
||||||
|
fallback_key = self._resolve_transcription_key(fallback_provider)
|
||||||
|
|
||||||
for name, cls in discover_all().items():
|
for name, cls in discover_all().items():
|
||||||
section = getattr(self.config.channels, name, None)
|
section = getattr(self.config.channels, name, None)
|
||||||
@ -57,6 +59,7 @@ class ChannelManager:
|
|||||||
channel = cls(section, self.bus)
|
channel = cls(section, self.bus)
|
||||||
channel.transcription_provider = transcription_provider
|
channel.transcription_provider = transcription_provider
|
||||||
channel.transcription_api_key = transcription_key
|
channel.transcription_api_key = transcription_key
|
||||||
|
channel._transcription_fallback_key = fallback_key
|
||||||
self.channels[name] = channel
|
self.channels[name] = channel
|
||||||
logger.info("{} channel enabled", cls.display_name)
|
logger.info("{} channel enabled", cls.display_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -66,9 +69,12 @@ class ChannelManager:
|
|||||||
|
|
||||||
def _resolve_transcription_key(self, provider: str) -> str:
|
def _resolve_transcription_key(self, provider: str) -> str:
|
||||||
"""Pick the API key for the configured transcription provider."""
|
"""Pick the API key for the configured transcription provider."""
|
||||||
if provider == "openai":
|
try:
|
||||||
return self.config.providers.openai.api_key
|
if provider == "openai":
|
||||||
return self.config.providers.groq.api_key
|
return self.config.providers.openai.api_key
|
||||||
|
return self.config.providers.groq.api_key
|
||||||
|
except AttributeError:
|
||||||
|
return ""
|
||||||
|
|
||||||
def _validate_allow_from(self) -> None:
|
def _validate_allow_from(self) -> None:
|
||||||
for name, ch in self.channels.items():
|
for name, ch in self.channels.items():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user