nanobot/tests/test_tool_contextvars.py
chengyongru c30e4d86f3 refactor(agent): simplify subagent concurrency with rejection over semaphore
Replace the asyncio.Semaphore queueing approach with a simple count
check in SpawnTool.execute(). When the concurrency limit is reached,
the tool returns an error string so the agent can perceive the reason
and adjust its behavior instead of silently queueing.

- Remove max_concurrent_subagents parameter threading through
  AgentLoop, commands.py, and nanobot.py
- SubagentManager reads the limit directly from AgentDefaults
- SpawnTool checks get_running_count() before calling spawn()
- Simplify tests to verify rejection behavior
2026-05-05 22:22:04 +08:00

242 lines
7.3 KiB
Python

from __future__ import annotations
import asyncio
import pytest
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.spawn import SpawnTool
from nanobot.cron.service import CronService
@pytest.mark.asyncio
async def test_message_tool_keeps_task_local_context() -> None:
seen: list[tuple[str, str, str]] = []
entered = asyncio.Event()
release = asyncio.Event()
async def send_callback(msg):
seen.append((msg.channel, msg.chat_id, msg.content))
return None
tool = MessageTool(send_callback=send_callback)
async def task_one() -> str:
tool.set_context("feishu", "chat-a")
entered.set()
await release.wait()
return await tool.execute(content="one")
async def task_two() -> str:
await entered.wait()
tool.set_context("email", "chat-b")
release.set()
return await tool.execute(content="two")
result_one, result_two = await asyncio.gather(task_one(), task_two())
assert result_one == "Message sent to feishu:chat-a"
assert result_two == "Message sent to email:chat-b"
assert ("feishu", "chat-a", "one") in seen
assert ("email", "chat-b", "two") in seen
@pytest.mark.asyncio
async def test_spawn_tool_keeps_task_local_context() -> None:
seen: list[tuple[str, str, str]] = []
entered = asyncio.Event()
release = asyncio.Event()
class _Manager:
max_concurrent_subagents = 1
def get_running_count(self) -> int:
return 0
async def spawn(
self,
*,
task: str,
label: str | None,
origin_channel: str,
origin_chat_id: str,
session_key: str,
origin_message_id: str | None = None,
) -> str:
seen.append((origin_channel, origin_chat_id, session_key))
return f"{origin_channel}:{origin_chat_id}:{task}"
tool = SpawnTool(_Manager())
async def task_one() -> str:
tool.set_context("whatsapp", "chat-a")
entered.set()
await release.wait()
return await tool.execute(task="one")
async def task_two() -> str:
await entered.wait()
tool.set_context("telegram", "chat-b")
release.set()
return await tool.execute(task="two")
result_one, result_two = await asyncio.gather(task_one(), task_two())
assert result_one == "whatsapp:chat-a:one"
assert result_two == "telegram:chat-b:two"
assert ("whatsapp", "chat-a", "whatsapp:chat-a") in seen
assert ("telegram", "chat-b", "telegram:chat-b") in seen
@pytest.mark.asyncio
async def test_cron_tool_keeps_task_local_context(tmp_path) -> None:
tool = CronTool(CronService(tmp_path / "jobs.json"))
entered = asyncio.Event()
release = asyncio.Event()
async def task_one() -> str:
tool.set_context("feishu", "chat-a")
entered.set()
await release.wait()
return await tool.execute(action="add", message="first", every_seconds=60)
async def task_two() -> str:
await entered.wait()
tool.set_context("email", "chat-b")
release.set()
return await tool.execute(action="add", message="second", every_seconds=60)
result_one, result_two = await asyncio.gather(task_one(), task_two())
assert result_one.startswith("Created job")
assert result_two.startswith("Created job")
jobs = tool._cron.list_jobs()
assert {job.payload.channel for job in jobs} == {"feishu", "email"}
assert {job.payload.to for job in jobs} == {"chat-a", "chat-b"}
# --- Basic single-task regression tests ---
@pytest.mark.asyncio
async def test_message_tool_basic_set_context_and_execute() -> None:
"""Single task: set_context then execute should route correctly."""
seen: list[tuple[str, str, str]] = []
async def send_callback(msg):
seen.append((msg.channel, msg.chat_id, msg.content))
tool = MessageTool(send_callback=send_callback)
tool.set_context("telegram", "chat-123", "msg-456")
result = await tool.execute(content="hello")
assert result == "Message sent to telegram:chat-123"
assert seen == [("telegram", "chat-123", "hello")]
@pytest.mark.asyncio
async def test_message_tool_default_values_without_set_context() -> None:
"""Without set_context, constructor defaults should be used."""
seen: list[tuple[str, str, str]] = []
async def send_callback(msg):
seen.append((msg.channel, msg.chat_id, msg.content))
tool = MessageTool(
send_callback=send_callback,
default_channel="discord",
default_chat_id="general",
)
result = await tool.execute(content="hi")
assert result == "Message sent to discord:general"
assert seen == [("discord", "general", "hi")]
@pytest.mark.asyncio
async def test_spawn_tool_basic_set_context_and_execute() -> None:
"""Single task: set_context then execute should pass correct origin."""
seen: list[tuple[str, str, str]] = []
class _Manager:
max_concurrent_subagents = 1
def get_running_count(self) -> int:
return 0
async def spawn(
self,
*,
task,
label,
origin_channel,
origin_chat_id,
session_key,
origin_message_id=None,
):
seen.append((origin_channel, origin_chat_id, session_key))
return f"ok: {task}"
tool = SpawnTool(_Manager())
tool.set_context("feishu", "chat-abc")
result = await tool.execute(task="do something")
assert result == "ok: do something"
assert seen == [("feishu", "chat-abc", "feishu:chat-abc")]
@pytest.mark.asyncio
async def test_spawn_tool_default_values_without_set_context() -> None:
"""Without set_context, default cli:direct should be used."""
seen: list[tuple[str, str, str]] = []
class _Manager:
max_concurrent_subagents = 1
def get_running_count(self) -> int:
return 0
async def spawn(
self,
*,
task,
label,
origin_channel,
origin_chat_id,
session_key,
origin_message_id=None,
):
seen.append((origin_channel, origin_chat_id, session_key))
return "ok"
tool = SpawnTool(_Manager())
await tool.execute(task="test")
assert seen == [("cli", "direct", "cli:direct")]
@pytest.mark.asyncio
async def test_cron_tool_basic_set_context_and_execute(tmp_path) -> None:
"""Single task: set_context then add job should use correct target."""
tool = CronTool(CronService(tmp_path / "jobs.json"))
tool.set_context("wechat", "user-789")
result = await tool.execute(action="add", message="standup", every_seconds=300)
assert result.startswith("Created job")
jobs = tool._cron.list_jobs()
assert len(jobs) == 1
assert jobs[0].payload.channel == "wechat"
assert jobs[0].payload.to == "user-789"
@pytest.mark.asyncio
async def test_cron_tool_no_context_returns_error(tmp_path) -> None:
"""Without set_context, add should fail with a clear error."""
tool = CronTool(CronService(tmp_path / "jobs.json"))
result = await tool.execute(action="add", message="test", every_seconds=60)
assert result == "Error: no session context (channel/chat_id)"