mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 08:02:30 +00:00
fix(cron): stream cron reminders with stream_id and turn_end
This commit is contained in:
parent
de13e72e15
commit
ca17292768
@ -5,6 +5,7 @@ import os
|
||||
import select
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext, suppress
|
||||
from pathlib import Path
|
||||
@ -747,13 +748,57 @@ def _run_gateway(
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_record_token = message_tool.set_record_channel_delivery(True)
|
||||
|
||||
channel_name = job.payload.channel or "cli"
|
||||
chat_id = job.payload.to or "direct"
|
||||
try:
|
||||
target_channel = channels.channels.get(channel_name)
|
||||
except NameError:
|
||||
target_channel = None
|
||||
wants_stream = target_channel is not None and target_channel.supports_streaming
|
||||
|
||||
stream_base_id = None
|
||||
stream_segment = 0
|
||||
|
||||
def _current_stream_id() -> str:
|
||||
return f"{stream_base_id}:{stream_segment}"
|
||||
|
||||
async def _on_stream(delta: str) -> None:
|
||||
meta = dict(job.payload.channel_meta)
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel=channel_name,
|
||||
chat_id=chat_id,
|
||||
content=delta,
|
||||
metadata=meta,
|
||||
))
|
||||
|
||||
async def _on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
meta = dict(job.payload.channel_meta)
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel=channel_name,
|
||||
chat_id=chat_id,
|
||||
content="",
|
||||
metadata=meta,
|
||||
))
|
||||
stream_segment += 1
|
||||
|
||||
if wants_stream:
|
||||
stream_base_id = f"cron:{job.id}:{time.time_ns()}"
|
||||
|
||||
try:
|
||||
resp = await agent.process_direct(
|
||||
reminder_note,
|
||||
session_key=f"cron:{job.id}",
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to or "direct",
|
||||
channel=channel_name,
|
||||
chat_id=chat_id,
|
||||
on_progress=_silent,
|
||||
on_stream=_on_stream if wants_stream else None,
|
||||
on_stream_end=_on_stream_end if wants_stream else None,
|
||||
)
|
||||
finally:
|
||||
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
||||
@ -764,6 +809,13 @@ def _run_gateway(
|
||||
response = resp.content if resp else ""
|
||||
|
||||
if job.payload.deliver and isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
||||
if wants_stream:
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel=channel_name,
|
||||
chat_id=chat_id,
|
||||
content="",
|
||||
metadata={**job.payload.channel_meta, "_turn_end": True},
|
||||
))
|
||||
return response
|
||||
|
||||
if job.payload.deliver and job.payload.to and response:
|
||||
@ -771,16 +823,26 @@ def _run_gateway(
|
||||
response, reminder_note, agent.provider, agent.model,
|
||||
)
|
||||
if should_notify:
|
||||
meta = dict(job.payload.channel_meta)
|
||||
if wants_stream:
|
||||
meta["_streamed"] = True
|
||||
await _deliver_to_channel(
|
||||
OutboundMessage(
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to,
|
||||
channel=channel_name,
|
||||
chat_id=chat_id,
|
||||
content=response,
|
||||
metadata=dict(job.payload.channel_meta),
|
||||
metadata=meta,
|
||||
),
|
||||
record=True,
|
||||
session_key=job.payload.session_key,
|
||||
)
|
||||
if wants_stream:
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel=channel_name,
|
||||
chat_id=chat_id,
|
||||
content="",
|
||||
metadata={**job.payload.channel_meta, "_turn_end": True},
|
||||
))
|
||||
return response
|
||||
|
||||
cron.on_job = on_cron_job
|
||||
|
||||
@ -1340,6 +1340,139 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
|
||||
bus.publish_outbound.assert_not_awaited()
|
||||
|
||||
|
||||
def test_gateway_cron_job_streams_when_channel_supports_it(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
"""Cron jobs on streaming channels must emit deltas with stream_id and turn_end."""
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
bus = MagicMock()
|
||||
bus.publish_outbound = AsyncMock()
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.providers.factory.make_provider", lambda _config: _fake_provider())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.providers.factory.build_provider_snapshot",
|
||||
lambda _config: _test_provider_snapshot(object(), _config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.providers.factory.load_provider_snapshot",
|
||||
lambda _config_path=None: _test_provider_snapshot(object(), config),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
|
||||
class _FakeStreamingChannel:
|
||||
supports_streaming = True
|
||||
|
||||
class _FakeChannelManager:
|
||||
def __init__(self, *_args, **_kwargs) -> None:
|
||||
self.channels = {"websocket": _FakeStreamingChannel()}
|
||||
self.enabled_channels = ["websocket"]
|
||||
|
||||
async def start_all(self):
|
||||
pass
|
||||
|
||||
async def stop_all(self):
|
||||
pass
|
||||
|
||||
class _FakeCron:
|
||||
def __init__(self, _store_path: Path) -> None:
|
||||
self.on_job = None
|
||||
seen["cron"] = self
|
||||
|
||||
def status(self):
|
||||
return {"enabled": True, "jobs": 0, "next_wake_at_ms": None}
|
||||
|
||||
def register_system_job(self, job):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
class _FakeAgentLoop:
|
||||
@classmethod
|
||||
def from_config(cls, config, bus=None, **extra):
|
||||
return cls(**extra)
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self.model = "test-model"
|
||||
self.provider = object()
|
||||
self.tools = {}
|
||||
self.dream = MagicMock()
|
||||
self.sessions = MagicMock()
|
||||
|
||||
async def process_direct(self, *_args, on_stream=None, on_stream_end=None, **_kwargs):
|
||||
seen["on_stream"] = on_stream
|
||||
seen["on_stream_end"] = on_stream_end
|
||||
if on_stream:
|
||||
await on_stream("Hello")
|
||||
await on_stream(" world")
|
||||
if on_stream_end:
|
||||
await on_stream_end(resuming=False)
|
||||
return OutboundMessage(
|
||||
channel="websocket",
|
||||
chat_id="user-1",
|
||||
content="Hello world",
|
||||
)
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
async def run(self) -> None:
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
assert result.exit_code == 0
|
||||
|
||||
cron = seen["cron"]
|
||||
job = CronJob(
|
||||
id="cron-stream-test",
|
||||
name="test-stream",
|
||||
payload=CronPayload(
|
||||
message="Say hello.",
|
||||
deliver=True,
|
||||
channel="websocket",
|
||||
to="user-1",
|
||||
),
|
||||
)
|
||||
response = asyncio.run(cron.on_job(job))
|
||||
|
||||
assert response == "Hello world"
|
||||
assert seen["on_stream"] is not None
|
||||
assert seen["on_stream_end"] is not None
|
||||
|
||||
calls = bus.publish_outbound.await_args_list
|
||||
# First two calls are streaming deltas
|
||||
assert calls[0].args[0].metadata.get("_stream_delta") is True
|
||||
assert calls[0].args[0].metadata.get("_stream_id") is not None
|
||||
assert calls[0].args[0].content == "Hello"
|
||||
assert calls[1].args[0].metadata.get("_stream_delta") is True
|
||||
assert calls[1].args[0].metadata.get("_stream_id") == calls[0].args[0].metadata["_stream_id"]
|
||||
assert calls[1].args[0].content == " world"
|
||||
# Third call is stream_end
|
||||
assert calls[2].args[0].metadata.get("_stream_end") is True
|
||||
assert calls[2].args[0].metadata.get("_stream_id") == calls[0].args[0].metadata["_stream_id"]
|
||||
# Fourth call is the final message with _streamed marker
|
||||
assert calls[3].args[0].metadata.get("_streamed") is True
|
||||
assert calls[3].args[0].content == "Hello world"
|
||||
# Fifth call is turn_end
|
||||
assert calls[4].args[0].metadata.get("_turn_end") is True
|
||||
|
||||
|
||||
def test_gateway_workspace_override_does_not_migrate_legacy_cron(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user