diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index 5b5922430..a59b31e20 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -45,7 +45,10 @@ class BaseChannel(ABC): try: if self.transcription_provider == "openai": from nanobot.providers.transcription import OpenAITranscriptionProvider - provider = OpenAITranscriptionProvider(api_key=self.transcription_api_key) + provider = OpenAITranscriptionProvider( + api_key=self.transcription_api_key, + api_base=self.transcription_api_base or None, + ) else: from nanobot.providers.transcription import GroqTranscriptionProvider provider = GroqTranscriptionProvider( diff --git a/nanobot/providers/transcription.py b/nanobot/providers/transcription.py index 8968c92ff..617fd3eb1 100644 --- a/nanobot/providers/transcription.py +++ b/nanobot/providers/transcription.py @@ -10,9 +10,13 @@ from loguru import logger class OpenAITranscriptionProvider: """Voice transcription provider using OpenAI's Whisper API.""" - def __init__(self, api_key: str | None = None): + def __init__(self, api_key: str | None = None, api_base: str | None = None): self.api_key = api_key or os.environ.get("OPENAI_API_KEY") - self.api_url = "https://api.openai.com/v1/audio/transcriptions" + self.api_url = ( + api_base + or os.environ.get("OPENAI_TRANSCRIPTION_BASE_URL") + or "https://api.openai.com/v1/audio/transcriptions" + ) async def transcribe(self, file_path: str | Path) -> str: if not self.api_key: diff --git a/tests/channels/test_channel_plugins.py b/tests/channels/test_channel_plugins.py index 11d6aa0af..a6959f937 100644 --- a/tests/channels/test_channel_plugins.py +++ b/tests/channels/test_channel_plugins.py @@ -225,6 +225,81 @@ async def test_manager_propagates_groq_transcription_api_base_to_channels(): assert channel.transcription_api_base == "http://proxy.local/v1/audio/transcriptions" +@pytest.mark.asyncio +async def test_manager_propagates_openai_transcription_api_base_to_channels(): + from nanobot.channels.manager import ChannelManager + + fake_config = SimpleNamespace( + channels=ChannelsConfig.model_validate({ + "fakeplugin": {"enabled": True, "allowFrom": ["*"]}, + "transcriptionProvider": "openai", + }), + providers=SimpleNamespace( + openai=SimpleNamespace( + api_key="openai-key", + api_base="http://proxy.local/v1/audio/transcriptions", + ), + groq=SimpleNamespace(api_key="groq-key", api_base=""), + ), + ) + + with patch( + "nanobot.channels.registry.discover_all", + return_value={"fakeplugin": _FakePlugin}, + ): + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {} + mgr._dispatch_task = None + mgr._init_channels() + + channel = mgr.channels["fakeplugin"] + assert channel.transcription_provider == "openai" + assert channel.transcription_api_key == "openai-key" + assert channel.transcription_api_base == "http://proxy.local/v1/audio/transcriptions" + + +@pytest.mark.asyncio +async def test_base_channel_passes_api_base_to_openai_transcription_provider(): + """BaseChannel.transcribe_audio must forward transcription_api_base to OpenAI.""" + from nanobot.providers import transcription as transcription_mod + + channel = _FakePlugin({"enabled": True, "allowFrom": ["*"]}, MessageBus()) + channel.transcription_provider = "openai" + channel.transcription_api_key = "k" + channel.transcription_api_base = "http://override/v1/audio/transcriptions" + + captured: dict[str, object] = {} + + class _StubOpenAI: + def __init__(self, api_key=None, api_base=None): + captured["api_key"] = api_key + captured["api_base"] = api_base + + async def transcribe(self, file_path): + return "ok" + + with patch.object(transcription_mod, "OpenAITranscriptionProvider", _StubOpenAI): + result = await channel.transcribe_audio("/tmp/does-not-matter.wav") + + assert result == "ok" + assert captured["api_key"] == "k" + assert captured["api_base"] == "http://override/v1/audio/transcriptions" + + +def test_openai_transcription_provider_honors_api_base_argument(): + from nanobot.providers.transcription import OpenAITranscriptionProvider + + default = OpenAITranscriptionProvider(api_key="k") + assert default.api_url == "https://api.openai.com/v1/audio/transcriptions" + + custom = OpenAITranscriptionProvider( + api_key="k", api_base="http://override/v1/audio/transcriptions" + ) + assert custom.api_url == "http://override/v1/audio/transcriptions" + + def test_channels_login_uses_discovered_plugin_class(monkeypatch): from nanobot.cli.commands import app from nanobot.config.schema import Config