From ca172927684cdb7b6cacbf065db023f88e17252c Mon Sep 17 00:00:00 2001 From: chengyongru Date: Sat, 9 May 2026 18:30:40 +0800 Subject: [PATCH] fix(cron): stream cron reminders with stream_id and turn_end --- nanobot/cli/commands.py | 72 ++++++++++++++++++-- tests/cli/test_commands.py | 133 +++++++++++++++++++++++++++++++++++++ 2 files changed, 200 insertions(+), 5 deletions(-) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index a610f256f..eac065994 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -5,6 +5,7 @@ import os import select import signal import sys +import time from collections.abc import Callable from contextlib import nullcontext, suppress from pathlib import Path @@ -747,13 +748,57 @@ def _run_gateway( if isinstance(message_tool, MessageTool): message_record_token = message_tool.set_record_channel_delivery(True) + channel_name = job.payload.channel or "cli" + chat_id = job.payload.to or "direct" + try: + target_channel = channels.channels.get(channel_name) + except NameError: + target_channel = None + wants_stream = target_channel is not None and target_channel.supports_streaming + + stream_base_id = None + stream_segment = 0 + + def _current_stream_id() -> str: + return f"{stream_base_id}:{stream_segment}" + + async def _on_stream(delta: str) -> None: + meta = dict(job.payload.channel_meta) + meta["_stream_delta"] = True + meta["_stream_id"] = _current_stream_id() + await bus.publish_outbound(OutboundMessage( + channel=channel_name, + chat_id=chat_id, + content=delta, + metadata=meta, + )) + + async def _on_stream_end(*, resuming: bool = False) -> None: + nonlocal stream_segment + meta = dict(job.payload.channel_meta) + meta["_stream_end"] = True + meta["_resuming"] = resuming + meta["_stream_id"] = _current_stream_id() + await bus.publish_outbound(OutboundMessage( + channel=channel_name, + chat_id=chat_id, + content="", + metadata=meta, + )) + stream_segment += 1 + + if wants_stream: + stream_base_id = f"cron:{job.id}:{time.time_ns()}" + try: resp = await agent.process_direct( reminder_note, session_key=f"cron:{job.id}", - channel=job.payload.channel or "cli", - chat_id=job.payload.to or "direct", + channel=channel_name, + chat_id=chat_id, on_progress=_silent, + on_stream=_on_stream if wants_stream else None, + on_stream_end=_on_stream_end if wants_stream else None, ) finally: if isinstance(cron_tool, CronTool) and cron_token is not None: @@ -764,6 +809,13 @@ def _run_gateway( response = resp.content if resp else "" if job.payload.deliver and isinstance(message_tool, MessageTool) and message_tool._sent_in_turn: + if wants_stream: + await bus.publish_outbound(OutboundMessage( + channel=channel_name, + chat_id=chat_id, + content="", + metadata={**job.payload.channel_meta, "_turn_end": True}, + )) return response if job.payload.deliver and job.payload.to and response: @@ -771,16 +823,26 @@ def _run_gateway( response, reminder_note, agent.provider, agent.model, ) if should_notify: + meta = dict(job.payload.channel_meta) + if wants_stream: + meta["_streamed"] = True await _deliver_to_channel( OutboundMessage( - channel=job.payload.channel or "cli", - chat_id=job.payload.to, + channel=channel_name, + chat_id=chat_id, content=response, - metadata=dict(job.payload.channel_meta), + metadata=meta, ), record=True, session_key=job.payload.session_key, ) + if wants_stream: + await bus.publish_outbound(OutboundMessage( + channel=channel_name, + chat_id=chat_id, + content="", + metadata={**job.payload.channel_meta, "_turn_end": True}, + )) return response cron.on_job = on_cron_job diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index b0c3c43ee..b1f4ae81b 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -1340,6 +1340,139 @@ def test_gateway_cron_job_suppresses_intermediate_progress( bus.publish_outbound.assert_not_awaited() +def test_gateway_cron_job_streams_when_channel_supports_it( + monkeypatch, tmp_path: Path +) -> None: + """Cron jobs on streaming channels must emit deltas with stream_id and turn_end.""" + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + config = Config() + config.agents.defaults.workspace = str(tmp_path / "config-workspace") + bus = MagicMock() + bus.publish_outbound = AsyncMock() + seen: dict[str, object] = {} + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) + monkeypatch.setattr("nanobot.providers.factory.make_provider", lambda _config: _fake_provider()) + monkeypatch.setattr( + "nanobot.providers.factory.build_provider_snapshot", + lambda _config: _test_provider_snapshot(object(), _config), + ) + monkeypatch.setattr( + "nanobot.providers.factory.load_provider_snapshot", + lambda _config_path=None: _test_provider_snapshot(object(), config), + ) + monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus) + monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) + + class _FakeStreamingChannel: + supports_streaming = True + + class _FakeChannelManager: + def __init__(self, *_args, **_kwargs) -> None: + self.channels = {"websocket": _FakeStreamingChannel()} + self.enabled_channels = ["websocket"] + + async def start_all(self): + pass + + async def stop_all(self): + pass + + class _FakeCron: + def __init__(self, _store_path: Path) -> None: + self.on_job = None + seen["cron"] = self + + def status(self): + return {"enabled": True, "jobs": 0, "next_wake_at_ms": None} + + def register_system_job(self, job): + pass + + def stop(self): + pass + + class _FakeAgentLoop: + @classmethod + def from_config(cls, config, bus=None, **extra): + return cls(**extra) + def __init__(self, *args, **kwargs) -> None: + self.model = "test-model" + self.provider = object() + self.tools = {} + self.dream = MagicMock() + self.sessions = MagicMock() + + async def process_direct(self, *_args, on_stream=None, on_stream_end=None, **_kwargs): + seen["on_stream"] = on_stream + seen["on_stream_end"] = on_stream_end + if on_stream: + await on_stream("Hello") + await on_stream(" world") + if on_stream_end: + await on_stream_end(resuming=False) + return OutboundMessage( + channel="websocket", + chat_id="user-1", + content="Hello world", + ) + + async def close_mcp(self) -> None: + return None + + async def run(self) -> None: + return None + + def stop(self) -> None: + return None + + monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron) + monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager) + + result = runner.invoke(app, ["gateway", "--config", str(config_file)]) + assert result.exit_code == 0 + + cron = seen["cron"] + job = CronJob( + id="cron-stream-test", + name="test-stream", + payload=CronPayload( + message="Say hello.", + deliver=True, + channel="websocket", + to="user-1", + ), + ) + response = asyncio.run(cron.on_job(job)) + + assert response == "Hello world" + assert seen["on_stream"] is not None + assert seen["on_stream_end"] is not None + + calls = bus.publish_outbound.await_args_list + # First two calls are streaming deltas + assert calls[0].args[0].metadata.get("_stream_delta") is True + assert calls[0].args[0].metadata.get("_stream_id") is not None + assert calls[0].args[0].content == "Hello" + assert calls[1].args[0].metadata.get("_stream_delta") is True + assert calls[1].args[0].metadata.get("_stream_id") == calls[0].args[0].metadata["_stream_id"] + assert calls[1].args[0].content == " world" + # Third call is stream_end + assert calls[2].args[0].metadata.get("_stream_end") is True + assert calls[2].args[0].metadata.get("_stream_id") == calls[0].args[0].metadata["_stream_id"] + # Fourth call is the final message with _streamed marker + assert calls[3].args[0].metadata.get("_streamed") is True + assert calls[3].args[0].content == "Hello world" + # Fifth call is turn_end + assert calls[4].args[0].metadata.get("_turn_end") is True + + def test_gateway_workspace_override_does_not_migrate_legacy_cron( monkeypatch, tmp_path: Path ) -> None: