mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
refactor: store cron origin delivery context
This commit is contained in:
parent
b232a52794
commit
5ae907bc2f
@ -58,6 +58,12 @@ class CronTool(Tool, ContextAware):
|
||||
self._cron = cron_service
|
||||
self._default_timezone = default_timezone
|
||||
self._session_key: ContextVar[str] = ContextVar("cron_session_key", default="")
|
||||
self._origin_channel: ContextVar[str] = ContextVar("cron_origin_channel", default="")
|
||||
self._origin_chat_id: ContextVar[str] = ContextVar("cron_origin_chat_id", default="")
|
||||
self._origin_metadata: ContextVar[dict[str, Any] | None] = ContextVar(
|
||||
"cron_origin_metadata",
|
||||
default=None,
|
||||
)
|
||||
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
|
||||
|
||||
@classmethod
|
||||
@ -74,6 +80,9 @@ class CronTool(Tool, ContextAware):
|
||||
self._session_key.set(
|
||||
raw_key if ctx.session_key == UNIFIED_SESSION_KEY else (ctx.session_key or "")
|
||||
)
|
||||
self._origin_channel.set(ctx.channel or "")
|
||||
self._origin_chat_id.set(ctx.chat_id or "")
|
||||
self._origin_metadata.set(dict(ctx.metadata or {}))
|
||||
|
||||
def set_cron_context(self, active: bool):
|
||||
"""Mark whether the tool is executing inside a cron job callback."""
|
||||
@ -165,6 +174,10 @@ class CronTool(Tool, ContextAware):
|
||||
session_key = self._session_key.get()
|
||||
if not session_key:
|
||||
return "Error: scheduled cron jobs must be created from a chat session"
|
||||
origin_channel = self._origin_channel.get()
|
||||
origin_chat_id = self._origin_chat_id.get()
|
||||
if not origin_channel or not origin_chat_id:
|
||||
return "Error: scheduled cron jobs must be created from a chat session"
|
||||
if tz and not cron_expr:
|
||||
return "Error: tz can only be used with cron_expr"
|
||||
if tz:
|
||||
@ -203,6 +216,9 @@ class CronTool(Tool, ContextAware):
|
||||
message=message,
|
||||
delete_after_run=delete_after,
|
||||
session_key=session_key,
|
||||
origin_channel=origin_channel,
|
||||
origin_chat_id=origin_chat_id,
|
||||
origin_metadata=dict(self._origin_metadata.get() or {}),
|
||||
)
|
||||
return f"Created job '{job.name}' (id: {job.id})"
|
||||
|
||||
|
||||
@ -980,7 +980,7 @@ def _run_gateway(
|
||||
from nanobot.bus.runtime_events import RuntimeEventBus
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.session_delivery import bound_session_inbound_context
|
||||
from nanobot.cron.session_delivery import origin_delivery_context
|
||||
from nanobot.cron.session_turns import (
|
||||
CRON_DEFER_UNTIL_IDLE_META,
|
||||
CRON_TRIGGER_META,
|
||||
@ -1046,12 +1046,12 @@ def _run_gateway(
|
||||
)
|
||||
|
||||
def _bound_session_delivery_context(
|
||||
session_key: str,
|
||||
job: CronJob,
|
||||
*,
|
||||
turn_seed: str,
|
||||
source_label: str | None,
|
||||
) -> tuple[str, str, dict[str, Any]]:
|
||||
channel, chat_id, metadata = bound_session_inbound_context(session_key)
|
||||
channel, chat_id, metadata = origin_delivery_context(job)
|
||||
|
||||
if channel == "websocket":
|
||||
metadata["webui"] = True
|
||||
@ -1086,7 +1086,7 @@ def _run_gateway(
|
||||
prompt_ref = _cron_prompt_ref(prompt)
|
||||
run_id = f"{job.id}:{int(time.time() * 1000)}:{uuid.uuid4().hex[:8]}"
|
||||
channel, chat_id, metadata = _bound_session_delivery_context(
|
||||
session_key,
|
||||
job,
|
||||
turn_seed=f"cron:{job.id}",
|
||||
source_label=job.name,
|
||||
)
|
||||
|
||||
@ -138,6 +138,19 @@ class CronService:
|
||||
or {}
|
||||
),
|
||||
session_key=j["payload"].get("sessionKey") or j["payload"].get("session_key"),
|
||||
origin_channel=(
|
||||
j["payload"].get("originChannel")
|
||||
or j["payload"].get("origin_channel")
|
||||
),
|
||||
origin_chat_id=(
|
||||
j["payload"].get("originChatId")
|
||||
or j["payload"].get("origin_chat_id")
|
||||
),
|
||||
origin_metadata=(
|
||||
j["payload"].get("originMetadata")
|
||||
or j["payload"].get("origin_metadata")
|
||||
or {}
|
||||
),
|
||||
),
|
||||
state=CronJobState(
|
||||
next_run_at_ms=j.get("state", {}).get("nextRunAtMs"),
|
||||
@ -268,6 +281,9 @@ class CronService:
|
||||
"to": j.payload.to,
|
||||
"channelMeta": j.payload.channel_meta,
|
||||
"sessionKey": j.payload.session_key,
|
||||
"originChannel": j.payload.origin_channel,
|
||||
"originChatId": j.payload.origin_chat_id,
|
||||
"originMetadata": j.payload.origin_metadata,
|
||||
},
|
||||
"state": {
|
||||
"nextRunAtMs": j.state.next_run_at_ms,
|
||||
@ -524,6 +540,9 @@ class CronService:
|
||||
delete_after_run: bool = False,
|
||||
channel_meta: dict | None = None,
|
||||
session_key: str | None = None,
|
||||
origin_channel: str | None = None,
|
||||
origin_chat_id: str | None = None,
|
||||
origin_metadata: dict | None = None,
|
||||
) -> CronJob:
|
||||
"""Add a new job."""
|
||||
_validate_schedule_for_add(schedule)
|
||||
@ -542,6 +561,9 @@ class CronService:
|
||||
to=to,
|
||||
channel_meta=channel_meta or {},
|
||||
session_key=session_key,
|
||||
origin_channel=origin_channel,
|
||||
origin_chat_id=origin_chat_id,
|
||||
origin_metadata=origin_metadata or {},
|
||||
),
|
||||
state=CronJobState(next_run_at_ms=_compute_next_run(schedule, now)),
|
||||
created_at_ms=now,
|
||||
|
||||
@ -4,54 +4,12 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.cron.types import CronJob
|
||||
|
||||
def bound_session_inbound_context(session_key: str) -> tuple[str, str, dict[str, Any]]:
|
||||
"""Return ``(channel, chat_id, metadata)`` for a bound cron session key."""
|
||||
if ":" not in session_key:
|
||||
raise ValueError(f"bound cron session_key is invalid: {session_key!r}")
|
||||
channel, rest = session_key.split(":", 1)
|
||||
if not channel or not rest:
|
||||
raise ValueError(f"bound cron session_key is invalid: {session_key!r}")
|
||||
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
if channel == "discord" and ":thread:" in rest:
|
||||
parent_channel_id, thread_id = rest.split(":thread:", 1)
|
||||
if parent_channel_id and thread_id:
|
||||
metadata.update({
|
||||
"context_chat_id": parent_channel_id,
|
||||
"parent_channel_id": parent_channel_id,
|
||||
"thread_id": thread_id,
|
||||
})
|
||||
return channel, thread_id, metadata
|
||||
|
||||
if channel == "feishu" and ":" in rest:
|
||||
chat_id, thread_id = rest.split(":", 1)
|
||||
if chat_id and thread_id:
|
||||
metadata.update({
|
||||
"chat_type": "group",
|
||||
"message_id": thread_id,
|
||||
"thread_id": thread_id,
|
||||
})
|
||||
return channel, chat_id, metadata
|
||||
|
||||
if channel == "slack" and ":" in rest:
|
||||
chat_id, thread_ts = rest.split(":", 1)
|
||||
if thread_ts:
|
||||
metadata["slack"] = {"thread_ts": thread_ts}
|
||||
return channel, chat_id, metadata
|
||||
|
||||
if channel == "telegram" and ":topic:" in rest:
|
||||
chat_id, thread_id = rest.split(":topic:", 1)
|
||||
if thread_id:
|
||||
metadata["message_thread_id"] = (
|
||||
int(thread_id) if thread_id.isdigit() else thread_id
|
||||
)
|
||||
return channel, chat_id, metadata
|
||||
|
||||
if channel == "dingtalk" and rest.startswith("group:"):
|
||||
parts = rest.split(":", 2)
|
||||
if len(parts) >= 2 and parts[1]:
|
||||
return channel, f"group:{parts[1]}", metadata
|
||||
|
||||
return channel, rest, metadata
|
||||
def origin_delivery_context(job: CronJob) -> tuple[str, str, dict[str, Any]]:
|
||||
"""Return ``(channel, chat_id, metadata)`` for a session-bound cron job."""
|
||||
payload = job.payload
|
||||
if not payload.origin_channel or not payload.origin_chat_id:
|
||||
raise ValueError(f"cron job {job.id} is missing origin delivery context")
|
||||
return payload.origin_channel, payload.origin_chat_id, dict(payload.origin_metadata or {})
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Cron types."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -23,12 +23,15 @@ class CronPayload:
|
||||
"""What to do when the job runs."""
|
||||
kind: Literal["system_event", "agent_turn"] = "agent_turn"
|
||||
message: str = ""
|
||||
# Deliver response to channel
|
||||
# Legacy delivery fields used by pre-session-bound cron jobs.
|
||||
deliver: bool = False
|
||||
channel: str | None = None # e.g. "whatsapp"
|
||||
to: str | None = None # e.g. phone number
|
||||
channel_meta: dict = field(default_factory=dict) # channel-specific routing (e.g. Slack thread_ts)
|
||||
channel_meta: dict[str, Any] = field(default_factory=dict)
|
||||
session_key: str | None = None # original session key for correct session recording
|
||||
origin_channel: str | None = None
|
||||
origin_chat_id: str | None = None
|
||||
origin_metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -1611,6 +1611,8 @@ def test_gateway_bound_cron_runs_as_session_turn(
|
||||
payload=CronPayload(
|
||||
message="Check repository health.",
|
||||
session_key="websocket:chat-1",
|
||||
origin_channel="websocket",
|
||||
origin_chat_id="chat-1",
|
||||
),
|
||||
)
|
||||
|
||||
@ -1646,6 +1648,13 @@ def test_gateway_bound_cron_runs_as_session_turn(
|
||||
payload=CronPayload(
|
||||
message="Check the Discord thread.",
|
||||
session_key="discord:456:thread:777",
|
||||
origin_channel="discord",
|
||||
origin_chat_id="777",
|
||||
origin_metadata={
|
||||
"context_chat_id": "456",
|
||||
"parent_channel_id": "456",
|
||||
"thread_id": "777",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@ -1667,6 +1676,9 @@ def test_gateway_bound_cron_runs_as_session_turn(
|
||||
payload=CronPayload(
|
||||
message="Check the Telegram topic.",
|
||||
session_key="telegram:-100123:topic:42",
|
||||
origin_channel="telegram",
|
||||
origin_chat_id="-100123",
|
||||
origin_metadata={"message_thread_id": 42},
|
||||
),
|
||||
)
|
||||
|
||||
@ -1686,6 +1698,13 @@ def test_gateway_bound_cron_runs_as_session_turn(
|
||||
payload=CronPayload(
|
||||
message="Check the Feishu topic.",
|
||||
session_key="feishu:oc_abc:om_root123",
|
||||
origin_channel="feishu",
|
||||
origin_chat_id="oc_abc",
|
||||
origin_metadata={
|
||||
"chat_type": "group",
|
||||
"message_id": "om_root123",
|
||||
"thread_id": "om_root123",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -87,6 +87,37 @@ def test_list_bound_agent_jobs_excludes_legacy_delivery_payloads(tmp_path) -> No
|
||||
assert service.list_bound_cron_jobs_for_session("websocket:chat-1") == [bound]
|
||||
|
||||
|
||||
def test_add_job_preserves_origin_delivery_context(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
metadata = {"slack": {"thread_ts": "1234567890.123456", "channel_type": "channel"}}
|
||||
|
||||
job = service.add_job(
|
||||
name="bound thread",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
session_key="slack:C123:1234567890.123456",
|
||||
origin_channel="slack",
|
||||
origin_chat_id="C123",
|
||||
origin_metadata=metadata,
|
||||
)
|
||||
|
||||
assert job.payload.origin_channel == "slack"
|
||||
assert job.payload.origin_chat_id == "C123"
|
||||
assert job.payload.origin_metadata == metadata
|
||||
|
||||
raw = json.loads((tmp_path / "cron" / "action.jsonl").read_text(encoding="utf-8"))
|
||||
payload = raw["params"]["payload"]
|
||||
assert payload["origin_channel"] == "slack"
|
||||
assert payload["origin_chat_id"] == "C123"
|
||||
assert payload["origin_metadata"] == metadata
|
||||
|
||||
reloaded = service.get_job(job.id)
|
||||
assert reloaded is not None
|
||||
assert reloaded.payload.origin_channel == "slack"
|
||||
assert reloaded.payload.origin_chat_id == "C123"
|
||||
assert reloaded.payload.origin_metadata == metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
@ -103,6 +134,9 @@ async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> No
|
||||
to="C123",
|
||||
channel_meta=meta,
|
||||
session_key="slack:C123:1234567890.123456",
|
||||
origin_channel="slack",
|
||||
origin_chat_id="C123",
|
||||
origin_metadata=meta,
|
||||
)
|
||||
finally:
|
||||
service.stop()
|
||||
@ -111,11 +145,17 @@ async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> No
|
||||
payload = raw["jobs"][0]["payload"]
|
||||
assert payload["channelMeta"] == meta
|
||||
assert payload["sessionKey"] == "slack:C123:1234567890.123456"
|
||||
assert payload["originChannel"] == "slack"
|
||||
assert payload["originChatId"] == "C123"
|
||||
assert payload["originMetadata"] == meta
|
||||
|
||||
reloaded = CronService(store_path).get_job(job.id)
|
||||
assert reloaded is not None
|
||||
assert reloaded.payload.channel_meta == meta
|
||||
assert reloaded.payload.session_key == "slack:C123:1234567890.123456"
|
||||
assert reloaded.payload.origin_channel == "slack"
|
||||
assert reloaded.payload.origin_chat_id == "C123"
|
||||
assert reloaded.payload.origin_metadata == meta
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -339,6 +339,9 @@ def test_add_job_binds_current_session_key(tmp_path) -> None:
|
||||
assert result.startswith("Created job")
|
||||
job = tool._cron.list_jobs()[0]
|
||||
assert job.payload.session_key == "telegram:chat-1"
|
||||
assert job.payload.origin_channel == "telegram"
|
||||
assert job.payload.origin_chat_id == "chat-1"
|
||||
assert job.payload.origin_metadata == {}
|
||||
assert job.payload.channel is None
|
||||
assert job.payload.to is None
|
||||
|
||||
@ -392,8 +395,8 @@ def test_add_job_empty_message_returns_actionable_error(tmp_path) -> None:
|
||||
assert "Retry including message=" in result
|
||||
|
||||
|
||||
def test_add_job_captures_only_session_key(tmp_path) -> None:
|
||||
"""CronTool stores the canonical session key without legacy delivery fields."""
|
||||
def test_add_job_captures_owner_and_origin_without_legacy_delivery_fields(tmp_path) -> None:
|
||||
"""CronTool stores owner/session identity separately from origin delivery context."""
|
||||
tool = _make_tool(tmp_path)
|
||||
meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
|
||||
tool.set_context(RequestContext(
|
||||
@ -406,6 +409,9 @@ def test_add_job_captures_only_session_key(tmp_path) -> None:
|
||||
jobs = tool._cron.list_jobs()
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].payload.session_key == "slack:C99:111.222"
|
||||
assert jobs[0].payload.origin_channel == "slack"
|
||||
assert jobs[0].payload.origin_chat_id == "C99"
|
||||
assert jobs[0].payload.origin_metadata == meta
|
||||
assert jobs[0].payload.channel is None
|
||||
assert jobs[0].payload.to is None
|
||||
assert jobs[0].payload.channel_meta == {}
|
||||
|
||||
@ -1,45 +1,44 @@
|
||||
import pytest
|
||||
|
||||
from nanobot.cron.session_delivery import bound_session_inbound_context
|
||||
from nanobot.cron.session_delivery import origin_delivery_context
|
||||
from nanobot.cron.types import CronJob, CronPayload
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("session_key", "expected"),
|
||||
[
|
||||
("websocket:chat-1", ("websocket", "chat-1", {})),
|
||||
(
|
||||
"discord:456:thread:777",
|
||||
(
|
||||
"discord",
|
||||
"777",
|
||||
{
|
||||
"context_chat_id": "456",
|
||||
"parent_channel_id": "456",
|
||||
"thread_id": "777",
|
||||
},
|
||||
),
|
||||
def test_origin_delivery_context_uses_explicit_origin_fields() -> None:
|
||||
metadata = {
|
||||
"context_chat_id": "456",
|
||||
"parent_channel_id": "456",
|
||||
"thread_id": "777",
|
||||
}
|
||||
job = CronJob(
|
||||
id="thread-check",
|
||||
name="Thread check",
|
||||
payload=CronPayload(
|
||||
message="check",
|
||||
session_key="discord:456:thread:777",
|
||||
origin_channel="discord",
|
||||
origin_chat_id="777",
|
||||
origin_metadata=metadata,
|
||||
),
|
||||
(
|
||||
"feishu:oc_abc:om_root123",
|
||||
(
|
||||
"feishu",
|
||||
"oc_abc",
|
||||
{
|
||||
"chat_type": "group",
|
||||
"message_id": "om_root123",
|
||||
"thread_id": "om_root123",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
channel, chat_id, returned_metadata = origin_delivery_context(job)
|
||||
|
||||
assert channel == "discord"
|
||||
assert chat_id == "777"
|
||||
assert returned_metadata == metadata
|
||||
assert returned_metadata is not metadata
|
||||
|
||||
|
||||
def test_origin_delivery_context_rejects_missing_origin_fields() -> None:
|
||||
job = CronJob(
|
||||
id="old-bound",
|
||||
name="Old bound job",
|
||||
payload=CronPayload(
|
||||
message="check",
|
||||
session_key="websocket:chat-1",
|
||||
),
|
||||
("slack:C123:1700.42", ("slack", "C123", {"slack": {"thread_ts": "1700.42"}})),
|
||||
("telegram:-100123:topic:42", ("telegram", "-100123", {"message_thread_id": 42})),
|
||||
("dingtalk:group:conv-1:user-1", ("dingtalk", "group:conv-1", {})),
|
||||
],
|
||||
)
|
||||
def test_bound_session_inbound_context(session_key, expected) -> None:
|
||||
assert bound_session_inbound_context(session_key) == expected
|
||||
)
|
||||
|
||||
|
||||
def test_bound_session_inbound_context_rejects_invalid_key() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
bound_session_inbound_context("unified")
|
||||
with pytest.raises(ValueError, match="missing origin delivery context"):
|
||||
origin_delivery_context(job)
|
||||
|
||||
@ -123,6 +123,10 @@ async def test_cron_tool_keeps_task_local_context(tmp_path) -> None:
|
||||
|
||||
jobs = tool._cron.list_jobs()
|
||||
assert {job.payload.session_key for job in jobs} == {"feishu:chat-a", "email:chat-b"}
|
||||
assert {(job.payload.origin_channel, job.payload.origin_chat_id) for job in jobs} == {
|
||||
("feishu", "chat-a"),
|
||||
("email", "chat-b"),
|
||||
}
|
||||
|
||||
|
||||
# --- Basic single-task regression tests ---
|
||||
@ -243,6 +247,8 @@ async def test_cron_tool_basic_set_context_and_execute(tmp_path) -> None:
|
||||
jobs = tool._cron.list_jobs()
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].payload.session_key == "wechat:user-789"
|
||||
assert jobs[0].payload.origin_channel == "wechat"
|
||||
assert jobs[0].payload.origin_chat_id == "user-789"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -272,6 +278,9 @@ async def test_webui_cron_tool_uses_origin_session_when_unified_enabled(tmp_path
|
||||
jobs = tool._cron.list_jobs()
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].payload.session_key == "websocket:chat-123"
|
||||
assert jobs[0].payload.origin_channel == "websocket"
|
||||
assert jobs[0].payload.origin_chat_id == "chat-123"
|
||||
assert jobs[0].payload.origin_metadata == {"webui": True}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -293,6 +302,9 @@ async def test_cron_tool_preserves_thread_scoped_session_key(tmp_path) -> None:
|
||||
jobs = tool._cron.list_jobs()
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].payload.session_key == "slack:C123:1700.42"
|
||||
assert jobs[0].payload.origin_channel == "slack"
|
||||
assert jobs[0].payload.origin_chat_id == "C123"
|
||||
assert jobs[0].payload.origin_metadata == {"slack": {"thread_ts": "1700.42"}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user