fix(discord): enable streaming replies

This commit is contained in:
SHLE1 2026-04-08 07:20:02 +00:00 committed by Xubin Ren
parent 715f2a79be
commit e49b6c0c96
3 changed files with 159 additions and 1 deletions

View File

@ -394,7 +394,8 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"],
"groupPolicy": "mention"
"groupPolicy": "mention",
"streaming": true
}
}
}
@ -405,6 +406,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
> - `"open"` — Respond to all messages
> DMs always respond when the sender is in `allowFrom`.
> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session.
> `streaming` defaults to `true`. Disable it only if you explicitly want non-streaming replies.
**5. Invite the bot**
- OAuth2 → URL Generator

View File

@ -4,6 +4,8 @@ from __future__ import annotations
import asyncio
import importlib.util
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
@ -34,6 +36,16 @@ MAX_MESSAGE_LEN = 2000 # Discord message character limit
TYPING_INTERVAL_S = 8
@dataclass
class _StreamBuf:
"""Per-chat streaming accumulator for progressive Discord message edits."""
text: str = ""
message: Any | None = None
last_edit: float = 0.0
stream_id: str | None = None
class DiscordConfig(Base):
"""Discord channel configuration."""
@ -45,6 +57,7 @@ class DiscordConfig(Base):
read_receipt_emoji: str = "👀"
working_emoji: str = "🔧"
working_emoji_delay: float = 2.0
streaming: bool = True
if DISCORD_AVAILABLE:
@ -242,6 +255,7 @@ class DiscordChannel(BaseChannel):
name = "discord"
display_name = "Discord"
_STREAM_EDIT_INTERVAL = 0.8
@classmethod
def default_config(cls) -> dict[str, Any]:
@ -263,6 +277,7 @@ class DiscordChannel(BaseChannel):
self._bot_user_id: str | None = None
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
self._stream_bufs: dict[str, _StreamBuf] = {}
async def start(self) -> None:
"""Start the Discord client."""
@ -320,6 +335,61 @@ class DiscordChannel(BaseChannel):
await self._stop_typing(msg.chat_id)
await self._clear_reactions(msg.chat_id)
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
"""Progressive Discord delivery: send once, then edit until the stream ends."""
client = self._client
if client is None or not client.is_ready():
logger.warning("Discord client not ready; dropping stream delta")
return
meta = metadata or {}
stream_id = meta.get("_stream_id")
if meta.get("_stream_end"):
buf = self._stream_bufs.get(chat_id)
if not buf or buf.message is None or not buf.text:
return
if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id:
return
await self._finalize_stream(chat_id, buf)
return
buf = self._stream_bufs.get(chat_id)
if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id):
buf = _StreamBuf(stream_id=stream_id)
self._stream_bufs[chat_id] = buf
elif buf.stream_id is None:
buf.stream_id = stream_id
buf.text += delta
if not buf.text.strip():
return
target = await self._resolve_channel(chat_id)
if target is None:
logger.warning("Discord stream target {} unavailable", chat_id)
return
now = time.monotonic()
if buf.message is None:
try:
buf.message = await target.send(content=buf.text)
buf.last_edit = now
except Exception as e:
logger.warning("Discord stream initial send failed: {}", e)
raise
return
if (now - buf.last_edit) < self._STREAM_EDIT_INTERVAL:
return
try:
await buf.message.edit(content=DiscordBotClient._build_chunks(buf.text, [], False)[0])
buf.last_edit = now
except Exception as e:
logger.warning("Discord stream edit failed: {}", e)
raise
async def _handle_discord_message(self, message: discord.Message) -> None:
"""Handle incoming Discord messages from discord.py."""
if message.author.bot:
@ -373,6 +443,47 @@ class DiscordChannel(BaseChannel):
"""Backward-compatible alias for legacy tests/callers."""
await self._handle_discord_message(message)
async def _resolve_channel(self, chat_id: str) -> Any | None:
"""Resolve a Discord channel from cache first, then network fetch."""
client = self._client
if client is None or not client.is_ready():
return None
channel_id = int(chat_id)
channel = client.get_channel(channel_id)
if channel is not None:
return channel
try:
return await client.fetch_channel(channel_id)
except Exception as e:
logger.warning("Discord channel {} unavailable: {}", chat_id, e)
return None
async def _finalize_stream(self, chat_id: str, buf: _StreamBuf) -> None:
"""Commit the final streamed content and flush overflow chunks."""
chunks = DiscordBotClient._build_chunks(buf.text, [], False)
if not chunks:
self._stream_bufs.pop(chat_id, None)
return
try:
await buf.message.edit(content=chunks[0])
except Exception as e:
logger.warning("Discord final stream edit failed: {}", e)
raise
target = getattr(buf.message, "channel", None) or await self._resolve_channel(chat_id)
if target is None:
logger.warning("Discord stream follow-up target {} unavailable", chat_id)
self._stream_bufs.pop(chat_id, None)
return
for extra_chunk in chunks[1:]:
await target.send(content=extra_chunk)
self._stream_bufs.pop(chat_id, None)
await self._stop_typing(chat_id)
await self._clear_reactions(chat_id)
def _should_accept_inbound(
self,
message: discord.Message,
@ -507,6 +618,7 @@ class DiscordChannel(BaseChannel):
async def _reset_runtime_state(self, close_client: bool) -> None:
"""Reset client and typing state."""
await self._cancel_all_typing()
self._stream_bufs.clear()
if close_client and self._client is not None and not self._client.is_closed():
try:
await self._client.close()

View File

@ -71,11 +71,25 @@ class _FakePartialMessage:
self.id = message_id
class _FakeSentMessage:
# Sent-message double supporting edit() for streaming tests.
def __init__(self, channel, content: str) -> None:
self.channel = channel
self.content = content
self.edits: list[dict] = []
async def edit(self, **kwargs) -> None:
self.edits.append(dict(kwargs))
if "content" in kwargs:
self.content = kwargs["content"]
class _FakeChannel:
# Channel double that records outbound payloads and typing activity.
def __init__(self, channel_id: int = 123) -> None:
self.id = channel_id
self.sent_payloads: list[dict] = []
self.sent_messages: list[_FakeSentMessage] = []
self.trigger_typing_calls = 0
self.typing_enter_hook = None
@ -85,6 +99,9 @@ class _FakeChannel:
payload["file_name"] = payload["file"].filename
del payload["file"]
self.sent_payloads.append(payload)
message = _FakeSentMessage(self, payload.get("content", ""))
self.sent_messages.append(message)
return message
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
return _FakePartialMessage(message_id)
@ -427,6 +444,33 @@ async def test_send_fetches_channel_when_not_cached() -> None:
assert target.sent_payloads == [{"content": "hello"}]
def test_supports_streaming_enabled_by_default() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
assert channel.supports_streaming is True
@pytest.mark.asyncio
async def test_send_delta_streams_by_editing_message(monkeypatch) -> None:
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = _FakeDiscordClient(owner, intents=None)
owner._client = client
owner._running = True
target = _FakeChannel(channel_id=123)
client.channels[123] = target
times = iter([1.0, 3.0, 5.0])
monkeypatch.setattr("nanobot.channels.discord.time.monotonic", lambda: next(times, 5.0))
await owner.send_delta("123", "hel", {"_stream_delta": True, "_stream_id": "s1"})
await owner.send_delta("123", "lo", {"_stream_delta": True, "_stream_id": "s1"})
await owner.send_delta("123", "", {"_stream_end": True, "_stream_id": "s1"})
assert target.sent_payloads[0] == {"content": "hel"}
assert target.sent_messages[0].edits == [{"content": "hello"}, {"content": "hello"}]
assert owner._stream_bufs == {}
@pytest.mark.asyncio
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())