mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-04 00:35:58 +00:00
fix(transcription): honor api_base for OpenAI transcription provider
Complete the symmetry left by #3214: ChannelManager._resolve_transcription_base already resolves providers.openai.api_base, but BaseChannel.transcribe_audio instantiated OpenAITranscriptionProvider without forwarding it, and the provider __init__ did not accept the parameter. Self-hosted OpenAI-compatible Whisper endpoints (LiteLLM, vLLM, etc.) configured via config.json were therefore ignored for the OpenAI backend. - OpenAITranscriptionProvider.__init__ now accepts api_base with env fallback (OPENAI_TRANSCRIPTION_BASE_URL) matching the Groq pattern. - BaseChannel.transcribe_audio forwards self.transcription_api_base to OpenAI. - Tests mirror the existing Groq coverage: manager propagation for provider "openai", BaseChannel-to-provider argument passing, and provider default vs override for api_url. Fully backward-compatible: when api_base is None and the env var is unset, the default https://api.openai.com/v1/audio/transcriptions is used. Refs #3213, follow-up to #3214.
This commit is contained in:
parent
d57af5c1d1
commit
ce5272c153
@ -45,7 +45,10 @@ class BaseChannel(ABC):
|
|||||||
try:
|
try:
|
||||||
if self.transcription_provider == "openai":
|
if self.transcription_provider == "openai":
|
||||||
from nanobot.providers.transcription import OpenAITranscriptionProvider
|
from nanobot.providers.transcription import OpenAITranscriptionProvider
|
||||||
provider = OpenAITranscriptionProvider(api_key=self.transcription_api_key)
|
provider = OpenAITranscriptionProvider(
|
||||||
|
api_key=self.transcription_api_key,
|
||||||
|
api_base=self.transcription_api_base or None,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||||
provider = GroqTranscriptionProvider(
|
provider = GroqTranscriptionProvider(
|
||||||
|
|||||||
@ -10,9 +10,13 @@ from loguru import logger
|
|||||||
class OpenAITranscriptionProvider:
|
class OpenAITranscriptionProvider:
|
||||||
"""Voice transcription provider using OpenAI's Whisper API."""
|
"""Voice transcription provider using OpenAI's Whisper API."""
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
||||||
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
self.api_url = "https://api.openai.com/v1/audio/transcriptions"
|
self.api_url = (
|
||||||
|
api_base
|
||||||
|
or os.environ.get("OPENAI_TRANSCRIPTION_BASE_URL")
|
||||||
|
or "https://api.openai.com/v1/audio/transcriptions"
|
||||||
|
)
|
||||||
|
|
||||||
async def transcribe(self, file_path: str | Path) -> str:
|
async def transcribe(self, file_path: str | Path) -> str:
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
|
|||||||
@ -225,6 +225,81 @@ async def test_manager_propagates_groq_transcription_api_base_to_channels():
|
|||||||
assert channel.transcription_api_base == "http://proxy.local/v1/audio/transcriptions"
|
assert channel.transcription_api_base == "http://proxy.local/v1/audio/transcriptions"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_propagates_openai_transcription_api_base_to_channels():
|
||||||
|
from nanobot.channels.manager import ChannelManager
|
||||||
|
|
||||||
|
fake_config = SimpleNamespace(
|
||||||
|
channels=ChannelsConfig.model_validate({
|
||||||
|
"fakeplugin": {"enabled": True, "allowFrom": ["*"]},
|
||||||
|
"transcriptionProvider": "openai",
|
||||||
|
}),
|
||||||
|
providers=SimpleNamespace(
|
||||||
|
openai=SimpleNamespace(
|
||||||
|
api_key="openai-key",
|
||||||
|
api_base="http://proxy.local/v1/audio/transcriptions",
|
||||||
|
),
|
||||||
|
groq=SimpleNamespace(api_key="groq-key", api_base=""),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"nanobot.channels.registry.discover_all",
|
||||||
|
return_value={"fakeplugin": _FakePlugin},
|
||||||
|
):
|
||||||
|
mgr = ChannelManager.__new__(ChannelManager)
|
||||||
|
mgr.config = fake_config
|
||||||
|
mgr.bus = MessageBus()
|
||||||
|
mgr.channels = {}
|
||||||
|
mgr._dispatch_task = None
|
||||||
|
mgr._init_channels()
|
||||||
|
|
||||||
|
channel = mgr.channels["fakeplugin"]
|
||||||
|
assert channel.transcription_provider == "openai"
|
||||||
|
assert channel.transcription_api_key == "openai-key"
|
||||||
|
assert channel.transcription_api_base == "http://proxy.local/v1/audio/transcriptions"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_base_channel_passes_api_base_to_openai_transcription_provider():
|
||||||
|
"""BaseChannel.transcribe_audio must forward transcription_api_base to OpenAI."""
|
||||||
|
from nanobot.providers import transcription as transcription_mod
|
||||||
|
|
||||||
|
channel = _FakePlugin({"enabled": True, "allowFrom": ["*"]}, MessageBus())
|
||||||
|
channel.transcription_provider = "openai"
|
||||||
|
channel.transcription_api_key = "k"
|
||||||
|
channel.transcription_api_base = "http://override/v1/audio/transcriptions"
|
||||||
|
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
class _StubOpenAI:
|
||||||
|
def __init__(self, api_key=None, api_base=None):
|
||||||
|
captured["api_key"] = api_key
|
||||||
|
captured["api_base"] = api_base
|
||||||
|
|
||||||
|
async def transcribe(self, file_path):
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
with patch.object(transcription_mod, "OpenAITranscriptionProvider", _StubOpenAI):
|
||||||
|
result = await channel.transcribe_audio("/tmp/does-not-matter.wav")
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
assert captured["api_key"] == "k"
|
||||||
|
assert captured["api_base"] == "http://override/v1/audio/transcriptions"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_transcription_provider_honors_api_base_argument():
|
||||||
|
from nanobot.providers.transcription import OpenAITranscriptionProvider
|
||||||
|
|
||||||
|
default = OpenAITranscriptionProvider(api_key="k")
|
||||||
|
assert default.api_url == "https://api.openai.com/v1/audio/transcriptions"
|
||||||
|
|
||||||
|
custom = OpenAITranscriptionProvider(
|
||||||
|
api_key="k", api_base="http://override/v1/audio/transcriptions"
|
||||||
|
)
|
||||||
|
assert custom.api_url == "http://override/v1/audio/transcriptions"
|
||||||
|
|
||||||
|
|
||||||
def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
def test_channels_login_uses_discovered_plugin_class(monkeypatch):
|
||||||
from nanobot.cli.commands import app
|
from nanobot.cli.commands import app
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user