Merge branch 'main' into fix/session-history-timestamps

Made-with: Cursor
This commit is contained in:
Xubin Ren 2026-04-26 18:13:11 +00:00
commit 4a4ba1efc1
16 changed files with 376 additions and 20 deletions

View File

@ -87,6 +87,11 @@ ruff check nanobot/
ruff format nanobot/ ruff format nanobot/
``` ```
## Contribution License
By submitting a contribution, you confirm that you have the right to submit it
and agree that it will be licensed under the project's MIT License.
## Code Style ## Code Style
We care about more than passing lint. We want nanobot to stay small, calm, and readable. We care about more than passing lint. We want nanobot to stay small, calm, and readable.

View File

@ -1,6 +1,6 @@
MIT License MIT License
Copyright (c) 2025 nanobot contributors Copyright (c) 2025-present Xubin Ren and the nanobot contributors
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View File

@ -282,6 +282,10 @@ PRs welcome! The codebase is intentionally small and readable. 🤗
- **More integrations** — Calendar and more - **More integrations** — Calendar and more
- **Self-improvement** — Learn from feedback and mistakes - **Self-improvement** — Learn from feedback and mistakes
## Contact
This project was started by [Xubin Ren](https://github.com/re-bin) as a personal open-source project and continues to be maintained in an individual capacity using personal resources, with contributions from the open-source community. Feel free to contact [xubinrencs@gmail.com](mailto:xubinrencs@gmail.com) for questions, ideas, or collaboration.
### Contributors ### Contributors
<a href="https://github.com/HKUDS/nanobot/graphs/contributors"> <a href="https://github.com/HKUDS/nanobot/graphs/contributors">

View File

@ -434,7 +434,7 @@ Uses **Socket Mode** — no public URL required.
**2. Configure the app** **2. Configure the app**
- **Socket Mode**: Toggle ON → Generate an **App-Level Token** with `connections:write` scope → copy it (`xapp-...`) - **Socket Mode**: Toggle ON → Generate an **App-Level Token** with `connections:write` scope → copy it (`xapp-...`)
- **OAuth & Permissions**: Add bot scopes: `chat:write`, `reactions:write`, `app_mentions:read` - **OAuth & Permissions**: Add bot scopes: `chat:write`, `reactions:write`, `app_mentions:read`, `channels:history`, `groups:history`, `im:history`, `mpim:history`
- **Event Subscriptions**: Toggle ON → Subscribe to bot events: `message.im`, `message.channels`, `app_mention` → Save Changes - **Event Subscriptions**: Toggle ON → Subscribe to bot events: `message.im`, `message.channels`, `app_mention` → Save Changes
- **App Home**: Scroll to **Show Tabs** → Enable **Messages Tab** → Check **"Allow users to send Slash commands and messages from the messages tab"** - **App Home**: Scroll to **Show Tabs** → Enable **Messages Tab** → Check **"Allow users to send Slash commands and messages from the messages tab"**
- **Install App**: Click **Install to Workspace** → Authorize → copy the **Bot Token** (`xoxb-...`) - **Install App**: Click **Install to Workspace** → Authorize → copy the **Bot Token** (`xoxb-...`)

View File

@ -76,6 +76,8 @@ class _LoopHook(AgentHook):
channel: str = "cli", channel: str = "cli",
chat_id: str = "direct", chat_id: str = "direct",
message_id: str | None = None, message_id: str | None = None,
metadata: dict[str, Any] | None = None,
session_key: str | None = None,
) -> None: ) -> None:
super().__init__(reraise=True) super().__init__(reraise=True)
self._loop = agent_loop self._loop = agent_loop
@ -85,6 +87,8 @@ class _LoopHook(AgentHook):
self._channel = channel self._channel = channel
self._chat_id = chat_id self._chat_id = chat_id
self._message_id = message_id self._message_id = message_id
self._metadata = metadata or {}
self._session_key = session_key
self._stream_buf = "" self._stream_buf = ""
def wants_streaming(self) -> bool: def wants_streaming(self) -> bool:
@ -127,7 +131,13 @@ class _LoopHook(AgentHook):
for tc in context.tool_calls: for tc in context.tool_calls:
args_str = json.dumps(tc.arguments, ensure_ascii=False) args_str = json.dumps(tc.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tc.name, args_str[:200]) logger.info("Tool call: {}({})", tc.name, args_str[:200])
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id) self._loop._set_tool_context(
self._channel,
self._chat_id,
self._message_id,
self._metadata,
session_key=self._session_key,
)
async def after_iteration(self, context: AgentHookContext) -> None: async def after_iteration(self, context: AgentHookContext) -> None:
if ( if (
@ -387,18 +397,24 @@ class AgentLoop:
finally: finally:
self._mcp_connecting = False self._mcp_connecting = False
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: def _set_tool_context(
self, channel: str, chat_id: str,
message_id: str | None = None, metadata: dict | None = None,
session_key: str | None = None,
) -> None:
"""Update context for all tools that need routing info.""" """Update context for all tools that need routing info."""
# Compute the effective session key (accounts for unified sessions)
# so that subagent results route to the correct pending queue.
effective_key = UNIFIED_SESSION_KEY if self._unified_session else f"{channel}:{chat_id}" effective_key = UNIFIED_SESSION_KEY if self._unified_session else f"{channel}:{chat_id}"
for name in ("message", "spawn", "cron", "my"): for name in ("message", "spawn", "cron", "my"):
if tool := self.tools.get(name): if tool := self.tools.get(name):
if hasattr(tool, "set_context"): if hasattr(tool, "set_context"):
if name == "spawn": if name == "spawn":
tool.set_context(channel, chat_id, effective_key=effective_key) tool.set_context(channel, chat_id, effective_key=effective_key)
elif name == "cron":
tool.set_context(channel, chat_id, metadata=metadata, session_key=session_key)
elif name == "message":
tool.set_context(channel, chat_id, message_id, metadata=metadata)
else: else:
tool.set_context(channel, chat_id, *([message_id] if name == "message" else [])) tool.set_context(channel, chat_id)
@staticmethod @staticmethod
def _strip_think(text: str | None) -> str | None: def _strip_think(text: str | None) -> str | None:
@ -464,6 +480,8 @@ class AgentLoop:
channel: str = "cli", channel: str = "cli",
chat_id: str = "direct", chat_id: str = "direct",
message_id: str | None = None, message_id: str | None = None,
metadata: dict[str, Any] | None = None,
session_key: str | None = None,
pending_queue: asyncio.Queue | None = None, pending_queue: asyncio.Queue | None = None,
) -> tuple[str | None, list[str], list[dict], str, bool]: ) -> tuple[str | None, list[str], list[dict], str, bool]:
"""Run the agent iteration loop. """Run the agent iteration loop.
@ -483,6 +501,8 @@ class AgentLoop:
channel=channel, channel=channel,
chat_id=chat_id, chat_id=chat_id,
message_id=message_id, message_id=message_id,
metadata=metadata,
session_key=session_key,
) )
hook: AgentHook = ( hook: AgentHook = (
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
@ -831,7 +851,10 @@ class AgentLoop:
is_subagent = msg.sender_id == "subagent" is_subagent = msg.sender_id == "subagent"
if is_subagent and self._persist_subagent_followup(session, msg): if is_subagent and self._persist_subagent_followup(session, msg):
self.sessions.save(session) self.sessions.save(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) self._set_tool_context(
channel, chat_id, msg.metadata.get("message_id"),
msg.metadata, session_key=key,
)
history = session.get_history(max_messages=0, include_timestamps=True) history = session.get_history(max_messages=0, include_timestamps=True)
current_role = "assistant" if is_subagent else "user" current_role = "assistant" if is_subagent else "user"
@ -848,6 +871,8 @@ class AgentLoop:
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop( final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
messages, session=session, channel=channel, chat_id=chat_id, messages, session=session, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"), message_id=msg.metadata.get("message_id"),
metadata=msg.metadata,
session_key=key,
pending_queue=pending_queue, pending_queue=pending_queue,
) )
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
@ -896,7 +921,10 @@ class AgentLoop:
session_summary=pending, session_summary=pending,
) )
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) self._set_tool_context(
msg.channel, msg.chat_id, msg.metadata.get("message_id"),
msg.metadata, session_key=key,
)
if message_tool := self.tools.get("message"): if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool): if isinstance(message_tool, MessageTool):
message_tool.start_turn() message_tool.start_turn()
@ -978,6 +1006,8 @@ class AgentLoop:
channel=msg.channel, channel=msg.channel,
chat_id=msg.chat_id, chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"), message_id=msg.metadata.get("message_id"),
metadata=msg.metadata,
session_key=key,
pending_queue=pending_queue, pending_queue=pending_queue,
) )

View File

@ -60,12 +60,19 @@ class CronTool(Tool):
self._default_timezone = default_timezone self._default_timezone = default_timezone
self._channel: ContextVar[str] = ContextVar("cron_channel", default="") self._channel: ContextVar[str] = ContextVar("cron_channel", default="")
self._chat_id: ContextVar[str] = ContextVar("cron_chat_id", default="") self._chat_id: ContextVar[str] = ContextVar("cron_chat_id", default="")
self._metadata: ContextVar[dict] = ContextVar("cron_metadata", default={})
self._session_key: ContextVar[str] = ContextVar("cron_session_key", default="")
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False) self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
def set_context(self, channel: str, chat_id: str) -> None: def set_context(
self, channel: str, chat_id: str,
metadata: dict | None = None, session_key: str | None = None,
) -> None:
"""Set the current session context for delivery.""" """Set the current session context for delivery."""
self._channel.set(channel) self._channel.set(channel)
self._chat_id.set(chat_id) self._chat_id.set(chat_id)
self._metadata.set(metadata or {})
self._session_key.set(session_key or f"{channel}:{chat_id}")
def set_cron_context(self, active: bool): def set_cron_context(self, active: bool):
"""Mark whether the tool is executing inside a cron job callback.""" """Mark whether the tool is executing inside a cron job callback."""
@ -199,6 +206,8 @@ class CronTool(Tool):
channel=channel, channel=channel,
to=chat_id, to=chat_id,
delete_after_run=delete_after, delete_after_run=delete_after,
channel_meta=self._metadata.get(),
session_key=self._session_key.get() or None,
) )
return f"Created job '{job.name}' (id: {job.id})" return f"Created job '{job.name}' (id: {job.id})"

View File

@ -41,17 +41,28 @@ class MessageTool(Tool):
"message_default_message_id", "message_default_message_id",
default=default_message_id, default=default_message_id,
) )
self._default_metadata: ContextVar[dict[str, Any]] = ContextVar(
"message_default_metadata",
default={},
)
self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False) self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False)
self._record_channel_delivery_var: ContextVar[bool] = ContextVar( self._record_channel_delivery_var: ContextVar[bool] = ContextVar(
"message_record_channel_delivery", "message_record_channel_delivery",
default=False, default=False,
) )
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: def set_context(
self,
channel: str,
chat_id: str,
message_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Set the current message context.""" """Set the current message context."""
self._default_channel.set(channel) self._default_channel.set(channel)
self._default_chat_id.set(chat_id) self._default_chat_id.set(chat_id)
self._default_message_id.set(message_id) self._default_message_id.set(message_id)
self._default_metadata.set(metadata or {})
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None: def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
"""Set the callback for sending messages.""" """Set the callback for sending messages."""
@ -118,7 +129,8 @@ class MessageTool(Tool):
# some channels (e.g. Feishu) use it to determine the target # some channels (e.g. Feishu) use it to determine the target
# conversation via their Reply API, which would route the message # conversation via their Reply API, which would route the message
# to the wrong chat entirely. # to the wrong chat entirely.
if channel == default_channel and chat_id == default_chat_id: same_target = channel == default_channel and chat_id == default_chat_id
if same_target:
message_id = message_id or self._default_message_id.get() message_id = message_id or self._default_message_id.get()
else: else:
message_id = None message_id = None
@ -129,9 +141,9 @@ class MessageTool(Tool):
if not self._send_callback: if not self._send_callback:
return "Error: Message sending not configured" return "Error: Message sending not configured"
metadata = { metadata = dict(self._default_metadata.get()) if same_target else {}
"message_id": message_id, if message_id:
} if message_id else {} metadata["message_id"] = message_id
if self._record_channel_delivery_var.get(): if self._record_channel_delivery_var.get():
metadata["_record_channel_delivery"] = True metadata["_record_channel_delivery"] = True

View File

@ -38,6 +38,8 @@ class SlackConfig(Base):
reply_in_thread: bool = True reply_in_thread: bool = True
react_emoji: str = "eyes" react_emoji: str = "eyes"
done_emoji: str = "white_check_mark" done_emoji: str = "white_check_mark"
include_thread_context: bool = True
thread_context_limit: int = 20
allow_from: list[str] = Field(default_factory=list) allow_from: list[str] = Field(default_factory=list)
group_policy: str = "mention" group_policy: str = "mention"
group_allow_from: list[str] = Field(default_factory=list) group_allow_from: list[str] = Field(default_factory=list)
@ -66,6 +68,7 @@ class SlackChannel(BaseChannel):
self._socket_client: SocketModeClient | None = None self._socket_client: SocketModeClient | None = None
self._bot_user_id: str | None = None self._bot_user_id: str | None = None
self._target_cache: dict[str, str] = {} self._target_cache: dict[str, str] = {}
self._thread_context_attempted: set[str] = set()
async def start(self) -> None: async def start(self) -> None:
"""Start the Slack Socket Mode client.""" """Start the Slack Socket Mode client."""
@ -327,9 +330,11 @@ class SlackChannel(BaseChannel):
text = self._strip_bot_mention(text) text = self._strip_bot_mention(text)
thread_ts = event.get("thread_ts") event_ts = event.get("ts")
raw_thread_ts = event.get("thread_ts")
thread_ts = raw_thread_ts
if self.config.reply_in_thread and not thread_ts: if self.config.reply_in_thread and not thread_ts:
thread_ts = event.get("ts") thread_ts = event_ts
# Add :eyes: reaction to the triggering message (best-effort) # Add :eyes: reaction to the triggering message (best-effort)
try: try:
if self._web_client and event.get("ts"): if self._web_client and event.get("ts"):
@ -343,12 +348,20 @@ class SlackChannel(BaseChannel):
# Thread-scoped session key for channel/group messages # Thread-scoped session key for channel/group messages
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None
content = await self._with_thread_context(
text,
chat_id=chat_id,
channel_type=channel_type,
thread_ts=thread_ts,
raw_thread_ts=raw_thread_ts,
current_ts=event_ts,
)
try: try:
await self._handle_message( await self._handle_message(
sender_id=sender_id, sender_id=sender_id,
chat_id=chat_id, chat_id=chat_id,
content=text, content=content,
metadata={ metadata={
"slack": { "slack": {
"event": event, "event": event,
@ -361,6 +374,66 @@ class SlackChannel(BaseChannel):
except Exception: except Exception:
logger.exception("Error handling Slack message from {}", sender_id) logger.exception("Error handling Slack message from {}", sender_id)
async def _with_thread_context(
self,
text: str,
*,
chat_id: str,
channel_type: str,
thread_ts: str | None,
raw_thread_ts: str | None,
current_ts: str | None,
) -> str:
"""Include thread history the first time the bot is pulled into a Slack thread."""
if (
not self.config.include_thread_context
or not self._web_client
or channel_type == "im"
or not raw_thread_ts
or not thread_ts
or current_ts == thread_ts
):
return text
key = f"{chat_id}:{thread_ts}"
if key in self._thread_context_attempted:
return text
self._thread_context_attempted.add(key)
try:
response = await self._web_client.conversations_replies(
channel=chat_id,
ts=thread_ts,
limit=max(1, self.config.thread_context_limit),
)
except Exception as e:
logger.warning("Slack thread context unavailable for {}: {}", key, e)
return text
lines = self._format_thread_context(
response.get("messages", []),
current_ts=current_ts,
)
if not lines:
return text
return "Slack thread context before this mention:\n" + "\n".join(lines) + f"\n\nCurrent message:\n{text}"
def _format_thread_context(self, messages: list[dict[str, Any]], *, current_ts: str | None) -> list[str]:
lines: list[str] = []
for item in messages:
if item.get("ts") == current_ts:
continue
if item.get("subtype"):
continue
sender = str(item.get("user") or item.get("bot_id") or "unknown")
if self._bot_user_id and sender == self._bot_user_id:
continue
text = str(item.get("text") or "").strip()
if not text:
continue
lines.append(f"- <@{sender}>: {self._strip_bot_mention(text)}")
return lines
async def _update_react_emoji(self, chat_id: str, ts: str | None) -> None: async def _update_react_emoji(self, chat_id: str, ts: str | None) -> None:
"""Remove the in-progress reaction and optionally add a done reaction.""" """Remove the in-progress reaction and optionally add a done reaction."""
if not self._web_client or not ts: if not self._web_client or not ts:

View File

@ -666,7 +666,9 @@ def _run_gateway(
else f"{channel}:{chat_id}" else f"{channel}:{chat_id}"
) )
async def _deliver_to_channel(msg: OutboundMessage, *, record: bool = False) -> None: async def _deliver_to_channel(
msg: OutboundMessage, *, record: bool = False, session_key: str | None = None,
) -> None:
"""Publish a user-visible message and mirror it into that channel's session.""" """Publish a user-visible message and mirror it into that channel's session."""
metadata = dict(msg.metadata or {}) metadata = dict(msg.metadata or {})
record = record or bool(metadata.pop("_record_channel_delivery", False)) record = record or bool(metadata.pop("_record_channel_delivery", False))
@ -687,7 +689,8 @@ def _run_gateway(
and hasattr(session_manager, "get_or_create") and hasattr(session_manager, "get_or_create")
and hasattr(session_manager, "save") and hasattr(session_manager, "save")
): ):
session = session_manager.get_or_create(_channel_session_key(msg.channel, msg.chat_id)) key = session_key or _channel_session_key(msg.channel, msg.chat_id)
session = session_manager.get_or_create(key)
session.add_message("assistant", msg.content, _channel_delivery=True) session.add_message("assistant", msg.content, _channel_delivery=True)
session_manager.save(session) session_manager.save(session)
await bus.publish_outbound(msg) await bus.publish_outbound(msg)
@ -757,8 +760,10 @@ def _run_gateway(
channel=job.payload.channel or "cli", channel=job.payload.channel or "cli",
chat_id=job.payload.to, chat_id=job.payload.to,
content=response, content=response,
metadata=dict(job.payload.channel_meta),
), ),
record=True, record=True,
session_key=job.payload.session_key,
) )
return response return response

View File

@ -379,6 +379,8 @@ class CronService:
channel: str | None = None, channel: str | None = None,
to: str | None = None, to: str | None = None,
delete_after_run: bool = False, delete_after_run: bool = False,
channel_meta: dict | None = None,
session_key: str | None = None,
) -> CronJob: ) -> CronJob:
"""Add a new job.""" """Add a new job."""
_validate_schedule_for_add(schedule) _validate_schedule_for_add(schedule)
@ -395,6 +397,8 @@ class CronService:
deliver=deliver, deliver=deliver,
channel=channel, channel=channel,
to=to, to=to,
channel_meta=channel_meta or {},
session_key=session_key,
), ),
state=CronJobState(next_run_at_ms=_compute_next_run(schedule, now)), state=CronJobState(next_run_at_ms=_compute_next_run(schedule, now)),
created_at_ms=now, created_at_ms=now,

View File

@ -27,6 +27,8 @@ class CronPayload:
deliver: bool = False deliver: bool = False
channel: str | None = None # e.g. "whatsapp" channel: str | None = None # e.g. "whatsapp"
to: str | None = None # e.g. phone number to: str | None = None # e.g. phone number
channel_meta: dict = field(default_factory=dict) # channel-specific routing (e.g. Slack thread_ts)
session_key: str | None = None # original session key for correct session recording
@dataclass @dataclass

View File

@ -0,0 +1,90 @@
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
class _ContextRecordingTool:
name = "cron"
concurrency_safe = False
def __init__(self) -> None:
self.contexts: list[dict] = []
def set_context(
self,
channel: str,
chat_id: str,
metadata: dict | None = None,
session_key: str | None = None,
) -> None:
self.contexts.append({
"channel": channel,
"chat_id": chat_id,
"metadata": metadata,
"session_key": session_key,
})
async def execute(self, **_kwargs) -> str:
return "created"
class _Tools:
def __init__(self, tool: _ContextRecordingTool) -> None:
self.tool = tool
def get(self, name: str):
return self.tool if name == "cron" else None
def get_definitions(self) -> list:
return []
def prepare_call(self, name: str, arguments: dict):
return (self.tool, arguments, None) if name == "cron" else (None, arguments, None)
@pytest.mark.asyncio
async def test_loop_hook_preserves_metadata_when_resetting_tool_context(tmp_path: Path) -> None:
provider = MagicMock()
calls = {"n": 0}
async def chat_with_retry(**_kwargs):
calls["n"] += 1
if calls["n"] == 1:
return LLMResponse(
content=None,
tool_calls=[ToolCallRequest(id="call_1", name="cron", arguments={"action": "add"})],
)
return LLMResponse(content="done", tool_calls=[])
provider.chat_with_retry = chat_with_retry
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=MessageBus(),
provider=provider,
workspace=tmp_path,
model="test-model",
)
cron = _ContextRecordingTool()
loop.tools = _Tools(cron)
metadata = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
await loop._run_agent_loop(
[],
channel="slack",
chat_id="C123",
metadata=metadata,
session_key="slack:C123:111.222",
)
assert cron.contexts[-1] == {
"channel": "slack",
"chat_id": "C123",
"metadata": metadata,
"session_key": "slack:C123:111.222",
}

View File

@ -20,9 +20,11 @@ class _FakeAsyncWebClient:
self.reactions_add_calls: list[dict[str, object | None]] = [] self.reactions_add_calls: list[dict[str, object | None]] = []
self.reactions_remove_calls: list[dict[str, object | None]] = [] self.reactions_remove_calls: list[dict[str, object | None]] = []
self.conversations_list_calls: list[dict[str, object | None]] = [] self.conversations_list_calls: list[dict[str, object | None]] = []
self.conversations_replies_calls: list[dict[str, object | None]] = []
self.users_list_calls: list[dict[str, object | None]] = [] self.users_list_calls: list[dict[str, object | None]] = []
self.conversations_open_calls: list[dict[str, object | None]] = [] self.conversations_open_calls: list[dict[str, object | None]] = []
self._conversations_pages: list[dict[str, object]] = [] self._conversations_pages: list[dict[str, object]] = []
self._conversations_replies_response: dict[str, object] = {"messages": []}
self._users_pages: list[dict[str, object]] = [] self._users_pages: list[dict[str, object]] = []
self._open_dm_response: dict[str, object] = {"channel": {"id": "D_OPENED"}} self._open_dm_response: dict[str, object] = {"channel": {"id": "D_OPENED"}}
@ -92,6 +94,10 @@ class _FakeAsyncWebClient:
return self._conversations_pages.pop(0) return self._conversations_pages.pop(0)
return {"channels": [], "response_metadata": {"next_cursor": ""}} return {"channels": [], "response_metadata": {"next_cursor": ""}}
async def conversations_replies(self, **kwargs):
self.conversations_replies_calls.append(kwargs)
return self._conversations_replies_response
async def users_list(self, **kwargs): async def users_list(self, **kwargs):
self.users_list_calls.append(kwargs) self.users_list_calls.append(kwargs)
if self._users_pages: if self._users_pages:
@ -316,3 +322,47 @@ async def test_send_raises_when_named_target_cannot_be_resolved() -> None:
content="hello", content="hello",
) )
) )
@pytest.mark.asyncio
async def test_with_thread_context_fetches_root_once() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
channel._bot_user_id = "UBOT"
fake_web = _FakeAsyncWebClient()
fake_web._conversations_replies_response = {
"messages": [
{"ts": "111.000", "user": "UROOT", "text": "drink water"},
{"ts": "112.000", "user": "U2", "text": "good idea"},
{"ts": "113.000", "user": "U3", "text": "<@UBOT> what did you see?"},
]
}
channel._web_client = fake_web
content = await channel._with_thread_context(
"what did you see?",
chat_id="C123",
channel_type="channel",
thread_ts="111.000",
raw_thread_ts="111.000",
current_ts="113.000",
)
assert fake_web.conversations_replies_calls == [
{"channel": "C123", "ts": "111.000", "limit": 20}
]
assert "Slack thread context before this mention:" in content
assert "- <@UROOT>: drink water" in content
assert "- <@U2>: good idea" in content
assert "U3" not in content
assert content.endswith("Current message:\nwhat did you see?")
second = await channel._with_thread_context(
"again",
chat_id="C123",
channel_type="channel",
thread_ts="111.000",
raw_thread_ts="111.000",
current_ts="114.000",
)
assert second == "again"
assert len(fake_web.conversations_replies_calls) == 1

View File

@ -43,6 +43,28 @@ def test_add_job_accepts_valid_timezone(tmp_path) -> None:
assert job.state.next_run_at_ms is not None assert job.state.next_run_at_ms is not None
def test_add_job_preserves_channel_meta_and_session_key(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
meta = {"slack": {"thread_ts": "1234567890.123456", "channel_type": "channel"}}
job = service.add_job(
name="thread test",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
deliver=True,
channel="slack",
to="C123",
channel_meta=meta,
session_key="slack:C123:1234567890.123456",
)
assert job.payload.channel_meta == meta
assert job.payload.session_key == "slack:C123:1234567890.123456"
reloaded = service.get_job(job.id)
assert reloaded is not None
assert reloaded.payload.channel_meta == meta
assert reloaded.payload.session_key == "slack:C123:1234567890.123456"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_execute_job_records_run_history(tmp_path) -> None: async def test_execute_job_records_run_history(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json" store_path = tmp_path / "cron" / "jobs.json"

View File

@ -382,6 +382,21 @@ def test_add_job_empty_message_returns_actionable_error(tmp_path) -> None:
assert "Retry including message=" in result assert "Retry including message=" in result
def test_add_job_captures_metadata_and_session_key(tmp_path) -> None:
"""CronTool stores channel metadata and session_key when adding a job."""
tool = _make_tool(tmp_path)
meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
tool.set_context("slack", "C99", metadata=meta, session_key="slack:C99:111.222")
result = tool._add_job("test", "say hi", 60, None, None, None)
assert "Created job" in result
jobs = tool._cron.list_jobs()
assert len(jobs) == 1
assert jobs[0].payload.channel_meta == meta
assert jobs[0].payload.session_key == "slack:C99:111.222"
def test_list_excludes_disabled_jobs(tmp_path) -> None: def test_list_excludes_disabled_jobs(tmp_path) -> None:
tool = _make_tool(tmp_path) tool = _make_tool(tmp_path)
job = tool._cron.add_job( job = tool._cron.add_job(

View File

@ -50,3 +50,38 @@ async def test_message_tool_marks_channel_delivery_only_when_enabled() -> None:
assert sent[0].metadata == {} assert sent[0].metadata == {}
assert sent[1].metadata == {"_record_channel_delivery": True} assert sent[1].metadata == {"_record_channel_delivery": True}
@pytest.mark.asyncio
async def test_message_tool_inherits_metadata_for_same_target() -> None:
sent: list[OutboundMessage] = []
async def _send(msg: OutboundMessage) -> None:
sent.append(msg)
tool = MessageTool(send_callback=_send)
slack_meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
tool.set_context("slack", "C123", metadata=slack_meta)
await tool.execute(content="thread reply")
assert sent[0].metadata == slack_meta
@pytest.mark.asyncio
async def test_message_tool_does_not_inherit_metadata_for_cross_target() -> None:
sent: list[OutboundMessage] = []
async def _send(msg: OutboundMessage) -> None:
sent.append(msg)
tool = MessageTool(send_callback=_send)
tool.set_context(
"slack",
"C123",
metadata={"slack": {"thread_ts": "111.222", "channel_type": "channel"}},
)
await tool.execute(content="channel reply", channel="slack", chat_id="C999")
assert sent[0].metadata == {}