agent: use ContextVar for tool routing context

This commit is contained in:
jr_blue_551 2026-03-18 19:10:57 +00:00 committed by Xubin Ren
parent 82aa9efc02
commit ff8c28d5a8
4 changed files with 150 additions and 29 deletions

View File

@ -58,14 +58,14 @@ class CronTool(Tool):
def __init__(self, cron_service: CronService, default_timezone: str = "UTC"): def __init__(self, cron_service: CronService, default_timezone: str = "UTC"):
self._cron = cron_service self._cron = cron_service
self._default_timezone = default_timezone self._default_timezone = default_timezone
self._channel = "" self._channel: ContextVar[str] = ContextVar("cron_channel", default="")
self._chat_id = "" self._chat_id: ContextVar[str] = ContextVar("cron_chat_id", default="")
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False) self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
def set_context(self, channel: str, chat_id: str) -> None: def set_context(self, channel: str, chat_id: str) -> None:
"""Set the current session context for delivery.""" """Set the current session context for delivery."""
self._channel = channel self._channel.set(channel)
self._chat_id = chat_id self._chat_id.set(chat_id)
def set_cron_context(self, active: bool): def set_cron_context(self, active: bool):
"""Mark whether the tool is executing inside a cron job callback.""" """Mark whether the tool is executing inside a cron job callback."""
@ -155,7 +155,9 @@ class CronTool(Tool):
"describing what to do when the job triggers " "describing what to do when the job triggers "
"(e.g. the reminder text). Retry including message=\"...\"." "(e.g. the reminder text). Retry including message=\"...\"."
) )
if not self._channel or not self._chat_id: channel = self._channel.get()
chat_id = self._chat_id.get()
if not channel or not chat_id:
return "Error: no session context (channel/chat_id)" return "Error: no session context (channel/chat_id)"
if tz and not cron_expr: if tz and not cron_expr:
return "Error: tz can only be used with cron_expr" return "Error: tz can only be used with cron_expr"
@ -194,8 +196,8 @@ class CronTool(Tool):
schedule=schedule, schedule=schedule,
message=message, message=message,
deliver=deliver, deliver=deliver,
channel=self._channel, channel=channel,
to=self._chat_id, to=chat_id,
delete_after_run=delete_after, delete_after_run=delete_after,
) )
return f"Created job '{job.name}' (id: {job.id})" return f"Created job '{job.name}' (id: {job.id})"

View File

@ -1,5 +1,6 @@
"""Message tool for sending messages to users.""" """Message tool for sending messages to users."""
from contextvars import ContextVar
from typing import Any, Awaitable, Callable from typing import Any, Awaitable, Callable
from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.base import Tool, tool_parameters
@ -30,16 +31,19 @@ class MessageTool(Tool):
default_message_id: str | None = None, default_message_id: str | None = None,
): ):
self._send_callback = send_callback self._send_callback = send_callback
self._default_channel = default_channel self._default_channel: ContextVar[str] = ContextVar("message_default_channel", default=default_channel)
self._default_chat_id = default_chat_id self._default_chat_id: ContextVar[str] = ContextVar("message_default_chat_id", default=default_chat_id)
self._default_message_id = default_message_id self._default_message_id: ContextVar[str | None] = ContextVar(
self._sent_in_turn: bool = False "message_default_message_id",
default=default_message_id,
)
self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False)
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
"""Set the current message context.""" """Set the current message context."""
self._default_channel = channel self._default_channel.set(channel)
self._default_chat_id = chat_id self._default_chat_id.set(chat_id)
self._default_message_id = message_id self._default_message_id.set(message_id)
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None: def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
"""Set the callback for sending messages.""" """Set the callback for sending messages."""
@ -49,6 +53,14 @@ class MessageTool(Tool):
"""Reset per-turn send tracking.""" """Reset per-turn send tracking."""
self._sent_in_turn = False self._sent_in_turn = False
@property
def _sent_in_turn(self) -> bool:
return self._sent_in_turn_var.get()
@_sent_in_turn.setter
def _sent_in_turn(self, value: bool) -> None:
self._sent_in_turn_var.set(value)
@property @property
def name(self) -> str: def name(self) -> str:
return "message" return "message"
@ -73,16 +85,19 @@ class MessageTool(Tool):
) -> str: ) -> str:
from nanobot.utils.helpers import strip_think from nanobot.utils.helpers import strip_think
content = strip_think(content) content = strip_think(content)
channel = channel or self._default_channel default_channel = self._default_channel.get()
chat_id = chat_id or self._default_chat_id default_chat_id = self._default_chat_id.get()
channel = channel or default_channel
chat_id = chat_id or default_chat_id
# Only inherit default message_id when targeting the same channel+chat. # Only inherit default message_id when targeting the same channel+chat.
# Cross-chat sends must not carry the original message_id, because # Cross-chat sends must not carry the original message_id, because
# some channels (e.g. Feishu) use it to determine the target # some channels (e.g. Feishu) use it to determine the target
# conversation via their Reply API, which would route the message # conversation via their Reply API, which would route the message
# to the wrong chat entirely. # to the wrong chat entirely.
if channel == self._default_channel and chat_id == self._default_chat_id: if channel == default_channel and chat_id == default_chat_id:
message_id = message_id or self._default_message_id message_id = message_id or self._default_message_id.get()
else: else:
message_id = None message_id = None
@ -104,7 +119,7 @@ class MessageTool(Tool):
try: try:
await self._send_callback(msg) await self._send_callback(msg)
if channel == self._default_channel and chat_id == self._default_chat_id: if channel == default_channel and chat_id == default_chat_id:
self._sent_in_turn = True self._sent_in_turn = True
media_info = f" with {len(media)} attachments" if media else "" media_info = f" with {len(media)} attachments" if media else ""
return f"Message sent to {channel}:{chat_id}{media_info}" return f"Message sent to {channel}:{chat_id}{media_info}"

View File

@ -1,5 +1,6 @@
"""Spawn tool for creating background subagents.""" """Spawn tool for creating background subagents."""
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.base import Tool, tool_parameters
@ -21,15 +22,15 @@ class SpawnTool(Tool):
def __init__(self, manager: "SubagentManager"): def __init__(self, manager: "SubagentManager"):
self._manager = manager self._manager = manager
self._origin_channel = "cli" self._origin_channel: ContextVar[str] = ContextVar("spawn_origin_channel", default="cli")
self._origin_chat_id = "direct" self._origin_chat_id: ContextVar[str] = ContextVar("spawn_origin_chat_id", default="direct")
self._session_key = "cli:direct" self._session_key: ContextVar[str] = ContextVar("spawn_session_key", default="cli:direct")
def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None: def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None:
"""Set the origin context for subagent announcements.""" """Set the origin context for subagent announcements."""
self._origin_channel = channel self._origin_channel.set(channel)
self._origin_chat_id = chat_id self._origin_chat_id.set(chat_id)
self._session_key = effective_key or f"{channel}:{chat_id}" self._session_key.set(effective_key or f"{channel}:{chat_id}")
@property @property
def name(self) -> str: def name(self) -> str:
@ -50,7 +51,7 @@ class SpawnTool(Tool):
return await self._manager.spawn( return await self._manager.spawn(
task=task, task=task,
label=label, label=label,
origin_channel=self._origin_channel, origin_channel=self._origin_channel.get(),
origin_chat_id=self._origin_chat_id, origin_chat_id=self._origin_chat_id.get(),
session_key=self._session_key, session_key=self._session_key.get(),
) )

View File

@ -0,0 +1,103 @@
from __future__ import annotations
import asyncio
import pytest
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.spawn import SpawnTool
from nanobot.cron.service import CronService
@pytest.mark.asyncio
async def test_message_tool_keeps_task_local_context() -> None:
seen: list[tuple[str, str, str]] = []
entered = asyncio.Event()
release = asyncio.Event()
async def send_callback(msg):
seen.append((msg.channel, msg.chat_id, msg.content))
return None
tool = MessageTool(send_callback=send_callback)
async def task_one() -> str:
tool.set_context("feishu", "chat-a")
entered.set()
await release.wait()
return await tool.execute(content="one")
async def task_two() -> str:
await entered.wait()
tool.set_context("email", "chat-b")
release.set()
return await tool.execute(content="two")
result_one, result_two = await asyncio.gather(task_one(), task_two())
assert result_one == "Message sent to feishu:chat-a"
assert result_two == "Message sent to email:chat-b"
assert ("feishu", "chat-a", "one") in seen
assert ("email", "chat-b", "two") in seen
@pytest.mark.asyncio
async def test_spawn_tool_keeps_task_local_context() -> None:
seen: list[tuple[str, str, str]] = []
entered = asyncio.Event()
release = asyncio.Event()
class _Manager:
async def spawn(self, *, task: str, label: str | None, origin_channel: str, origin_chat_id: str, session_key: str) -> str:
seen.append((origin_channel, origin_chat_id, session_key))
return f"{origin_channel}:{origin_chat_id}:{task}"
tool = SpawnTool(_Manager())
async def task_one() -> str:
tool.set_context("whatsapp", "chat-a")
entered.set()
await release.wait()
return await tool.execute(task="one")
async def task_two() -> str:
await entered.wait()
tool.set_context("telegram", "chat-b")
release.set()
return await tool.execute(task="two")
result_one, result_two = await asyncio.gather(task_one(), task_two())
assert result_one == "whatsapp:chat-a:one"
assert result_two == "telegram:chat-b:two"
assert ("whatsapp", "chat-a", "whatsapp:chat-a") in seen
assert ("telegram", "chat-b", "telegram:chat-b") in seen
@pytest.mark.asyncio
async def test_cron_tool_keeps_task_local_context(tmp_path) -> None:
tool = CronTool(CronService(tmp_path / "jobs.json"))
entered = asyncio.Event()
release = asyncio.Event()
async def task_one() -> str:
tool.set_context("feishu", "chat-a")
entered.set()
await release.wait()
return await tool.execute(action="add", message="first", every_seconds=60)
async def task_two() -> str:
await entered.wait()
tool.set_context("email", "chat-b")
release.set()
return await tool.execute(action="add", message="second", every_seconds=60)
result_one, result_two = await asyncio.gather(task_one(), task_two())
assert result_one.startswith("Created job")
assert result_two.startswith("Created job")
jobs = tool._cron.list_jobs()
assert {job.payload.channel for job in jobs} == {"feishu", "email"}
assert {job.payload.to for job in jobs} == {"chat-a", "chat-b"}