mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-06 17:55:59 +00:00
Replace the asyncio.Semaphore queueing approach with a simple count check in SpawnTool.execute(). When the concurrency limit is reached, the tool returns an error string so the agent can perceive the reason and adjust its behavior instead of silently queueing. - Remove max_concurrent_subagents parameter threading through AgentLoop, commands.py, and nanobot.py - SubagentManager reads the limit directly from AgentDefaults - SpawnTool checks get_running_count() before calling spawn() - Simplify tests to verify rejection behavior
242 lines
7.3 KiB
Python
242 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
|
|
import pytest
|
|
|
|
from nanobot.agent.tools.cron import CronTool
|
|
from nanobot.agent.tools.message import MessageTool
|
|
from nanobot.agent.tools.spawn import SpawnTool
|
|
from nanobot.cron.service import CronService
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_tool_keeps_task_local_context() -> None:
|
|
seen: list[tuple[str, str, str]] = []
|
|
entered = asyncio.Event()
|
|
release = asyncio.Event()
|
|
|
|
async def send_callback(msg):
|
|
seen.append((msg.channel, msg.chat_id, msg.content))
|
|
return None
|
|
|
|
tool = MessageTool(send_callback=send_callback)
|
|
|
|
async def task_one() -> str:
|
|
tool.set_context("feishu", "chat-a")
|
|
entered.set()
|
|
await release.wait()
|
|
return await tool.execute(content="one")
|
|
|
|
async def task_two() -> str:
|
|
await entered.wait()
|
|
tool.set_context("email", "chat-b")
|
|
release.set()
|
|
return await tool.execute(content="two")
|
|
|
|
result_one, result_two = await asyncio.gather(task_one(), task_two())
|
|
|
|
assert result_one == "Message sent to feishu:chat-a"
|
|
assert result_two == "Message sent to email:chat-b"
|
|
assert ("feishu", "chat-a", "one") in seen
|
|
assert ("email", "chat-b", "two") in seen
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_spawn_tool_keeps_task_local_context() -> None:
|
|
seen: list[tuple[str, str, str]] = []
|
|
entered = asyncio.Event()
|
|
release = asyncio.Event()
|
|
|
|
class _Manager:
|
|
max_concurrent_subagents = 1
|
|
|
|
def get_running_count(self) -> int:
|
|
return 0
|
|
|
|
async def spawn(
|
|
self,
|
|
*,
|
|
task: str,
|
|
label: str | None,
|
|
origin_channel: str,
|
|
origin_chat_id: str,
|
|
session_key: str,
|
|
origin_message_id: str | None = None,
|
|
) -> str:
|
|
seen.append((origin_channel, origin_chat_id, session_key))
|
|
return f"{origin_channel}:{origin_chat_id}:{task}"
|
|
|
|
tool = SpawnTool(_Manager())
|
|
|
|
async def task_one() -> str:
|
|
tool.set_context("whatsapp", "chat-a")
|
|
entered.set()
|
|
await release.wait()
|
|
return await tool.execute(task="one")
|
|
|
|
async def task_two() -> str:
|
|
await entered.wait()
|
|
tool.set_context("telegram", "chat-b")
|
|
release.set()
|
|
return await tool.execute(task="two")
|
|
|
|
result_one, result_two = await asyncio.gather(task_one(), task_two())
|
|
|
|
assert result_one == "whatsapp:chat-a:one"
|
|
assert result_two == "telegram:chat-b:two"
|
|
assert ("whatsapp", "chat-a", "whatsapp:chat-a") in seen
|
|
assert ("telegram", "chat-b", "telegram:chat-b") in seen
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cron_tool_keeps_task_local_context(tmp_path) -> None:
|
|
tool = CronTool(CronService(tmp_path / "jobs.json"))
|
|
entered = asyncio.Event()
|
|
release = asyncio.Event()
|
|
|
|
async def task_one() -> str:
|
|
tool.set_context("feishu", "chat-a")
|
|
entered.set()
|
|
await release.wait()
|
|
return await tool.execute(action="add", message="first", every_seconds=60)
|
|
|
|
async def task_two() -> str:
|
|
await entered.wait()
|
|
tool.set_context("email", "chat-b")
|
|
release.set()
|
|
return await tool.execute(action="add", message="second", every_seconds=60)
|
|
|
|
result_one, result_two = await asyncio.gather(task_one(), task_two())
|
|
|
|
assert result_one.startswith("Created job")
|
|
assert result_two.startswith("Created job")
|
|
|
|
jobs = tool._cron.list_jobs()
|
|
assert {job.payload.channel for job in jobs} == {"feishu", "email"}
|
|
assert {job.payload.to for job in jobs} == {"chat-a", "chat-b"}
|
|
|
|
|
|
# --- Basic single-task regression tests ---
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_tool_basic_set_context_and_execute() -> None:
|
|
"""Single task: set_context then execute should route correctly."""
|
|
seen: list[tuple[str, str, str]] = []
|
|
|
|
async def send_callback(msg):
|
|
seen.append((msg.channel, msg.chat_id, msg.content))
|
|
|
|
tool = MessageTool(send_callback=send_callback)
|
|
tool.set_context("telegram", "chat-123", "msg-456")
|
|
|
|
result = await tool.execute(content="hello")
|
|
assert result == "Message sent to telegram:chat-123"
|
|
assert seen == [("telegram", "chat-123", "hello")]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_tool_default_values_without_set_context() -> None:
|
|
"""Without set_context, constructor defaults should be used."""
|
|
seen: list[tuple[str, str, str]] = []
|
|
|
|
async def send_callback(msg):
|
|
seen.append((msg.channel, msg.chat_id, msg.content))
|
|
|
|
tool = MessageTool(
|
|
send_callback=send_callback,
|
|
default_channel="discord",
|
|
default_chat_id="general",
|
|
)
|
|
|
|
result = await tool.execute(content="hi")
|
|
assert result == "Message sent to discord:general"
|
|
assert seen == [("discord", "general", "hi")]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_spawn_tool_basic_set_context_and_execute() -> None:
|
|
"""Single task: set_context then execute should pass correct origin."""
|
|
seen: list[tuple[str, str, str]] = []
|
|
|
|
class _Manager:
|
|
max_concurrent_subagents = 1
|
|
|
|
def get_running_count(self) -> int:
|
|
return 0
|
|
|
|
async def spawn(
|
|
self,
|
|
*,
|
|
task,
|
|
label,
|
|
origin_channel,
|
|
origin_chat_id,
|
|
session_key,
|
|
origin_message_id=None,
|
|
):
|
|
seen.append((origin_channel, origin_chat_id, session_key))
|
|
return f"ok: {task}"
|
|
|
|
tool = SpawnTool(_Manager())
|
|
tool.set_context("feishu", "chat-abc")
|
|
|
|
result = await tool.execute(task="do something")
|
|
assert result == "ok: do something"
|
|
assert seen == [("feishu", "chat-abc", "feishu:chat-abc")]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_spawn_tool_default_values_without_set_context() -> None:
|
|
"""Without set_context, default cli:direct should be used."""
|
|
seen: list[tuple[str, str, str]] = []
|
|
|
|
class _Manager:
|
|
max_concurrent_subagents = 1
|
|
|
|
def get_running_count(self) -> int:
|
|
return 0
|
|
|
|
async def spawn(
|
|
self,
|
|
*,
|
|
task,
|
|
label,
|
|
origin_channel,
|
|
origin_chat_id,
|
|
session_key,
|
|
origin_message_id=None,
|
|
):
|
|
seen.append((origin_channel, origin_chat_id, session_key))
|
|
return "ok"
|
|
|
|
tool = SpawnTool(_Manager())
|
|
|
|
await tool.execute(task="test")
|
|
assert seen == [("cli", "direct", "cli:direct")]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cron_tool_basic_set_context_and_execute(tmp_path) -> None:
|
|
"""Single task: set_context then add job should use correct target."""
|
|
tool = CronTool(CronService(tmp_path / "jobs.json"))
|
|
tool.set_context("wechat", "user-789")
|
|
|
|
result = await tool.execute(action="add", message="standup", every_seconds=300)
|
|
assert result.startswith("Created job")
|
|
|
|
jobs = tool._cron.list_jobs()
|
|
assert len(jobs) == 1
|
|
assert jobs[0].payload.channel == "wechat"
|
|
assert jobs[0].payload.to == "user-789"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cron_tool_no_context_returns_error(tmp_path) -> None:
|
|
"""Without set_context, add should fail with a clear error."""
|
|
tool = CronTool(CronService(tmp_path / "jobs.json"))
|
|
|
|
result = await tool.execute(action="add", message="test", every_seconds=60)
|
|
assert result == "Error: no session context (channel/chat_id)"
|