mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
refactor: store cron origin delivery context
This commit is contained in:
parent
b232a52794
commit
5ae907bc2f
@ -58,6 +58,12 @@ class CronTool(Tool, ContextAware):
|
|||||||
self._cron = cron_service
|
self._cron = cron_service
|
||||||
self._default_timezone = default_timezone
|
self._default_timezone = default_timezone
|
||||||
self._session_key: ContextVar[str] = ContextVar("cron_session_key", default="")
|
self._session_key: ContextVar[str] = ContextVar("cron_session_key", default="")
|
||||||
|
self._origin_channel: ContextVar[str] = ContextVar("cron_origin_channel", default="")
|
||||||
|
self._origin_chat_id: ContextVar[str] = ContextVar("cron_origin_chat_id", default="")
|
||||||
|
self._origin_metadata: ContextVar[dict[str, Any] | None] = ContextVar(
|
||||||
|
"cron_origin_metadata",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
|
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -74,6 +80,9 @@ class CronTool(Tool, ContextAware):
|
|||||||
self._session_key.set(
|
self._session_key.set(
|
||||||
raw_key if ctx.session_key == UNIFIED_SESSION_KEY else (ctx.session_key or "")
|
raw_key if ctx.session_key == UNIFIED_SESSION_KEY else (ctx.session_key or "")
|
||||||
)
|
)
|
||||||
|
self._origin_channel.set(ctx.channel or "")
|
||||||
|
self._origin_chat_id.set(ctx.chat_id or "")
|
||||||
|
self._origin_metadata.set(dict(ctx.metadata or {}))
|
||||||
|
|
||||||
def set_cron_context(self, active: bool):
|
def set_cron_context(self, active: bool):
|
||||||
"""Mark whether the tool is executing inside a cron job callback."""
|
"""Mark whether the tool is executing inside a cron job callback."""
|
||||||
@ -165,6 +174,10 @@ class CronTool(Tool, ContextAware):
|
|||||||
session_key = self._session_key.get()
|
session_key = self._session_key.get()
|
||||||
if not session_key:
|
if not session_key:
|
||||||
return "Error: scheduled cron jobs must be created from a chat session"
|
return "Error: scheduled cron jobs must be created from a chat session"
|
||||||
|
origin_channel = self._origin_channel.get()
|
||||||
|
origin_chat_id = self._origin_chat_id.get()
|
||||||
|
if not origin_channel or not origin_chat_id:
|
||||||
|
return "Error: scheduled cron jobs must be created from a chat session"
|
||||||
if tz and not cron_expr:
|
if tz and not cron_expr:
|
||||||
return "Error: tz can only be used with cron_expr"
|
return "Error: tz can only be used with cron_expr"
|
||||||
if tz:
|
if tz:
|
||||||
@ -203,6 +216,9 @@ class CronTool(Tool, ContextAware):
|
|||||||
message=message,
|
message=message,
|
||||||
delete_after_run=delete_after,
|
delete_after_run=delete_after,
|
||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
|
origin_channel=origin_channel,
|
||||||
|
origin_chat_id=origin_chat_id,
|
||||||
|
origin_metadata=dict(self._origin_metadata.get() or {}),
|
||||||
)
|
)
|
||||||
return f"Created job '{job.name}' (id: {job.id})"
|
return f"Created job '{job.name}' (id: {job.id})"
|
||||||
|
|
||||||
|
|||||||
@ -980,7 +980,7 @@ def _run_gateway(
|
|||||||
from nanobot.bus.runtime_events import RuntimeEventBus
|
from nanobot.bus.runtime_events import RuntimeEventBus
|
||||||
from nanobot.channels.manager import ChannelManager
|
from nanobot.channels.manager import ChannelManager
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
from nanobot.cron.session_delivery import bound_session_inbound_context
|
from nanobot.cron.session_delivery import origin_delivery_context
|
||||||
from nanobot.cron.session_turns import (
|
from nanobot.cron.session_turns import (
|
||||||
CRON_DEFER_UNTIL_IDLE_META,
|
CRON_DEFER_UNTIL_IDLE_META,
|
||||||
CRON_TRIGGER_META,
|
CRON_TRIGGER_META,
|
||||||
@ -1046,12 +1046,12 @@ def _run_gateway(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _bound_session_delivery_context(
|
def _bound_session_delivery_context(
|
||||||
session_key: str,
|
job: CronJob,
|
||||||
*,
|
*,
|
||||||
turn_seed: str,
|
turn_seed: str,
|
||||||
source_label: str | None,
|
source_label: str | None,
|
||||||
) -> tuple[str, str, dict[str, Any]]:
|
) -> tuple[str, str, dict[str, Any]]:
|
||||||
channel, chat_id, metadata = bound_session_inbound_context(session_key)
|
channel, chat_id, metadata = origin_delivery_context(job)
|
||||||
|
|
||||||
if channel == "websocket":
|
if channel == "websocket":
|
||||||
metadata["webui"] = True
|
metadata["webui"] = True
|
||||||
@ -1086,7 +1086,7 @@ def _run_gateway(
|
|||||||
prompt_ref = _cron_prompt_ref(prompt)
|
prompt_ref = _cron_prompt_ref(prompt)
|
||||||
run_id = f"{job.id}:{int(time.time() * 1000)}:{uuid.uuid4().hex[:8]}"
|
run_id = f"{job.id}:{int(time.time() * 1000)}:{uuid.uuid4().hex[:8]}"
|
||||||
channel, chat_id, metadata = _bound_session_delivery_context(
|
channel, chat_id, metadata = _bound_session_delivery_context(
|
||||||
session_key,
|
job,
|
||||||
turn_seed=f"cron:{job.id}",
|
turn_seed=f"cron:{job.id}",
|
||||||
source_label=job.name,
|
source_label=job.name,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -138,6 +138,19 @@ class CronService:
|
|||||||
or {}
|
or {}
|
||||||
),
|
),
|
||||||
session_key=j["payload"].get("sessionKey") or j["payload"].get("session_key"),
|
session_key=j["payload"].get("sessionKey") or j["payload"].get("session_key"),
|
||||||
|
origin_channel=(
|
||||||
|
j["payload"].get("originChannel")
|
||||||
|
or j["payload"].get("origin_channel")
|
||||||
|
),
|
||||||
|
origin_chat_id=(
|
||||||
|
j["payload"].get("originChatId")
|
||||||
|
or j["payload"].get("origin_chat_id")
|
||||||
|
),
|
||||||
|
origin_metadata=(
|
||||||
|
j["payload"].get("originMetadata")
|
||||||
|
or j["payload"].get("origin_metadata")
|
||||||
|
or {}
|
||||||
|
),
|
||||||
),
|
),
|
||||||
state=CronJobState(
|
state=CronJobState(
|
||||||
next_run_at_ms=j.get("state", {}).get("nextRunAtMs"),
|
next_run_at_ms=j.get("state", {}).get("nextRunAtMs"),
|
||||||
@ -268,6 +281,9 @@ class CronService:
|
|||||||
"to": j.payload.to,
|
"to": j.payload.to,
|
||||||
"channelMeta": j.payload.channel_meta,
|
"channelMeta": j.payload.channel_meta,
|
||||||
"sessionKey": j.payload.session_key,
|
"sessionKey": j.payload.session_key,
|
||||||
|
"originChannel": j.payload.origin_channel,
|
||||||
|
"originChatId": j.payload.origin_chat_id,
|
||||||
|
"originMetadata": j.payload.origin_metadata,
|
||||||
},
|
},
|
||||||
"state": {
|
"state": {
|
||||||
"nextRunAtMs": j.state.next_run_at_ms,
|
"nextRunAtMs": j.state.next_run_at_ms,
|
||||||
@ -524,6 +540,9 @@ class CronService:
|
|||||||
delete_after_run: bool = False,
|
delete_after_run: bool = False,
|
||||||
channel_meta: dict | None = None,
|
channel_meta: dict | None = None,
|
||||||
session_key: str | None = None,
|
session_key: str | None = None,
|
||||||
|
origin_channel: str | None = None,
|
||||||
|
origin_chat_id: str | None = None,
|
||||||
|
origin_metadata: dict | None = None,
|
||||||
) -> CronJob:
|
) -> CronJob:
|
||||||
"""Add a new job."""
|
"""Add a new job."""
|
||||||
_validate_schedule_for_add(schedule)
|
_validate_schedule_for_add(schedule)
|
||||||
@ -542,6 +561,9 @@ class CronService:
|
|||||||
to=to,
|
to=to,
|
||||||
channel_meta=channel_meta or {},
|
channel_meta=channel_meta or {},
|
||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
|
origin_channel=origin_channel,
|
||||||
|
origin_chat_id=origin_chat_id,
|
||||||
|
origin_metadata=origin_metadata or {},
|
||||||
),
|
),
|
||||||
state=CronJobState(next_run_at_ms=_compute_next_run(schedule, now)),
|
state=CronJobState(next_run_at_ms=_compute_next_run(schedule, now)),
|
||||||
created_at_ms=now,
|
created_at_ms=now,
|
||||||
|
|||||||
@ -4,54 +4,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from nanobot.cron.types import CronJob
|
||||||
|
|
||||||
def bound_session_inbound_context(session_key: str) -> tuple[str, str, dict[str, Any]]:
|
|
||||||
"""Return ``(channel, chat_id, metadata)`` for a bound cron session key."""
|
|
||||||
if ":" not in session_key:
|
|
||||||
raise ValueError(f"bound cron session_key is invalid: {session_key!r}")
|
|
||||||
channel, rest = session_key.split(":", 1)
|
|
||||||
if not channel or not rest:
|
|
||||||
raise ValueError(f"bound cron session_key is invalid: {session_key!r}")
|
|
||||||
|
|
||||||
metadata: dict[str, Any] = {}
|
def origin_delivery_context(job: CronJob) -> tuple[str, str, dict[str, Any]]:
|
||||||
|
"""Return ``(channel, chat_id, metadata)`` for a session-bound cron job."""
|
||||||
if channel == "discord" and ":thread:" in rest:
|
payload = job.payload
|
||||||
parent_channel_id, thread_id = rest.split(":thread:", 1)
|
if not payload.origin_channel or not payload.origin_chat_id:
|
||||||
if parent_channel_id and thread_id:
|
raise ValueError(f"cron job {job.id} is missing origin delivery context")
|
||||||
metadata.update({
|
return payload.origin_channel, payload.origin_chat_id, dict(payload.origin_metadata or {})
|
||||||
"context_chat_id": parent_channel_id,
|
|
||||||
"parent_channel_id": parent_channel_id,
|
|
||||||
"thread_id": thread_id,
|
|
||||||
})
|
|
||||||
return channel, thread_id, metadata
|
|
||||||
|
|
||||||
if channel == "feishu" and ":" in rest:
|
|
||||||
chat_id, thread_id = rest.split(":", 1)
|
|
||||||
if chat_id and thread_id:
|
|
||||||
metadata.update({
|
|
||||||
"chat_type": "group",
|
|
||||||
"message_id": thread_id,
|
|
||||||
"thread_id": thread_id,
|
|
||||||
})
|
|
||||||
return channel, chat_id, metadata
|
|
||||||
|
|
||||||
if channel == "slack" and ":" in rest:
|
|
||||||
chat_id, thread_ts = rest.split(":", 1)
|
|
||||||
if thread_ts:
|
|
||||||
metadata["slack"] = {"thread_ts": thread_ts}
|
|
||||||
return channel, chat_id, metadata
|
|
||||||
|
|
||||||
if channel == "telegram" and ":topic:" in rest:
|
|
||||||
chat_id, thread_id = rest.split(":topic:", 1)
|
|
||||||
if thread_id:
|
|
||||||
metadata["message_thread_id"] = (
|
|
||||||
int(thread_id) if thread_id.isdigit() else thread_id
|
|
||||||
)
|
|
||||||
return channel, chat_id, metadata
|
|
||||||
|
|
||||||
if channel == "dingtalk" and rest.startswith("group:"):
|
|
||||||
parts = rest.split(":", 2)
|
|
||||||
if len(parts) >= 2 and parts[1]:
|
|
||||||
return channel, f"group:{parts[1]}", metadata
|
|
||||||
|
|
||||||
return channel, rest, metadata
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""Cron types."""
|
"""Cron types."""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -23,12 +23,15 @@ class CronPayload:
|
|||||||
"""What to do when the job runs."""
|
"""What to do when the job runs."""
|
||||||
kind: Literal["system_event", "agent_turn"] = "agent_turn"
|
kind: Literal["system_event", "agent_turn"] = "agent_turn"
|
||||||
message: str = ""
|
message: str = ""
|
||||||
# Deliver response to channel
|
# Legacy delivery fields used by pre-session-bound cron jobs.
|
||||||
deliver: bool = False
|
deliver: bool = False
|
||||||
channel: str | None = None # e.g. "whatsapp"
|
channel: str | None = None # e.g. "whatsapp"
|
||||||
to: str | None = None # e.g. phone number
|
to: str | None = None # e.g. phone number
|
||||||
channel_meta: dict = field(default_factory=dict) # channel-specific routing (e.g. Slack thread_ts)
|
channel_meta: dict[str, Any] = field(default_factory=dict)
|
||||||
session_key: str | None = None # original session key for correct session recording
|
session_key: str | None = None # original session key for correct session recording
|
||||||
|
origin_channel: str | None = None
|
||||||
|
origin_chat_id: str | None = None
|
||||||
|
origin_metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -1611,6 +1611,8 @@ def test_gateway_bound_cron_runs_as_session_turn(
|
|||||||
payload=CronPayload(
|
payload=CronPayload(
|
||||||
message="Check repository health.",
|
message="Check repository health.",
|
||||||
session_key="websocket:chat-1",
|
session_key="websocket:chat-1",
|
||||||
|
origin_channel="websocket",
|
||||||
|
origin_chat_id="chat-1",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1646,6 +1648,13 @@ def test_gateway_bound_cron_runs_as_session_turn(
|
|||||||
payload=CronPayload(
|
payload=CronPayload(
|
||||||
message="Check the Discord thread.",
|
message="Check the Discord thread.",
|
||||||
session_key="discord:456:thread:777",
|
session_key="discord:456:thread:777",
|
||||||
|
origin_channel="discord",
|
||||||
|
origin_chat_id="777",
|
||||||
|
origin_metadata={
|
||||||
|
"context_chat_id": "456",
|
||||||
|
"parent_channel_id": "456",
|
||||||
|
"thread_id": "777",
|
||||||
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1667,6 +1676,9 @@ def test_gateway_bound_cron_runs_as_session_turn(
|
|||||||
payload=CronPayload(
|
payload=CronPayload(
|
||||||
message="Check the Telegram topic.",
|
message="Check the Telegram topic.",
|
||||||
session_key="telegram:-100123:topic:42",
|
session_key="telegram:-100123:topic:42",
|
||||||
|
origin_channel="telegram",
|
||||||
|
origin_chat_id="-100123",
|
||||||
|
origin_metadata={"message_thread_id": 42},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1686,6 +1698,13 @@ def test_gateway_bound_cron_runs_as_session_turn(
|
|||||||
payload=CronPayload(
|
payload=CronPayload(
|
||||||
message="Check the Feishu topic.",
|
message="Check the Feishu topic.",
|
||||||
session_key="feishu:oc_abc:om_root123",
|
session_key="feishu:oc_abc:om_root123",
|
||||||
|
origin_channel="feishu",
|
||||||
|
origin_chat_id="oc_abc",
|
||||||
|
origin_metadata={
|
||||||
|
"chat_type": "group",
|
||||||
|
"message_id": "om_root123",
|
||||||
|
"thread_id": "om_root123",
|
||||||
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -87,6 +87,37 @@ def test_list_bound_agent_jobs_excludes_legacy_delivery_payloads(tmp_path) -> No
|
|||||||
assert service.list_bound_cron_jobs_for_session("websocket:chat-1") == [bound]
|
assert service.list_bound_cron_jobs_for_session("websocket:chat-1") == [bound]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_job_preserves_origin_delivery_context(tmp_path) -> None:
|
||||||
|
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
|
metadata = {"slack": {"thread_ts": "1234567890.123456", "channel_type": "channel"}}
|
||||||
|
|
||||||
|
job = service.add_job(
|
||||||
|
name="bound thread",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||||
|
message="hello",
|
||||||
|
session_key="slack:C123:1234567890.123456",
|
||||||
|
origin_channel="slack",
|
||||||
|
origin_chat_id="C123",
|
||||||
|
origin_metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert job.payload.origin_channel == "slack"
|
||||||
|
assert job.payload.origin_chat_id == "C123"
|
||||||
|
assert job.payload.origin_metadata == metadata
|
||||||
|
|
||||||
|
raw = json.loads((tmp_path / "cron" / "action.jsonl").read_text(encoding="utf-8"))
|
||||||
|
payload = raw["params"]["payload"]
|
||||||
|
assert payload["origin_channel"] == "slack"
|
||||||
|
assert payload["origin_chat_id"] == "C123"
|
||||||
|
assert payload["origin_metadata"] == metadata
|
||||||
|
|
||||||
|
reloaded = service.get_job(job.id)
|
||||||
|
assert reloaded is not None
|
||||||
|
assert reloaded.payload.origin_channel == "slack"
|
||||||
|
assert reloaded.payload.origin_chat_id == "C123"
|
||||||
|
assert reloaded.payload.origin_metadata == metadata
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> None:
|
async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> None:
|
||||||
store_path = tmp_path / "cron" / "jobs.json"
|
store_path = tmp_path / "cron" / "jobs.json"
|
||||||
@ -103,6 +134,9 @@ async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> No
|
|||||||
to="C123",
|
to="C123",
|
||||||
channel_meta=meta,
|
channel_meta=meta,
|
||||||
session_key="slack:C123:1234567890.123456",
|
session_key="slack:C123:1234567890.123456",
|
||||||
|
origin_channel="slack",
|
||||||
|
origin_chat_id="C123",
|
||||||
|
origin_metadata=meta,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
service.stop()
|
service.stop()
|
||||||
@ -111,11 +145,17 @@ async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> No
|
|||||||
payload = raw["jobs"][0]["payload"]
|
payload = raw["jobs"][0]["payload"]
|
||||||
assert payload["channelMeta"] == meta
|
assert payload["channelMeta"] == meta
|
||||||
assert payload["sessionKey"] == "slack:C123:1234567890.123456"
|
assert payload["sessionKey"] == "slack:C123:1234567890.123456"
|
||||||
|
assert payload["originChannel"] == "slack"
|
||||||
|
assert payload["originChatId"] == "C123"
|
||||||
|
assert payload["originMetadata"] == meta
|
||||||
|
|
||||||
reloaded = CronService(store_path).get_job(job.id)
|
reloaded = CronService(store_path).get_job(job.id)
|
||||||
assert reloaded is not None
|
assert reloaded is not None
|
||||||
assert reloaded.payload.channel_meta == meta
|
assert reloaded.payload.channel_meta == meta
|
||||||
assert reloaded.payload.session_key == "slack:C123:1234567890.123456"
|
assert reloaded.payload.session_key == "slack:C123:1234567890.123456"
|
||||||
|
assert reloaded.payload.origin_channel == "slack"
|
||||||
|
assert reloaded.payload.origin_chat_id == "C123"
|
||||||
|
assert reloaded.payload.origin_metadata == meta
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -339,6 +339,9 @@ def test_add_job_binds_current_session_key(tmp_path) -> None:
|
|||||||
assert result.startswith("Created job")
|
assert result.startswith("Created job")
|
||||||
job = tool._cron.list_jobs()[0]
|
job = tool._cron.list_jobs()[0]
|
||||||
assert job.payload.session_key == "telegram:chat-1"
|
assert job.payload.session_key == "telegram:chat-1"
|
||||||
|
assert job.payload.origin_channel == "telegram"
|
||||||
|
assert job.payload.origin_chat_id == "chat-1"
|
||||||
|
assert job.payload.origin_metadata == {}
|
||||||
assert job.payload.channel is None
|
assert job.payload.channel is None
|
||||||
assert job.payload.to is None
|
assert job.payload.to is None
|
||||||
|
|
||||||
@ -392,8 +395,8 @@ def test_add_job_empty_message_returns_actionable_error(tmp_path) -> None:
|
|||||||
assert "Retry including message=" in result
|
assert "Retry including message=" in result
|
||||||
|
|
||||||
|
|
||||||
def test_add_job_captures_only_session_key(tmp_path) -> None:
|
def test_add_job_captures_owner_and_origin_without_legacy_delivery_fields(tmp_path) -> None:
|
||||||
"""CronTool stores the canonical session key without legacy delivery fields."""
|
"""CronTool stores owner/session identity separately from origin delivery context."""
|
||||||
tool = _make_tool(tmp_path)
|
tool = _make_tool(tmp_path)
|
||||||
meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
|
meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
|
||||||
tool.set_context(RequestContext(
|
tool.set_context(RequestContext(
|
||||||
@ -406,6 +409,9 @@ def test_add_job_captures_only_session_key(tmp_path) -> None:
|
|||||||
jobs = tool._cron.list_jobs()
|
jobs = tool._cron.list_jobs()
|
||||||
assert len(jobs) == 1
|
assert len(jobs) == 1
|
||||||
assert jobs[0].payload.session_key == "slack:C99:111.222"
|
assert jobs[0].payload.session_key == "slack:C99:111.222"
|
||||||
|
assert jobs[0].payload.origin_channel == "slack"
|
||||||
|
assert jobs[0].payload.origin_chat_id == "C99"
|
||||||
|
assert jobs[0].payload.origin_metadata == meta
|
||||||
assert jobs[0].payload.channel is None
|
assert jobs[0].payload.channel is None
|
||||||
assert jobs[0].payload.to is None
|
assert jobs[0].payload.to is None
|
||||||
assert jobs[0].payload.channel_meta == {}
|
assert jobs[0].payload.channel_meta == {}
|
||||||
|
|||||||
@ -1,45 +1,44 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.cron.session_delivery import bound_session_inbound_context
|
from nanobot.cron.session_delivery import origin_delivery_context
|
||||||
|
from nanobot.cron.types import CronJob, CronPayload
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
def test_origin_delivery_context_uses_explicit_origin_fields() -> None:
|
||||||
("session_key", "expected"),
|
metadata = {
|
||||||
[
|
"context_chat_id": "456",
|
||||||
("websocket:chat-1", ("websocket", "chat-1", {})),
|
"parent_channel_id": "456",
|
||||||
(
|
"thread_id": "777",
|
||||||
"discord:456:thread:777",
|
}
|
||||||
(
|
job = CronJob(
|
||||||
"discord",
|
id="thread-check",
|
||||||
"777",
|
name="Thread check",
|
||||||
{
|
payload=CronPayload(
|
||||||
"context_chat_id": "456",
|
message="check",
|
||||||
"parent_channel_id": "456",
|
session_key="discord:456:thread:777",
|
||||||
"thread_id": "777",
|
origin_channel="discord",
|
||||||
},
|
origin_chat_id="777",
|
||||||
),
|
origin_metadata=metadata,
|
||||||
),
|
),
|
||||||
(
|
)
|
||||||
"feishu:oc_abc:om_root123",
|
|
||||||
(
|
channel, chat_id, returned_metadata = origin_delivery_context(job)
|
||||||
"feishu",
|
|
||||||
"oc_abc",
|
assert channel == "discord"
|
||||||
{
|
assert chat_id == "777"
|
||||||
"chat_type": "group",
|
assert returned_metadata == metadata
|
||||||
"message_id": "om_root123",
|
assert returned_metadata is not metadata
|
||||||
"thread_id": "om_root123",
|
|
||||||
},
|
|
||||||
),
|
def test_origin_delivery_context_rejects_missing_origin_fields() -> None:
|
||||||
|
job = CronJob(
|
||||||
|
id="old-bound",
|
||||||
|
name="Old bound job",
|
||||||
|
payload=CronPayload(
|
||||||
|
message="check",
|
||||||
|
session_key="websocket:chat-1",
|
||||||
),
|
),
|
||||||
("slack:C123:1700.42", ("slack", "C123", {"slack": {"thread_ts": "1700.42"}})),
|
)
|
||||||
("telegram:-100123:topic:42", ("telegram", "-100123", {"message_thread_id": 42})),
|
|
||||||
("dingtalk:group:conv-1:user-1", ("dingtalk", "group:conv-1", {})),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_bound_session_inbound_context(session_key, expected) -> None:
|
|
||||||
assert bound_session_inbound_context(session_key) == expected
|
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="missing origin delivery context"):
|
||||||
def test_bound_session_inbound_context_rejects_invalid_key() -> None:
|
origin_delivery_context(job)
|
||||||
with pytest.raises(ValueError):
|
|
||||||
bound_session_inbound_context("unified")
|
|
||||||
|
|||||||
@ -123,6 +123,10 @@ async def test_cron_tool_keeps_task_local_context(tmp_path) -> None:
|
|||||||
|
|
||||||
jobs = tool._cron.list_jobs()
|
jobs = tool._cron.list_jobs()
|
||||||
assert {job.payload.session_key for job in jobs} == {"feishu:chat-a", "email:chat-b"}
|
assert {job.payload.session_key for job in jobs} == {"feishu:chat-a", "email:chat-b"}
|
||||||
|
assert {(job.payload.origin_channel, job.payload.origin_chat_id) for job in jobs} == {
|
||||||
|
("feishu", "chat-a"),
|
||||||
|
("email", "chat-b"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# --- Basic single-task regression tests ---
|
# --- Basic single-task regression tests ---
|
||||||
@ -243,6 +247,8 @@ async def test_cron_tool_basic_set_context_and_execute(tmp_path) -> None:
|
|||||||
jobs = tool._cron.list_jobs()
|
jobs = tool._cron.list_jobs()
|
||||||
assert len(jobs) == 1
|
assert len(jobs) == 1
|
||||||
assert jobs[0].payload.session_key == "wechat:user-789"
|
assert jobs[0].payload.session_key == "wechat:user-789"
|
||||||
|
assert jobs[0].payload.origin_channel == "wechat"
|
||||||
|
assert jobs[0].payload.origin_chat_id == "user-789"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -272,6 +278,9 @@ async def test_webui_cron_tool_uses_origin_session_when_unified_enabled(tmp_path
|
|||||||
jobs = tool._cron.list_jobs()
|
jobs = tool._cron.list_jobs()
|
||||||
assert len(jobs) == 1
|
assert len(jobs) == 1
|
||||||
assert jobs[0].payload.session_key == "websocket:chat-123"
|
assert jobs[0].payload.session_key == "websocket:chat-123"
|
||||||
|
assert jobs[0].payload.origin_channel == "websocket"
|
||||||
|
assert jobs[0].payload.origin_chat_id == "chat-123"
|
||||||
|
assert jobs[0].payload.origin_metadata == {"webui": True}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -293,6 +302,9 @@ async def test_cron_tool_preserves_thread_scoped_session_key(tmp_path) -> None:
|
|||||||
jobs = tool._cron.list_jobs()
|
jobs = tool._cron.list_jobs()
|
||||||
assert len(jobs) == 1
|
assert len(jobs) == 1
|
||||||
assert jobs[0].payload.session_key == "slack:C123:1700.42"
|
assert jobs[0].payload.session_key == "slack:C123:1700.42"
|
||||||
|
assert jobs[0].payload.origin_channel == "slack"
|
||||||
|
assert jobs[0].payload.origin_chat_id == "C123"
|
||||||
|
assert jobs[0].payload.origin_metadata == {"slack": {"thread_ts": "1700.42"}}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user