From c4b64a4caf2b3a13507477b987f72350eaf8f6cc Mon Sep 17 00:00:00 2001 From: chengyongru Date: Fri, 12 Jun 2026 14:21:09 +0800 Subject: [PATCH] refactor: preserve origin session routing for cron --- nanobot/agent/tools/cron.py | 12 ++--- nanobot/channels/manager.py | 1 - nanobot/cli/commands.py | 19 ++----- nanobot/cron/session_delivery.py | 57 ++++++++++++++++++++ nanobot/webui/gateway_services.py | 2 - nanobot/webui/ws_http.py | 8 +-- tests/channels/test_websocket_http_routes.py | 8 +-- tests/cli/test_commands.py | 48 +++++++++++++++-- tests/cron/test_session_delivery.py | 45 ++++++++++++++++ tests/test_tool_contextvars.py | 21 ++++++++ 10 files changed, 179 insertions(+), 42 deletions(-) create mode 100644 nanobot/cron/session_delivery.py create mode 100644 tests/cron/test_session_delivery.py diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index 100b64486..6f554d7bd 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -15,6 +15,7 @@ from nanobot.agent.tools.schema import ( ) from nanobot.cron.service import CronService from nanobot.cron.types import CronJob, CronJobState, CronSchedule +from nanobot.session.keys import UNIFIED_SESSION_KEY _CRON_PARAMETERS = tool_parameters_schema( action=StringSchema("Action to perform", enum=["add", "list", "remove"]), @@ -56,9 +57,6 @@ class CronTool(Tool, ContextAware): def __init__(self, cron_service: CronService, default_timezone: str = "UTC"): self._cron = cron_service self._default_timezone = default_timezone - self._channel: ContextVar[str] = ContextVar("cron_channel", default="") - self._chat_id: ContextVar[str] = ContextVar("cron_chat_id", default="") - self._metadata: ContextVar[dict] = ContextVar("cron_metadata", default={}) self._session_key: ContextVar[str] = ContextVar("cron_session_key", default="") self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False) @@ -72,10 +70,10 @@ class CronTool(Tool, ContextAware): def set_context(self, ctx: RequestContext) -> None: """Set the current session context for scheduled cron job ownership.""" - self._channel.set(ctx.channel) - self._chat_id.set(ctx.chat_id) - self._metadata.set(ctx.metadata) - self._session_key.set(f"{ctx.channel}:{ctx.chat_id}" if ctx.channel and ctx.chat_id else "") + raw_key = f"{ctx.channel}:{ctx.chat_id}" if ctx.channel and ctx.chat_id else "" + self._session_key.set( + raw_key if ctx.session_key in {None, "", UNIFIED_SESSION_KEY} else ctx.session_key + ) def set_cron_context(self, active: bool): """Mark whether the tool is executing inside a cron job callback.""" diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index cc5c62b1a..b59925232 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -125,7 +125,6 @@ class ChannelManager: runtime_model_name=self._webui_runtime_model_name, runtime_surface=self._webui_runtime_surface, runtime_capabilities_overrides=self._webui_runtime_capabilities, - unified_session=self.config.agents.defaults.unified_session, cron_service=self._cron_service, logger=logger, ) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 048925ae7..ea6d6cdf2 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -980,6 +980,7 @@ def _run_gateway( from nanobot.bus.runtime_events import RuntimeEventBus from nanobot.channels.manager import ChannelManager from nanobot.cron.service import CronService + from nanobot.cron.session_delivery import bound_session_inbound_context from nanobot.cron.session_turns import ( CRON_DEFER_UNTIL_IDLE_META, CRON_TRIGGER_META, @@ -1050,13 +1051,7 @@ def _run_gateway( turn_seed: str, source_label: str | None, ) -> tuple[str, str, dict[str, Any]]: - if ":" not in session_key: - raise ValueError(f"bound cron session_key is invalid: {session_key!r}") - channel, rest = session_key.split(":", 1) - if not channel or not rest: - raise ValueError(f"bound cron session_key is invalid: {session_key!r}") - - metadata: dict[str, Any] = {} + channel, chat_id, metadata = bound_session_inbound_context(session_key) if channel == "websocket": metadata["webui"] = True @@ -1068,15 +1063,8 @@ def _run_gateway( source_label=source_label, ) ) - return channel, rest, metadata - if channel == "slack" and ":" in rest: - chat_id, thread_ts = rest.split(":", 1) - if thread_ts: - metadata["slack"] = {"thread_ts": thread_ts} - return channel, chat_id, metadata - - return channel, rest, metadata + return channel, chat_id, metadata def _cron_prompt_ref(prompt: str) -> dict[str, Any]: return { @@ -1141,6 +1129,7 @@ def _run_gateway( chat_id=chat_id, content=prompt, metadata=metadata, + session_key_override=session_key, ) ) except (Exception, asyncio.CancelledError) as exc: diff --git a/nanobot/cron/session_delivery.py b/nanobot/cron/session_delivery.py new file mode 100644 index 000000000..1d10bb890 --- /dev/null +++ b/nanobot/cron/session_delivery.py @@ -0,0 +1,57 @@ +"""Helpers for routing bound cron turns back through their origin session.""" + +from __future__ import annotations + +from typing import Any + + +def bound_session_inbound_context(session_key: str) -> tuple[str, str, dict[str, Any]]: + """Return ``(channel, chat_id, metadata)`` for a bound cron session key.""" + if ":" not in session_key: + raise ValueError(f"bound cron session_key is invalid: {session_key!r}") + channel, rest = session_key.split(":", 1) + if not channel or not rest: + raise ValueError(f"bound cron session_key is invalid: {session_key!r}") + + metadata: dict[str, Any] = {} + + if channel == "discord" and ":thread:" in rest: + parent_channel_id, thread_id = rest.split(":thread:", 1) + if parent_channel_id and thread_id: + metadata.update({ + "context_chat_id": parent_channel_id, + "parent_channel_id": parent_channel_id, + "thread_id": thread_id, + }) + return channel, thread_id, metadata + + if channel == "feishu" and ":" in rest: + chat_id, thread_id = rest.split(":", 1) + if chat_id and thread_id: + metadata.update({ + "chat_type": "group", + "message_id": thread_id, + "thread_id": thread_id, + }) + return channel, chat_id, metadata + + if channel == "slack" and ":" in rest: + chat_id, thread_ts = rest.split(":", 1) + if thread_ts: + metadata["slack"] = {"thread_ts": thread_ts} + return channel, chat_id, metadata + + if channel == "telegram" and ":topic:" in rest: + chat_id, thread_id = rest.split(":topic:", 1) + if thread_id: + metadata["message_thread_id"] = ( + int(thread_id) if thread_id.isdigit() else thread_id + ) + return channel, chat_id, metadata + + if channel == "dingtalk" and rest.startswith("group:"): + parts = rest.split(":", 2) + if len(parts) >= 2 and parts[1]: + return channel, f"group:{parts[1]}", metadata + + return channel, rest, metadata diff --git a/nanobot/webui/gateway_services.py b/nanobot/webui/gateway_services.py index 53d3f0db1..15649d08d 100644 --- a/nanobot/webui/gateway_services.py +++ b/nanobot/webui/gateway_services.py @@ -39,7 +39,6 @@ def build_gateway_services( runtime_model_name: Any | None, runtime_surface: str, runtime_capabilities_overrides: dict[str, Any] | None, - unified_session: bool = False, disabled_skills: set[str] | None = None, cron_service: Any | None = None, logger: Any = default_logger, @@ -62,7 +61,6 @@ def build_gateway_services( runtime_model_name=runtime_model_name, runtime_surface=runtime_surface, runtime_capabilities_overrides=runtime_capabilities_overrides, - unified_session=unified_session, bus=bus, tokens=tokens, media=media, diff --git a/nanobot/webui/ws_http.py b/nanobot/webui/ws_http.py index f88ec4916..70e19e01b 100644 --- a/nanobot/webui/ws_http.py +++ b/nanobot/webui/ws_http.py @@ -139,7 +139,6 @@ class GatewayHTTPHandler: runtime_model_name: Callable[[], str | None] | None, runtime_surface: str, runtime_capabilities_overrides: dict[str, Any] | None, - unified_session: bool = False, bus: MessageBus, tokens: GatewayTokenStore, media: WebUIMediaGateway, @@ -162,7 +161,6 @@ class GatewayHTTPHandler: self.cron_service = cron_service self._log = log self._runtime_surface = runtime_surface - self._unified_session = unified_session from nanobot.webui.settings_api import runtime_capabilities as _rc from nanobot.webui.settings_routes import WebUISettingsRouter @@ -439,7 +437,7 @@ class GatewayHTTPHandler: if not _is_websocket_channel_session_key(decoded_key): return _http_error(404, "session not found") return _http_json_response( - session_automations_payload(self.cron_service, self._automation_display_key(decoded_key)) + session_automations_payload(self.cron_service, decoded_key) ) def _handle_session_delete(self, request: WsRequest, key: str) -> Response: @@ -470,10 +468,6 @@ class GatewayHTTPHandler: delete_webui_thread(decoded_key) return _http_json_response({"deleted": bool(deleted)}) - def _automation_display_key(self, session_key: str) -> str: - """Return the cron ownership key shown for this WebUI thread.""" - return session_key - # -- Media routes ------------------------------------------------------- def _dispatch_media_routes(self, request: WsRequest, got: str) -> Response | None: diff --git a/tests/channels/test_websocket_http_routes.py b/tests/channels/test_websocket_http_routes.py index 96d40f767..d8c137630 100644 --- a/tests/channels/test_websocket_http_routes.py +++ b/tests/channels/test_websocket_http_routes.py @@ -30,7 +30,6 @@ def _make_handler( workspace_path: Path | None = None, runtime_model_name: Any | None = None, cron_service: CronService | None = None, - unified_session: bool = False, ) -> GatewayServices: config = WebSocketConfig.model_validate(cfg) if isinstance(cfg, dict) else cfg workspace = workspace_path or Path.cwd() @@ -44,7 +43,6 @@ def _make_handler( runtime_model_name=runtime_model_name, runtime_surface="browser", runtime_capabilities_overrides=None, - unified_session=unified_session, cron_service=cron_service, ) @@ -58,7 +56,6 @@ def _ch( port: int = _PORT, runtime_model_name: Any | None = None, cron_service: CronService | None = None, - unified_session: bool = False, **extra: Any, ) -> WebSocketChannel: cfg: dict[str, Any] = { @@ -77,7 +74,6 @@ def _ch( workspace_path=workspace_path, runtime_model_name=runtime_model_name, cron_service=cron_service, - unified_session=unified_session, ) return WebSocketChannel(cfg, bus, gateway=gateway) @@ -243,7 +239,7 @@ async def test_session_automations_route_filters_by_webui_session( @pytest.mark.asyncio -async def test_session_automations_route_uses_origin_owner_when_unified_enabled( +async def test_session_automations_route_ignores_unified_owner( bus: MagicMock, tmp_path: Path ) -> None: cron = CronService(tmp_path / "cron" / "jobs.json") @@ -264,7 +260,6 @@ async def test_session_automations_route_uses_origin_owner_when_unified_enabled( bus, session_manager=_seed_session(tmp_path, key="websocket:abc"), cron_service=cron, - unified_session=True, port=29917, ) server_task = asyncio.create_task(channel.start()) @@ -823,7 +818,6 @@ async def test_session_delete_blocks_origin_automation_when_unified_enabled( bus, session_manager=sm, cron_service=cron, - unified_session=True, port=29918, ) server_task = asyncio.create_task(channel.start()) diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 99c1bb399..84dd0d170 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -1622,7 +1622,7 @@ def test_gateway_bound_cron_runs_as_session_turn( assert msg.channel == "websocket" assert msg.chat_id == "chat-1" assert msg.sender_id == "cron" - assert msg.session_key_override is None + assert msg.session_key_override == "websocket:chat-1" assert "Cron job: Check repository health." in msg.content assert msg.metadata["webui"] is True assert msg.metadata[WEBUI_MESSAGE_SOURCE_METADATA_KEY] == { @@ -1645,7 +1645,7 @@ def test_gateway_bound_cron_runs_as_session_turn( name="Thread check", payload=CronPayload( message="Check the Discord thread.", - session_key="discord:777", + session_key="discord:456:thread:777", ), ) @@ -1656,7 +1656,49 @@ def test_gateway_bound_cron_runs_as_session_turn( assert isinstance(msg, InboundMessage) assert msg.channel == "discord" assert msg.chat_id == "777" - assert msg.session_key_override is None + assert msg.session_key_override == "discord:456:thread:777" + assert msg.metadata["context_chat_id"] == "456" + assert msg.metadata["parent_channel_id"] == "456" + assert msg.metadata["thread_id"] == "777" + + telegram_job = CronJob( + id="telegram-topic", + name="Telegram topic", + payload=CronPayload( + message="Check the Telegram topic.", + session_key="telegram:-100123:topic:42", + ), + ) + + response = asyncio.run(cron.on_job(telegram_job)) + + assert response == "Checked the repo." + msg = seen["cron_msg"] + assert isinstance(msg, InboundMessage) + assert msg.channel == "telegram" + assert msg.chat_id == "-100123" + assert msg.session_key_override == "telegram:-100123:topic:42" + assert msg.metadata["message_thread_id"] == 42 + + feishu_job = CronJob( + id="feishu-topic", + name="Feishu topic", + payload=CronPayload( + message="Check the Feishu topic.", + session_key="feishu:oc_abc:om_root123", + ), + ) + + response = asyncio.run(cron.on_job(feishu_job)) + + assert response == "Checked the repo." + msg = seen["cron_msg"] + assert isinstance(msg, InboundMessage) + assert msg.channel == "feishu" + assert msg.chat_id == "oc_abc" + assert msg.session_key_override == "feishu:oc_abc:om_root123" + assert msg.metadata["message_id"] == "om_root123" + assert msg.metadata["thread_id"] == "om_root123" def test_gateway_cron_job_suppresses_intermediate_progress( diff --git a/tests/cron/test_session_delivery.py b/tests/cron/test_session_delivery.py new file mode 100644 index 000000000..02948a3fa --- /dev/null +++ b/tests/cron/test_session_delivery.py @@ -0,0 +1,45 @@ +import pytest + +from nanobot.cron.session_delivery import bound_session_inbound_context + + +@pytest.mark.parametrize( + ("session_key", "expected"), + [ + ("websocket:chat-1", ("websocket", "chat-1", {})), + ( + "discord:456:thread:777", + ( + "discord", + "777", + { + "context_chat_id": "456", + "parent_channel_id": "456", + "thread_id": "777", + }, + ), + ), + ( + "feishu:oc_abc:om_root123", + ( + "feishu", + "oc_abc", + { + "chat_type": "group", + "message_id": "om_root123", + "thread_id": "om_root123", + }, + ), + ), + ("slack:C123:1700.42", ("slack", "C123", {"slack": {"thread_ts": "1700.42"}})), + ("telegram:-100123:topic:42", ("telegram", "-100123", {"message_thread_id": 42})), + ("dingtalk:group:conv-1:user-1", ("dingtalk", "group:conv-1", {})), + ], +) +def test_bound_session_inbound_context(session_key, expected) -> None: + assert bound_session_inbound_context(session_key) == expected + + +def test_bound_session_inbound_context_rejects_invalid_key() -> None: + with pytest.raises(ValueError): + bound_session_inbound_context("unified") diff --git a/tests/test_tool_contextvars.py b/tests/test_tool_contextvars.py index 4dd70c527..ea9b4753e 100644 --- a/tests/test_tool_contextvars.py +++ b/tests/test_tool_contextvars.py @@ -274,6 +274,27 @@ async def test_webui_cron_tool_uses_origin_session_when_unified_enabled(tmp_path assert jobs[0].payload.session_key == "websocket:chat-123" +@pytest.mark.asyncio +async def test_cron_tool_preserves_thread_scoped_session_key(tmp_path) -> None: + """Channel-provided thread session keys should remain the cron owner.""" + tool = CronTool(CronService(tmp_path / "jobs.json")) + tool.set_context( + RequestContext( + channel="slack", + chat_id="C123", + metadata={"slack": {"thread_ts": "1700.42"}}, + session_key="slack:C123:1700.42", + ) + ) + + result = await tool.execute(action="add", message="check thread", every_seconds=300) + assert result.startswith("Created job") + + jobs = tool._cron.list_jobs() + assert len(jobs) == 1 + assert jobs[0].payload.session_key == "slack:C123:1700.42" + + @pytest.mark.asyncio async def test_cron_tool_no_context_returns_error(tmp_path) -> None: """Without set_context, add should fail with a clear error."""