From d0eba7cd9dd3e1fb6fe62942d12e9379c7aadea9 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Wed, 3 Jun 2026 14:12:45 +0800 Subject: [PATCH] fix: cover MCP reconnect edge cases maintainer edit: handle prompt sessions that report Connection closed outside McpError, and match reconnect registration prefixes with the same sanitization used by MCP wrapper names. --- nanobot/agent/tools/mcp.py | 12 ++++-- tests/agent/test_mcp_connection.py | 49 +++++++++++++++++++++++++ tests/agent/test_mcp_transient_retry.py | 26 +++++++++++++ 3 files changed, 83 insertions(+), 4 deletions(-) diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index 114772f0a..59a41127e 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -514,6 +514,13 @@ class MCPPromptWrapper(_MCPWrapperBase): ) return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])" except Exception as exc: + if await self._refresh_session_after_termination( + exc, + refreshed_session, + "prompt", + ): + refreshed_session = True + continue if _is_transient(exc): if not retried_transient: retried_transient = True @@ -1066,10 +1073,7 @@ def _server_signature(cfg: Any) -> Any: def _tool_prefix(server_name: str) -> str: - safe_name = "".join(ch if ch.isalnum() or ch in {"_", "-"} else "_" for ch in server_name) - while "__" in safe_name: - safe_name = safe_name.replace("__", "_") - return f"mcp_{safe_name}_" + return _sanitize_name(f"mcp_{server_name}_") def _unregister_server_tools(state: Any, registry: ToolRegistry, server_name: str) -> int: diff --git a/tests/agent/test_mcp_connection.py b/tests/agent/test_mcp_connection.py index 78dde4187..d9b32520c 100644 --- a/tests/agent/test_mcp_connection.py +++ b/tests/agent/test_mcp_connection.py @@ -288,6 +288,55 @@ async def test_mcp_tool_reconnects_after_session_terminated( assert loop.tools.get("mcp_remote_quote") is not old_tool +@pytest.mark.asyncio +async def test_mcp_reconnect_handler_uses_sanitized_server_prefix( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +): + loop = _make_loop(tmp_path, mcp_servers={"remote_": object()}) + connect_count = 0 + + class _FakeSession: + def __init__(self, index: int) -> None: + self.index = index + + async def call_tool(self, _name: str, arguments: dict[str, Any]) -> Any: + assert arguments == {} + if self.index == 1: + raise McpError(ErrorData(code=-32000, message="Session terminated")) + return SimpleNamespace( + content=[mcp_types.TextContent(type="text", text="recovered")] + ) + + async def _fake_connect(servers, registry): + nonlocal connect_count + stacks = {} + for name in servers: + connect_count += 1 + tool_def = SimpleNamespace( + name="quote", + description="quote tool", + inputSchema={"type": "object", "properties": {}}, + ) + registry.register(MCPToolWrapper(_FakeSession(connect_count), name, tool_def)) + stack = AsyncExitStack() + await stack.__aenter__() + stacks[name] = stack + return stacks + + monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect) + + await loop._connect_mcp() + old_tool = loop.tools.get("mcp_remote_quote") + assert isinstance(old_tool, MCPToolWrapper) + + output = await old_tool.execute() + + assert output == "recovered" + assert connect_count == 2 + assert loop.tools.get("mcp_remote_quote") is not old_tool + + @pytest.mark.asyncio async def test_concurrent_mcp_reconnect_reuses_fresh_session( tmp_path, diff --git a/tests/agent/test_mcp_transient_retry.py b/tests/agent/test_mcp_transient_retry.py index 573c43d08..8a0246aca 100644 --- a/tests/agent/test_mcp_transient_retry.py +++ b/tests/agent/test_mcp_transient_retry.py @@ -464,3 +464,29 @@ async def test_prompt_reconnects_on_session_terminated(): assert output == "fresh prompt" assert old_session.get_prompt.call_count == 1 assert new_session.get_prompt.call_count == 1 + + +@pytest.mark.asyncio +async def test_prompt_reconnects_on_connection_closed_exception(): + """Prompt should reconnect when the SDK reports a closed session as a generic exception.""" + old_session = AsyncMock() + old_session.get_prompt = AsyncMock(side_effect=RuntimeError("Connection closed")) + new_session = AsyncMock() + new_session.get_prompt = AsyncMock(return_value=_make_prompt_result("fresh prompt")) + + wrapper = MCPPromptWrapper(old_session, "test_server", _make_prompt_def()) + replacement = MCPPromptWrapper(new_session, "test_server", _make_prompt_def()) + + async def reconnect(server_name: str, tool_name: str, stale_tool): + assert server_name == "test_server" + assert tool_name == "mcp_test_server_prompt_test_prompt" + assert stale_tool is wrapper + return replacement + + wrapper.set_reconnect_handler(reconnect) + + output = await wrapper.execute() + + assert output == "fresh prompt" + assert old_session.get_prompt.call_count == 1 + assert new_session.get_prompt.call_count == 1