fix(cron): stream cron reminders with stream_id and turn_end

This commit is contained in:
chengyongru 2026-05-09 18:30:40 +08:00
parent de13e72e15
commit ca17292768
2 changed files with 200 additions and 5 deletions

View File

@ -5,6 +5,7 @@ import os
import select import select
import signal import signal
import sys import sys
import time
from collections.abc import Callable from collections.abc import Callable
from contextlib import nullcontext, suppress from contextlib import nullcontext, suppress
from pathlib import Path from pathlib import Path
@ -747,13 +748,57 @@ def _run_gateway(
if isinstance(message_tool, MessageTool): if isinstance(message_tool, MessageTool):
message_record_token = message_tool.set_record_channel_delivery(True) message_record_token = message_tool.set_record_channel_delivery(True)
channel_name = job.payload.channel or "cli"
chat_id = job.payload.to or "direct"
try:
target_channel = channels.channels.get(channel_name)
except NameError:
target_channel = None
wants_stream = target_channel is not None and target_channel.supports_streaming
stream_base_id = None
stream_segment = 0
def _current_stream_id() -> str:
return f"{stream_base_id}:{stream_segment}"
async def _on_stream(delta: str) -> None:
meta = dict(job.payload.channel_meta)
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
await bus.publish_outbound(OutboundMessage(
channel=channel_name,
chat_id=chat_id,
content=delta,
metadata=meta,
))
async def _on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
meta = dict(job.payload.channel_meta)
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await bus.publish_outbound(OutboundMessage(
channel=channel_name,
chat_id=chat_id,
content="",
metadata=meta,
))
stream_segment += 1
if wants_stream:
stream_base_id = f"cron:{job.id}:{time.time_ns()}"
try: try:
resp = await agent.process_direct( resp = await agent.process_direct(
reminder_note, reminder_note,
session_key=f"cron:{job.id}", session_key=f"cron:{job.id}",
channel=job.payload.channel or "cli", channel=channel_name,
chat_id=job.payload.to or "direct", chat_id=chat_id,
on_progress=_silent, on_progress=_silent,
on_stream=_on_stream if wants_stream else None,
on_stream_end=_on_stream_end if wants_stream else None,
) )
finally: finally:
if isinstance(cron_tool, CronTool) and cron_token is not None: if isinstance(cron_tool, CronTool) and cron_token is not None:
@ -764,6 +809,13 @@ def _run_gateway(
response = resp.content if resp else "" response = resp.content if resp else ""
if job.payload.deliver and isinstance(message_tool, MessageTool) and message_tool._sent_in_turn: if job.payload.deliver and isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
if wants_stream:
await bus.publish_outbound(OutboundMessage(
channel=channel_name,
chat_id=chat_id,
content="",
metadata={**job.payload.channel_meta, "_turn_end": True},
))
return response return response
if job.payload.deliver and job.payload.to and response: if job.payload.deliver and job.payload.to and response:
@ -771,16 +823,26 @@ def _run_gateway(
response, reminder_note, agent.provider, agent.model, response, reminder_note, agent.provider, agent.model,
) )
if should_notify: if should_notify:
meta = dict(job.payload.channel_meta)
if wants_stream:
meta["_streamed"] = True
await _deliver_to_channel( await _deliver_to_channel(
OutboundMessage( OutboundMessage(
channel=job.payload.channel or "cli", channel=channel_name,
chat_id=job.payload.to, chat_id=chat_id,
content=response, content=response,
metadata=dict(job.payload.channel_meta), metadata=meta,
), ),
record=True, record=True,
session_key=job.payload.session_key, session_key=job.payload.session_key,
) )
if wants_stream:
await bus.publish_outbound(OutboundMessage(
channel=channel_name,
chat_id=chat_id,
content="",
metadata={**job.payload.channel_meta, "_turn_end": True},
))
return response return response
cron.on_job = on_cron_job cron.on_job = on_cron_job

View File

@ -1340,6 +1340,139 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
bus.publish_outbound.assert_not_awaited() bus.publish_outbound.assert_not_awaited()
def test_gateway_cron_job_streams_when_channel_supports_it(
monkeypatch, tmp_path: Path
) -> None:
"""Cron jobs on streaming channels must emit deltas with stream_id and turn_end."""
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
bus = MagicMock()
bus.publish_outbound = AsyncMock()
seen: dict[str, object] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.providers.factory.make_provider", lambda _config: _fake_provider())
monkeypatch.setattr(
"nanobot.providers.factory.build_provider_snapshot",
lambda _config: _test_provider_snapshot(object(), _config),
)
monkeypatch.setattr(
"nanobot.providers.factory.load_provider_snapshot",
lambda _config_path=None: _test_provider_snapshot(object(), config),
)
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
class _FakeStreamingChannel:
supports_streaming = True
class _FakeChannelManager:
def __init__(self, *_args, **_kwargs) -> None:
self.channels = {"websocket": _FakeStreamingChannel()}
self.enabled_channels = ["websocket"]
async def start_all(self):
pass
async def stop_all(self):
pass
class _FakeCron:
def __init__(self, _store_path: Path) -> None:
self.on_job = None
seen["cron"] = self
def status(self):
return {"enabled": True, "jobs": 0, "next_wake_at_ms": None}
def register_system_job(self, job):
pass
def stop(self):
pass
class _FakeAgentLoop:
@classmethod
def from_config(cls, config, bus=None, **extra):
return cls(**extra)
def __init__(self, *args, **kwargs) -> None:
self.model = "test-model"
self.provider = object()
self.tools = {}
self.dream = MagicMock()
self.sessions = MagicMock()
async def process_direct(self, *_args, on_stream=None, on_stream_end=None, **_kwargs):
seen["on_stream"] = on_stream
seen["on_stream_end"] = on_stream_end
if on_stream:
await on_stream("Hello")
await on_stream(" world")
if on_stream_end:
await on_stream_end(resuming=False)
return OutboundMessage(
channel="websocket",
chat_id="user-1",
content="Hello world",
)
async def close_mcp(self) -> None:
return None
async def run(self) -> None:
return None
def stop(self) -> None:
return None
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert result.exit_code == 0
cron = seen["cron"]
job = CronJob(
id="cron-stream-test",
name="test-stream",
payload=CronPayload(
message="Say hello.",
deliver=True,
channel="websocket",
to="user-1",
),
)
response = asyncio.run(cron.on_job(job))
assert response == "Hello world"
assert seen["on_stream"] is not None
assert seen["on_stream_end"] is not None
calls = bus.publish_outbound.await_args_list
# First two calls are streaming deltas
assert calls[0].args[0].metadata.get("_stream_delta") is True
assert calls[0].args[0].metadata.get("_stream_id") is not None
assert calls[0].args[0].content == "Hello"
assert calls[1].args[0].metadata.get("_stream_delta") is True
assert calls[1].args[0].metadata.get("_stream_id") == calls[0].args[0].metadata["_stream_id"]
assert calls[1].args[0].content == " world"
# Third call is stream_end
assert calls[2].args[0].metadata.get("_stream_end") is True
assert calls[2].args[0].metadata.get("_stream_id") == calls[0].args[0].metadata["_stream_id"]
# Fourth call is the final message with _streamed marker
assert calls[3].args[0].metadata.get("_streamed") is True
assert calls[3].args[0].content == "Hello world"
# Fifth call is turn_end
assert calls[4].args[0].metadata.get("_turn_end") is True
def test_gateway_workspace_override_does_not_migrate_legacy_cron( def test_gateway_workspace_override_does_not_migrate_legacy_cron(
monkeypatch, tmp_path: Path monkeypatch, tmp_path: Path
) -> None: ) -> None: