refactor: preserve origin session routing for cron

This commit is contained in:
chengyongru 2026-06-12 14:21:09 +08:00
parent bc18142650
commit c4b64a4caf
10 changed files with 179 additions and 42 deletions

View File

@ -15,6 +15,7 @@ from nanobot.agent.tools.schema import (
) )
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob, CronJobState, CronSchedule from nanobot.cron.types import CronJob, CronJobState, CronSchedule
from nanobot.session.keys import UNIFIED_SESSION_KEY
_CRON_PARAMETERS = tool_parameters_schema( _CRON_PARAMETERS = tool_parameters_schema(
action=StringSchema("Action to perform", enum=["add", "list", "remove"]), action=StringSchema("Action to perform", enum=["add", "list", "remove"]),
@ -56,9 +57,6 @@ class CronTool(Tool, ContextAware):
def __init__(self, cron_service: CronService, default_timezone: str = "UTC"): def __init__(self, cron_service: CronService, default_timezone: str = "UTC"):
self._cron = cron_service self._cron = cron_service
self._default_timezone = default_timezone self._default_timezone = default_timezone
self._channel: ContextVar[str] = ContextVar("cron_channel", default="")
self._chat_id: ContextVar[str] = ContextVar("cron_chat_id", default="")
self._metadata: ContextVar[dict] = ContextVar("cron_metadata", default={})
self._session_key: ContextVar[str] = ContextVar("cron_session_key", default="") self._session_key: ContextVar[str] = ContextVar("cron_session_key", default="")
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False) self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
@ -72,10 +70,10 @@ class CronTool(Tool, ContextAware):
def set_context(self, ctx: RequestContext) -> None: def set_context(self, ctx: RequestContext) -> None:
"""Set the current session context for scheduled cron job ownership.""" """Set the current session context for scheduled cron job ownership."""
self._channel.set(ctx.channel) raw_key = f"{ctx.channel}:{ctx.chat_id}" if ctx.channel and ctx.chat_id else ""
self._chat_id.set(ctx.chat_id) self._session_key.set(
self._metadata.set(ctx.metadata) raw_key if ctx.session_key in {None, "", UNIFIED_SESSION_KEY} else ctx.session_key
self._session_key.set(f"{ctx.channel}:{ctx.chat_id}" if ctx.channel and ctx.chat_id else "") )
def set_cron_context(self, active: bool): def set_cron_context(self, active: bool):
"""Mark whether the tool is executing inside a cron job callback.""" """Mark whether the tool is executing inside a cron job callback."""

View File

@ -125,7 +125,6 @@ class ChannelManager:
runtime_model_name=self._webui_runtime_model_name, runtime_model_name=self._webui_runtime_model_name,
runtime_surface=self._webui_runtime_surface, runtime_surface=self._webui_runtime_surface,
runtime_capabilities_overrides=self._webui_runtime_capabilities, runtime_capabilities_overrides=self._webui_runtime_capabilities,
unified_session=self.config.agents.defaults.unified_session,
cron_service=self._cron_service, cron_service=self._cron_service,
logger=logger, logger=logger,
) )

View File

@ -980,6 +980,7 @@ def _run_gateway(
from nanobot.bus.runtime_events import RuntimeEventBus from nanobot.bus.runtime_events import RuntimeEventBus
from nanobot.channels.manager import ChannelManager from nanobot.channels.manager import ChannelManager
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
from nanobot.cron.session_delivery import bound_session_inbound_context
from nanobot.cron.session_turns import ( from nanobot.cron.session_turns import (
CRON_DEFER_UNTIL_IDLE_META, CRON_DEFER_UNTIL_IDLE_META,
CRON_TRIGGER_META, CRON_TRIGGER_META,
@ -1050,13 +1051,7 @@ def _run_gateway(
turn_seed: str, turn_seed: str,
source_label: str | None, source_label: str | None,
) -> tuple[str, str, dict[str, Any]]: ) -> tuple[str, str, dict[str, Any]]:
if ":" not in session_key: channel, chat_id, metadata = bound_session_inbound_context(session_key)
raise ValueError(f"bound cron session_key is invalid: {session_key!r}")
channel, rest = session_key.split(":", 1)
if not channel or not rest:
raise ValueError(f"bound cron session_key is invalid: {session_key!r}")
metadata: dict[str, Any] = {}
if channel == "websocket": if channel == "websocket":
metadata["webui"] = True metadata["webui"] = True
@ -1068,15 +1063,8 @@ def _run_gateway(
source_label=source_label, source_label=source_label,
) )
) )
return channel, rest, metadata
if channel == "slack" and ":" in rest: return channel, chat_id, metadata
chat_id, thread_ts = rest.split(":", 1)
if thread_ts:
metadata["slack"] = {"thread_ts": thread_ts}
return channel, chat_id, metadata
return channel, rest, metadata
def _cron_prompt_ref(prompt: str) -> dict[str, Any]: def _cron_prompt_ref(prompt: str) -> dict[str, Any]:
return { return {
@ -1141,6 +1129,7 @@ def _run_gateway(
chat_id=chat_id, chat_id=chat_id,
content=prompt, content=prompt,
metadata=metadata, metadata=metadata,
session_key_override=session_key,
) )
) )
except (Exception, asyncio.CancelledError) as exc: except (Exception, asyncio.CancelledError) as exc:

View File

@ -0,0 +1,57 @@
"""Helpers for routing bound cron turns back through their origin session."""
from __future__ import annotations
from typing import Any
def bound_session_inbound_context(session_key: str) -> tuple[str, str, dict[str, Any]]:
"""Return ``(channel, chat_id, metadata)`` for a bound cron session key."""
if ":" not in session_key:
raise ValueError(f"bound cron session_key is invalid: {session_key!r}")
channel, rest = session_key.split(":", 1)
if not channel or not rest:
raise ValueError(f"bound cron session_key is invalid: {session_key!r}")
metadata: dict[str, Any] = {}
if channel == "discord" and ":thread:" in rest:
parent_channel_id, thread_id = rest.split(":thread:", 1)
if parent_channel_id and thread_id:
metadata.update({
"context_chat_id": parent_channel_id,
"parent_channel_id": parent_channel_id,
"thread_id": thread_id,
})
return channel, thread_id, metadata
if channel == "feishu" and ":" in rest:
chat_id, thread_id = rest.split(":", 1)
if chat_id and thread_id:
metadata.update({
"chat_type": "group",
"message_id": thread_id,
"thread_id": thread_id,
})
return channel, chat_id, metadata
if channel == "slack" and ":" in rest:
chat_id, thread_ts = rest.split(":", 1)
if thread_ts:
metadata["slack"] = {"thread_ts": thread_ts}
return channel, chat_id, metadata
if channel == "telegram" and ":topic:" in rest:
chat_id, thread_id = rest.split(":topic:", 1)
if thread_id:
metadata["message_thread_id"] = (
int(thread_id) if thread_id.isdigit() else thread_id
)
return channel, chat_id, metadata
if channel == "dingtalk" and rest.startswith("group:"):
parts = rest.split(":", 2)
if len(parts) >= 2 and parts[1]:
return channel, f"group:{parts[1]}", metadata
return channel, rest, metadata

View File

@ -39,7 +39,6 @@ def build_gateway_services(
runtime_model_name: Any | None, runtime_model_name: Any | None,
runtime_surface: str, runtime_surface: str,
runtime_capabilities_overrides: dict[str, Any] | None, runtime_capabilities_overrides: dict[str, Any] | None,
unified_session: bool = False,
disabled_skills: set[str] | None = None, disabled_skills: set[str] | None = None,
cron_service: Any | None = None, cron_service: Any | None = None,
logger: Any = default_logger, logger: Any = default_logger,
@ -62,7 +61,6 @@ def build_gateway_services(
runtime_model_name=runtime_model_name, runtime_model_name=runtime_model_name,
runtime_surface=runtime_surface, runtime_surface=runtime_surface,
runtime_capabilities_overrides=runtime_capabilities_overrides, runtime_capabilities_overrides=runtime_capabilities_overrides,
unified_session=unified_session,
bus=bus, bus=bus,
tokens=tokens, tokens=tokens,
media=media, media=media,

View File

@ -139,7 +139,6 @@ class GatewayHTTPHandler:
runtime_model_name: Callable[[], str | None] | None, runtime_model_name: Callable[[], str | None] | None,
runtime_surface: str, runtime_surface: str,
runtime_capabilities_overrides: dict[str, Any] | None, runtime_capabilities_overrides: dict[str, Any] | None,
unified_session: bool = False,
bus: MessageBus, bus: MessageBus,
tokens: GatewayTokenStore, tokens: GatewayTokenStore,
media: WebUIMediaGateway, media: WebUIMediaGateway,
@ -162,7 +161,6 @@ class GatewayHTTPHandler:
self.cron_service = cron_service self.cron_service = cron_service
self._log = log self._log = log
self._runtime_surface = runtime_surface self._runtime_surface = runtime_surface
self._unified_session = unified_session
from nanobot.webui.settings_api import runtime_capabilities as _rc from nanobot.webui.settings_api import runtime_capabilities as _rc
from nanobot.webui.settings_routes import WebUISettingsRouter from nanobot.webui.settings_routes import WebUISettingsRouter
@ -439,7 +437,7 @@ class GatewayHTTPHandler:
if not _is_websocket_channel_session_key(decoded_key): if not _is_websocket_channel_session_key(decoded_key):
return _http_error(404, "session not found") return _http_error(404, "session not found")
return _http_json_response( return _http_json_response(
session_automations_payload(self.cron_service, self._automation_display_key(decoded_key)) session_automations_payload(self.cron_service, decoded_key)
) )
def _handle_session_delete(self, request: WsRequest, key: str) -> Response: def _handle_session_delete(self, request: WsRequest, key: str) -> Response:
@ -470,10 +468,6 @@ class GatewayHTTPHandler:
delete_webui_thread(decoded_key) delete_webui_thread(decoded_key)
return _http_json_response({"deleted": bool(deleted)}) return _http_json_response({"deleted": bool(deleted)})
def _automation_display_key(self, session_key: str) -> str:
"""Return the cron ownership key shown for this WebUI thread."""
return session_key
# -- Media routes ------------------------------------------------------- # -- Media routes -------------------------------------------------------
def _dispatch_media_routes(self, request: WsRequest, got: str) -> Response | None: def _dispatch_media_routes(self, request: WsRequest, got: str) -> Response | None:

View File

@ -30,7 +30,6 @@ def _make_handler(
workspace_path: Path | None = None, workspace_path: Path | None = None,
runtime_model_name: Any | None = None, runtime_model_name: Any | None = None,
cron_service: CronService | None = None, cron_service: CronService | None = None,
unified_session: bool = False,
) -> GatewayServices: ) -> GatewayServices:
config = WebSocketConfig.model_validate(cfg) if isinstance(cfg, dict) else cfg config = WebSocketConfig.model_validate(cfg) if isinstance(cfg, dict) else cfg
workspace = workspace_path or Path.cwd() workspace = workspace_path or Path.cwd()
@ -44,7 +43,6 @@ def _make_handler(
runtime_model_name=runtime_model_name, runtime_model_name=runtime_model_name,
runtime_surface="browser", runtime_surface="browser",
runtime_capabilities_overrides=None, runtime_capabilities_overrides=None,
unified_session=unified_session,
cron_service=cron_service, cron_service=cron_service,
) )
@ -58,7 +56,6 @@ def _ch(
port: int = _PORT, port: int = _PORT,
runtime_model_name: Any | None = None, runtime_model_name: Any | None = None,
cron_service: CronService | None = None, cron_service: CronService | None = None,
unified_session: bool = False,
**extra: Any, **extra: Any,
) -> WebSocketChannel: ) -> WebSocketChannel:
cfg: dict[str, Any] = { cfg: dict[str, Any] = {
@ -77,7 +74,6 @@ def _ch(
workspace_path=workspace_path, workspace_path=workspace_path,
runtime_model_name=runtime_model_name, runtime_model_name=runtime_model_name,
cron_service=cron_service, cron_service=cron_service,
unified_session=unified_session,
) )
return WebSocketChannel(cfg, bus, gateway=gateway) return WebSocketChannel(cfg, bus, gateway=gateway)
@ -243,7 +239,7 @@ async def test_session_automations_route_filters_by_webui_session(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_session_automations_route_uses_origin_owner_when_unified_enabled( async def test_session_automations_route_ignores_unified_owner(
bus: MagicMock, tmp_path: Path bus: MagicMock, tmp_path: Path
) -> None: ) -> None:
cron = CronService(tmp_path / "cron" / "jobs.json") cron = CronService(tmp_path / "cron" / "jobs.json")
@ -264,7 +260,6 @@ async def test_session_automations_route_uses_origin_owner_when_unified_enabled(
bus, bus,
session_manager=_seed_session(tmp_path, key="websocket:abc"), session_manager=_seed_session(tmp_path, key="websocket:abc"),
cron_service=cron, cron_service=cron,
unified_session=True,
port=29917, port=29917,
) )
server_task = asyncio.create_task(channel.start()) server_task = asyncio.create_task(channel.start())
@ -823,7 +818,6 @@ async def test_session_delete_blocks_origin_automation_when_unified_enabled(
bus, bus,
session_manager=sm, session_manager=sm,
cron_service=cron, cron_service=cron,
unified_session=True,
port=29918, port=29918,
) )
server_task = asyncio.create_task(channel.start()) server_task = asyncio.create_task(channel.start())

View File

@ -1622,7 +1622,7 @@ def test_gateway_bound_cron_runs_as_session_turn(
assert msg.channel == "websocket" assert msg.channel == "websocket"
assert msg.chat_id == "chat-1" assert msg.chat_id == "chat-1"
assert msg.sender_id == "cron" assert msg.sender_id == "cron"
assert msg.session_key_override is None assert msg.session_key_override == "websocket:chat-1"
assert "Cron job: Check repository health." in msg.content assert "Cron job: Check repository health." in msg.content
assert msg.metadata["webui"] is True assert msg.metadata["webui"] is True
assert msg.metadata[WEBUI_MESSAGE_SOURCE_METADATA_KEY] == { assert msg.metadata[WEBUI_MESSAGE_SOURCE_METADATA_KEY] == {
@ -1645,7 +1645,7 @@ def test_gateway_bound_cron_runs_as_session_turn(
name="Thread check", name="Thread check",
payload=CronPayload( payload=CronPayload(
message="Check the Discord thread.", message="Check the Discord thread.",
session_key="discord:777", session_key="discord:456:thread:777",
), ),
) )
@ -1656,7 +1656,49 @@ def test_gateway_bound_cron_runs_as_session_turn(
assert isinstance(msg, InboundMessage) assert isinstance(msg, InboundMessage)
assert msg.channel == "discord" assert msg.channel == "discord"
assert msg.chat_id == "777" assert msg.chat_id == "777"
assert msg.session_key_override is None assert msg.session_key_override == "discord:456:thread:777"
assert msg.metadata["context_chat_id"] == "456"
assert msg.metadata["parent_channel_id"] == "456"
assert msg.metadata["thread_id"] == "777"
telegram_job = CronJob(
id="telegram-topic",
name="Telegram topic",
payload=CronPayload(
message="Check the Telegram topic.",
session_key="telegram:-100123:topic:42",
),
)
response = asyncio.run(cron.on_job(telegram_job))
assert response == "Checked the repo."
msg = seen["cron_msg"]
assert isinstance(msg, InboundMessage)
assert msg.channel == "telegram"
assert msg.chat_id == "-100123"
assert msg.session_key_override == "telegram:-100123:topic:42"
assert msg.metadata["message_thread_id"] == 42
feishu_job = CronJob(
id="feishu-topic",
name="Feishu topic",
payload=CronPayload(
message="Check the Feishu topic.",
session_key="feishu:oc_abc:om_root123",
),
)
response = asyncio.run(cron.on_job(feishu_job))
assert response == "Checked the repo."
msg = seen["cron_msg"]
assert isinstance(msg, InboundMessage)
assert msg.channel == "feishu"
assert msg.chat_id == "oc_abc"
assert msg.session_key_override == "feishu:oc_abc:om_root123"
assert msg.metadata["message_id"] == "om_root123"
assert msg.metadata["thread_id"] == "om_root123"
def test_gateway_cron_job_suppresses_intermediate_progress( def test_gateway_cron_job_suppresses_intermediate_progress(

View File

@ -0,0 +1,45 @@
import pytest
from nanobot.cron.session_delivery import bound_session_inbound_context
@pytest.mark.parametrize(
("session_key", "expected"),
[
("websocket:chat-1", ("websocket", "chat-1", {})),
(
"discord:456:thread:777",
(
"discord",
"777",
{
"context_chat_id": "456",
"parent_channel_id": "456",
"thread_id": "777",
},
),
),
(
"feishu:oc_abc:om_root123",
(
"feishu",
"oc_abc",
{
"chat_type": "group",
"message_id": "om_root123",
"thread_id": "om_root123",
},
),
),
("slack:C123:1700.42", ("slack", "C123", {"slack": {"thread_ts": "1700.42"}})),
("telegram:-100123:topic:42", ("telegram", "-100123", {"message_thread_id": 42})),
("dingtalk:group:conv-1:user-1", ("dingtalk", "group:conv-1", {})),
],
)
def test_bound_session_inbound_context(session_key, expected) -> None:
assert bound_session_inbound_context(session_key) == expected
def test_bound_session_inbound_context_rejects_invalid_key() -> None:
with pytest.raises(ValueError):
bound_session_inbound_context("unified")

View File

@ -274,6 +274,27 @@ async def test_webui_cron_tool_uses_origin_session_when_unified_enabled(tmp_path
assert jobs[0].payload.session_key == "websocket:chat-123" assert jobs[0].payload.session_key == "websocket:chat-123"
@pytest.mark.asyncio
async def test_cron_tool_preserves_thread_scoped_session_key(tmp_path) -> None:
"""Channel-provided thread session keys should remain the cron owner."""
tool = CronTool(CronService(tmp_path / "jobs.json"))
tool.set_context(
RequestContext(
channel="slack",
chat_id="C123",
metadata={"slack": {"thread_ts": "1700.42"}},
session_key="slack:C123:1700.42",
)
)
result = await tool.execute(action="add", message="check thread", every_seconds=300)
assert result.startswith("Created job")
jobs = tool._cron.list_jobs()
assert len(jobs) == 1
assert jobs[0].payload.session_key == "slack:C123:1700.42"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cron_tool_no_context_returns_error(tmp_path) -> None: async def test_cron_tool_no_context_returns_error(tmp_path) -> None:
"""Without set_context, add should fail with a clear error.""" """Without set_context, add should fail with a clear error."""