agent: use ContextVar for tool routing context

This commit is contained in:
jr_blue_551 2026-03-18 19:10:57 +00:00 committed by Xubin Ren
parent 82aa9efc02
commit ff8c28d5a8
4 changed files with 150 additions and 29 deletions

View File

@ -58,14 +58,14 @@ class CronTool(Tool):
def __init__(self, cron_service: CronService, default_timezone: str = "UTC"):
self._cron = cron_service
self._default_timezone = default_timezone
self._channel = ""
self._chat_id = ""
self._channel: ContextVar[str] = ContextVar("cron_channel", default="")
self._chat_id: ContextVar[str] = ContextVar("cron_chat_id", default="")
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
def set_context(self, channel: str, chat_id: str) -> None:
"""Set the current session context for delivery."""
self._channel = channel
self._chat_id = chat_id
self._channel.set(channel)
self._chat_id.set(chat_id)
def set_cron_context(self, active: bool):
"""Mark whether the tool is executing inside a cron job callback."""
@ -155,7 +155,9 @@ class CronTool(Tool):
"describing what to do when the job triggers "
"(e.g. the reminder text). Retry including message=\"...\"."
)
if not self._channel or not self._chat_id:
channel = self._channel.get()
chat_id = self._chat_id.get()
if not channel or not chat_id:
return "Error: no session context (channel/chat_id)"
if tz and not cron_expr:
return "Error: tz can only be used with cron_expr"
@ -194,8 +196,8 @@ class CronTool(Tool):
schedule=schedule,
message=message,
deliver=deliver,
channel=self._channel,
to=self._chat_id,
channel=channel,
to=chat_id,
delete_after_run=delete_after,
)
return f"Created job '{job.name}' (id: {job.id})"

View File

@ -1,5 +1,6 @@
"""Message tool for sending messages to users."""
from contextvars import ContextVar
from typing import Any, Awaitable, Callable
from nanobot.agent.tools.base import Tool, tool_parameters
@ -30,16 +31,19 @@ class MessageTool(Tool):
default_message_id: str | None = None,
):
self._send_callback = send_callback
self._default_channel = default_channel
self._default_chat_id = default_chat_id
self._default_message_id = default_message_id
self._sent_in_turn: bool = False
self._default_channel: ContextVar[str] = ContextVar("message_default_channel", default=default_channel)
self._default_chat_id: ContextVar[str] = ContextVar("message_default_chat_id", default=default_chat_id)
self._default_message_id: ContextVar[str | None] = ContextVar(
"message_default_message_id",
default=default_message_id,
)
self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False)
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
"""Set the current message context."""
self._default_channel = channel
self._default_chat_id = chat_id
self._default_message_id = message_id
self._default_channel.set(channel)
self._default_chat_id.set(chat_id)
self._default_message_id.set(message_id)
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
"""Set the callback for sending messages."""
@ -49,6 +53,14 @@ class MessageTool(Tool):
"""Reset per-turn send tracking."""
self._sent_in_turn = False
@property
def _sent_in_turn(self) -> bool:
return self._sent_in_turn_var.get()
@_sent_in_turn.setter
def _sent_in_turn(self, value: bool) -> None:
self._sent_in_turn_var.set(value)
@property
def name(self) -> str:
return "message"
@ -73,16 +85,19 @@ class MessageTool(Tool):
) -> str:
from nanobot.utils.helpers import strip_think
content = strip_think(content)
channel = channel or self._default_channel
chat_id = chat_id or self._default_chat_id
default_channel = self._default_channel.get()
default_chat_id = self._default_chat_id.get()
channel = channel or default_channel
chat_id = chat_id or default_chat_id
# Only inherit default message_id when targeting the same channel+chat.
# Cross-chat sends must not carry the original message_id, because
# some channels (e.g. Feishu) use it to determine the target
# conversation via their Reply API, which would route the message
# to the wrong chat entirely.
if channel == self._default_channel and chat_id == self._default_chat_id:
message_id = message_id or self._default_message_id
if channel == default_channel and chat_id == default_chat_id:
message_id = message_id or self._default_message_id.get()
else:
message_id = None
@ -104,7 +119,7 @@ class MessageTool(Tool):
try:
await self._send_callback(msg)
if channel == self._default_channel and chat_id == self._default_chat_id:
if channel == default_channel and chat_id == default_chat_id:
self._sent_in_turn = True
media_info = f" with {len(media)} attachments" if media else ""
return f"Message sent to {channel}:{chat_id}{media_info}"

View File

@ -1,5 +1,6 @@
"""Spawn tool for creating background subagents."""
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from nanobot.agent.tools.base import Tool, tool_parameters
@ -21,15 +22,15 @@ class SpawnTool(Tool):
def __init__(self, manager: "SubagentManager"):
self._manager = manager
self._origin_channel = "cli"
self._origin_chat_id = "direct"
self._session_key = "cli:direct"
self._origin_channel: ContextVar[str] = ContextVar("spawn_origin_channel", default="cli")
self._origin_chat_id: ContextVar[str] = ContextVar("spawn_origin_chat_id", default="direct")
self._session_key: ContextVar[str] = ContextVar("spawn_session_key", default="cli:direct")
def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None:
"""Set the origin context for subagent announcements."""
self._origin_channel = channel
self._origin_chat_id = chat_id
self._session_key = effective_key or f"{channel}:{chat_id}"
self._origin_channel.set(channel)
self._origin_chat_id.set(chat_id)
self._session_key.set(effective_key or f"{channel}:{chat_id}")
@property
def name(self) -> str:
@ -50,7 +51,7 @@ class SpawnTool(Tool):
return await self._manager.spawn(
task=task,
label=label,
origin_channel=self._origin_channel,
origin_chat_id=self._origin_chat_id,
session_key=self._session_key,
origin_channel=self._origin_channel.get(),
origin_chat_id=self._origin_chat_id.get(),
session_key=self._session_key.get(),
)

View File

@ -0,0 +1,103 @@
from __future__ import annotations
import asyncio
import pytest
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.spawn import SpawnTool
from nanobot.cron.service import CronService
@pytest.mark.asyncio
async def test_message_tool_keeps_task_local_context() -> None:
seen: list[tuple[str, str, str]] = []
entered = asyncio.Event()
release = asyncio.Event()
async def send_callback(msg):
seen.append((msg.channel, msg.chat_id, msg.content))
return None
tool = MessageTool(send_callback=send_callback)
async def task_one() -> str:
tool.set_context("feishu", "chat-a")
entered.set()
await release.wait()
return await tool.execute(content="one")
async def task_two() -> str:
await entered.wait()
tool.set_context("email", "chat-b")
release.set()
return await tool.execute(content="two")
result_one, result_two = await asyncio.gather(task_one(), task_two())
assert result_one == "Message sent to feishu:chat-a"
assert result_two == "Message sent to email:chat-b"
assert ("feishu", "chat-a", "one") in seen
assert ("email", "chat-b", "two") in seen
@pytest.mark.asyncio
async def test_spawn_tool_keeps_task_local_context() -> None:
seen: list[tuple[str, str, str]] = []
entered = asyncio.Event()
release = asyncio.Event()
class _Manager:
async def spawn(self, *, task: str, label: str | None, origin_channel: str, origin_chat_id: str, session_key: str) -> str:
seen.append((origin_channel, origin_chat_id, session_key))
return f"{origin_channel}:{origin_chat_id}:{task}"
tool = SpawnTool(_Manager())
async def task_one() -> str:
tool.set_context("whatsapp", "chat-a")
entered.set()
await release.wait()
return await tool.execute(task="one")
async def task_two() -> str:
await entered.wait()
tool.set_context("telegram", "chat-b")
release.set()
return await tool.execute(task="two")
result_one, result_two = await asyncio.gather(task_one(), task_two())
assert result_one == "whatsapp:chat-a:one"
assert result_two == "telegram:chat-b:two"
assert ("whatsapp", "chat-a", "whatsapp:chat-a") in seen
assert ("telegram", "chat-b", "telegram:chat-b") in seen
@pytest.mark.asyncio
async def test_cron_tool_keeps_task_local_context(tmp_path) -> None:
tool = CronTool(CronService(tmp_path / "jobs.json"))
entered = asyncio.Event()
release = asyncio.Event()
async def task_one() -> str:
tool.set_context("feishu", "chat-a")
entered.set()
await release.wait()
return await tool.execute(action="add", message="first", every_seconds=60)
async def task_two() -> str:
await entered.wait()
tool.set_context("email", "chat-b")
release.set()
return await tool.execute(action="add", message="second", every_seconds=60)
result_one, result_two = await asyncio.gather(task_one(), task_two())
assert result_one.startswith("Created job")
assert result_two.startswith("Created job")
jobs = tool._cron.list_jobs()
assert {job.payload.channel for job in jobs} == {"feishu", "email"}
assert {job.payload.to for job in jobs} == {"chat-a", "chat-b"}