mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-06 01:36:17 +00:00
agent: use ContextVar for tool routing context
This commit is contained in:
parent
82aa9efc02
commit
ff8c28d5a8
@ -58,14 +58,14 @@ class CronTool(Tool):
|
||||
def __init__(self, cron_service: CronService, default_timezone: str = "UTC"):
|
||||
self._cron = cron_service
|
||||
self._default_timezone = default_timezone
|
||||
self._channel = ""
|
||||
self._chat_id = ""
|
||||
self._channel: ContextVar[str] = ContextVar("cron_channel", default="")
|
||||
self._chat_id: ContextVar[str] = ContextVar("cron_chat_id", default="")
|
||||
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
|
||||
|
||||
def set_context(self, channel: str, chat_id: str) -> None:
|
||||
"""Set the current session context for delivery."""
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
self._channel.set(channel)
|
||||
self._chat_id.set(chat_id)
|
||||
|
||||
def set_cron_context(self, active: bool):
|
||||
"""Mark whether the tool is executing inside a cron job callback."""
|
||||
@ -155,7 +155,9 @@ class CronTool(Tool):
|
||||
"describing what to do when the job triggers "
|
||||
"(e.g. the reminder text). Retry including message=\"...\"."
|
||||
)
|
||||
if not self._channel or not self._chat_id:
|
||||
channel = self._channel.get()
|
||||
chat_id = self._chat_id.get()
|
||||
if not channel or not chat_id:
|
||||
return "Error: no session context (channel/chat_id)"
|
||||
if tz and not cron_expr:
|
||||
return "Error: tz can only be used with cron_expr"
|
||||
@ -194,8 +196,8 @@ class CronTool(Tool):
|
||||
schedule=schedule,
|
||||
message=message,
|
||||
deliver=deliver,
|
||||
channel=self._channel,
|
||||
to=self._chat_id,
|
||||
channel=channel,
|
||||
to=chat_id,
|
||||
delete_after_run=delete_after,
|
||||
)
|
||||
return f"Created job '{job.name}' (id: {job.id})"
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Message tool for sending messages to users."""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
@ -30,16 +31,19 @@ class MessageTool(Tool):
|
||||
default_message_id: str | None = None,
|
||||
):
|
||||
self._send_callback = send_callback
|
||||
self._default_channel = default_channel
|
||||
self._default_chat_id = default_chat_id
|
||||
self._default_message_id = default_message_id
|
||||
self._sent_in_turn: bool = False
|
||||
self._default_channel: ContextVar[str] = ContextVar("message_default_channel", default=default_channel)
|
||||
self._default_chat_id: ContextVar[str] = ContextVar("message_default_chat_id", default=default_chat_id)
|
||||
self._default_message_id: ContextVar[str | None] = ContextVar(
|
||||
"message_default_message_id",
|
||||
default=default_message_id,
|
||||
)
|
||||
self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False)
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
"""Set the current message context."""
|
||||
self._default_channel = channel
|
||||
self._default_chat_id = chat_id
|
||||
self._default_message_id = message_id
|
||||
self._default_channel.set(channel)
|
||||
self._default_chat_id.set(chat_id)
|
||||
self._default_message_id.set(message_id)
|
||||
|
||||
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
|
||||
"""Set the callback for sending messages."""
|
||||
@ -49,6 +53,14 @@ class MessageTool(Tool):
|
||||
"""Reset per-turn send tracking."""
|
||||
self._sent_in_turn = False
|
||||
|
||||
@property
|
||||
def _sent_in_turn(self) -> bool:
|
||||
return self._sent_in_turn_var.get()
|
||||
|
||||
@_sent_in_turn.setter
|
||||
def _sent_in_turn(self, value: bool) -> None:
|
||||
self._sent_in_turn_var.set(value)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "message"
|
||||
@ -73,16 +85,19 @@ class MessageTool(Tool):
|
||||
) -> str:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
content = strip_think(content)
|
||||
|
||||
channel = channel or self._default_channel
|
||||
chat_id = chat_id or self._default_chat_id
|
||||
|
||||
default_channel = self._default_channel.get()
|
||||
default_chat_id = self._default_chat_id.get()
|
||||
|
||||
channel = channel or default_channel
|
||||
chat_id = chat_id or default_chat_id
|
||||
# Only inherit default message_id when targeting the same channel+chat.
|
||||
# Cross-chat sends must not carry the original message_id, because
|
||||
# some channels (e.g. Feishu) use it to determine the target
|
||||
# conversation via their Reply API, which would route the message
|
||||
# to the wrong chat entirely.
|
||||
if channel == self._default_channel and chat_id == self._default_chat_id:
|
||||
message_id = message_id or self._default_message_id
|
||||
if channel == default_channel and chat_id == default_chat_id:
|
||||
message_id = message_id or self._default_message_id.get()
|
||||
else:
|
||||
message_id = None
|
||||
|
||||
@ -104,7 +119,7 @@ class MessageTool(Tool):
|
||||
|
||||
try:
|
||||
await self._send_callback(msg)
|
||||
if channel == self._default_channel and chat_id == self._default_chat_id:
|
||||
if channel == default_channel and chat_id == default_chat_id:
|
||||
self._sent_in_turn = True
|
||||
media_info = f" with {len(media)} attachments" if media else ""
|
||||
return f"Message sent to {channel}:{chat_id}{media_info}"
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Spawn tool for creating background subagents."""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
@ -21,15 +22,15 @@ class SpawnTool(Tool):
|
||||
|
||||
def __init__(self, manager: "SubagentManager"):
|
||||
self._manager = manager
|
||||
self._origin_channel = "cli"
|
||||
self._origin_chat_id = "direct"
|
||||
self._session_key = "cli:direct"
|
||||
self._origin_channel: ContextVar[str] = ContextVar("spawn_origin_channel", default="cli")
|
||||
self._origin_chat_id: ContextVar[str] = ContextVar("spawn_origin_chat_id", default="direct")
|
||||
self._session_key: ContextVar[str] = ContextVar("spawn_session_key", default="cli:direct")
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None:
|
||||
"""Set the origin context for subagent announcements."""
|
||||
self._origin_channel = channel
|
||||
self._origin_chat_id = chat_id
|
||||
self._session_key = effective_key or f"{channel}:{chat_id}"
|
||||
self._origin_channel.set(channel)
|
||||
self._origin_chat_id.set(chat_id)
|
||||
self._session_key.set(effective_key or f"{channel}:{chat_id}")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -50,7 +51,7 @@ class SpawnTool(Tool):
|
||||
return await self._manager.spawn(
|
||||
task=task,
|
||||
label=label,
|
||||
origin_channel=self._origin_channel,
|
||||
origin_chat_id=self._origin_chat_id,
|
||||
session_key=self._session_key,
|
||||
origin_channel=self._origin_channel.get(),
|
||||
origin_chat_id=self._origin_chat_id.get(),
|
||||
session_key=self._session_key.get(),
|
||||
)
|
||||
|
||||
103
tests/test_tool_contextvars.py
Normal file
103
tests/test_tool_contextvars.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_keeps_task_local_context() -> None:
|
||||
seen: list[tuple[str, str, str]] = []
|
||||
entered = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def send_callback(msg):
|
||||
seen.append((msg.channel, msg.chat_id, msg.content))
|
||||
return None
|
||||
|
||||
tool = MessageTool(send_callback=send_callback)
|
||||
|
||||
async def task_one() -> str:
|
||||
tool.set_context("feishu", "chat-a")
|
||||
entered.set()
|
||||
await release.wait()
|
||||
return await tool.execute(content="one")
|
||||
|
||||
async def task_two() -> str:
|
||||
await entered.wait()
|
||||
tool.set_context("email", "chat-b")
|
||||
release.set()
|
||||
return await tool.execute(content="two")
|
||||
|
||||
result_one, result_two = await asyncio.gather(task_one(), task_two())
|
||||
|
||||
assert result_one == "Message sent to feishu:chat-a"
|
||||
assert result_two == "Message sent to email:chat-b"
|
||||
assert ("feishu", "chat-a", "one") in seen
|
||||
assert ("email", "chat-b", "two") in seen
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spawn_tool_keeps_task_local_context() -> None:
|
||||
seen: list[tuple[str, str, str]] = []
|
||||
entered = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
class _Manager:
|
||||
async def spawn(self, *, task: str, label: str | None, origin_channel: str, origin_chat_id: str, session_key: str) -> str:
|
||||
seen.append((origin_channel, origin_chat_id, session_key))
|
||||
return f"{origin_channel}:{origin_chat_id}:{task}"
|
||||
|
||||
tool = SpawnTool(_Manager())
|
||||
|
||||
async def task_one() -> str:
|
||||
tool.set_context("whatsapp", "chat-a")
|
||||
entered.set()
|
||||
await release.wait()
|
||||
return await tool.execute(task="one")
|
||||
|
||||
async def task_two() -> str:
|
||||
await entered.wait()
|
||||
tool.set_context("telegram", "chat-b")
|
||||
release.set()
|
||||
return await tool.execute(task="two")
|
||||
|
||||
result_one, result_two = await asyncio.gather(task_one(), task_two())
|
||||
|
||||
assert result_one == "whatsapp:chat-a:one"
|
||||
assert result_two == "telegram:chat-b:two"
|
||||
assert ("whatsapp", "chat-a", "whatsapp:chat-a") in seen
|
||||
assert ("telegram", "chat-b", "telegram:chat-b") in seen
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_tool_keeps_task_local_context(tmp_path) -> None:
|
||||
tool = CronTool(CronService(tmp_path / "jobs.json"))
|
||||
entered = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def task_one() -> str:
|
||||
tool.set_context("feishu", "chat-a")
|
||||
entered.set()
|
||||
await release.wait()
|
||||
return await tool.execute(action="add", message="first", every_seconds=60)
|
||||
|
||||
async def task_two() -> str:
|
||||
await entered.wait()
|
||||
tool.set_context("email", "chat-b")
|
||||
release.set()
|
||||
return await tool.execute(action="add", message="second", every_seconds=60)
|
||||
|
||||
result_one, result_two = await asyncio.gather(task_one(), task_two())
|
||||
|
||||
assert result_one.startswith("Created job")
|
||||
assert result_two.startswith("Created job")
|
||||
|
||||
jobs = tool._cron.list_jobs()
|
||||
assert {job.payload.channel for job in jobs} == {"feishu", "email"}
|
||||
assert {job.payload.to for job in jobs} == {"chat-a", "chat-b"}
|
||||
Loading…
x
Reference in New Issue
Block a user