diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index 9e68bb46b..c50b4ff19 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -22,6 +22,7 @@ from nanobot.utils.helpers import safe_filename, split_message DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None if TYPE_CHECKING: + import aiohttp import discord from discord import app_commands from discord.abc import Messageable @@ -58,6 +59,9 @@ class DiscordConfig(Base): working_emoji: str = "🔧" working_emoji_delay: float = 2.0 streaming: bool = True + proxy: str | None = None + proxy_username: str | None = None + proxy_password: str | None = None if DISCORD_AVAILABLE: @@ -65,8 +69,15 @@ if DISCORD_AVAILABLE: class DiscordBotClient(discord.Client): """discord.py client that forwards events to the channel.""" - def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None: - super().__init__(intents=intents) + def __init__( + self, + channel: DiscordChannel, + *, + intents: discord.Intents, + proxy: str | None = None, + proxy_auth: aiohttp.BasicAuth | None = None, + ) -> None: + super().__init__(intents=intents, proxy=proxy, proxy_auth=proxy_auth) self._channel = channel self.tree = app_commands.CommandTree(self) self._register_app_commands() @@ -130,6 +141,7 @@ if DISCORD_AVAILABLE: ) for name, description, command_text in commands: + @self.tree.command(name=name, description=description) async def command_handler( interaction: discord.Interaction, @@ -186,7 +198,9 @@ if DISCORD_AVAILABLE: else: failed_media.append(Path(media_path).name) - for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)): + for index, chunk in enumerate( + self._build_chunks(msg.content or "", failed_media, sent_media) + ): kwargs: dict[str, Any] = {"content": chunk} if index == 0 and reference is not None and not sent_media: kwargs["reference"] = reference @@ -292,7 +306,22 @@ class DiscordChannel(BaseChannel): try: intents = discord.Intents.none() intents.value = self.config.intents - self._client = DiscordBotClient(self, intents=intents) + + proxy_auth = None + if self.config.proxy_username and self.config.proxy_password: + import aiohttp + + proxy_auth = aiohttp.BasicAuth( + login=self.config.proxy_username, + password=self.config.proxy_password, + ) + + self._client = DiscordBotClient( + self, + intents=intents, + proxy=self.config.proxy, + proxy_auth=proxy_auth, + ) except Exception as e: logger.error("Failed to initialize Discord client: {}", e) self._client = None @@ -335,7 +364,9 @@ class DiscordChannel(BaseChannel): await self._stop_typing(msg.chat_id) await self._clear_reactions(msg.chat_id) - async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: + async def send_delta( + self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None + ) -> None: """Progressive Discord delivery: send once, then edit until the stream ends.""" client = self._client if client is None or not client.is_ready(): @@ -355,7 +386,9 @@ class DiscordChannel(BaseChannel): return buf = self._stream_bufs.get(chat_id) - if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id): + if buf is None or ( + stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id + ): buf = _StreamBuf(stream_id=stream_id) self._stream_bufs[chat_id] = buf elif buf.stream_id is None: @@ -534,7 +567,11 @@ class DiscordChannel(BaseChannel): @staticmethod def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]: """Build metadata for inbound Discord messages.""" - reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None + reply_to = ( + str(message.reference.message_id) + if message.reference and message.reference.message_id + else None + ) return { "message_id": str(message.id), "guild_id": str(message.guild.id) if message.guild else None, @@ -549,7 +586,9 @@ class DiscordChannel(BaseChannel): if self.config.group_policy == "mention": bot_user_id = self._bot_user_id if bot_user_id is None: - logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id) + logger.debug( + "Discord message in {} ignored (bot identity unavailable)", message.channel.id + ) return False if any(str(user.id) == bot_user_id for user in message.mentions): @@ -591,7 +630,6 @@ class DiscordChannel(BaseChannel): except asyncio.CancelledError: pass - async def _clear_reactions(self, chat_id: str) -> None: """Remove all pending reactions after bot replies.""" # Cancel delayed working emoji if it hasn't fired yet diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py index 09b80740f..3f0f3388a 100644 --- a/tests/channels/test_discord_channel.py +++ b/tests/channels/test_discord_channel.py @@ -5,11 +5,17 @@ from pathlib import Path from types import SimpleNamespace import pytest + discord = pytest.importorskip("discord") from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus -from nanobot.channels.discord import MAX_MESSAGE_LEN, DiscordBotClient, DiscordChannel, DiscordConfig +from nanobot.channels.discord import ( + MAX_MESSAGE_LEN, + DiscordBotClient, + DiscordChannel, + DiscordConfig, +) from nanobot.command.builtin import build_help_text @@ -18,9 +24,11 @@ class _FakeDiscordClient: instances: list["_FakeDiscordClient"] = [] start_error: Exception | None = None - def __init__(self, owner, *, intents) -> None: + def __init__(self, owner, *, intents, proxy=None, proxy_auth=None) -> None: self.owner = owner self.intents = intents + self.proxy = proxy + self.proxy_auth = proxy_auth self.closed = False self.ready = True self.channels: dict[int, object] = {} @@ -53,7 +61,9 @@ class _FakeDiscordClient: class _FakeAttachment: # Attachment double that can simulate successful or failing save() calls. - def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None: + def __init__( + self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False + ) -> None: self.id = attachment_id self.filename = filename self.size = size @@ -211,7 +221,7 @@ async def test_start_handles_client_construction_failure(monkeypatch) -> None: MessageBus(), ) - def _boom(owner, *, intents): + def _boom(owner, *, intents, proxy=None, proxy_auth=None): raise RuntimeError("bad client") monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom) @@ -514,9 +524,7 @@ async def test_slash_new_forwards_when_user_is_allowlisted() -> None: assert new_cmd is not None await new_cmd.callback(interaction) - assert interaction.response.messages == [ - {"content": "Processing /new...", "ephemeral": True} - ] + assert interaction.response.messages == [{"content": "Processing /new...", "ephemeral": True}] assert len(handled) == 1 assert handled[0]["content"] == "/new" assert handled[0]["sender_id"] == "123" @@ -590,9 +598,7 @@ async def test_slash_help_returns_ephemeral_help_text() -> None: assert help_cmd is not None await help_cmd.callback(interaction) - assert interaction.response.messages == [ - {"content": build_help_text(), "ephemeral": True} - ] + assert interaction.response.messages == [{"content": build_help_text(), "ephemeral": True}] assert handled == [] @@ -727,11 +733,13 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> def typing(self): async def _waiter(): await release.wait() + # Hold the loop so task remains active until explicitly stopped. class _Ctx(_TypingCtx): async def __aenter__(self): await super().__aenter__() await _waiter() + return _Ctx() typing_channel = _NoTriggerChannel(channel_id=123) @@ -745,3 +753,95 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> await asyncio.sleep(0) assert channel._typing_tasks == {} + + +def test_config_accepts_proxy_fields() -> None: + config = DiscordConfig( + enabled=True, + token="token", + allow_from=["*"], + proxy="http://127.0.0.1:7890", + proxy_username="user", + proxy_password="pass", + ) + assert config.proxy == "http://127.0.0.1:7890" + assert config.proxy_username == "user" + assert config.proxy_password == "pass" + + +def test_config_proxy_defaults_to_none() -> None: + config = DiscordConfig(enabled=True, token="token", allow_from=["*"]) + assert config.proxy is None + assert config.proxy_username is None + assert config.proxy_password is None + + +@pytest.mark.asyncio +async def test_start_passes_proxy_to_client(monkeypatch) -> None: + _FakeDiscordClient.instances.clear() + channel = DiscordChannel( + DiscordConfig( + enabled=True, + token="token", + allow_from=["*"], + proxy="http://127.0.0.1:7890", + ), + MessageBus(), + ) + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient) + + await channel.start() + + assert channel.is_running is False + assert len(_FakeDiscordClient.instances) == 1 + assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890" + assert _FakeDiscordClient.instances[0].proxy_auth is None + + +@pytest.mark.asyncio +async def test_start_passes_proxy_auth_when_credentials_provided(monkeypatch) -> None: + aiohttp = pytest.importorskip("aiohttp") + _FakeDiscordClient.instances.clear() + channel = DiscordChannel( + DiscordConfig( + enabled=True, + token="token", + allow_from=["*"], + proxy="http://127.0.0.1:7890", + proxy_username="user", + proxy_password="pass", + ), + MessageBus(), + ) + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient) + + await channel.start() + + assert channel.is_running is False + assert len(_FakeDiscordClient.instances) == 1 + assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890" + assert _FakeDiscordClient.instances[0].proxy_auth is not None + assert isinstance(_FakeDiscordClient.instances[0].proxy_auth, aiohttp.BasicAuth) + assert _FakeDiscordClient.instances[0].proxy_auth.login == "user" + assert _FakeDiscordClient.instances[0].proxy_auth.password == "pass" + + +@pytest.mark.asyncio +async def test_start_no_proxy_auth_when_only_username(monkeypatch) -> None: + _FakeDiscordClient.instances.clear() + channel = DiscordChannel( + DiscordConfig( + enabled=True, + token="token", + allow_from=["*"], + proxy="http://127.0.0.1:7890", + proxy_username="user", + ), + MessageBus(), + ) + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient) + + await channel.start() + + assert channel.is_running is False + assert _FakeDiscordClient.instances[0].proxy_auth is None