mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-03 16:25:53 +00:00
Merge branch 'main' into fix/session-history-timestamps
Made-with: Cursor
This commit is contained in:
commit
4a4ba1efc1
@ -87,6 +87,11 @@ ruff check nanobot/
|
||||
ruff format nanobot/
|
||||
```
|
||||
|
||||
## Contribution License
|
||||
|
||||
By submitting a contribution, you confirm that you have the right to submit it
|
||||
and agree that it will be licensed under the project's MIT License.
|
||||
|
||||
## Code Style
|
||||
|
||||
We care about more than passing lint. We want nanobot to stay small, calm, and readable.
|
||||
|
||||
2
LICENSE
2
LICENSE
@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 nanobot contributors
|
||||
Copyright (c) 2025-present Xubin Ren and the nanobot contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
@ -282,6 +282,10 @@ PRs welcome! The codebase is intentionally small and readable. 🤗
|
||||
- **More integrations** — Calendar and more
|
||||
- **Self-improvement** — Learn from feedback and mistakes
|
||||
|
||||
## Contact
|
||||
|
||||
This project was started by [Xubin Ren](https://github.com/re-bin) as a personal open-source project and continues to be maintained in an individual capacity using personal resources, with contributions from the open-source community. Feel free to contact [xubinrencs@gmail.com](mailto:xubinrencs@gmail.com) for questions, ideas, or collaboration.
|
||||
|
||||
### Contributors
|
||||
|
||||
<a href="https://github.com/HKUDS/nanobot/graphs/contributors">
|
||||
|
||||
@ -434,7 +434,7 @@ Uses **Socket Mode** — no public URL required.
|
||||
|
||||
**2. Configure the app**
|
||||
- **Socket Mode**: Toggle ON → Generate an **App-Level Token** with `connections:write` scope → copy it (`xapp-...`)
|
||||
- **OAuth & Permissions**: Add bot scopes: `chat:write`, `reactions:write`, `app_mentions:read`
|
||||
- **OAuth & Permissions**: Add bot scopes: `chat:write`, `reactions:write`, `app_mentions:read`, `channels:history`, `groups:history`, `im:history`, `mpim:history`
|
||||
- **Event Subscriptions**: Toggle ON → Subscribe to bot events: `message.im`, `message.channels`, `app_mention` → Save Changes
|
||||
- **App Home**: Scroll to **Show Tabs** → Enable **Messages Tab** → Check **"Allow users to send Slash commands and messages from the messages tab"**
|
||||
- **Install App**: Click **Install to Workspace** → Authorize → copy the **Bot Token** (`xoxb-...`)
|
||||
|
||||
@ -76,6 +76,8 @@ class _LoopHook(AgentHook):
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(reraise=True)
|
||||
self._loop = agent_loop
|
||||
@ -85,6 +87,8 @@ class _LoopHook(AgentHook):
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
self._message_id = message_id
|
||||
self._metadata = metadata or {}
|
||||
self._session_key = session_key
|
||||
self._stream_buf = ""
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
@ -127,7 +131,13 @@ class _LoopHook(AgentHook):
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
|
||||
self._loop._set_tool_context(
|
||||
self._channel,
|
||||
self._chat_id,
|
||||
self._message_id,
|
||||
self._metadata,
|
||||
session_key=self._session_key,
|
||||
)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
if (
|
||||
@ -387,18 +397,24 @@ class AgentLoop:
|
||||
finally:
|
||||
self._mcp_connecting = False
|
||||
|
||||
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
def _set_tool_context(
|
||||
self, channel: str, chat_id: str,
|
||||
message_id: str | None = None, metadata: dict | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
"""Update context for all tools that need routing info."""
|
||||
# Compute the effective session key (accounts for unified sessions)
|
||||
# so that subagent results route to the correct pending queue.
|
||||
effective_key = UNIFIED_SESSION_KEY if self._unified_session else f"{channel}:{chat_id}"
|
||||
for name in ("message", "spawn", "cron", "my"):
|
||||
if tool := self.tools.get(name):
|
||||
if hasattr(tool, "set_context"):
|
||||
if name == "spawn":
|
||||
tool.set_context(channel, chat_id, effective_key=effective_key)
|
||||
elif name == "cron":
|
||||
tool.set_context(channel, chat_id, metadata=metadata, session_key=session_key)
|
||||
elif name == "message":
|
||||
tool.set_context(channel, chat_id, message_id, metadata=metadata)
|
||||
else:
|
||||
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
|
||||
tool.set_context(channel, chat_id)
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
@ -464,6 +480,8 @@ class AgentLoop:
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict], str, bool]:
|
||||
"""Run the agent iteration loop.
|
||||
@ -483,6 +501,8 @@ class AgentLoop:
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
metadata=metadata,
|
||||
session_key=session_key,
|
||||
)
|
||||
hook: AgentHook = (
|
||||
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
|
||||
@ -831,7 +851,10 @@ class AgentLoop:
|
||||
is_subagent = msg.sender_id == "subagent"
|
||||
if is_subagent and self._persist_subagent_followup(session, msg):
|
||||
self.sessions.save(session)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
self._set_tool_context(
|
||||
channel, chat_id, msg.metadata.get("message_id"),
|
||||
msg.metadata, session_key=key,
|
||||
)
|
||||
history = session.get_history(max_messages=0, include_timestamps=True)
|
||||
current_role = "assistant" if is_subagent else "user"
|
||||
|
||||
@ -848,6 +871,8 @@ class AgentLoop:
|
||||
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
metadata=msg.metadata,
|
||||
session_key=key,
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
@ -896,7 +921,10 @@ class AgentLoop:
|
||||
session_summary=pending,
|
||||
)
|
||||
|
||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||
self._set_tool_context(
|
||||
msg.channel, msg.chat_id, msg.metadata.get("message_id"),
|
||||
msg.metadata, session_key=key,
|
||||
)
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.start_turn()
|
||||
@ -978,6 +1006,8 @@ class AgentLoop:
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
metadata=msg.metadata,
|
||||
session_key=key,
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
|
||||
|
||||
@ -60,12 +60,19 @@ class CronTool(Tool):
|
||||
self._default_timezone = default_timezone
|
||||
self._channel: ContextVar[str] = ContextVar("cron_channel", default="")
|
||||
self._chat_id: ContextVar[str] = ContextVar("cron_chat_id", default="")
|
||||
self._metadata: ContextVar[dict] = ContextVar("cron_metadata", default={})
|
||||
self._session_key: ContextVar[str] = ContextVar("cron_session_key", default="")
|
||||
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
|
||||
|
||||
def set_context(self, channel: str, chat_id: str) -> None:
|
||||
def set_context(
|
||||
self, channel: str, chat_id: str,
|
||||
metadata: dict | None = None, session_key: str | None = None,
|
||||
) -> None:
|
||||
"""Set the current session context for delivery."""
|
||||
self._channel.set(channel)
|
||||
self._chat_id.set(chat_id)
|
||||
self._metadata.set(metadata or {})
|
||||
self._session_key.set(session_key or f"{channel}:{chat_id}")
|
||||
|
||||
def set_cron_context(self, active: bool):
|
||||
"""Mark whether the tool is executing inside a cron job callback."""
|
||||
@ -199,6 +206,8 @@ class CronTool(Tool):
|
||||
channel=channel,
|
||||
to=chat_id,
|
||||
delete_after_run=delete_after,
|
||||
channel_meta=self._metadata.get(),
|
||||
session_key=self._session_key.get() or None,
|
||||
)
|
||||
return f"Created job '{job.name}' (id: {job.id})"
|
||||
|
||||
|
||||
@ -41,17 +41,28 @@ class MessageTool(Tool):
|
||||
"message_default_message_id",
|
||||
default=default_message_id,
|
||||
)
|
||||
self._default_metadata: ContextVar[dict[str, Any]] = ContextVar(
|
||||
"message_default_metadata",
|
||||
default={},
|
||||
)
|
||||
self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False)
|
||||
self._record_channel_delivery_var: ContextVar[bool] = ContextVar(
|
||||
"message_record_channel_delivery",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
def set_context(
|
||||
self,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
message_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Set the current message context."""
|
||||
self._default_channel.set(channel)
|
||||
self._default_chat_id.set(chat_id)
|
||||
self._default_message_id.set(message_id)
|
||||
self._default_metadata.set(metadata or {})
|
||||
|
||||
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
|
||||
"""Set the callback for sending messages."""
|
||||
@ -118,7 +129,8 @@ class MessageTool(Tool):
|
||||
# some channels (e.g. Feishu) use it to determine the target
|
||||
# conversation via their Reply API, which would route the message
|
||||
# to the wrong chat entirely.
|
||||
if channel == default_channel and chat_id == default_chat_id:
|
||||
same_target = channel == default_channel and chat_id == default_chat_id
|
||||
if same_target:
|
||||
message_id = message_id or self._default_message_id.get()
|
||||
else:
|
||||
message_id = None
|
||||
@ -129,9 +141,9 @@ class MessageTool(Tool):
|
||||
if not self._send_callback:
|
||||
return "Error: Message sending not configured"
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
} if message_id else {}
|
||||
metadata = dict(self._default_metadata.get()) if same_target else {}
|
||||
if message_id:
|
||||
metadata["message_id"] = message_id
|
||||
if self._record_channel_delivery_var.get():
|
||||
metadata["_record_channel_delivery"] = True
|
||||
|
||||
|
||||
@ -38,6 +38,8 @@ class SlackConfig(Base):
|
||||
reply_in_thread: bool = True
|
||||
react_emoji: str = "eyes"
|
||||
done_emoji: str = "white_check_mark"
|
||||
include_thread_context: bool = True
|
||||
thread_context_limit: int = 20
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
group_policy: str = "mention"
|
||||
group_allow_from: list[str] = Field(default_factory=list)
|
||||
@ -66,6 +68,7 @@ class SlackChannel(BaseChannel):
|
||||
self._socket_client: SocketModeClient | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
self._target_cache: dict[str, str] = {}
|
||||
self._thread_context_attempted: set[str] = set()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Slack Socket Mode client."""
|
||||
@ -327,9 +330,11 @@ class SlackChannel(BaseChannel):
|
||||
|
||||
text = self._strip_bot_mention(text)
|
||||
|
||||
thread_ts = event.get("thread_ts")
|
||||
event_ts = event.get("ts")
|
||||
raw_thread_ts = event.get("thread_ts")
|
||||
thread_ts = raw_thread_ts
|
||||
if self.config.reply_in_thread and not thread_ts:
|
||||
thread_ts = event.get("ts")
|
||||
thread_ts = event_ts
|
||||
# Add :eyes: reaction to the triggering message (best-effort)
|
||||
try:
|
||||
if self._web_client and event.get("ts"):
|
||||
@ -343,12 +348,20 @@ class SlackChannel(BaseChannel):
|
||||
|
||||
# Thread-scoped session key for channel/group messages
|
||||
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None
|
||||
content = await self._with_thread_context(
|
||||
text,
|
||||
chat_id=chat_id,
|
||||
channel_type=channel_type,
|
||||
thread_ts=thread_ts,
|
||||
raw_thread_ts=raw_thread_ts,
|
||||
current_ts=event_ts,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=text,
|
||||
content=content,
|
||||
metadata={
|
||||
"slack": {
|
||||
"event": event,
|
||||
@ -361,6 +374,66 @@ class SlackChannel(BaseChannel):
|
||||
except Exception:
|
||||
logger.exception("Error handling Slack message from {}", sender_id)
|
||||
|
||||
async def _with_thread_context(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
chat_id: str,
|
||||
channel_type: str,
|
||||
thread_ts: str | None,
|
||||
raw_thread_ts: str | None,
|
||||
current_ts: str | None,
|
||||
) -> str:
|
||||
"""Include thread history the first time the bot is pulled into a Slack thread."""
|
||||
if (
|
||||
not self.config.include_thread_context
|
||||
or not self._web_client
|
||||
or channel_type == "im"
|
||||
or not raw_thread_ts
|
||||
or not thread_ts
|
||||
or current_ts == thread_ts
|
||||
):
|
||||
return text
|
||||
|
||||
key = f"{chat_id}:{thread_ts}"
|
||||
if key in self._thread_context_attempted:
|
||||
return text
|
||||
self._thread_context_attempted.add(key)
|
||||
|
||||
try:
|
||||
response = await self._web_client.conversations_replies(
|
||||
channel=chat_id,
|
||||
ts=thread_ts,
|
||||
limit=max(1, self.config.thread_context_limit),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Slack thread context unavailable for {}: {}", key, e)
|
||||
return text
|
||||
|
||||
lines = self._format_thread_context(
|
||||
response.get("messages", []),
|
||||
current_ts=current_ts,
|
||||
)
|
||||
if not lines:
|
||||
return text
|
||||
return "Slack thread context before this mention:\n" + "\n".join(lines) + f"\n\nCurrent message:\n{text}"
|
||||
|
||||
def _format_thread_context(self, messages: list[dict[str, Any]], *, current_ts: str | None) -> list[str]:
|
||||
lines: list[str] = []
|
||||
for item in messages:
|
||||
if item.get("ts") == current_ts:
|
||||
continue
|
||||
if item.get("subtype"):
|
||||
continue
|
||||
sender = str(item.get("user") or item.get("bot_id") or "unknown")
|
||||
if self._bot_user_id and sender == self._bot_user_id:
|
||||
continue
|
||||
text = str(item.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
lines.append(f"- <@{sender}>: {self._strip_bot_mention(text)}")
|
||||
return lines
|
||||
|
||||
async def _update_react_emoji(self, chat_id: str, ts: str | None) -> None:
|
||||
"""Remove the in-progress reaction and optionally add a done reaction."""
|
||||
if not self._web_client or not ts:
|
||||
|
||||
@ -666,7 +666,9 @@ def _run_gateway(
|
||||
else f"{channel}:{chat_id}"
|
||||
)
|
||||
|
||||
async def _deliver_to_channel(msg: OutboundMessage, *, record: bool = False) -> None:
|
||||
async def _deliver_to_channel(
|
||||
msg: OutboundMessage, *, record: bool = False, session_key: str | None = None,
|
||||
) -> None:
|
||||
"""Publish a user-visible message and mirror it into that channel's session."""
|
||||
metadata = dict(msg.metadata or {})
|
||||
record = record or bool(metadata.pop("_record_channel_delivery", False))
|
||||
@ -687,7 +689,8 @@ def _run_gateway(
|
||||
and hasattr(session_manager, "get_or_create")
|
||||
and hasattr(session_manager, "save")
|
||||
):
|
||||
session = session_manager.get_or_create(_channel_session_key(msg.channel, msg.chat_id))
|
||||
key = session_key or _channel_session_key(msg.channel, msg.chat_id)
|
||||
session = session_manager.get_or_create(key)
|
||||
session.add_message("assistant", msg.content, _channel_delivery=True)
|
||||
session_manager.save(session)
|
||||
await bus.publish_outbound(msg)
|
||||
@ -757,8 +760,10 @@ def _run_gateway(
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to,
|
||||
content=response,
|
||||
metadata=dict(job.payload.channel_meta),
|
||||
),
|
||||
record=True,
|
||||
session_key=job.payload.session_key,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@ -379,6 +379,8 @@ class CronService:
|
||||
channel: str | None = None,
|
||||
to: str | None = None,
|
||||
delete_after_run: bool = False,
|
||||
channel_meta: dict | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> CronJob:
|
||||
"""Add a new job."""
|
||||
_validate_schedule_for_add(schedule)
|
||||
@ -395,6 +397,8 @@ class CronService:
|
||||
deliver=deliver,
|
||||
channel=channel,
|
||||
to=to,
|
||||
channel_meta=channel_meta or {},
|
||||
session_key=session_key,
|
||||
),
|
||||
state=CronJobState(next_run_at_ms=_compute_next_run(schedule, now)),
|
||||
created_at_ms=now,
|
||||
|
||||
@ -27,6 +27,8 @@ class CronPayload:
|
||||
deliver: bool = False
|
||||
channel: str | None = None # e.g. "whatsapp"
|
||||
to: str | None = None # e.g. phone number
|
||||
channel_meta: dict = field(default_factory=dict) # channel-specific routing (e.g. Slack thread_ts)
|
||||
session_key: str | None = None # original session key for correct session recording
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
90
tests/agent/test_loop_tool_context.py
Normal file
90
tests/agent/test_loop_tool_context.py
Normal file
@ -0,0 +1,90 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
class _ContextRecordingTool:
|
||||
name = "cron"
|
||||
concurrency_safe = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.contexts: list[dict] = []
|
||||
|
||||
def set_context(
|
||||
self,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
metadata: dict | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
self.contexts.append({
|
||||
"channel": channel,
|
||||
"chat_id": chat_id,
|
||||
"metadata": metadata,
|
||||
"session_key": session_key,
|
||||
})
|
||||
|
||||
async def execute(self, **_kwargs) -> str:
|
||||
return "created"
|
||||
|
||||
|
||||
class _Tools:
|
||||
def __init__(self, tool: _ContextRecordingTool) -> None:
|
||||
self.tool = tool
|
||||
|
||||
def get(self, name: str):
|
||||
return self.tool if name == "cron" else None
|
||||
|
||||
def get_definitions(self) -> list:
|
||||
return []
|
||||
|
||||
def prepare_call(self, name: str, arguments: dict):
|
||||
return (self.tool, arguments, None) if name == "cron" else (None, arguments, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_hook_preserves_metadata_when_resetting_tool_context(tmp_path: Path) -> None:
|
||||
provider = MagicMock()
|
||||
calls = {"n": 0}
|
||||
|
||||
async def chat_with_retry(**_kwargs):
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="cron", arguments={"action": "add"})],
|
||||
)
|
||||
return LLMResponse(content="done", tool_calls=[])
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
cron = _ContextRecordingTool()
|
||||
loop.tools = _Tools(cron)
|
||||
|
||||
metadata = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
|
||||
await loop._run_agent_loop(
|
||||
[],
|
||||
channel="slack",
|
||||
chat_id="C123",
|
||||
metadata=metadata,
|
||||
session_key="slack:C123:111.222",
|
||||
)
|
||||
|
||||
assert cron.contexts[-1] == {
|
||||
"channel": "slack",
|
||||
"chat_id": "C123",
|
||||
"metadata": metadata,
|
||||
"session_key": "slack:C123:111.222",
|
||||
}
|
||||
@ -20,9 +20,11 @@ class _FakeAsyncWebClient:
|
||||
self.reactions_add_calls: list[dict[str, object | None]] = []
|
||||
self.reactions_remove_calls: list[dict[str, object | None]] = []
|
||||
self.conversations_list_calls: list[dict[str, object | None]] = []
|
||||
self.conversations_replies_calls: list[dict[str, object | None]] = []
|
||||
self.users_list_calls: list[dict[str, object | None]] = []
|
||||
self.conversations_open_calls: list[dict[str, object | None]] = []
|
||||
self._conversations_pages: list[dict[str, object]] = []
|
||||
self._conversations_replies_response: dict[str, object] = {"messages": []}
|
||||
self._users_pages: list[dict[str, object]] = []
|
||||
self._open_dm_response: dict[str, object] = {"channel": {"id": "D_OPENED"}}
|
||||
|
||||
@ -92,6 +94,10 @@ class _FakeAsyncWebClient:
|
||||
return self._conversations_pages.pop(0)
|
||||
return {"channels": [], "response_metadata": {"next_cursor": ""}}
|
||||
|
||||
async def conversations_replies(self, **kwargs):
|
||||
self.conversations_replies_calls.append(kwargs)
|
||||
return self._conversations_replies_response
|
||||
|
||||
async def users_list(self, **kwargs):
|
||||
self.users_list_calls.append(kwargs)
|
||||
if self._users_pages:
|
||||
@ -316,3 +322,47 @@ async def test_send_raises_when_named_target_cannot_be_resolved() -> None:
|
||||
content="hello",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_thread_context_fetches_root_once() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
channel._bot_user_id = "UBOT"
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
fake_web._conversations_replies_response = {
|
||||
"messages": [
|
||||
{"ts": "111.000", "user": "UROOT", "text": "drink water"},
|
||||
{"ts": "112.000", "user": "U2", "text": "good idea"},
|
||||
{"ts": "113.000", "user": "U3", "text": "<@UBOT> what did you see?"},
|
||||
]
|
||||
}
|
||||
channel._web_client = fake_web
|
||||
|
||||
content = await channel._with_thread_context(
|
||||
"what did you see?",
|
||||
chat_id="C123",
|
||||
channel_type="channel",
|
||||
thread_ts="111.000",
|
||||
raw_thread_ts="111.000",
|
||||
current_ts="113.000",
|
||||
)
|
||||
|
||||
assert fake_web.conversations_replies_calls == [
|
||||
{"channel": "C123", "ts": "111.000", "limit": 20}
|
||||
]
|
||||
assert "Slack thread context before this mention:" in content
|
||||
assert "- <@UROOT>: drink water" in content
|
||||
assert "- <@U2>: good idea" in content
|
||||
assert "U3" not in content
|
||||
assert content.endswith("Current message:\nwhat did you see?")
|
||||
|
||||
second = await channel._with_thread_context(
|
||||
"again",
|
||||
chat_id="C123",
|
||||
channel_type="channel",
|
||||
thread_ts="111.000",
|
||||
raw_thread_ts="111.000",
|
||||
current_ts="114.000",
|
||||
)
|
||||
assert second == "again"
|
||||
assert len(fake_web.conversations_replies_calls) == 1
|
||||
|
||||
@ -43,6 +43,28 @@ def test_add_job_accepts_valid_timezone(tmp_path) -> None:
|
||||
assert job.state.next_run_at_ms is not None
|
||||
|
||||
|
||||
def test_add_job_preserves_channel_meta_and_session_key(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
meta = {"slack": {"thread_ts": "1234567890.123456", "channel_type": "channel"}}
|
||||
job = service.add_job(
|
||||
name="thread test",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
deliver=True,
|
||||
channel="slack",
|
||||
to="C123",
|
||||
channel_meta=meta,
|
||||
session_key="slack:C123:1234567890.123456",
|
||||
)
|
||||
assert job.payload.channel_meta == meta
|
||||
assert job.payload.session_key == "slack:C123:1234567890.123456"
|
||||
|
||||
reloaded = service.get_job(job.id)
|
||||
assert reloaded is not None
|
||||
assert reloaded.payload.channel_meta == meta
|
||||
assert reloaded.payload.session_key == "slack:C123:1234567890.123456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_job_records_run_history(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
|
||||
@ -382,6 +382,21 @@ def test_add_job_empty_message_returns_actionable_error(tmp_path) -> None:
|
||||
assert "Retry including message=" in result
|
||||
|
||||
|
||||
def test_add_job_captures_metadata_and_session_key(tmp_path) -> None:
|
||||
"""CronTool stores channel metadata and session_key when adding a job."""
|
||||
tool = _make_tool(tmp_path)
|
||||
meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
|
||||
tool.set_context("slack", "C99", metadata=meta, session_key="slack:C99:111.222")
|
||||
|
||||
result = tool._add_job("test", "say hi", 60, None, None, None)
|
||||
assert "Created job" in result
|
||||
|
||||
jobs = tool._cron.list_jobs()
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].payload.channel_meta == meta
|
||||
assert jobs[0].payload.session_key == "slack:C99:111.222"
|
||||
|
||||
|
||||
def test_list_excludes_disabled_jobs(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
job = tool._cron.add_job(
|
||||
|
||||
@ -50,3 +50,38 @@ async def test_message_tool_marks_channel_delivery_only_when_enabled() -> None:
|
||||
|
||||
assert sent[0].metadata == {}
|
||||
assert sent[1].metadata == {"_record_channel_delivery": True}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_inherits_metadata_for_same_target() -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
async def _send(msg: OutboundMessage) -> None:
|
||||
sent.append(msg)
|
||||
|
||||
tool = MessageTool(send_callback=_send)
|
||||
slack_meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
|
||||
tool.set_context("slack", "C123", metadata=slack_meta)
|
||||
|
||||
await tool.execute(content="thread reply")
|
||||
|
||||
assert sent[0].metadata == slack_meta
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_does_not_inherit_metadata_for_cross_target() -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
async def _send(msg: OutboundMessage) -> None:
|
||||
sent.append(msg)
|
||||
|
||||
tool = MessageTool(send_callback=_send)
|
||||
tool.set_context(
|
||||
"slack",
|
||||
"C123",
|
||||
metadata={"slack": {"thread_ts": "111.222", "channel_type": "channel"}},
|
||||
)
|
||||
|
||||
await tool.execute(content="channel reply", channel="slack", chat_id="C999")
|
||||
|
||||
assert sent[0].metadata == {}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user