mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-12 05:44:03 +00:00
feat(channel): add proxy support for Discord channel
- Add proxy, proxy_username, proxy_password fields to DiscordConfig - Pass proxy and proxy_auth to discord.Client - Add aiohttp.BasicAuth when credentials are provided - Add tests for proxy configuration scenarios
This commit is contained in:
parent
0e6331b66d
commit
7506af7104
@ -22,6 +22,7 @@ from nanobot.utils.helpers import safe_filename, split_message
|
|||||||
|
|
||||||
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
|
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
import aiohttp
|
||||||
import discord
|
import discord
|
||||||
from discord import app_commands
|
from discord import app_commands
|
||||||
from discord.abc import Messageable
|
from discord.abc import Messageable
|
||||||
@ -58,6 +59,9 @@ class DiscordConfig(Base):
|
|||||||
working_emoji: str = "🔧"
|
working_emoji: str = "🔧"
|
||||||
working_emoji_delay: float = 2.0
|
working_emoji_delay: float = 2.0
|
||||||
streaming: bool = True
|
streaming: bool = True
|
||||||
|
proxy: str | None = None
|
||||||
|
proxy_username: str | None = None
|
||||||
|
proxy_password: str | None = None
|
||||||
|
|
||||||
|
|
||||||
if DISCORD_AVAILABLE:
|
if DISCORD_AVAILABLE:
|
||||||
@ -65,8 +69,15 @@ if DISCORD_AVAILABLE:
|
|||||||
class DiscordBotClient(discord.Client):
|
class DiscordBotClient(discord.Client):
|
||||||
"""discord.py client that forwards events to the channel."""
|
"""discord.py client that forwards events to the channel."""
|
||||||
|
|
||||||
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
|
def __init__(
|
||||||
super().__init__(intents=intents)
|
self,
|
||||||
|
channel: DiscordChannel,
|
||||||
|
*,
|
||||||
|
intents: discord.Intents,
|
||||||
|
proxy: str | None = None,
|
||||||
|
proxy_auth: aiohttp.BasicAuth | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(intents=intents, proxy=proxy, proxy_auth=proxy_auth)
|
||||||
self._channel = channel
|
self._channel = channel
|
||||||
self.tree = app_commands.CommandTree(self)
|
self.tree = app_commands.CommandTree(self)
|
||||||
self._register_app_commands()
|
self._register_app_commands()
|
||||||
@ -130,6 +141,7 @@ if DISCORD_AVAILABLE:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for name, description, command_text in commands:
|
for name, description, command_text in commands:
|
||||||
|
|
||||||
@self.tree.command(name=name, description=description)
|
@self.tree.command(name=name, description=description)
|
||||||
async def command_handler(
|
async def command_handler(
|
||||||
interaction: discord.Interaction,
|
interaction: discord.Interaction,
|
||||||
@ -186,7 +198,9 @@ if DISCORD_AVAILABLE:
|
|||||||
else:
|
else:
|
||||||
failed_media.append(Path(media_path).name)
|
failed_media.append(Path(media_path).name)
|
||||||
|
|
||||||
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
|
for index, chunk in enumerate(
|
||||||
|
self._build_chunks(msg.content or "", failed_media, sent_media)
|
||||||
|
):
|
||||||
kwargs: dict[str, Any] = {"content": chunk}
|
kwargs: dict[str, Any] = {"content": chunk}
|
||||||
if index == 0 and reference is not None and not sent_media:
|
if index == 0 and reference is not None and not sent_media:
|
||||||
kwargs["reference"] = reference
|
kwargs["reference"] = reference
|
||||||
@ -292,7 +306,22 @@ class DiscordChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
intents = discord.Intents.none()
|
intents = discord.Intents.none()
|
||||||
intents.value = self.config.intents
|
intents.value = self.config.intents
|
||||||
self._client = DiscordBotClient(self, intents=intents)
|
|
||||||
|
proxy_auth = None
|
||||||
|
if self.config.proxy_username and self.config.proxy_password:
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
proxy_auth = aiohttp.BasicAuth(
|
||||||
|
login=self.config.proxy_username,
|
||||||
|
password=self.config.proxy_password,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client = DiscordBotClient(
|
||||||
|
self,
|
||||||
|
intents=intents,
|
||||||
|
proxy=self.config.proxy,
|
||||||
|
proxy_auth=proxy_auth,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to initialize Discord client: {}", e)
|
logger.error("Failed to initialize Discord client: {}", e)
|
||||||
self._client = None
|
self._client = None
|
||||||
@ -335,7 +364,9 @@ class DiscordChannel(BaseChannel):
|
|||||||
await self._stop_typing(msg.chat_id)
|
await self._stop_typing(msg.chat_id)
|
||||||
await self._clear_reactions(msg.chat_id)
|
await self._clear_reactions(msg.chat_id)
|
||||||
|
|
||||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
async def send_delta(
|
||||||
|
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
||||||
|
) -> None:
|
||||||
"""Progressive Discord delivery: send once, then edit until the stream ends."""
|
"""Progressive Discord delivery: send once, then edit until the stream ends."""
|
||||||
client = self._client
|
client = self._client
|
||||||
if client is None or not client.is_ready():
|
if client is None or not client.is_ready():
|
||||||
@ -355,7 +386,9 @@ class DiscordChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
buf = self._stream_bufs.get(chat_id)
|
buf = self._stream_bufs.get(chat_id)
|
||||||
if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id):
|
if buf is None or (
|
||||||
|
stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id
|
||||||
|
):
|
||||||
buf = _StreamBuf(stream_id=stream_id)
|
buf = _StreamBuf(stream_id=stream_id)
|
||||||
self._stream_bufs[chat_id] = buf
|
self._stream_bufs[chat_id] = buf
|
||||||
elif buf.stream_id is None:
|
elif buf.stream_id is None:
|
||||||
@ -534,7 +567,11 @@ class DiscordChannel(BaseChannel):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
||||||
"""Build metadata for inbound Discord messages."""
|
"""Build metadata for inbound Discord messages."""
|
||||||
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
|
reply_to = (
|
||||||
|
str(message.reference.message_id)
|
||||||
|
if message.reference and message.reference.message_id
|
||||||
|
else None
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"message_id": str(message.id),
|
"message_id": str(message.id),
|
||||||
"guild_id": str(message.guild.id) if message.guild else None,
|
"guild_id": str(message.guild.id) if message.guild else None,
|
||||||
@ -549,7 +586,9 @@ class DiscordChannel(BaseChannel):
|
|||||||
if self.config.group_policy == "mention":
|
if self.config.group_policy == "mention":
|
||||||
bot_user_id = self._bot_user_id
|
bot_user_id = self._bot_user_id
|
||||||
if bot_user_id is None:
|
if bot_user_id is None:
|
||||||
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
|
logger.debug(
|
||||||
|
"Discord message in {} ignored (bot identity unavailable)", message.channel.id
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if any(str(user.id) == bot_user_id for user in message.mentions):
|
if any(str(user.id) == bot_user_id for user in message.mentions):
|
||||||
@ -591,7 +630,6 @@ class DiscordChannel(BaseChannel):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def _clear_reactions(self, chat_id: str) -> None:
|
async def _clear_reactions(self, chat_id: str) -> None:
|
||||||
"""Remove all pending reactions after bot replies."""
|
"""Remove all pending reactions after bot replies."""
|
||||||
# Cancel delayed working emoji if it hasn't fired yet
|
# Cancel delayed working emoji if it hasn't fired yet
|
||||||
|
|||||||
@ -5,11 +5,17 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
discord = pytest.importorskip("discord")
|
discord = pytest.importorskip("discord")
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.discord import MAX_MESSAGE_LEN, DiscordBotClient, DiscordChannel, DiscordConfig
|
from nanobot.channels.discord import (
|
||||||
|
MAX_MESSAGE_LEN,
|
||||||
|
DiscordBotClient,
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordConfig,
|
||||||
|
)
|
||||||
from nanobot.command.builtin import build_help_text
|
from nanobot.command.builtin import build_help_text
|
||||||
|
|
||||||
|
|
||||||
@ -18,9 +24,11 @@ class _FakeDiscordClient:
|
|||||||
instances: list["_FakeDiscordClient"] = []
|
instances: list["_FakeDiscordClient"] = []
|
||||||
start_error: Exception | None = None
|
start_error: Exception | None = None
|
||||||
|
|
||||||
def __init__(self, owner, *, intents) -> None:
|
def __init__(self, owner, *, intents, proxy=None, proxy_auth=None) -> None:
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
self.intents = intents
|
self.intents = intents
|
||||||
|
self.proxy = proxy
|
||||||
|
self.proxy_auth = proxy_auth
|
||||||
self.closed = False
|
self.closed = False
|
||||||
self.ready = True
|
self.ready = True
|
||||||
self.channels: dict[int, object] = {}
|
self.channels: dict[int, object] = {}
|
||||||
@ -53,7 +61,9 @@ class _FakeDiscordClient:
|
|||||||
|
|
||||||
class _FakeAttachment:
|
class _FakeAttachment:
|
||||||
# Attachment double that can simulate successful or failing save() calls.
|
# Attachment double that can simulate successful or failing save() calls.
|
||||||
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
|
def __init__(
|
||||||
|
self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False
|
||||||
|
) -> None:
|
||||||
self.id = attachment_id
|
self.id = attachment_id
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -211,7 +221,7 @@ async def test_start_handles_client_construction_failure(monkeypatch) -> None:
|
|||||||
MessageBus(),
|
MessageBus(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _boom(owner, *, intents):
|
def _boom(owner, *, intents, proxy=None, proxy_auth=None):
|
||||||
raise RuntimeError("bad client")
|
raise RuntimeError("bad client")
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
|
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
|
||||||
@ -514,9 +524,7 @@ async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
|||||||
assert new_cmd is not None
|
assert new_cmd is not None
|
||||||
await new_cmd.callback(interaction)
|
await new_cmd.callback(interaction)
|
||||||
|
|
||||||
assert interaction.response.messages == [
|
assert interaction.response.messages == [{"content": "Processing /new...", "ephemeral": True}]
|
||||||
{"content": "Processing /new...", "ephemeral": True}
|
|
||||||
]
|
|
||||||
assert len(handled) == 1
|
assert len(handled) == 1
|
||||||
assert handled[0]["content"] == "/new"
|
assert handled[0]["content"] == "/new"
|
||||||
assert handled[0]["sender_id"] == "123"
|
assert handled[0]["sender_id"] == "123"
|
||||||
@ -590,9 +598,7 @@ async def test_slash_help_returns_ephemeral_help_text() -> None:
|
|||||||
assert help_cmd is not None
|
assert help_cmd is not None
|
||||||
await help_cmd.callback(interaction)
|
await help_cmd.callback(interaction)
|
||||||
|
|
||||||
assert interaction.response.messages == [
|
assert interaction.response.messages == [{"content": build_help_text(), "ephemeral": True}]
|
||||||
{"content": build_help_text(), "ephemeral": True}
|
|
||||||
]
|
|
||||||
assert handled == []
|
assert handled == []
|
||||||
|
|
||||||
|
|
||||||
@ -727,11 +733,13 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
|||||||
def typing(self):
|
def typing(self):
|
||||||
async def _waiter():
|
async def _waiter():
|
||||||
await release.wait()
|
await release.wait()
|
||||||
|
|
||||||
# Hold the loop so task remains active until explicitly stopped.
|
# Hold the loop so task remains active until explicitly stopped.
|
||||||
class _Ctx(_TypingCtx):
|
class _Ctx(_TypingCtx):
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
await super().__aenter__()
|
await super().__aenter__()
|
||||||
await _waiter()
|
await _waiter()
|
||||||
|
|
||||||
return _Ctx()
|
return _Ctx()
|
||||||
|
|
||||||
typing_channel = _NoTriggerChannel(channel_id=123)
|
typing_channel = _NoTriggerChannel(channel_id=123)
|
||||||
@ -745,3 +753,95 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
|||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
assert channel._typing_tasks == {}
|
assert channel._typing_tasks == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_accepts_proxy_fields() -> None:
|
||||||
|
config = DiscordConfig(
|
||||||
|
enabled=True,
|
||||||
|
token="token",
|
||||||
|
allow_from=["*"],
|
||||||
|
proxy="http://127.0.0.1:7890",
|
||||||
|
proxy_username="user",
|
||||||
|
proxy_password="pass",
|
||||||
|
)
|
||||||
|
assert config.proxy == "http://127.0.0.1:7890"
|
||||||
|
assert config.proxy_username == "user"
|
||||||
|
assert config.proxy_password == "pass"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_proxy_defaults_to_none() -> None:
|
||||||
|
config = DiscordConfig(enabled=True, token="token", allow_from=["*"])
|
||||||
|
assert config.proxy is None
|
||||||
|
assert config.proxy_username is None
|
||||||
|
assert config.proxy_password is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_passes_proxy_to_client(monkeypatch) -> None:
|
||||||
|
_FakeDiscordClient.instances.clear()
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(
|
||||||
|
enabled=True,
|
||||||
|
token="token",
|
||||||
|
allow_from=["*"],
|
||||||
|
proxy="http://127.0.0.1:7890",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||||
|
|
||||||
|
await channel.start()
|
||||||
|
|
||||||
|
assert channel.is_running is False
|
||||||
|
assert len(_FakeDiscordClient.instances) == 1
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_passes_proxy_auth_when_credentials_provided(monkeypatch) -> None:
|
||||||
|
aiohttp = pytest.importorskip("aiohttp")
|
||||||
|
_FakeDiscordClient.instances.clear()
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(
|
||||||
|
enabled=True,
|
||||||
|
token="token",
|
||||||
|
allow_from=["*"],
|
||||||
|
proxy="http://127.0.0.1:7890",
|
||||||
|
proxy_username="user",
|
||||||
|
proxy_password="pass",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||||
|
|
||||||
|
await channel.start()
|
||||||
|
|
||||||
|
assert channel.is_running is False
|
||||||
|
assert len(_FakeDiscordClient.instances) == 1
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy_auth is not None
|
||||||
|
assert isinstance(_FakeDiscordClient.instances[0].proxy_auth, aiohttp.BasicAuth)
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy_auth.login == "user"
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy_auth.password == "pass"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_no_proxy_auth_when_only_username(monkeypatch) -> None:
|
||||||
|
_FakeDiscordClient.instances.clear()
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(
|
||||||
|
enabled=True,
|
||||||
|
token="token",
|
||||||
|
allow_from=["*"],
|
||||||
|
proxy="http://127.0.0.1:7890",
|
||||||
|
proxy_username="user",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||||
|
|
||||||
|
await channel.start()
|
||||||
|
|
||||||
|
assert channel.is_running is False
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user