refactor: store cron origin delivery context

This commit is contained in:
chengyongru 2026-06-12 15:07:25 +08:00
parent b232a52794
commit 5ae907bc2f
10 changed files with 171 additions and 96 deletions

View File

@ -58,6 +58,12 @@ class CronTool(Tool, ContextAware):
self._cron = cron_service self._cron = cron_service
self._default_timezone = default_timezone self._default_timezone = default_timezone
self._session_key: ContextVar[str] = ContextVar("cron_session_key", default="") self._session_key: ContextVar[str] = ContextVar("cron_session_key", default="")
self._origin_channel: ContextVar[str] = ContextVar("cron_origin_channel", default="")
self._origin_chat_id: ContextVar[str] = ContextVar("cron_origin_chat_id", default="")
self._origin_metadata: ContextVar[dict[str, Any] | None] = ContextVar(
"cron_origin_metadata",
default=None,
)
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False) self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
@classmethod @classmethod
@ -74,6 +80,9 @@ class CronTool(Tool, ContextAware):
self._session_key.set( self._session_key.set(
raw_key if ctx.session_key == UNIFIED_SESSION_KEY else (ctx.session_key or "") raw_key if ctx.session_key == UNIFIED_SESSION_KEY else (ctx.session_key or "")
) )
self._origin_channel.set(ctx.channel or "")
self._origin_chat_id.set(ctx.chat_id or "")
self._origin_metadata.set(dict(ctx.metadata or {}))
def set_cron_context(self, active: bool): def set_cron_context(self, active: bool):
"""Mark whether the tool is executing inside a cron job callback.""" """Mark whether the tool is executing inside a cron job callback."""
@ -165,6 +174,10 @@ class CronTool(Tool, ContextAware):
session_key = self._session_key.get() session_key = self._session_key.get()
if not session_key: if not session_key:
return "Error: scheduled cron jobs must be created from a chat session" return "Error: scheduled cron jobs must be created from a chat session"
origin_channel = self._origin_channel.get()
origin_chat_id = self._origin_chat_id.get()
if not origin_channel or not origin_chat_id:
return "Error: scheduled cron jobs must be created from a chat session"
if tz and not cron_expr: if tz and not cron_expr:
return "Error: tz can only be used with cron_expr" return "Error: tz can only be used with cron_expr"
if tz: if tz:
@ -203,6 +216,9 @@ class CronTool(Tool, ContextAware):
message=message, message=message,
delete_after_run=delete_after, delete_after_run=delete_after,
session_key=session_key, session_key=session_key,
origin_channel=origin_channel,
origin_chat_id=origin_chat_id,
origin_metadata=dict(self._origin_metadata.get() or {}),
) )
return f"Created job '{job.name}' (id: {job.id})" return f"Created job '{job.name}' (id: {job.id})"

View File

@ -980,7 +980,7 @@ def _run_gateway(
from nanobot.bus.runtime_events import RuntimeEventBus from nanobot.bus.runtime_events import RuntimeEventBus
from nanobot.channels.manager import ChannelManager from nanobot.channels.manager import ChannelManager
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
from nanobot.cron.session_delivery import bound_session_inbound_context from nanobot.cron.session_delivery import origin_delivery_context
from nanobot.cron.session_turns import ( from nanobot.cron.session_turns import (
CRON_DEFER_UNTIL_IDLE_META, CRON_DEFER_UNTIL_IDLE_META,
CRON_TRIGGER_META, CRON_TRIGGER_META,
@ -1046,12 +1046,12 @@ def _run_gateway(
) )
def _bound_session_delivery_context( def _bound_session_delivery_context(
session_key: str, job: CronJob,
*, *,
turn_seed: str, turn_seed: str,
source_label: str | None, source_label: str | None,
) -> tuple[str, str, dict[str, Any]]: ) -> tuple[str, str, dict[str, Any]]:
channel, chat_id, metadata = bound_session_inbound_context(session_key) channel, chat_id, metadata = origin_delivery_context(job)
if channel == "websocket": if channel == "websocket":
metadata["webui"] = True metadata["webui"] = True
@ -1086,7 +1086,7 @@ def _run_gateway(
prompt_ref = _cron_prompt_ref(prompt) prompt_ref = _cron_prompt_ref(prompt)
run_id = f"{job.id}:{int(time.time() * 1000)}:{uuid.uuid4().hex[:8]}" run_id = f"{job.id}:{int(time.time() * 1000)}:{uuid.uuid4().hex[:8]}"
channel, chat_id, metadata = _bound_session_delivery_context( channel, chat_id, metadata = _bound_session_delivery_context(
session_key, job,
turn_seed=f"cron:{job.id}", turn_seed=f"cron:{job.id}",
source_label=job.name, source_label=job.name,
) )

View File

@ -138,6 +138,19 @@ class CronService:
or {} or {}
), ),
session_key=j["payload"].get("sessionKey") or j["payload"].get("session_key"), session_key=j["payload"].get("sessionKey") or j["payload"].get("session_key"),
origin_channel=(
j["payload"].get("originChannel")
or j["payload"].get("origin_channel")
),
origin_chat_id=(
j["payload"].get("originChatId")
or j["payload"].get("origin_chat_id")
),
origin_metadata=(
j["payload"].get("originMetadata")
or j["payload"].get("origin_metadata")
or {}
),
), ),
state=CronJobState( state=CronJobState(
next_run_at_ms=j.get("state", {}).get("nextRunAtMs"), next_run_at_ms=j.get("state", {}).get("nextRunAtMs"),
@ -268,6 +281,9 @@ class CronService:
"to": j.payload.to, "to": j.payload.to,
"channelMeta": j.payload.channel_meta, "channelMeta": j.payload.channel_meta,
"sessionKey": j.payload.session_key, "sessionKey": j.payload.session_key,
"originChannel": j.payload.origin_channel,
"originChatId": j.payload.origin_chat_id,
"originMetadata": j.payload.origin_metadata,
}, },
"state": { "state": {
"nextRunAtMs": j.state.next_run_at_ms, "nextRunAtMs": j.state.next_run_at_ms,
@ -524,6 +540,9 @@ class CronService:
delete_after_run: bool = False, delete_after_run: bool = False,
channel_meta: dict | None = None, channel_meta: dict | None = None,
session_key: str | None = None, session_key: str | None = None,
origin_channel: str | None = None,
origin_chat_id: str | None = None,
origin_metadata: dict | None = None,
) -> CronJob: ) -> CronJob:
"""Add a new job.""" """Add a new job."""
_validate_schedule_for_add(schedule) _validate_schedule_for_add(schedule)
@ -542,6 +561,9 @@ class CronService:
to=to, to=to,
channel_meta=channel_meta or {}, channel_meta=channel_meta or {},
session_key=session_key, session_key=session_key,
origin_channel=origin_channel,
origin_chat_id=origin_chat_id,
origin_metadata=origin_metadata or {},
), ),
state=CronJobState(next_run_at_ms=_compute_next_run(schedule, now)), state=CronJobState(next_run_at_ms=_compute_next_run(schedule, now)),
created_at_ms=now, created_at_ms=now,

View File

@ -4,54 +4,12 @@ from __future__ import annotations
from typing import Any from typing import Any
from nanobot.cron.types import CronJob
def bound_session_inbound_context(session_key: str) -> tuple[str, str, dict[str, Any]]:
"""Return ``(channel, chat_id, metadata)`` for a bound cron session key."""
if ":" not in session_key:
raise ValueError(f"bound cron session_key is invalid: {session_key!r}")
channel, rest = session_key.split(":", 1)
if not channel or not rest:
raise ValueError(f"bound cron session_key is invalid: {session_key!r}")
metadata: dict[str, Any] = {} def origin_delivery_context(job: CronJob) -> tuple[str, str, dict[str, Any]]:
"""Return ``(channel, chat_id, metadata)`` for a session-bound cron job."""
if channel == "discord" and ":thread:" in rest: payload = job.payload
parent_channel_id, thread_id = rest.split(":thread:", 1) if not payload.origin_channel or not payload.origin_chat_id:
if parent_channel_id and thread_id: raise ValueError(f"cron job {job.id} is missing origin delivery context")
metadata.update({ return payload.origin_channel, payload.origin_chat_id, dict(payload.origin_metadata or {})
"context_chat_id": parent_channel_id,
"parent_channel_id": parent_channel_id,
"thread_id": thread_id,
})
return channel, thread_id, metadata
if channel == "feishu" and ":" in rest:
chat_id, thread_id = rest.split(":", 1)
if chat_id and thread_id:
metadata.update({
"chat_type": "group",
"message_id": thread_id,
"thread_id": thread_id,
})
return channel, chat_id, metadata
if channel == "slack" and ":" in rest:
chat_id, thread_ts = rest.split(":", 1)
if thread_ts:
metadata["slack"] = {"thread_ts": thread_ts}
return channel, chat_id, metadata
if channel == "telegram" and ":topic:" in rest:
chat_id, thread_id = rest.split(":topic:", 1)
if thread_id:
metadata["message_thread_id"] = (
int(thread_id) if thread_id.isdigit() else thread_id
)
return channel, chat_id, metadata
if channel == "dingtalk" and rest.startswith("group:"):
parts = rest.split(":", 2)
if len(parts) >= 2 and parts[1]:
return channel, f"group:{parts[1]}", metadata
return channel, rest, metadata

View File

@ -1,7 +1,7 @@
"""Cron types.""" """Cron types."""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal from typing import Any, Literal
@dataclass @dataclass
@ -23,12 +23,15 @@ class CronPayload:
"""What to do when the job runs.""" """What to do when the job runs."""
kind: Literal["system_event", "agent_turn"] = "agent_turn" kind: Literal["system_event", "agent_turn"] = "agent_turn"
message: str = "" message: str = ""
# Deliver response to channel # Legacy delivery fields used by pre-session-bound cron jobs.
deliver: bool = False deliver: bool = False
channel: str | None = None # e.g. "whatsapp" channel: str | None = None # e.g. "whatsapp"
to: str | None = None # e.g. phone number to: str | None = None # e.g. phone number
channel_meta: dict = field(default_factory=dict) # channel-specific routing (e.g. Slack thread_ts) channel_meta: dict[str, Any] = field(default_factory=dict)
session_key: str | None = None # original session key for correct session recording session_key: str | None = None # original session key for correct session recording
origin_channel: str | None = None
origin_chat_id: str | None = None
origin_metadata: dict[str, Any] = field(default_factory=dict)
@dataclass @dataclass

View File

@ -1611,6 +1611,8 @@ def test_gateway_bound_cron_runs_as_session_turn(
payload=CronPayload( payload=CronPayload(
message="Check repository health.", message="Check repository health.",
session_key="websocket:chat-1", session_key="websocket:chat-1",
origin_channel="websocket",
origin_chat_id="chat-1",
), ),
) )
@ -1646,6 +1648,13 @@ def test_gateway_bound_cron_runs_as_session_turn(
payload=CronPayload( payload=CronPayload(
message="Check the Discord thread.", message="Check the Discord thread.",
session_key="discord:456:thread:777", session_key="discord:456:thread:777",
origin_channel="discord",
origin_chat_id="777",
origin_metadata={
"context_chat_id": "456",
"parent_channel_id": "456",
"thread_id": "777",
},
), ),
) )
@ -1667,6 +1676,9 @@ def test_gateway_bound_cron_runs_as_session_turn(
payload=CronPayload( payload=CronPayload(
message="Check the Telegram topic.", message="Check the Telegram topic.",
session_key="telegram:-100123:topic:42", session_key="telegram:-100123:topic:42",
origin_channel="telegram",
origin_chat_id="-100123",
origin_metadata={"message_thread_id": 42},
), ),
) )
@ -1686,6 +1698,13 @@ def test_gateway_bound_cron_runs_as_session_turn(
payload=CronPayload( payload=CronPayload(
message="Check the Feishu topic.", message="Check the Feishu topic.",
session_key="feishu:oc_abc:om_root123", session_key="feishu:oc_abc:om_root123",
origin_channel="feishu",
origin_chat_id="oc_abc",
origin_metadata={
"chat_type": "group",
"message_id": "om_root123",
"thread_id": "om_root123",
},
), ),
) )

View File

@ -87,6 +87,37 @@ def test_list_bound_agent_jobs_excludes_legacy_delivery_payloads(tmp_path) -> No
assert service.list_bound_cron_jobs_for_session("websocket:chat-1") == [bound] assert service.list_bound_cron_jobs_for_session("websocket:chat-1") == [bound]
def test_add_job_preserves_origin_delivery_context(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
metadata = {"slack": {"thread_ts": "1234567890.123456", "channel_type": "channel"}}
job = service.add_job(
name="bound thread",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
session_key="slack:C123:1234567890.123456",
origin_channel="slack",
origin_chat_id="C123",
origin_metadata=metadata,
)
assert job.payload.origin_channel == "slack"
assert job.payload.origin_chat_id == "C123"
assert job.payload.origin_metadata == metadata
raw = json.loads((tmp_path / "cron" / "action.jsonl").read_text(encoding="utf-8"))
payload = raw["params"]["payload"]
assert payload["origin_channel"] == "slack"
assert payload["origin_chat_id"] == "C123"
assert payload["origin_metadata"] == metadata
reloaded = service.get_job(job.id)
assert reloaded is not None
assert reloaded.payload.origin_channel == "slack"
assert reloaded.payload.origin_chat_id == "C123"
assert reloaded.payload.origin_metadata == metadata
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> None: async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json" store_path = tmp_path / "cron" / "jobs.json"
@ -103,6 +134,9 @@ async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> No
to="C123", to="C123",
channel_meta=meta, channel_meta=meta,
session_key="slack:C123:1234567890.123456", session_key="slack:C123:1234567890.123456",
origin_channel="slack",
origin_chat_id="C123",
origin_metadata=meta,
) )
finally: finally:
service.stop() service.stop()
@ -111,11 +145,17 @@ async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> No
payload = raw["jobs"][0]["payload"] payload = raw["jobs"][0]["payload"]
assert payload["channelMeta"] == meta assert payload["channelMeta"] == meta
assert payload["sessionKey"] == "slack:C123:1234567890.123456" assert payload["sessionKey"] == "slack:C123:1234567890.123456"
assert payload["originChannel"] == "slack"
assert payload["originChatId"] == "C123"
assert payload["originMetadata"] == meta
reloaded = CronService(store_path).get_job(job.id) reloaded = CronService(store_path).get_job(job.id)
assert reloaded is not None assert reloaded is not None
assert reloaded.payload.channel_meta == meta assert reloaded.payload.channel_meta == meta
assert reloaded.payload.session_key == "slack:C123:1234567890.123456" assert reloaded.payload.session_key == "slack:C123:1234567890.123456"
assert reloaded.payload.origin_channel == "slack"
assert reloaded.payload.origin_chat_id == "C123"
assert reloaded.payload.origin_metadata == meta
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -339,6 +339,9 @@ def test_add_job_binds_current_session_key(tmp_path) -> None:
assert result.startswith("Created job") assert result.startswith("Created job")
job = tool._cron.list_jobs()[0] job = tool._cron.list_jobs()[0]
assert job.payload.session_key == "telegram:chat-1" assert job.payload.session_key == "telegram:chat-1"
assert job.payload.origin_channel == "telegram"
assert job.payload.origin_chat_id == "chat-1"
assert job.payload.origin_metadata == {}
assert job.payload.channel is None assert job.payload.channel is None
assert job.payload.to is None assert job.payload.to is None
@ -392,8 +395,8 @@ def test_add_job_empty_message_returns_actionable_error(tmp_path) -> None:
assert "Retry including message=" in result assert "Retry including message=" in result
def test_add_job_captures_only_session_key(tmp_path) -> None: def test_add_job_captures_owner_and_origin_without_legacy_delivery_fields(tmp_path) -> None:
"""CronTool stores the canonical session key without legacy delivery fields.""" """CronTool stores owner/session identity separately from origin delivery context."""
tool = _make_tool(tmp_path) tool = _make_tool(tmp_path)
meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}} meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
tool.set_context(RequestContext( tool.set_context(RequestContext(
@ -406,6 +409,9 @@ def test_add_job_captures_only_session_key(tmp_path) -> None:
jobs = tool._cron.list_jobs() jobs = tool._cron.list_jobs()
assert len(jobs) == 1 assert len(jobs) == 1
assert jobs[0].payload.session_key == "slack:C99:111.222" assert jobs[0].payload.session_key == "slack:C99:111.222"
assert jobs[0].payload.origin_channel == "slack"
assert jobs[0].payload.origin_chat_id == "C99"
assert jobs[0].payload.origin_metadata == meta
assert jobs[0].payload.channel is None assert jobs[0].payload.channel is None
assert jobs[0].payload.to is None assert jobs[0].payload.to is None
assert jobs[0].payload.channel_meta == {} assert jobs[0].payload.channel_meta == {}

View File

@ -1,45 +1,44 @@
import pytest import pytest
from nanobot.cron.session_delivery import bound_session_inbound_context from nanobot.cron.session_delivery import origin_delivery_context
from nanobot.cron.types import CronJob, CronPayload
@pytest.mark.parametrize( def test_origin_delivery_context_uses_explicit_origin_fields() -> None:
("session_key", "expected"), metadata = {
[ "context_chat_id": "456",
("websocket:chat-1", ("websocket", "chat-1", {})), "parent_channel_id": "456",
( "thread_id": "777",
"discord:456:thread:777", }
( job = CronJob(
"discord", id="thread-check",
"777", name="Thread check",
{ payload=CronPayload(
"context_chat_id": "456", message="check",
"parent_channel_id": "456", session_key="discord:456:thread:777",
"thread_id": "777", origin_channel="discord",
}, origin_chat_id="777",
), origin_metadata=metadata,
), ),
( )
"feishu:oc_abc:om_root123",
( channel, chat_id, returned_metadata = origin_delivery_context(job)
"feishu",
"oc_abc", assert channel == "discord"
{ assert chat_id == "777"
"chat_type": "group", assert returned_metadata == metadata
"message_id": "om_root123", assert returned_metadata is not metadata
"thread_id": "om_root123",
},
), def test_origin_delivery_context_rejects_missing_origin_fields() -> None:
job = CronJob(
id="old-bound",
name="Old bound job",
payload=CronPayload(
message="check",
session_key="websocket:chat-1",
), ),
("slack:C123:1700.42", ("slack", "C123", {"slack": {"thread_ts": "1700.42"}})), )
("telegram:-100123:topic:42", ("telegram", "-100123", {"message_thread_id": 42})),
("dingtalk:group:conv-1:user-1", ("dingtalk", "group:conv-1", {})),
],
)
def test_bound_session_inbound_context(session_key, expected) -> None:
assert bound_session_inbound_context(session_key) == expected
with pytest.raises(ValueError, match="missing origin delivery context"):
def test_bound_session_inbound_context_rejects_invalid_key() -> None: origin_delivery_context(job)
with pytest.raises(ValueError):
bound_session_inbound_context("unified")

View File

@ -123,6 +123,10 @@ async def test_cron_tool_keeps_task_local_context(tmp_path) -> None:
jobs = tool._cron.list_jobs() jobs = tool._cron.list_jobs()
assert {job.payload.session_key for job in jobs} == {"feishu:chat-a", "email:chat-b"} assert {job.payload.session_key for job in jobs} == {"feishu:chat-a", "email:chat-b"}
assert {(job.payload.origin_channel, job.payload.origin_chat_id) for job in jobs} == {
("feishu", "chat-a"),
("email", "chat-b"),
}
# --- Basic single-task regression tests --- # --- Basic single-task regression tests ---
@ -243,6 +247,8 @@ async def test_cron_tool_basic_set_context_and_execute(tmp_path) -> None:
jobs = tool._cron.list_jobs() jobs = tool._cron.list_jobs()
assert len(jobs) == 1 assert len(jobs) == 1
assert jobs[0].payload.session_key == "wechat:user-789" assert jobs[0].payload.session_key == "wechat:user-789"
assert jobs[0].payload.origin_channel == "wechat"
assert jobs[0].payload.origin_chat_id == "user-789"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -272,6 +278,9 @@ async def test_webui_cron_tool_uses_origin_session_when_unified_enabled(tmp_path
jobs = tool._cron.list_jobs() jobs = tool._cron.list_jobs()
assert len(jobs) == 1 assert len(jobs) == 1
assert jobs[0].payload.session_key == "websocket:chat-123" assert jobs[0].payload.session_key == "websocket:chat-123"
assert jobs[0].payload.origin_channel == "websocket"
assert jobs[0].payload.origin_chat_id == "chat-123"
assert jobs[0].payload.origin_metadata == {"webui": True}
@pytest.mark.asyncio @pytest.mark.asyncio
@ -293,6 +302,9 @@ async def test_cron_tool_preserves_thread_scoped_session_key(tmp_path) -> None:
jobs = tool._cron.list_jobs() jobs = tool._cron.list_jobs()
assert len(jobs) == 1 assert len(jobs) == 1
assert jobs[0].payload.session_key == "slack:C123:1700.42" assert jobs[0].payload.session_key == "slack:C123:1700.42"
assert jobs[0].payload.origin_channel == "slack"
assert jobs[0].payload.origin_chat_id == "C123"
assert jobs[0].payload.origin_metadata == {"slack": {"thread_ts": "1700.42"}}
@pytest.mark.asyncio @pytest.mark.asyncio