mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-11 05:33:36 +00:00
fix(discord): enable streaming replies
This commit is contained in:
parent
715f2a79be
commit
e49b6c0c96
@ -394,7 +394,8 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"],
|
||||
"groupPolicy": "mention"
|
||||
"groupPolicy": "mention",
|
||||
"streaming": true
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -405,6 +406,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
|
||||
> - `"open"` — Respond to all messages
|
||||
> DMs always respond when the sender is in `allowFrom`.
|
||||
> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session.
|
||||
> `streaming` defaults to `true`. Disable it only if you explicitly want non-streaming replies.
|
||||
|
||||
**5. Invite the bot**
|
||||
- OAuth2 → URL Generator
|
||||
|
||||
@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
@ -34,6 +36,16 @@ MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||
TYPING_INTERVAL_S = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StreamBuf:
|
||||
"""Per-chat streaming accumulator for progressive Discord message edits."""
|
||||
|
||||
text: str = ""
|
||||
message: Any | None = None
|
||||
last_edit: float = 0.0
|
||||
stream_id: str | None = None
|
||||
|
||||
|
||||
class DiscordConfig(Base):
|
||||
"""Discord channel configuration."""
|
||||
|
||||
@ -45,6 +57,7 @@ class DiscordConfig(Base):
|
||||
read_receipt_emoji: str = "👀"
|
||||
working_emoji: str = "🔧"
|
||||
working_emoji_delay: float = 2.0
|
||||
streaming: bool = True
|
||||
|
||||
|
||||
if DISCORD_AVAILABLE:
|
||||
@ -242,6 +255,7 @@ class DiscordChannel(BaseChannel):
|
||||
|
||||
name = "discord"
|
||||
display_name = "Discord"
|
||||
_STREAM_EDIT_INTERVAL = 0.8
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
@ -263,6 +277,7 @@ class DiscordChannel(BaseChannel):
|
||||
self._bot_user_id: str | None = None
|
||||
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
|
||||
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._stream_bufs: dict[str, _StreamBuf] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Discord client."""
|
||||
@ -320,6 +335,61 @@ class DiscordChannel(BaseChannel):
|
||||
await self._stop_typing(msg.chat_id)
|
||||
await self._clear_reactions(msg.chat_id)
|
||||
|
||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||
"""Progressive Discord delivery: send once, then edit until the stream ends."""
|
||||
client = self._client
|
||||
if client is None or not client.is_ready():
|
||||
logger.warning("Discord client not ready; dropping stream delta")
|
||||
return
|
||||
|
||||
meta = metadata or {}
|
||||
stream_id = meta.get("_stream_id")
|
||||
|
||||
if meta.get("_stream_end"):
|
||||
buf = self._stream_bufs.get(chat_id)
|
||||
if not buf or buf.message is None or not buf.text:
|
||||
return
|
||||
if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id:
|
||||
return
|
||||
await self._finalize_stream(chat_id, buf)
|
||||
return
|
||||
|
||||
buf = self._stream_bufs.get(chat_id)
|
||||
if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id):
|
||||
buf = _StreamBuf(stream_id=stream_id)
|
||||
self._stream_bufs[chat_id] = buf
|
||||
elif buf.stream_id is None:
|
||||
buf.stream_id = stream_id
|
||||
|
||||
buf.text += delta
|
||||
if not buf.text.strip():
|
||||
return
|
||||
|
||||
target = await self._resolve_channel(chat_id)
|
||||
if target is None:
|
||||
logger.warning("Discord stream target {} unavailable", chat_id)
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
if buf.message is None:
|
||||
try:
|
||||
buf.message = await target.send(content=buf.text)
|
||||
buf.last_edit = now
|
||||
except Exception as e:
|
||||
logger.warning("Discord stream initial send failed: {}", e)
|
||||
raise
|
||||
return
|
||||
|
||||
if (now - buf.last_edit) < self._STREAM_EDIT_INTERVAL:
|
||||
return
|
||||
|
||||
try:
|
||||
await buf.message.edit(content=DiscordBotClient._build_chunks(buf.text, [], False)[0])
|
||||
buf.last_edit = now
|
||||
except Exception as e:
|
||||
logger.warning("Discord stream edit failed: {}", e)
|
||||
raise
|
||||
|
||||
async def _handle_discord_message(self, message: discord.Message) -> None:
|
||||
"""Handle incoming Discord messages from discord.py."""
|
||||
if message.author.bot:
|
||||
@ -373,6 +443,47 @@ class DiscordChannel(BaseChannel):
|
||||
"""Backward-compatible alias for legacy tests/callers."""
|
||||
await self._handle_discord_message(message)
|
||||
|
||||
async def _resolve_channel(self, chat_id: str) -> Any | None:
|
||||
"""Resolve a Discord channel from cache first, then network fetch."""
|
||||
client = self._client
|
||||
if client is None or not client.is_ready():
|
||||
return None
|
||||
channel_id = int(chat_id)
|
||||
channel = client.get_channel(channel_id)
|
||||
if channel is not None:
|
||||
return channel
|
||||
try:
|
||||
return await client.fetch_channel(channel_id)
|
||||
except Exception as e:
|
||||
logger.warning("Discord channel {} unavailable: {}", chat_id, e)
|
||||
return None
|
||||
|
||||
async def _finalize_stream(self, chat_id: str, buf: _StreamBuf) -> None:
|
||||
"""Commit the final streamed content and flush overflow chunks."""
|
||||
chunks = DiscordBotClient._build_chunks(buf.text, [], False)
|
||||
if not chunks:
|
||||
self._stream_bufs.pop(chat_id, None)
|
||||
return
|
||||
|
||||
try:
|
||||
await buf.message.edit(content=chunks[0])
|
||||
except Exception as e:
|
||||
logger.warning("Discord final stream edit failed: {}", e)
|
||||
raise
|
||||
|
||||
target = getattr(buf.message, "channel", None) or await self._resolve_channel(chat_id)
|
||||
if target is None:
|
||||
logger.warning("Discord stream follow-up target {} unavailable", chat_id)
|
||||
self._stream_bufs.pop(chat_id, None)
|
||||
return
|
||||
|
||||
for extra_chunk in chunks[1:]:
|
||||
await target.send(content=extra_chunk)
|
||||
|
||||
self._stream_bufs.pop(chat_id, None)
|
||||
await self._stop_typing(chat_id)
|
||||
await self._clear_reactions(chat_id)
|
||||
|
||||
def _should_accept_inbound(
|
||||
self,
|
||||
message: discord.Message,
|
||||
@ -507,6 +618,7 @@ class DiscordChannel(BaseChannel):
|
||||
async def _reset_runtime_state(self, close_client: bool) -> None:
|
||||
"""Reset client and typing state."""
|
||||
await self._cancel_all_typing()
|
||||
self._stream_bufs.clear()
|
||||
if close_client and self._client is not None and not self._client.is_closed():
|
||||
try:
|
||||
await self._client.close()
|
||||
|
||||
@ -71,11 +71,25 @@ class _FakePartialMessage:
|
||||
self.id = message_id
|
||||
|
||||
|
||||
class _FakeSentMessage:
|
||||
# Sent-message double supporting edit() for streaming tests.
|
||||
def __init__(self, channel, content: str) -> None:
|
||||
self.channel = channel
|
||||
self.content = content
|
||||
self.edits: list[dict] = []
|
||||
|
||||
async def edit(self, **kwargs) -> None:
|
||||
self.edits.append(dict(kwargs))
|
||||
if "content" in kwargs:
|
||||
self.content = kwargs["content"]
|
||||
|
||||
|
||||
class _FakeChannel:
|
||||
# Channel double that records outbound payloads and typing activity.
|
||||
def __init__(self, channel_id: int = 123) -> None:
|
||||
self.id = channel_id
|
||||
self.sent_payloads: list[dict] = []
|
||||
self.sent_messages: list[_FakeSentMessage] = []
|
||||
self.trigger_typing_calls = 0
|
||||
self.typing_enter_hook = None
|
||||
|
||||
@ -85,6 +99,9 @@ class _FakeChannel:
|
||||
payload["file_name"] = payload["file"].filename
|
||||
del payload["file"]
|
||||
self.sent_payloads.append(payload)
|
||||
message = _FakeSentMessage(self, payload.get("content", ""))
|
||||
self.sent_messages.append(message)
|
||||
return message
|
||||
|
||||
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
|
||||
return _FakePartialMessage(message_id)
|
||||
@ -427,6 +444,33 @@ async def test_send_fetches_channel_when_not_cached() -> None:
|
||||
assert target.sent_payloads == [{"content": "hello"}]
|
||||
|
||||
|
||||
def test_supports_streaming_enabled_by_default() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
|
||||
assert channel.supports_streaming is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_streams_by_editing_message(monkeypatch) -> None:
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(owner, intents=None)
|
||||
owner._client = client
|
||||
owner._running = True
|
||||
target = _FakeChannel(channel_id=123)
|
||||
client.channels[123] = target
|
||||
|
||||
times = iter([1.0, 3.0, 5.0])
|
||||
monkeypatch.setattr("nanobot.channels.discord.time.monotonic", lambda: next(times, 5.0))
|
||||
|
||||
await owner.send_delta("123", "hel", {"_stream_delta": True, "_stream_id": "s1"})
|
||||
await owner.send_delta("123", "lo", {"_stream_delta": True, "_stream_id": "s1"})
|
||||
await owner.send_delta("123", "", {"_stream_end": True, "_stream_id": "s1"})
|
||||
|
||||
assert target.sent_payloads[0] == {"content": "hel"}
|
||||
assert target.sent_messages[0].edits == [{"content": "hello"}, {"content": "hello"}]
|
||||
assert owner._stream_bufs == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user