feat(channel): add proxy support for Discord channel

- Add proxy, proxy_username, proxy_password fields to DiscordConfig
- Pass proxy and proxy_auth to discord.Client
- Add aiohttp.BasicAuth when credentials are provided
- Add tests for proxy configuration scenarios
This commit is contained in:
Jonas 2026-04-09 14:24:04 +08:00 committed by Xubin Ren
parent 0e6331b66d
commit 7506af7104
2 changed files with 157 additions and 19 deletions

View File

@ -22,6 +22,7 @@ from nanobot.utils.helpers import safe_filename, split_message
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
if TYPE_CHECKING:
import aiohttp
import discord
from discord import app_commands
from discord.abc import Messageable
@ -58,6 +59,9 @@ class DiscordConfig(Base):
working_emoji: str = "🔧"
working_emoji_delay: float = 2.0
streaming: bool = True
proxy: str | None = None
proxy_username: str | None = None
proxy_password: str | None = None
if DISCORD_AVAILABLE:
@ -65,8 +69,15 @@ if DISCORD_AVAILABLE:
class DiscordBotClient(discord.Client):
"""discord.py client that forwards events to the channel."""
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
super().__init__(intents=intents)
def __init__(
self,
channel: DiscordChannel,
*,
intents: discord.Intents,
proxy: str | None = None,
proxy_auth: aiohttp.BasicAuth | None = None,
) -> None:
super().__init__(intents=intents, proxy=proxy, proxy_auth=proxy_auth)
self._channel = channel
self.tree = app_commands.CommandTree(self)
self._register_app_commands()
@ -130,6 +141,7 @@ if DISCORD_AVAILABLE:
)
for name, description, command_text in commands:
@self.tree.command(name=name, description=description)
async def command_handler(
interaction: discord.Interaction,
@ -186,7 +198,9 @@ if DISCORD_AVAILABLE:
else:
failed_media.append(Path(media_path).name)
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
for index, chunk in enumerate(
self._build_chunks(msg.content or "", failed_media, sent_media)
):
kwargs: dict[str, Any] = {"content": chunk}
if index == 0 and reference is not None and not sent_media:
kwargs["reference"] = reference
@ -292,7 +306,22 @@ class DiscordChannel(BaseChannel):
try:
intents = discord.Intents.none()
intents.value = self.config.intents
self._client = DiscordBotClient(self, intents=intents)
proxy_auth = None
if self.config.proxy_username and self.config.proxy_password:
import aiohttp
proxy_auth = aiohttp.BasicAuth(
login=self.config.proxy_username,
password=self.config.proxy_password,
)
self._client = DiscordBotClient(
self,
intents=intents,
proxy=self.config.proxy,
proxy_auth=proxy_auth,
)
except Exception as e:
logger.error("Failed to initialize Discord client: {}", e)
self._client = None
@ -335,7 +364,9 @@ class DiscordChannel(BaseChannel):
await self._stop_typing(msg.chat_id)
await self._clear_reactions(msg.chat_id)
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
async def send_delta(
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
) -> None:
"""Progressive Discord delivery: send once, then edit until the stream ends."""
client = self._client
if client is None or not client.is_ready():
@ -355,7 +386,9 @@ class DiscordChannel(BaseChannel):
return
buf = self._stream_bufs.get(chat_id)
if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id):
if buf is None or (
stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id
):
buf = _StreamBuf(stream_id=stream_id)
self._stream_bufs[chat_id] = buf
elif buf.stream_id is None:
@ -534,7 +567,11 @@ class DiscordChannel(BaseChannel):
@staticmethod
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
"""Build metadata for inbound Discord messages."""
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
reply_to = (
str(message.reference.message_id)
if message.reference and message.reference.message_id
else None
)
return {
"message_id": str(message.id),
"guild_id": str(message.guild.id) if message.guild else None,
@ -549,7 +586,9 @@ class DiscordChannel(BaseChannel):
if self.config.group_policy == "mention":
bot_user_id = self._bot_user_id
if bot_user_id is None:
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
logger.debug(
"Discord message in {} ignored (bot identity unavailable)", message.channel.id
)
return False
if any(str(user.id) == bot_user_id for user in message.mentions):
@ -591,7 +630,6 @@ class DiscordChannel(BaseChannel):
except asyncio.CancelledError:
pass
async def _clear_reactions(self, chat_id: str) -> None:
"""Remove all pending reactions after bot replies."""
# Cancel delayed working emoji if it hasn't fired yet

View File

@ -5,11 +5,17 @@ from pathlib import Path
from types import SimpleNamespace
import pytest
discord = pytest.importorskip("discord")
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.discord import MAX_MESSAGE_LEN, DiscordBotClient, DiscordChannel, DiscordConfig
from nanobot.channels.discord import (
MAX_MESSAGE_LEN,
DiscordBotClient,
DiscordChannel,
DiscordConfig,
)
from nanobot.command.builtin import build_help_text
@ -18,9 +24,11 @@ class _FakeDiscordClient:
instances: list["_FakeDiscordClient"] = []
start_error: Exception | None = None
def __init__(self, owner, *, intents) -> None:
def __init__(self, owner, *, intents, proxy=None, proxy_auth=None) -> None:
self.owner = owner
self.intents = intents
self.proxy = proxy
self.proxy_auth = proxy_auth
self.closed = False
self.ready = True
self.channels: dict[int, object] = {}
@ -53,7 +61,9 @@ class _FakeDiscordClient:
class _FakeAttachment:
# Attachment double that can simulate successful or failing save() calls.
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
def __init__(
self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False
) -> None:
self.id = attachment_id
self.filename = filename
self.size = size
@ -211,7 +221,7 @@ async def test_start_handles_client_construction_failure(monkeypatch) -> None:
MessageBus(),
)
def _boom(owner, *, intents):
def _boom(owner, *, intents, proxy=None, proxy_auth=None):
raise RuntimeError("bad client")
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
@ -514,9 +524,7 @@ async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
assert new_cmd is not None
await new_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": "Processing /new...", "ephemeral": True}
]
assert interaction.response.messages == [{"content": "Processing /new...", "ephemeral": True}]
assert len(handled) == 1
assert handled[0]["content"] == "/new"
assert handled[0]["sender_id"] == "123"
@ -590,9 +598,7 @@ async def test_slash_help_returns_ephemeral_help_text() -> None:
assert help_cmd is not None
await help_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": build_help_text(), "ephemeral": True}
]
assert interaction.response.messages == [{"content": build_help_text(), "ephemeral": True}]
assert handled == []
@ -727,11 +733,13 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
def typing(self):
async def _waiter():
await release.wait()
# Hold the loop so task remains active until explicitly stopped.
class _Ctx(_TypingCtx):
async def __aenter__(self):
await super().__aenter__()
await _waiter()
return _Ctx()
typing_channel = _NoTriggerChannel(channel_id=123)
@ -745,3 +753,95 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
await asyncio.sleep(0)
assert channel._typing_tasks == {}
def test_config_accepts_proxy_fields() -> None:
config = DiscordConfig(
enabled=True,
token="token",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
proxy_username="user",
proxy_password="pass",
)
assert config.proxy == "http://127.0.0.1:7890"
assert config.proxy_username == "user"
assert config.proxy_password == "pass"
def test_config_proxy_defaults_to_none() -> None:
config = DiscordConfig(enabled=True, token="token", allow_from=["*"])
assert config.proxy is None
assert config.proxy_username is None
assert config.proxy_password is None
@pytest.mark.asyncio
async def test_start_passes_proxy_to_client(monkeypatch) -> None:
_FakeDiscordClient.instances.clear()
channel = DiscordChannel(
DiscordConfig(
enabled=True,
token="token",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
),
MessageBus(),
)
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
await channel.start()
assert channel.is_running is False
assert len(_FakeDiscordClient.instances) == 1
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
assert _FakeDiscordClient.instances[0].proxy_auth is None
@pytest.mark.asyncio
async def test_start_passes_proxy_auth_when_credentials_provided(monkeypatch) -> None:
aiohttp = pytest.importorskip("aiohttp")
_FakeDiscordClient.instances.clear()
channel = DiscordChannel(
DiscordConfig(
enabled=True,
token="token",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
proxy_username="user",
proxy_password="pass",
),
MessageBus(),
)
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
await channel.start()
assert channel.is_running is False
assert len(_FakeDiscordClient.instances) == 1
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
assert _FakeDiscordClient.instances[0].proxy_auth is not None
assert isinstance(_FakeDiscordClient.instances[0].proxy_auth, aiohttp.BasicAuth)
assert _FakeDiscordClient.instances[0].proxy_auth.login == "user"
assert _FakeDiscordClient.instances[0].proxy_auth.password == "pass"
@pytest.mark.asyncio
async def test_start_no_proxy_auth_when_only_username(monkeypatch) -> None:
_FakeDiscordClient.instances.clear()
channel = DiscordChannel(
DiscordConfig(
enabled=True,
token="token",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
proxy_username="user",
),
MessageBus(),
)
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
await channel.start()
assert channel.is_running is False
assert _FakeDiscordClient.instances[0].proxy_auth is None