"""Tests for channel plugin discovery, merging, and config compatibility.""" from __future__ import annotations import asyncio from types import SimpleNamespace from unittest.mock import AsyncMock, patch import pytest from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.channels.manager import ChannelManager from nanobot.config.schema import ChannelsConfig from nanobot.utils.restart import RestartNotice # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- class _FakePlugin(BaseChannel): name = "fakeplugin" display_name = "Fake Plugin" def __init__(self, config, bus): super().__init__(config, bus) self.login_calls: list[bool] = [] async def start(self) -> None: pass async def stop(self) -> None: pass async def send(self, msg: OutboundMessage) -> None: pass async def login(self, force: bool = False) -> bool: self.login_calls.append(force) return True class _FakeTelegram(BaseChannel): """Plugin that tries to shadow built-in telegram.""" name = "telegram" display_name = "Fake Telegram" async def start(self) -> None: pass async def stop(self) -> None: pass async def send(self, msg: OutboundMessage) -> None: pass def _make_entry_point(name: str, cls: type): """Create a mock entry point that returns *cls* on load().""" ep = SimpleNamespace(name=name, load=lambda _cls=cls: _cls) return ep # --------------------------------------------------------------------------- # ChannelsConfig extra="allow" # --------------------------------------------------------------------------- def test_channels_config_accepts_unknown_keys(): cfg = ChannelsConfig.model_validate({ "myplugin": {"enabled": True, "token": "abc"}, }) extra = cfg.model_extra assert extra is not None assert extra["myplugin"]["enabled"] is True assert extra["myplugin"]["token"] == "abc" def test_channels_config_getattr_returns_extra(): cfg = ChannelsConfig.model_validate({"myplugin": {"enabled": True}}) section = getattr(cfg, "myplugin", None) assert isinstance(section, dict) assert section["enabled"] is True def test_channels_config_builtin_fields_removed(): """After decoupling, ChannelsConfig has no explicit channel fields.""" cfg = ChannelsConfig() assert not hasattr(cfg, "telegram") assert cfg.send_progress is True assert cfg.send_tool_hints is False # --------------------------------------------------------------------------- # discover_plugins # --------------------------------------------------------------------------- _EP_TARGET = "importlib.metadata.entry_points" def test_discover_plugins_loads_entry_points(): from nanobot.channels.registry import discover_plugins ep = _make_entry_point("line", _FakePlugin) with patch(_EP_TARGET, return_value=[ep]): result = discover_plugins() assert "line" in result assert result["line"] is _FakePlugin def test_discover_plugins_handles_load_error(): from nanobot.channels.registry import discover_plugins def _boom(): raise RuntimeError("broken") ep = SimpleNamespace(name="broken", load=_boom) with patch(_EP_TARGET, return_value=[ep]): result = discover_plugins() assert "broken" not in result # --------------------------------------------------------------------------- # discover_all — merge & priority # --------------------------------------------------------------------------- def test_discover_all_includes_builtins(): from nanobot.channels.registry import discover_all, discover_channel_names with patch(_EP_TARGET, return_value=[]): result = discover_all() # discover_all() only returns channels that are actually available (dependencies installed) # discover_channel_names() returns all built-in channel names # So we check that all actually loaded channels are in the result for name in result: assert name in discover_channel_names() def test_discover_all_includes_external_plugin(): from nanobot.channels.registry import discover_all ep = _make_entry_point("line", _FakePlugin) with patch(_EP_TARGET, return_value=[ep]): result = discover_all() assert "line" in result assert result["line"] is _FakePlugin def test_discover_all_builtin_shadows_plugin(): from nanobot.channels.registry import discover_all ep = _make_entry_point("telegram", _FakeTelegram) with patch(_EP_TARGET, return_value=[ep]): result = discover_all() assert "telegram" in result assert result["telegram"] is not _FakeTelegram # --------------------------------------------------------------------------- # Manager _init_channels with dict config (plugin scenario) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_manager_loads_plugin_from_dict_config(): """ChannelManager should instantiate a plugin channel from a raw dict config.""" from nanobot.channels.manager import ChannelManager fake_config = SimpleNamespace( channels=ChannelsConfig.model_validate({ "fakeplugin": {"enabled": True, "allowFrom": ["*"]}, }), providers=SimpleNamespace(groq=SimpleNamespace(api_key="", api_base="")), ) with patch( "nanobot.channels.registry.discover_all", return_value={"fakeplugin": _FakePlugin}, ): mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {} mgr._dispatch_task = None mgr._init_channels() assert "fakeplugin" in mgr.channels assert isinstance(mgr.channels["fakeplugin"], _FakePlugin) @pytest.mark.asyncio async def test_manager_propagates_groq_transcription_api_base_to_channels(): from nanobot.channels.manager import ChannelManager fake_config = SimpleNamespace( channels=ChannelsConfig.model_validate({ "fakeplugin": {"enabled": True, "allowFrom": ["*"]}, "transcriptionLanguage": "en", }), providers=SimpleNamespace( groq=SimpleNamespace(api_key="groq-key", api_base="http://proxy.local/v1/audio/transcriptions"), openai=SimpleNamespace(api_key="openai-key", api_base="https://api.openai.com/v1/audio/transcriptions"), ), ) with patch( "nanobot.channels.registry.discover_all", return_value={"fakeplugin": _FakePlugin}, ): mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {} mgr._dispatch_task = None mgr._init_channels() channel = mgr.channels["fakeplugin"] assert channel.transcription_provider == "groq" assert channel.transcription_api_key == "groq-key" assert channel.transcription_api_base == "http://proxy.local/v1/audio/transcriptions" assert channel.transcription_language == "en" @pytest.mark.asyncio async def test_manager_propagates_openai_transcription_api_base_to_channels(): from nanobot.channels.manager import ChannelManager fake_config = SimpleNamespace( channels=ChannelsConfig.model_validate({ "fakeplugin": {"enabled": True, "allowFrom": ["*"]}, "transcriptionProvider": "openai", }), providers=SimpleNamespace( openai=SimpleNamespace( api_key="openai-key", api_base="http://proxy.local/v1/audio/transcriptions", ), groq=SimpleNamespace(api_key="groq-key", api_base=""), ), ) with patch( "nanobot.channels.registry.discover_all", return_value={"fakeplugin": _FakePlugin}, ): mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {} mgr._dispatch_task = None mgr._init_channels() channel = mgr.channels["fakeplugin"] assert channel.transcription_provider == "openai" assert channel.transcription_api_key == "openai-key" assert channel.transcription_api_base == "http://proxy.local/v1/audio/transcriptions" @pytest.mark.asyncio async def test_base_channel_passes_api_base_to_openai_transcription_provider(): """BaseChannel.transcribe_audio must forward transcription_api_base to OpenAI.""" from nanobot.providers import transcription as transcription_mod channel = _FakePlugin({"enabled": True, "allowFrom": ["*"]}, MessageBus()) channel.transcription_provider = "openai" channel.transcription_api_key = "k" channel.transcription_api_base = "http://override/v1/audio/transcriptions" channel.transcription_language = "en" captured: dict[str, object] = {} class _StubOpenAI: def __init__(self, api_key=None, api_base=None, language=None): captured["api_key"] = api_key captured["api_base"] = api_base captured["language"] = language async def transcribe(self, file_path): return "ok" with patch.object(transcription_mod, "OpenAITranscriptionProvider", _StubOpenAI): result = await channel.transcribe_audio("/tmp/does-not-matter.wav") assert result == "ok" assert captured["api_key"] == "k" assert captured["api_base"] == "http://override/v1/audio/transcriptions" assert captured["language"] == "en" def test_openai_transcription_provider_honors_api_base_argument(): from nanobot.providers.transcription import OpenAITranscriptionProvider default = OpenAITranscriptionProvider(api_key="k") assert default.api_url == "https://api.openai.com/v1/audio/transcriptions" custom = OpenAITranscriptionProvider( api_key="k", api_base="http://override/v1/audio/transcriptions" ) assert custom.api_url == "http://override/v1/audio/transcriptions" @pytest.mark.asyncio async def test_base_channel_passes_language_to_groq_transcription_provider(): """BaseChannel.transcribe_audio must forward transcription_language to Groq.""" from nanobot.providers import transcription as transcription_mod channel = _FakePlugin({"enabled": True, "allowFrom": ["*"]}, MessageBus()) channel.transcription_provider = "groq" channel.transcription_api_key = "k" channel.transcription_api_base = "http://override/v1/audio/transcriptions" channel.transcription_language = "ko" captured: dict[str, object] = {} class _StubGroq: def __init__(self, api_key=None, api_base=None, language=None): captured["api_key"] = api_key captured["api_base"] = api_base captured["language"] = language async def transcribe(self, file_path): return "ok" with patch.object(transcription_mod, "GroqTranscriptionProvider", _StubGroq): result = await channel.transcribe_audio("/tmp/does-not-matter.wav") assert result == "ok" assert captured["api_key"] == "k" assert captured["api_base"] == "http://override/v1/audio/transcriptions" assert captured["language"] == "ko" # --------------------------------------------------------------------------- # Transcription provider HTTP tests # --------------------------------------------------------------------------- from nanobot.providers.transcription import GroqTranscriptionProvider as _GroqProvider from nanobot.providers.transcription import OpenAITranscriptionProvider as _OpenAIProvider class _StubResponse: def raise_for_status(self): return None def json(self): return {"text": "hello"} def _stub_async_client(captured: dict[str, object]): """Return an httpx.AsyncClient stub that records POST calls into *captured*.""" class _AsyncClient: async def __aenter__(self): return self async def __aexit__(self, exc_type, exc, tb): return False async def post(self, url, headers=None, files=None, timeout=None): captured["files"] = files return _StubResponse() return _AsyncClient() @pytest.mark.parametrize( "provider_cls,language", [(_GroqProvider, "ko"), (_OpenAIProvider, "en")], ids=["groq", "openai"], ) @pytest.mark.asyncio async def test_transcription_provider_includes_language(tmp_path, provider_cls, language): """Provider must include the 'language' field in multipart body when set.""" audio = tmp_path / "sample.wav" audio.write_bytes(b"audio") captured: dict[str, object] = {} with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_stub_async_client(captured)): provider = provider_cls(api_key="k", language=language) result = await provider.transcribe(audio) assert result == "hello" assert captured["files"]["language"] == (None, language) @pytest.mark.parametrize( "provider_cls", [_GroqProvider, _OpenAIProvider], ids=["groq", "openai"], ) @pytest.mark.asyncio async def test_transcription_provider_omits_language_when_none(tmp_path, provider_cls): """When language is not set, the 'language' key must be absent from the multipart body.""" audio = tmp_path / "sample.wav" audio.write_bytes(b"audio") captured: dict[str, object] = {} with patch("nanobot.providers.transcription.httpx.AsyncClient", return_value=_stub_async_client(captured)): provider = provider_cls(api_key="k") result = await provider.transcribe(audio) assert result == "hello" assert "language" not in captured["files"] def test_channels_login_uses_discovered_plugin_class(monkeypatch): from typer.testing import CliRunner from nanobot.cli.commands import app from nanobot.config.schema import Config runner = CliRunner() seen: dict[str, object] = {} class _LoginPlugin(_FakePlugin): display_name = "Login Plugin" async def login(self, force: bool = False) -> bool: seen["force"] = force seen["config"] = self.config return True monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config()) monkeypatch.setattr( "nanobot.channels.registry.discover_all", lambda: {"fakeplugin": _LoginPlugin}, ) result = runner.invoke(app, ["channels", "login", "fakeplugin", "--force"]) assert result.exit_code == 0 assert seen["force"] is True def test_channels_login_sets_custom_config_path(monkeypatch, tmp_path): from typer.testing import CliRunner from nanobot.cli.commands import app from nanobot.config.schema import Config runner = CliRunner() seen: dict[str, object] = {} config_path = tmp_path / "custom-config.json" class _LoginPlugin(_FakePlugin): async def login(self, force: bool = False) -> bool: return True monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config()) monkeypatch.setattr( "nanobot.config.loader.set_config_path", lambda path: seen.__setitem__("config_path", path), ) monkeypatch.setattr( "nanobot.channels.registry.discover_all", lambda: {"fakeplugin": _LoginPlugin}, ) result = runner.invoke(app, ["channels", "login", "fakeplugin", "--config", str(config_path)]) assert result.exit_code == 0 assert seen["config_path"] == config_path.resolve() def test_channels_status_sets_custom_config_path(monkeypatch, tmp_path): from typer.testing import CliRunner from nanobot.cli.commands import app from nanobot.config.schema import Config runner = CliRunner() seen: dict[str, object] = {} config_path = tmp_path / "custom-config.json" monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config()) monkeypatch.setattr( "nanobot.config.loader.set_config_path", lambda path: seen.__setitem__("config_path", path), ) monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {}) result = runner.invoke(app, ["channels", "status", "--config", str(config_path)]) assert result.exit_code == 0 assert seen["config_path"] == config_path.resolve() @pytest.mark.asyncio async def test_manager_skips_disabled_plugin(): fake_config = SimpleNamespace( channels=ChannelsConfig.model_validate({ "fakeplugin": {"enabled": False}, }), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) with patch( "nanobot.channels.registry.discover_all", return_value={"fakeplugin": _FakePlugin}, ): mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {} mgr._dispatch_task = None mgr._init_channels() assert "fakeplugin" not in mgr.channels # --------------------------------------------------------------------------- # Built-in channel default_config() and dict->Pydantic conversion # --------------------------------------------------------------------------- def test_builtin_channel_default_config(): """Built-in channels expose default_config() returning a dict with 'enabled': False.""" from nanobot.channels.telegram import TelegramChannel cfg = TelegramChannel.default_config() assert isinstance(cfg, dict) assert cfg["enabled"] is False assert "token" in cfg def test_builtin_channel_init_from_dict(): """Built-in channels accept a raw dict and convert to Pydantic internally.""" from nanobot.channels.telegram import TelegramChannel bus = MessageBus() ch = TelegramChannel({"enabled": False, "token": "test-tok", "allowFrom": ["*"]}, bus) assert ch.config.token == "test-tok" assert ch.config.allow_from == ["*"] def test_channels_config_send_max_retries_default(): """ChannelsConfig should have send_max_retries with default value of 3.""" cfg = ChannelsConfig() assert hasattr(cfg, 'send_max_retries') assert cfg.send_max_retries == 3 def test_channels_config_send_max_retries_upper_bound(): """send_max_retries should be bounded to prevent resource exhaustion.""" from pydantic import ValidationError # Value too high should be rejected with pytest.raises(ValidationError): ChannelsConfig(send_max_retries=100) # Negative should be rejected with pytest.raises(ValidationError): ChannelsConfig(send_max_retries=-1) # Boundary values should be allowed cfg_min = ChannelsConfig(send_max_retries=0) assert cfg_min.send_max_retries == 0 cfg_max = ChannelsConfig(send_max_retries=10) assert cfg_max.send_max_retries == 10 # Value above upper bound should be rejected with pytest.raises(ValidationError): ChannelsConfig(send_max_retries=11) def test_channels_config_transcription_language_pattern(): """transcription_language must match ISO-639 format (2-3 lowercase letters) or be None.""" from pydantic import ValidationError # Valid values assert ChannelsConfig(transcription_language="en").transcription_language == "en" assert ChannelsConfig(transcription_language="kor").transcription_language == "kor" assert ChannelsConfig(transcription_language=None).transcription_language is None # Invalid values with pytest.raises(ValidationError): ChannelsConfig(transcription_language="EN") # uppercase with pytest.raises(ValidationError): ChannelsConfig(transcription_language="english") # full word with pytest.raises(ValidationError): ChannelsConfig(transcription_language="en-US") # BCP 47 tag # --------------------------------------------------------------------------- # _send_with_retry # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_send_with_retry_succeeds_first_try(): """_send_with_retry should succeed on first try and not retry.""" call_count = 0 class _FailingChannel(BaseChannel): name = "failing" display_name = "Failing" async def start(self) -> None: pass async def stop(self) -> None: pass async def send(self, msg: OutboundMessage) -> None: nonlocal call_count call_count += 1 # Succeeds on first try fake_config = SimpleNamespace( channels=ChannelsConfig(send_max_retries=3), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)} mgr._dispatch_task = None msg = OutboundMessage(channel="failing", chat_id="123", content="test") await mgr._send_with_retry(mgr.channels["failing"], msg) assert call_count == 1 @pytest.mark.asyncio async def test_send_with_retry_retries_on_failure(): """_send_with_retry should retry on failure up to max_retries times.""" call_count = 0 class _FailingChannel(BaseChannel): name = "failing" display_name = "Failing" async def start(self) -> None: pass async def stop(self) -> None: pass async def send(self, msg: OutboundMessage) -> None: nonlocal call_count call_count += 1 raise RuntimeError("simulated failure") fake_config = SimpleNamespace( channels=ChannelsConfig(send_max_retries=3), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)} mgr._dispatch_task = None msg = OutboundMessage(channel="failing", chat_id="123", content="test") # Patch asyncio.sleep to avoid actual delays with patch("nanobot.channels.manager.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: await mgr._send_with_retry(mgr.channels["failing"], msg) assert call_count == 3 # 3 total attempts (initial + 2 retries) assert mock_sleep.call_count == 2 # 2 sleeps between retries @pytest.mark.asyncio async def test_send_with_retry_no_retry_when_max_is_zero(): """_send_with_retry should not retry when send_max_retries is 0.""" call_count = 0 class _FailingChannel(BaseChannel): name = "failing" display_name = "Failing" async def start(self) -> None: pass async def stop(self) -> None: pass async def send(self, msg: OutboundMessage) -> None: nonlocal call_count call_count += 1 raise RuntimeError("simulated failure") fake_config = SimpleNamespace( channels=ChannelsConfig(send_max_retries=0), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)} mgr._dispatch_task = None msg = OutboundMessage(channel="failing", chat_id="123", content="test") with patch("nanobot.channels.manager.asyncio.sleep", new_callable=AsyncMock): await mgr._send_with_retry(mgr.channels["failing"], msg) assert call_count == 1 # Called once but no retry (max(0, 1) = 1) @pytest.mark.asyncio async def test_send_with_retry_calls_send_delta(): """_send_with_retry should call send_delta when metadata has _stream_delta.""" send_delta_called = False class _StreamingChannel(BaseChannel): name = "streaming" display_name = "Streaming" async def start(self) -> None: pass async def stop(self) -> None: pass async def send(self, msg: OutboundMessage) -> None: pass # Should not be called async def send_delta(self, chat_id: str, delta: str, metadata: dict | None = None) -> None: nonlocal send_delta_called send_delta_called = True fake_config = SimpleNamespace( channels=ChannelsConfig(send_max_retries=3), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {"streaming": _StreamingChannel(fake_config, mgr.bus)} mgr._dispatch_task = None msg = OutboundMessage( channel="streaming", chat_id="123", content="test delta", metadata={"_stream_delta": True} ) await mgr._send_with_retry(mgr.channels["streaming"], msg) assert send_delta_called is True @pytest.mark.asyncio async def test_send_with_retry_skips_send_when_streamed(): """_send_with_retry should not call send when metadata has _streamed flag.""" send_called = False send_delta_called = False class _StreamedChannel(BaseChannel): name = "streamed" display_name = "Streamed" async def start(self) -> None: pass async def stop(self) -> None: pass async def send(self, msg: OutboundMessage) -> None: nonlocal send_called send_called = True async def send_delta(self, chat_id: str, delta: str, metadata: dict | None = None) -> None: nonlocal send_delta_called send_delta_called = True fake_config = SimpleNamespace( channels=ChannelsConfig(send_max_retries=3), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {"streamed": _StreamedChannel(fake_config, mgr.bus)} mgr._dispatch_task = None # _streamed means message was already sent via send_delta, so skip send msg = OutboundMessage( channel="streamed", chat_id="123", content="test", metadata={"_streamed": True} ) await mgr._send_with_retry(mgr.channels["streamed"], msg) assert send_called is False assert send_delta_called is False @pytest.mark.asyncio async def test_send_with_retry_propagates_cancelled_error(): """_send_with_retry should re-raise CancelledError for graceful shutdown.""" class _CancellingChannel(BaseChannel): name = "cancelling" display_name = "Cancelling" async def start(self) -> None: pass async def stop(self) -> None: pass async def send(self, msg: OutboundMessage) -> None: raise asyncio.CancelledError("simulated cancellation") fake_config = SimpleNamespace( channels=ChannelsConfig(send_max_retries=3), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {"cancelling": _CancellingChannel(fake_config, mgr.bus)} mgr._dispatch_task = None msg = OutboundMessage(channel="cancelling", chat_id="123", content="test") with pytest.raises(asyncio.CancelledError): await mgr._send_with_retry(mgr.channels["cancelling"], msg) @pytest.mark.asyncio async def test_send_with_retry_propagates_cancelled_error_during_sleep(): """_send_with_retry should re-raise CancelledError during sleep.""" call_count = 0 class _FailingChannel(BaseChannel): name = "failing" display_name = "Failing" async def start(self) -> None: pass async def stop(self) -> None: pass async def send(self, msg: OutboundMessage) -> None: nonlocal call_count call_count += 1 raise RuntimeError("simulated failure") fake_config = SimpleNamespace( channels=ChannelsConfig(send_max_retries=3), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)} mgr._dispatch_task = None msg = OutboundMessage(channel="failing", chat_id="123", content="test") # Mock sleep to raise CancelledError async def cancel_during_sleep(_): raise asyncio.CancelledError("cancelled during sleep") with patch("nanobot.channels.manager.asyncio.sleep", side_effect=cancel_during_sleep): with pytest.raises(asyncio.CancelledError): await mgr._send_with_retry(mgr.channels["failing"], msg) # Should have attempted once before sleep was cancelled assert call_count == 1 # --------------------------------------------------------------------------- # ChannelManager - lifecycle and getters # --------------------------------------------------------------------------- class _ChannelWithAllowFrom(BaseChannel): """Channel with configurable allow_from.""" name = "withallow" display_name = "With Allow" def __init__(self, config, bus, allow_from): super().__init__(config, bus) if isinstance(self.config, dict): self.config["allow_from"] = allow_from else: self.config.allow_from = allow_from async def start(self) -> None: pass async def stop(self) -> None: pass async def send(self, msg: OutboundMessage) -> None: pass class _StartableChannel(BaseChannel): """Channel that tracks start/stop calls.""" name = "startable" display_name = "Startable" def __init__(self, config, bus): super().__init__(config, bus) self.started = False self.stopped = False async def start(self) -> None: self.started = True async def stop(self) -> None: self.stopped = True async def send(self, msg: OutboundMessage) -> None: pass @pytest.mark.asyncio async def test_validate_allow_from_raises_on_empty_list(): """_validate_allow_from should raise SystemExit when allow_from is empty list.""" fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])} mgr._dispatch_task = None with pytest.raises(SystemExit) as exc_info: mgr._validate_allow_from() assert "empty allowFrom" in str(exc_info.value) @pytest.mark.asyncio async def test_validate_allow_from_passes_with_asterisk(): """_validate_allow_from should not raise when allow_from contains '*'.""" fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, ["*"])} mgr._dispatch_task = None # Should not raise mgr._validate_allow_from() @pytest.mark.asyncio async def test_validate_allow_from_raises_on_empty_dict_allow_from(): """_validate_allow_from should reject empty dict-backed allow_from lists.""" fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.channels = {"test": _ChannelWithAllowFrom({"enabled": True}, None, [])} mgr._dispatch_task = None with pytest.raises(SystemExit) as exc_info: mgr._validate_allow_from() assert "empty allowFrom" in str(exc_info.value) @pytest.mark.asyncio async def test_get_channel_returns_channel_if_exists(): """get_channel should return the channel if it exists.""" fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {"telegram": _StartableChannel(fake_config, mgr.bus)} mgr._dispatch_task = None assert mgr.get_channel("telegram") is not None assert mgr.get_channel("nonexistent") is None @pytest.mark.asyncio async def test_get_status_returns_running_state(): """get_status should return enabled and running state for each channel.""" fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() ch = _StartableChannel(fake_config, mgr.bus) mgr.channels = {"startable": ch} mgr._dispatch_task = None status = mgr.get_status() assert status["startable"]["enabled"] is True assert status["startable"]["running"] is False # Not started yet @pytest.mark.asyncio async def test_enabled_channels_returns_channel_names(): """enabled_channels should return list of enabled channel names.""" fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = { "telegram": _StartableChannel(fake_config, mgr.bus), "slack": _StartableChannel(fake_config, mgr.bus), } mgr._dispatch_task = None enabled = mgr.enabled_channels assert "telegram" in enabled assert "slack" in enabled assert len(enabled) == 2 @pytest.mark.asyncio async def test_stop_all_cancels_dispatcher_and_stops_channels(): """stop_all should cancel the dispatch task and stop all channels.""" fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() ch = _StartableChannel(fake_config, mgr.bus) mgr.channels = {"startable": ch} # Create a real cancelled task async def dummy_task(): while True: await asyncio.sleep(1) dispatch_task = asyncio.create_task(dummy_task()) mgr._dispatch_task = dispatch_task await mgr.stop_all() # Task should be cancelled assert dispatch_task.cancelled() # Channel should be stopped assert ch.stopped is True @pytest.mark.asyncio async def test_start_channel_logs_error_on_failure(): """_start_channel should log error when channel start fails.""" class _FailingChannel(BaseChannel): name = "failing" display_name = "Failing" async def start(self) -> None: raise RuntimeError("connection failed") async def stop(self) -> None: pass async def send(self, msg: OutboundMessage) -> None: pass fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {} mgr._dispatch_task = None ch = _FailingChannel(fake_config, mgr.bus) # Should not raise, just log error await mgr._start_channel("failing", ch) @pytest.mark.asyncio async def test_stop_all_handles_channel_exception(): """stop_all should handle exceptions when stopping channels gracefully.""" class _StopFailingChannel(BaseChannel): name = "stopfailing" display_name = "Stop Failing" async def start(self) -> None: pass async def stop(self) -> None: raise RuntimeError("stop failed") async def send(self, msg: OutboundMessage) -> None: pass fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {"stopfailing": _StopFailingChannel(fake_config, mgr.bus)} mgr._dispatch_task = None # Should not raise even if channel.stop() raises await mgr.stop_all() @pytest.mark.asyncio async def test_start_all_no_channels_logs_warning(): """start_all should log warning when no channels are enabled.""" fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {} # No channels mgr._dispatch_task = None # Should return early without creating dispatch task await mgr.start_all() assert mgr._dispatch_task is None @pytest.mark.asyncio async def test_start_all_creates_dispatch_task(): """start_all should create the dispatch task when channels exist.""" fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() ch = _StartableChannel(fake_config, mgr.bus) mgr.channels = {"startable": ch} mgr._dispatch_task = None # Cancel immediately after start to avoid running forever async def cancel_after_start(): await asyncio.sleep(0.01) if mgr._dispatch_task: mgr._dispatch_task.cancel() cancel_task = asyncio.create_task(cancel_after_start()) try: await mgr.start_all() except asyncio.CancelledError: pass finally: cancel_task.cancel() try: await cancel_task except asyncio.CancelledError: pass # Dispatch task should have been created assert mgr._dispatch_task is not None @pytest.mark.asyncio async def test_notify_restart_done_enqueues_outbound_message(): """Restart notice should schedule send_with_retry for target channel.""" fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), ) mgr = ChannelManager.__new__(ChannelManager) mgr.config = fake_config mgr.bus = MessageBus() mgr.channels = {"feishu": _StartableChannel(fake_config, mgr.bus)} mgr._dispatch_task = None mgr._send_with_retry = AsyncMock() notice = RestartNotice(channel="feishu", chat_id="oc_123", started_at_raw="100.0") with patch("nanobot.channels.manager.consume_restart_notice_from_env", return_value=notice): mgr._notify_restart_done_if_needed() await asyncio.sleep(0) mgr._send_with_retry.assert_awaited_once() sent_channel, sent_msg = mgr._send_with_retry.await_args.args assert sent_channel is mgr.channels["feishu"] assert sent_msg.channel == "feishu" assert sent_msg.chat_id == "oc_123" assert sent_msg.content.startswith("Restart completed")