diff --git a/README.md b/README.md index a2ea20f8c..0747b25ed 100644 --- a/README.md +++ b/README.md @@ -394,7 +394,8 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso "enabled": true, "token": "YOUR_BOT_TOKEN", "allowFrom": ["YOUR_USER_ID"], - "groupPolicy": "mention" + "groupPolicy": "mention", + "streaming": true } } } @@ -405,6 +406,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso > - `"open"` — Respond to all messages > DMs always respond when the sender is in `allowFrom`. > - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session. +> `streaming` defaults to `true`. Disable it only if you explicitly want non-streaming replies. **5. Invite the bot** - OAuth2 → URL Generator diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index 9bf4d919c..9e68bb46b 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -4,6 +4,8 @@ from __future__ import annotations import asyncio import importlib.util +import time +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -34,6 +36,16 @@ MAX_MESSAGE_LEN = 2000 # Discord message character limit TYPING_INTERVAL_S = 8 +@dataclass +class _StreamBuf: + """Per-chat streaming accumulator for progressive Discord message edits.""" + + text: str = "" + message: Any | None = None + last_edit: float = 0.0 + stream_id: str | None = None + + class DiscordConfig(Base): """Discord channel configuration.""" @@ -45,6 +57,7 @@ class DiscordConfig(Base): read_receipt_emoji: str = "👀" working_emoji: str = "🔧" working_emoji_delay: float = 2.0 + streaming: bool = True if DISCORD_AVAILABLE: @@ -242,6 +255,7 @@ class DiscordChannel(BaseChannel): name = "discord" display_name = "Discord" + _STREAM_EDIT_INTERVAL = 0.8 @classmethod def default_config(cls) -> dict[str, Any]: @@ -263,6 +277,7 @@ class DiscordChannel(BaseChannel): self._bot_user_id: str | None = None self._pending_reactions: dict[str, Any] = {} # chat_id -> message object self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {} + self._stream_bufs: dict[str, _StreamBuf] = {} async def start(self) -> None: """Start the Discord client.""" @@ -320,6 +335,61 @@ class DiscordChannel(BaseChannel): await self._stop_typing(msg.chat_id) await self._clear_reactions(msg.chat_id) + async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: + """Progressive Discord delivery: send once, then edit until the stream ends.""" + client = self._client + if client is None or not client.is_ready(): + logger.warning("Discord client not ready; dropping stream delta") + return + + meta = metadata or {} + stream_id = meta.get("_stream_id") + + if meta.get("_stream_end"): + buf = self._stream_bufs.get(chat_id) + if not buf or buf.message is None or not buf.text: + return + if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id: + return + await self._finalize_stream(chat_id, buf) + return + + buf = self._stream_bufs.get(chat_id) + if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id): + buf = _StreamBuf(stream_id=stream_id) + self._stream_bufs[chat_id] = buf + elif buf.stream_id is None: + buf.stream_id = stream_id + + buf.text += delta + if not buf.text.strip(): + return + + target = await self._resolve_channel(chat_id) + if target is None: + logger.warning("Discord stream target {} unavailable", chat_id) + return + + now = time.monotonic() + if buf.message is None: + try: + buf.message = await target.send(content=buf.text) + buf.last_edit = now + except Exception as e: + logger.warning("Discord stream initial send failed: {}", e) + raise + return + + if (now - buf.last_edit) < self._STREAM_EDIT_INTERVAL: + return + + try: + await buf.message.edit(content=DiscordBotClient._build_chunks(buf.text, [], False)[0]) + buf.last_edit = now + except Exception as e: + logger.warning("Discord stream edit failed: {}", e) + raise + async def _handle_discord_message(self, message: discord.Message) -> None: """Handle incoming Discord messages from discord.py.""" if message.author.bot: @@ -373,6 +443,47 @@ class DiscordChannel(BaseChannel): """Backward-compatible alias for legacy tests/callers.""" await self._handle_discord_message(message) + async def _resolve_channel(self, chat_id: str) -> Any | None: + """Resolve a Discord channel from cache first, then network fetch.""" + client = self._client + if client is None or not client.is_ready(): + return None + channel_id = int(chat_id) + channel = client.get_channel(channel_id) + if channel is not None: + return channel + try: + return await client.fetch_channel(channel_id) + except Exception as e: + logger.warning("Discord channel {} unavailable: {}", chat_id, e) + return None + + async def _finalize_stream(self, chat_id: str, buf: _StreamBuf) -> None: + """Commit the final streamed content and flush overflow chunks.""" + chunks = DiscordBotClient._build_chunks(buf.text, [], False) + if not chunks: + self._stream_bufs.pop(chat_id, None) + return + + try: + await buf.message.edit(content=chunks[0]) + except Exception as e: + logger.warning("Discord final stream edit failed: {}", e) + raise + + target = getattr(buf.message, "channel", None) or await self._resolve_channel(chat_id) + if target is None: + logger.warning("Discord stream follow-up target {} unavailable", chat_id) + self._stream_bufs.pop(chat_id, None) + return + + for extra_chunk in chunks[1:]: + await target.send(content=extra_chunk) + + self._stream_bufs.pop(chat_id, None) + await self._stop_typing(chat_id) + await self._clear_reactions(chat_id) + def _should_accept_inbound( self, message: discord.Message, @@ -507,6 +618,7 @@ class DiscordChannel(BaseChannel): async def _reset_runtime_state(self, close_client: bool) -> None: """Reset client and typing state.""" await self._cancel_all_typing() + self._stream_bufs.clear() if close_client and self._client is not None and not self._client.is_closed(): try: await self._client.close() diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py index 845c03c57..f588334ba 100644 --- a/tests/channels/test_discord_channel.py +++ b/tests/channels/test_discord_channel.py @@ -71,11 +71,25 @@ class _FakePartialMessage: self.id = message_id +class _FakeSentMessage: + # Sent-message double supporting edit() for streaming tests. + def __init__(self, channel, content: str) -> None: + self.channel = channel + self.content = content + self.edits: list[dict] = [] + + async def edit(self, **kwargs) -> None: + self.edits.append(dict(kwargs)) + if "content" in kwargs: + self.content = kwargs["content"] + + class _FakeChannel: # Channel double that records outbound payloads and typing activity. def __init__(self, channel_id: int = 123) -> None: self.id = channel_id self.sent_payloads: list[dict] = [] + self.sent_messages: list[_FakeSentMessage] = [] self.trigger_typing_calls = 0 self.typing_enter_hook = None @@ -85,6 +99,9 @@ class _FakeChannel: payload["file_name"] = payload["file"].filename del payload["file"] self.sent_payloads.append(payload) + message = _FakeSentMessage(self, payload.get("content", "")) + self.sent_messages.append(message) + return message def get_partial_message(self, message_id: int) -> _FakePartialMessage: return _FakePartialMessage(message_id) @@ -427,6 +444,33 @@ async def test_send_fetches_channel_when_not_cached() -> None: assert target.sent_payloads == [{"content": "hello"}] +def test_supports_streaming_enabled_by_default() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + + assert channel.supports_streaming is True + + +@pytest.mark.asyncio +async def test_send_delta_streams_by_editing_message(monkeypatch) -> None: + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(owner, intents=None) + owner._client = client + owner._running = True + target = _FakeChannel(channel_id=123) + client.channels[123] = target + + times = iter([1.0, 3.0, 5.0]) + monkeypatch.setattr("nanobot.channels.discord.time.monotonic", lambda: next(times, 5.0)) + + await owner.send_delta("123", "hel", {"_stream_delta": True, "_stream_id": "s1"}) + await owner.send_delta("123", "lo", {"_stream_delta": True, "_stream_id": "s1"}) + await owner.send_delta("123", "", {"_stream_end": True, "_stream_id": "s1"}) + + assert target.sent_payloads[0] == {"content": "hel"} + assert target.sent_messages[0].edits == [{"content": "hello"}, {"content": "hello"}] + assert owner._stream_bufs == {} + + @pytest.mark.asyncio async def test_slash_new_forwards_when_user_is_allowlisted() -> None: channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())