diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index 7124a2a74..127ac6c90 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -58,14 +58,14 @@ class CronTool(Tool): def __init__(self, cron_service: CronService, default_timezone: str = "UTC"): self._cron = cron_service self._default_timezone = default_timezone - self._channel = "" - self._chat_id = "" + self._channel: ContextVar[str] = ContextVar("cron_channel", default="") + self._chat_id: ContextVar[str] = ContextVar("cron_chat_id", default="") self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False) def set_context(self, channel: str, chat_id: str) -> None: """Set the current session context for delivery.""" - self._channel = channel - self._chat_id = chat_id + self._channel.set(channel) + self._chat_id.set(chat_id) def set_cron_context(self, active: bool): """Mark whether the tool is executing inside a cron job callback.""" @@ -155,7 +155,9 @@ class CronTool(Tool): "describing what to do when the job triggers " "(e.g. the reminder text). Retry including message=\"...\"." ) - if not self._channel or not self._chat_id: + channel = self._channel.get() + chat_id = self._chat_id.get() + if not channel or not chat_id: return "Error: no session context (channel/chat_id)" if tz and not cron_expr: return "Error: tz can only be used with cron_expr" @@ -194,8 +196,8 @@ class CronTool(Tool): schedule=schedule, message=message, deliver=deliver, - channel=self._channel, - to=self._chat_id, + channel=channel, + to=chat_id, delete_after_run=delete_after, ) return f"Created job '{job.name}' (id: {job.id})" diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index 524cadcf5..ee81effbd 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -1,5 +1,6 @@ """Message tool for sending messages to users.""" +from contextvars import ContextVar from typing import Any, Awaitable, Callable from nanobot.agent.tools.base import Tool, tool_parameters @@ -30,16 +31,19 @@ class MessageTool(Tool): default_message_id: str | None = None, ): self._send_callback = send_callback - self._default_channel = default_channel - self._default_chat_id = default_chat_id - self._default_message_id = default_message_id - self._sent_in_turn: bool = False + self._default_channel: ContextVar[str] = ContextVar("message_default_channel", default=default_channel) + self._default_chat_id: ContextVar[str] = ContextVar("message_default_chat_id", default=default_chat_id) + self._default_message_id: ContextVar[str | None] = ContextVar( + "message_default_message_id", + default=default_message_id, + ) + self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False) def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: """Set the current message context.""" - self._default_channel = channel - self._default_chat_id = chat_id - self._default_message_id = message_id + self._default_channel.set(channel) + self._default_chat_id.set(chat_id) + self._default_message_id.set(message_id) def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None: """Set the callback for sending messages.""" @@ -49,6 +53,14 @@ class MessageTool(Tool): """Reset per-turn send tracking.""" self._sent_in_turn = False + @property + def _sent_in_turn(self) -> bool: + return self._sent_in_turn_var.get() + + @_sent_in_turn.setter + def _sent_in_turn(self, value: bool) -> None: + self._sent_in_turn_var.set(value) + @property def name(self) -> str: return "message" @@ -73,16 +85,19 @@ class MessageTool(Tool): ) -> str: from nanobot.utils.helpers import strip_think content = strip_think(content) - - channel = channel or self._default_channel - chat_id = chat_id or self._default_chat_id + + default_channel = self._default_channel.get() + default_chat_id = self._default_chat_id.get() + + channel = channel or default_channel + chat_id = chat_id or default_chat_id # Only inherit default message_id when targeting the same channel+chat. # Cross-chat sends must not carry the original message_id, because # some channels (e.g. Feishu) use it to determine the target # conversation via their Reply API, which would route the message # to the wrong chat entirely. - if channel == self._default_channel and chat_id == self._default_chat_id: - message_id = message_id or self._default_message_id + if channel == default_channel and chat_id == default_chat_id: + message_id = message_id or self._default_message_id.get() else: message_id = None @@ -104,7 +119,7 @@ class MessageTool(Tool): try: await self._send_callback(msg) - if channel == self._default_channel and chat_id == self._default_chat_id: + if channel == default_channel and chat_id == default_chat_id: self._sent_in_turn = True media_info = f" with {len(media)} attachments" if media else "" return f"Message sent to {channel}:{chat_id}{media_info}" diff --git a/nanobot/agent/tools/spawn.py b/nanobot/agent/tools/spawn.py index 8ffb438bf..beda058a8 100644 --- a/nanobot/agent/tools/spawn.py +++ b/nanobot/agent/tools/spawn.py @@ -1,5 +1,6 @@ """Spawn tool for creating background subagents.""" +from contextvars import ContextVar from typing import TYPE_CHECKING, Any from nanobot.agent.tools.base import Tool, tool_parameters @@ -21,15 +22,15 @@ class SpawnTool(Tool): def __init__(self, manager: "SubagentManager"): self._manager = manager - self._origin_channel = "cli" - self._origin_chat_id = "direct" - self._session_key = "cli:direct" + self._origin_channel: ContextVar[str] = ContextVar("spawn_origin_channel", default="cli") + self._origin_chat_id: ContextVar[str] = ContextVar("spawn_origin_chat_id", default="direct") + self._session_key: ContextVar[str] = ContextVar("spawn_session_key", default="cli:direct") def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None: """Set the origin context for subagent announcements.""" - self._origin_channel = channel - self._origin_chat_id = chat_id - self._session_key = effective_key or f"{channel}:{chat_id}" + self._origin_channel.set(channel) + self._origin_chat_id.set(chat_id) + self._session_key.set(effective_key or f"{channel}:{chat_id}") @property def name(self) -> str: @@ -50,7 +51,7 @@ class SpawnTool(Tool): return await self._manager.spawn( task=task, label=label, - origin_channel=self._origin_channel, - origin_chat_id=self._origin_chat_id, - session_key=self._session_key, + origin_channel=self._origin_channel.get(), + origin_chat_id=self._origin_chat_id.get(), + session_key=self._session_key.get(), ) diff --git a/tests/test_tool_contextvars.py b/tests/test_tool_contextvars.py new file mode 100644 index 000000000..3cfef9515 --- /dev/null +++ b/tests/test_tool_contextvars.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from nanobot.agent.tools.cron import CronTool +from nanobot.agent.tools.message import MessageTool +from nanobot.agent.tools.spawn import SpawnTool +from nanobot.cron.service import CronService + + +@pytest.mark.asyncio +async def test_message_tool_keeps_task_local_context() -> None: + seen: list[tuple[str, str, str]] = [] + entered = asyncio.Event() + release = asyncio.Event() + + async def send_callback(msg): + seen.append((msg.channel, msg.chat_id, msg.content)) + return None + + tool = MessageTool(send_callback=send_callback) + + async def task_one() -> str: + tool.set_context("feishu", "chat-a") + entered.set() + await release.wait() + return await tool.execute(content="one") + + async def task_two() -> str: + await entered.wait() + tool.set_context("email", "chat-b") + release.set() + return await tool.execute(content="two") + + result_one, result_two = await asyncio.gather(task_one(), task_two()) + + assert result_one == "Message sent to feishu:chat-a" + assert result_two == "Message sent to email:chat-b" + assert ("feishu", "chat-a", "one") in seen + assert ("email", "chat-b", "two") in seen + + +@pytest.mark.asyncio +async def test_spawn_tool_keeps_task_local_context() -> None: + seen: list[tuple[str, str, str]] = [] + entered = asyncio.Event() + release = asyncio.Event() + + class _Manager: + async def spawn(self, *, task: str, label: str | None, origin_channel: str, origin_chat_id: str, session_key: str) -> str: + seen.append((origin_channel, origin_chat_id, session_key)) + return f"{origin_channel}:{origin_chat_id}:{task}" + + tool = SpawnTool(_Manager()) + + async def task_one() -> str: + tool.set_context("whatsapp", "chat-a") + entered.set() + await release.wait() + return await tool.execute(task="one") + + async def task_two() -> str: + await entered.wait() + tool.set_context("telegram", "chat-b") + release.set() + return await tool.execute(task="two") + + result_one, result_two = await asyncio.gather(task_one(), task_two()) + + assert result_one == "whatsapp:chat-a:one" + assert result_two == "telegram:chat-b:two" + assert ("whatsapp", "chat-a", "whatsapp:chat-a") in seen + assert ("telegram", "chat-b", "telegram:chat-b") in seen + + +@pytest.mark.asyncio +async def test_cron_tool_keeps_task_local_context(tmp_path) -> None: + tool = CronTool(CronService(tmp_path / "jobs.json")) + entered = asyncio.Event() + release = asyncio.Event() + + async def task_one() -> str: + tool.set_context("feishu", "chat-a") + entered.set() + await release.wait() + return await tool.execute(action="add", message="first", every_seconds=60) + + async def task_two() -> str: + await entered.wait() + tool.set_context("email", "chat-b") + release.set() + return await tool.execute(action="add", message="second", every_seconds=60) + + result_one, result_two = await asyncio.gather(task_one(), task_two()) + + assert result_one.startswith("Created job") + assert result_two.startswith("Created job") + + jobs = tool._cron.list_jobs() + assert {job.payload.channel for job in jobs} == {"feishu", "email"} + assert {job.payload.to for job in jobs} == {"chat-a", "chat-b"}