mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-10 21:23:39 +00:00
feat(channel): add proxy support for Discord channel
- Add proxy, proxy_username, proxy_password fields to DiscordConfig - Pass proxy and proxy_auth to discord.Client - Add aiohttp.BasicAuth when credentials are provided - Add tests for proxy configuration scenarios
This commit is contained in:
parent
0e6331b66d
commit
7506af7104
@ -22,6 +22,7 @@ from nanobot.utils.helpers import safe_filename, split_message
|
||||
|
||||
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.abc import Messageable
|
||||
@ -58,6 +59,9 @@ class DiscordConfig(Base):
|
||||
working_emoji: str = "🔧"
|
||||
working_emoji_delay: float = 2.0
|
||||
streaming: bool = True
|
||||
proxy: str | None = None
|
||||
proxy_username: str | None = None
|
||||
proxy_password: str | None = None
|
||||
|
||||
|
||||
if DISCORD_AVAILABLE:
|
||||
@ -65,8 +69,15 @@ if DISCORD_AVAILABLE:
|
||||
class DiscordBotClient(discord.Client):
|
||||
"""discord.py client that forwards events to the channel."""
|
||||
|
||||
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
|
||||
super().__init__(intents=intents)
|
||||
def __init__(
|
||||
self,
|
||||
channel: DiscordChannel,
|
||||
*,
|
||||
intents: discord.Intents,
|
||||
proxy: str | None = None,
|
||||
proxy_auth: aiohttp.BasicAuth | None = None,
|
||||
) -> None:
|
||||
super().__init__(intents=intents, proxy=proxy, proxy_auth=proxy_auth)
|
||||
self._channel = channel
|
||||
self.tree = app_commands.CommandTree(self)
|
||||
self._register_app_commands()
|
||||
@ -130,6 +141,7 @@ if DISCORD_AVAILABLE:
|
||||
)
|
||||
|
||||
for name, description, command_text in commands:
|
||||
|
||||
@self.tree.command(name=name, description=description)
|
||||
async def command_handler(
|
||||
interaction: discord.Interaction,
|
||||
@ -186,7 +198,9 @@ if DISCORD_AVAILABLE:
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
|
||||
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
|
||||
for index, chunk in enumerate(
|
||||
self._build_chunks(msg.content or "", failed_media, sent_media)
|
||||
):
|
||||
kwargs: dict[str, Any] = {"content": chunk}
|
||||
if index == 0 and reference is not None and not sent_media:
|
||||
kwargs["reference"] = reference
|
||||
@ -292,7 +306,22 @@ class DiscordChannel(BaseChannel):
|
||||
try:
|
||||
intents = discord.Intents.none()
|
||||
intents.value = self.config.intents
|
||||
self._client = DiscordBotClient(self, intents=intents)
|
||||
|
||||
proxy_auth = None
|
||||
if self.config.proxy_username and self.config.proxy_password:
|
||||
import aiohttp
|
||||
|
||||
proxy_auth = aiohttp.BasicAuth(
|
||||
login=self.config.proxy_username,
|
||||
password=self.config.proxy_password,
|
||||
)
|
||||
|
||||
self._client = DiscordBotClient(
|
||||
self,
|
||||
intents=intents,
|
||||
proxy=self.config.proxy,
|
||||
proxy_auth=proxy_auth,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Discord client: {}", e)
|
||||
self._client = None
|
||||
@ -335,7 +364,9 @@ class DiscordChannel(BaseChannel):
|
||||
await self._stop_typing(msg.chat_id)
|
||||
await self._clear_reactions(msg.chat_id)
|
||||
|
||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||
async def send_delta(
|
||||
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Progressive Discord delivery: send once, then edit until the stream ends."""
|
||||
client = self._client
|
||||
if client is None or not client.is_ready():
|
||||
@ -355,7 +386,9 @@ class DiscordChannel(BaseChannel):
|
||||
return
|
||||
|
||||
buf = self._stream_bufs.get(chat_id)
|
||||
if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id):
|
||||
if buf is None or (
|
||||
stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id
|
||||
):
|
||||
buf = _StreamBuf(stream_id=stream_id)
|
||||
self._stream_bufs[chat_id] = buf
|
||||
elif buf.stream_id is None:
|
||||
@ -534,7 +567,11 @@ class DiscordChannel(BaseChannel):
|
||||
@staticmethod
|
||||
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
||||
"""Build metadata for inbound Discord messages."""
|
||||
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
|
||||
reply_to = (
|
||||
str(message.reference.message_id)
|
||||
if message.reference and message.reference.message_id
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"message_id": str(message.id),
|
||||
"guild_id": str(message.guild.id) if message.guild else None,
|
||||
@ -549,7 +586,9 @@ class DiscordChannel(BaseChannel):
|
||||
if self.config.group_policy == "mention":
|
||||
bot_user_id = self._bot_user_id
|
||||
if bot_user_id is None:
|
||||
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
|
||||
logger.debug(
|
||||
"Discord message in {} ignored (bot identity unavailable)", message.channel.id
|
||||
)
|
||||
return False
|
||||
|
||||
if any(str(user.id) == bot_user_id for user in message.mentions):
|
||||
@ -591,7 +630,6 @@ class DiscordChannel(BaseChannel):
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
async def _clear_reactions(self, chat_id: str) -> None:
|
||||
"""Remove all pending reactions after bot replies."""
|
||||
# Cancel delayed working emoji if it hasn't fired yet
|
||||
|
||||
@ -5,11 +5,17 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
discord = pytest.importorskip("discord")
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.discord import MAX_MESSAGE_LEN, DiscordBotClient, DiscordChannel, DiscordConfig
|
||||
from nanobot.channels.discord import (
|
||||
MAX_MESSAGE_LEN,
|
||||
DiscordBotClient,
|
||||
DiscordChannel,
|
||||
DiscordConfig,
|
||||
)
|
||||
from nanobot.command.builtin import build_help_text
|
||||
|
||||
|
||||
@ -18,9 +24,11 @@ class _FakeDiscordClient:
|
||||
instances: list["_FakeDiscordClient"] = []
|
||||
start_error: Exception | None = None
|
||||
|
||||
def __init__(self, owner, *, intents) -> None:
|
||||
def __init__(self, owner, *, intents, proxy=None, proxy_auth=None) -> None:
|
||||
self.owner = owner
|
||||
self.intents = intents
|
||||
self.proxy = proxy
|
||||
self.proxy_auth = proxy_auth
|
||||
self.closed = False
|
||||
self.ready = True
|
||||
self.channels: dict[int, object] = {}
|
||||
@ -53,7 +61,9 @@ class _FakeDiscordClient:
|
||||
|
||||
class _FakeAttachment:
|
||||
# Attachment double that can simulate successful or failing save() calls.
|
||||
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
|
||||
def __init__(
|
||||
self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False
|
||||
) -> None:
|
||||
self.id = attachment_id
|
||||
self.filename = filename
|
||||
self.size = size
|
||||
@ -211,7 +221,7 @@ async def test_start_handles_client_construction_failure(monkeypatch) -> None:
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
def _boom(owner, *, intents):
|
||||
def _boom(owner, *, intents, proxy=None, proxy_auth=None):
|
||||
raise RuntimeError("bad client")
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
|
||||
@ -514,9 +524,7 @@ async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
||||
assert new_cmd is not None
|
||||
await new_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": "Processing /new...", "ephemeral": True}
|
||||
]
|
||||
assert interaction.response.messages == [{"content": "Processing /new...", "ephemeral": True}]
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "/new"
|
||||
assert handled[0]["sender_id"] == "123"
|
||||
@ -590,9 +598,7 @@ async def test_slash_help_returns_ephemeral_help_text() -> None:
|
||||
assert help_cmd is not None
|
||||
await help_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": build_help_text(), "ephemeral": True}
|
||||
]
|
||||
assert interaction.response.messages == [{"content": build_help_text(), "ephemeral": True}]
|
||||
assert handled == []
|
||||
|
||||
|
||||
@ -727,11 +733,13 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
||||
def typing(self):
|
||||
async def _waiter():
|
||||
await release.wait()
|
||||
|
||||
# Hold the loop so task remains active until explicitly stopped.
|
||||
class _Ctx(_TypingCtx):
|
||||
async def __aenter__(self):
|
||||
await super().__aenter__()
|
||||
await _waiter()
|
||||
|
||||
return _Ctx()
|
||||
|
||||
typing_channel = _NoTriggerChannel(channel_id=123)
|
||||
@ -745,3 +753,95 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
def test_config_accepts_proxy_fields() -> None:
|
||||
config = DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
proxy_username="user",
|
||||
proxy_password="pass",
|
||||
)
|
||||
assert config.proxy == "http://127.0.0.1:7890"
|
||||
assert config.proxy_username == "user"
|
||||
assert config.proxy_password == "pass"
|
||||
|
||||
|
||||
def test_config_proxy_defaults_to_none() -> None:
|
||||
config = DiscordConfig(enabled=True, token="token", allow_from=["*"])
|
||||
assert config.proxy is None
|
||||
assert config.proxy_username is None
|
||||
assert config.proxy_password is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_passes_proxy_to_client(monkeypatch) -> None:
|
||||
_FakeDiscordClient.instances.clear()
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert len(_FakeDiscordClient.instances) == 1
|
||||
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_passes_proxy_auth_when_credentials_provided(monkeypatch) -> None:
|
||||
aiohttp = pytest.importorskip("aiohttp")
|
||||
_FakeDiscordClient.instances.clear()
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
proxy_username="user",
|
||||
proxy_password="pass",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert len(_FakeDiscordClient.instances) == 1
|
||||
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth is not None
|
||||
assert isinstance(_FakeDiscordClient.instances[0].proxy_auth, aiohttp.BasicAuth)
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth.login == "user"
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth.password == "pass"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_no_proxy_auth_when_only_username(monkeypatch) -> None:
|
||||
_FakeDiscordClient.instances.clear()
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
proxy_username="user",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user