fix(cron): stream cron reminders with stream_id and turn_end

This commit is contained in:
chengyongru 2026-05-09 18:30:40 +08:00
parent de13e72e15
commit ca17292768
2 changed files with 200 additions and 5 deletions

View File

@ -5,6 +5,7 @@ import os
import select
import signal
import sys
import time
from collections.abc import Callable
from contextlib import nullcontext, suppress
from pathlib import Path
@ -747,13 +748,57 @@ def _run_gateway(
if isinstance(message_tool, MessageTool):
message_record_token = message_tool.set_record_channel_delivery(True)
channel_name = job.payload.channel or "cli"
chat_id = job.payload.to or "direct"
try:
target_channel = channels.channels.get(channel_name)
except NameError:
target_channel = None
wants_stream = target_channel is not None and target_channel.supports_streaming
stream_base_id = None
stream_segment = 0
def _current_stream_id() -> str:
return f"{stream_base_id}:{stream_segment}"
async def _on_stream(delta: str) -> None:
meta = dict(job.payload.channel_meta)
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
await bus.publish_outbound(OutboundMessage(
channel=channel_name,
chat_id=chat_id,
content=delta,
metadata=meta,
))
async def _on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
meta = dict(job.payload.channel_meta)
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await bus.publish_outbound(OutboundMessage(
channel=channel_name,
chat_id=chat_id,
content="",
metadata=meta,
))
stream_segment += 1
if wants_stream:
stream_base_id = f"cron:{job.id}:{time.time_ns()}"
try:
resp = await agent.process_direct(
reminder_note,
session_key=f"cron:{job.id}",
channel=job.payload.channel or "cli",
chat_id=job.payload.to or "direct",
channel=channel_name,
chat_id=chat_id,
on_progress=_silent,
on_stream=_on_stream if wants_stream else None,
on_stream_end=_on_stream_end if wants_stream else None,
)
finally:
if isinstance(cron_tool, CronTool) and cron_token is not None:
@ -764,6 +809,13 @@ def _run_gateway(
response = resp.content if resp else ""
if job.payload.deliver and isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
if wants_stream:
await bus.publish_outbound(OutboundMessage(
channel=channel_name,
chat_id=chat_id,
content="",
metadata={**job.payload.channel_meta, "_turn_end": True},
))
return response
if job.payload.deliver and job.payload.to and response:
@ -771,16 +823,26 @@ def _run_gateway(
response, reminder_note, agent.provider, agent.model,
)
if should_notify:
meta = dict(job.payload.channel_meta)
if wants_stream:
meta["_streamed"] = True
await _deliver_to_channel(
OutboundMessage(
channel=job.payload.channel or "cli",
chat_id=job.payload.to,
channel=channel_name,
chat_id=chat_id,
content=response,
metadata=dict(job.payload.channel_meta),
metadata=meta,
),
record=True,
session_key=job.payload.session_key,
)
if wants_stream:
await bus.publish_outbound(OutboundMessage(
channel=channel_name,
chat_id=chat_id,
content="",
metadata={**job.payload.channel_meta, "_turn_end": True},
))
return response
cron.on_job = on_cron_job

View File

@ -1340,6 +1340,139 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
bus.publish_outbound.assert_not_awaited()
def test_gateway_cron_job_streams_when_channel_supports_it(
monkeypatch, tmp_path: Path
) -> None:
"""Cron jobs on streaming channels must emit deltas with stream_id and turn_end."""
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
bus = MagicMock()
bus.publish_outbound = AsyncMock()
seen: dict[str, object] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.providers.factory.make_provider", lambda _config: _fake_provider())
monkeypatch.setattr(
"nanobot.providers.factory.build_provider_snapshot",
lambda _config: _test_provider_snapshot(object(), _config),
)
monkeypatch.setattr(
"nanobot.providers.factory.load_provider_snapshot",
lambda _config_path=None: _test_provider_snapshot(object(), config),
)
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
class _FakeStreamingChannel:
supports_streaming = True
class _FakeChannelManager:
def __init__(self, *_args, **_kwargs) -> None:
self.channels = {"websocket": _FakeStreamingChannel()}
self.enabled_channels = ["websocket"]
async def start_all(self):
pass
async def stop_all(self):
pass
class _FakeCron:
def __init__(self, _store_path: Path) -> None:
self.on_job = None
seen["cron"] = self
def status(self):
return {"enabled": True, "jobs": 0, "next_wake_at_ms": None}
def register_system_job(self, job):
pass
def stop(self):
pass
class _FakeAgentLoop:
@classmethod
def from_config(cls, config, bus=None, **extra):
return cls(**extra)
def __init__(self, *args, **kwargs) -> None:
self.model = "test-model"
self.provider = object()
self.tools = {}
self.dream = MagicMock()
self.sessions = MagicMock()
async def process_direct(self, *_args, on_stream=None, on_stream_end=None, **_kwargs):
seen["on_stream"] = on_stream
seen["on_stream_end"] = on_stream_end
if on_stream:
await on_stream("Hello")
await on_stream(" world")
if on_stream_end:
await on_stream_end(resuming=False)
return OutboundMessage(
channel="websocket",
chat_id="user-1",
content="Hello world",
)
async def close_mcp(self) -> None:
return None
async def run(self) -> None:
return None
def stop(self) -> None:
return None
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert result.exit_code == 0
cron = seen["cron"]
job = CronJob(
id="cron-stream-test",
name="test-stream",
payload=CronPayload(
message="Say hello.",
deliver=True,
channel="websocket",
to="user-1",
),
)
response = asyncio.run(cron.on_job(job))
assert response == "Hello world"
assert seen["on_stream"] is not None
assert seen["on_stream_end"] is not None
calls = bus.publish_outbound.await_args_list
# First two calls are streaming deltas
assert calls[0].args[0].metadata.get("_stream_delta") is True
assert calls[0].args[0].metadata.get("_stream_id") is not None
assert calls[0].args[0].content == "Hello"
assert calls[1].args[0].metadata.get("_stream_delta") is True
assert calls[1].args[0].metadata.get("_stream_id") == calls[0].args[0].metadata["_stream_id"]
assert calls[1].args[0].content == " world"
# Third call is stream_end
assert calls[2].args[0].metadata.get("_stream_end") is True
assert calls[2].args[0].metadata.get("_stream_id") == calls[0].args[0].metadata["_stream_id"]
# Fourth call is the final message with _streamed marker
assert calls[3].args[0].metadata.get("_streamed") is True
assert calls[3].args[0].content == "Hello world"
# Fifth call is turn_end
assert calls[4].args[0].metadata.get("_turn_end") is True
def test_gateway_workspace_override_does_not_migrate_legacy_cron(
monkeypatch, tmp_path: Path
) -> None: