mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
351 lines
11 KiB
Python
351 lines
11 KiB
Python
"""Tests for MCP connection lifecycle in AgentLoop."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from contextlib import AsyncExitStack
|
|
from types import SimpleNamespace
|
|
from typing import Any
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from mcp import types as mcp_types
|
|
from mcp.shared.exceptions import McpError
|
|
from mcp.types import ErrorData
|
|
|
|
from nanobot.agent.loop import AgentLoop
|
|
from nanobot.agent.tools import mcp as mcp_runtime
|
|
from nanobot.agent.tools.base import Tool
|
|
from nanobot.agent.tools.mcp import MCPResourceWrapper, MCPToolWrapper
|
|
from nanobot.bus.queue import MessageBus
|
|
from nanobot.config.loader import load_config, save_config
|
|
from nanobot.config.schema import MCPServerConfig
|
|
|
|
|
|
class _FakeMcpTool(Tool):
|
|
def __init__(self, name: str) -> None:
|
|
self._name = name
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._name
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
return "fake MCP tool"
|
|
|
|
@property
|
|
def parameters(self) -> dict[str, Any]:
|
|
return {"type": "object", "properties": {}}
|
|
|
|
async def execute(self, **_kwargs: Any) -> str:
|
|
return "ok"
|
|
|
|
|
|
def _make_loop(tmp_path, *, mcp_servers: dict | None = None) -> AgentLoop:
|
|
bus = MessageBus()
|
|
provider = MagicMock()
|
|
provider.get_default_model.return_value = "test-model"
|
|
provider.generation.max_tokens = 4096
|
|
return AgentLoop(
|
|
bus=bus,
|
|
provider=provider,
|
|
workspace=tmp_path,
|
|
model="test-model",
|
|
mcp_servers=mcp_servers or {"test": object()},
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_mcp_retries_when_no_servers_connect(tmp_path, monkeypatch: pytest.MonkeyPatch):
|
|
loop = _make_loop(tmp_path)
|
|
attempts = 0
|
|
|
|
async def _fake_connect(_servers, _registry):
|
|
nonlocal attempts
|
|
attempts += 1
|
|
return {}
|
|
|
|
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
|
|
|
await loop._connect_mcp()
|
|
await loop._connect_mcp()
|
|
|
|
assert attempts == 2
|
|
assert loop._mcp_connected is False
|
|
assert loop._mcp_stacks == {}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reload_mcp_servers_adds_and_removes_tools_without_restart(
|
|
tmp_path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
):
|
|
config_path = tmp_path / "config.json"
|
|
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
|
|
config = load_config()
|
|
config.tools.mcp_servers["browserbase"] = MCPServerConfig(
|
|
type="stdio",
|
|
command="browserbase-mcp",
|
|
)
|
|
save_config(config)
|
|
|
|
closed: list[str] = []
|
|
|
|
async def _mark_closed(name: str) -> None:
|
|
closed.append(name)
|
|
|
|
async def _fake_connect(servers, registry):
|
|
stacks = {}
|
|
for name in servers:
|
|
registry.register(_FakeMcpTool(f"mcp_{name}_navigate"))
|
|
stack = AsyncExitStack()
|
|
await stack.__aenter__()
|
|
stack.push_async_callback(_mark_closed, name)
|
|
stacks[name] = stack
|
|
return stacks
|
|
|
|
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
|
loop = _make_loop(tmp_path, mcp_servers={})
|
|
|
|
added = await mcp_runtime.reload_servers(loop, loop.tools)
|
|
|
|
assert added["ok"] is True
|
|
assert added["added"] == ["browserbase"]
|
|
assert loop.tools.has("mcp_browserbase_navigate")
|
|
assert "browserbase" in loop._mcp_stacks
|
|
|
|
config = load_config()
|
|
del config.tools.mcp_servers["browserbase"]
|
|
save_config(config)
|
|
|
|
removed = await mcp_runtime.reload_servers(loop, loop.tools)
|
|
|
|
assert removed["ok"] is True
|
|
assert removed["removed"] == ["browserbase"]
|
|
assert not loop.tools.has("mcp_browserbase_navigate")
|
|
assert "browserbase" not in loop._mcp_stacks
|
|
assert closed == ["browserbase"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_request_mcp_reload_reaches_runtime_control_without_restart(
|
|
tmp_path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
):
|
|
config_path = tmp_path / "config.json"
|
|
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
|
|
config = load_config()
|
|
config.tools.mcp_servers["browserbase"] = MCPServerConfig(
|
|
type="stdio",
|
|
command="browserbase-mcp",
|
|
)
|
|
save_config(config)
|
|
|
|
closed: list[str] = []
|
|
|
|
async def _mark_closed(name: str) -> None:
|
|
closed.append(name)
|
|
|
|
async def _fake_connect(servers, registry):
|
|
stacks = {}
|
|
for name in servers:
|
|
registry.register(_FakeMcpTool(f"mcp_{name}_navigate"))
|
|
stack = AsyncExitStack()
|
|
await stack.__aenter__()
|
|
stack.push_async_callback(_mark_closed, name)
|
|
stacks[name] = stack
|
|
return stacks
|
|
|
|
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
|
loop = _make_loop(tmp_path, mcp_servers={})
|
|
|
|
async def _handle_one_runtime_control() -> None:
|
|
msg = await loop.bus.consume_inbound()
|
|
handled = await mcp_runtime.handle_runtime_control(loop, msg, loop.tools)
|
|
assert handled is True
|
|
|
|
consumer = asyncio.create_task(_handle_one_runtime_control())
|
|
result = await mcp_runtime.request_mcp_reload(loop.bus, timeout=2.0)
|
|
await consumer
|
|
|
|
assert result["ok"] is True
|
|
assert result["added"] == ["browserbase"]
|
|
assert result["requires_restart"] is False
|
|
assert loop.tools.has("mcp_browserbase_navigate")
|
|
|
|
config = load_config()
|
|
del config.tools.mcp_servers["browserbase"]
|
|
save_config(config)
|
|
|
|
consumer = asyncio.create_task(_handle_one_runtime_control())
|
|
result = await mcp_runtime.request_mcp_reload(loop.bus, timeout=2.0)
|
|
await consumer
|
|
|
|
assert result["ok"] is True
|
|
assert result["removed"] == ["browserbase"]
|
|
assert result["requires_restart"] is False
|
|
assert not loop.tools.has("mcp_browserbase_navigate")
|
|
assert closed == ["browserbase"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reload_mcp_servers_retries_configured_server_without_live_stack(
|
|
tmp_path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
):
|
|
config_path = tmp_path / "config.json"
|
|
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
|
|
config = load_config()
|
|
config.tools.mcp_servers["browserbase"] = MCPServerConfig(
|
|
type="stdio",
|
|
command="browserbase-mcp",
|
|
)
|
|
save_config(config)
|
|
|
|
async def _fake_connect(servers, registry):
|
|
stacks = {}
|
|
for name in servers:
|
|
registry.register(_FakeMcpTool(f"mcp_{name}_navigate"))
|
|
stack = AsyncExitStack()
|
|
await stack.__aenter__()
|
|
stacks[name] = stack
|
|
return stacks
|
|
|
|
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
|
loop = _make_loop(tmp_path, mcp_servers={"browserbase": config.tools.mcp_servers["browserbase"]})
|
|
|
|
result = await mcp_runtime.reload_servers(loop, loop.tools)
|
|
|
|
assert result["ok"] is True
|
|
assert result["added"] == []
|
|
assert result["changed"] == []
|
|
assert result["retried"] == ["browserbase"]
|
|
assert loop.tools.has("mcp_browserbase_navigate")
|
|
await loop.close_mcp()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_tool_reconnects_after_session_terminated(
|
|
tmp_path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
):
|
|
loop = _make_loop(tmp_path, mcp_servers={"remote": object()})
|
|
closed: list[str] = []
|
|
sessions: list[Any] = []
|
|
connect_count = 0
|
|
|
|
async def _mark_closed(name: str) -> None:
|
|
closed.append(name)
|
|
|
|
class _FakeSession:
|
|
def __init__(self, index: int) -> None:
|
|
self.index = index
|
|
self.call_count = 0
|
|
|
|
async def call_tool(self, _name: str, arguments: dict[str, Any]) -> Any:
|
|
self.call_count += 1
|
|
assert arguments == {"symbol": "AAPL"}
|
|
if self.index == 1:
|
|
raise McpError(ErrorData(code=-32000, message="Session terminated"))
|
|
return SimpleNamespace(
|
|
content=[mcp_types.TextContent(type="text", text="recovered")]
|
|
)
|
|
|
|
async def _fake_connect(servers, registry):
|
|
nonlocal connect_count
|
|
stacks = {}
|
|
for name in servers:
|
|
connect_count += 1
|
|
session = _FakeSession(connect_count)
|
|
sessions.append(session)
|
|
tool_def = SimpleNamespace(
|
|
name="quote",
|
|
description="quote tool",
|
|
inputSchema={"type": "object", "properties": {}},
|
|
)
|
|
registry.register(MCPToolWrapper(session, name, tool_def, tool_timeout=5))
|
|
stack = AsyncExitStack()
|
|
await stack.__aenter__()
|
|
stack.push_async_callback(_mark_closed, name)
|
|
stacks[name] = stack
|
|
return stacks
|
|
|
|
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
|
|
|
await loop._connect_mcp()
|
|
old_tool = loop.tools.get("mcp_remote_quote")
|
|
assert isinstance(old_tool, MCPToolWrapper)
|
|
|
|
output = await old_tool.execute(symbol="AAPL")
|
|
|
|
assert output == "recovered"
|
|
assert connect_count == 2
|
|
assert closed == ["remote"]
|
|
assert sessions[0].call_count == 1
|
|
assert sessions[1].call_count == 1
|
|
assert "remote" in loop._mcp_stacks
|
|
assert loop.tools.get("mcp_remote_quote") is not old_tool
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_mcp_reconnect_reuses_fresh_session(
|
|
tmp_path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
):
|
|
loop = _make_loop(tmp_path, mcp_servers={"remote": object()})
|
|
closed: list[str] = []
|
|
connect_count = 0
|
|
|
|
async def _mark_closed(name: str) -> None:
|
|
closed.append(name)
|
|
|
|
class _DeadSession:
|
|
async def read_resource(self, _uri: str) -> Any:
|
|
raise McpError(ErrorData(code=-32000, message="Session terminated"))
|
|
|
|
class _LiveSession:
|
|
async def read_resource(self, uri: str) -> Any:
|
|
await asyncio.sleep(0)
|
|
return SimpleNamespace(
|
|
contents=[
|
|
mcp_types.TextResourceContents(
|
|
uri=uri,
|
|
text=f"fresh:{uri.rsplit('/', maxsplit=1)[-1]}",
|
|
)
|
|
]
|
|
)
|
|
|
|
async def _fake_connect(servers, registry):
|
|
nonlocal connect_count
|
|
stacks = {}
|
|
for name in servers:
|
|
connect_count += 1
|
|
session = _DeadSession() if connect_count == 1 else _LiveSession()
|
|
for resource_name in ("alpha", "beta"):
|
|
resource_def = SimpleNamespace(
|
|
name=resource_name,
|
|
uri=f"file:///{resource_name}",
|
|
description=f"{resource_name} resource",
|
|
)
|
|
registry.register(MCPResourceWrapper(session, name, resource_def))
|
|
stack = AsyncExitStack()
|
|
await stack.__aenter__()
|
|
stack.push_async_callback(_mark_closed, name)
|
|
stacks[name] = stack
|
|
return stacks
|
|
|
|
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
|
|
|
await loop._connect_mcp()
|
|
old_alpha = loop.tools.get("mcp_remote_resource_alpha")
|
|
old_beta = loop.tools.get("mcp_remote_resource_beta")
|
|
assert isinstance(old_alpha, MCPResourceWrapper)
|
|
assert isinstance(old_beta, MCPResourceWrapper)
|
|
|
|
outputs = await asyncio.gather(old_alpha.execute(), old_beta.execute())
|
|
|
|
assert outputs == ["fresh:alpha", "fresh:beta"]
|
|
assert connect_count == 2
|
|
assert closed == ["remote"]
|