From e9145b7acd92c32e0143530233c0f38e95b9cb5b Mon Sep 17 00:00:00 2001 From: chengyongru Date: Wed, 3 Jun 2026 13:26:28 +0800 Subject: [PATCH] fix(mcp): reconnect terminated sessions --- nanobot/agent/tools/mcp.py | 180 ++++++++++++++++++++++-- tests/agent/test_mcp_connection.py | 130 +++++++++++++++++ tests/agent/test_mcp_transient_retry.py | 98 +++++++++++++ 3 files changed, 393 insertions(+), 15 deletions(-) diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index e26a434db..114772f0a 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -5,6 +5,7 @@ import os import re import shutil import urllib.parse +from collections.abc import Awaitable, Callable from contextlib import AsyncExitStack, suppress from typing import Any, Mapping from weakref import WeakKeyDictionary @@ -41,6 +42,7 @@ _WINDOWS_SHELL_LAUNCHERS: frozenset[str] = frozenset(("npx", "npm", "pnpm", "yar # Replace anything outside [a-zA-Z0-9_-] with underscore and collapse runs. _SANITIZE_RE = re.compile(r"_+") _RELOAD_LOCKS: WeakKeyDictionary[Any, asyncio.Lock] = WeakKeyDictionary() +_ReconnectCallback = Callable[[str, str, Tool], Awaitable[Tool | None]] def _sanitize_name(name: str) -> str: @@ -53,6 +55,19 @@ def _is_transient(exc: BaseException) -> bool: return type(exc).__name__ in _TRANSIENT_EXC_NAMES +def _is_session_terminated(exc: BaseException) -> bool: + """Return True when the MCP SDK reports a dead client session.""" + messages = [str(exc)] + error = getattr(exc, "error", None) + if error is not None: + messages.append(str(getattr(error, "message", ""))) + return any( + marker in message.lower() + for marker in ("session terminated", "connection closed") + for message in messages + ) + + async def _probe_http_url(url: str, timeout: float = 3.0) -> bool: """Quick TCP probe to check if an HTTP MCP server is reachable. @@ -68,7 +83,8 @@ async def _probe_http_url(url: str, timeout: float = 3.0) -> bool: port = 443 if parsed.scheme == "https" else 80 try: reader, writer = await asyncio.wait_for( - asyncio.open_connection(host, port), timeout=timeout, + asyncio.open_connection(host, port), + timeout=timeout, ) writer.close() await writer.wait_closed() @@ -174,13 +190,54 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]: return normalized -class MCPToolWrapper(Tool): +class _MCPWrapperBase(Tool): + """Common reconnect handling for wrappers bound to one MCP server session.""" + + _plugin_discoverable = False + + def _set_mcp_connection(self, session: Any, server_name: str) -> None: + self._session = session + self._server_name = server_name + self._reconnect: _ReconnectCallback | None = None + + def set_reconnect_handler(self, reconnect: _ReconnectCallback) -> None: + self._reconnect = reconnect + + async def _refresh_session_after_termination( + self, + exc: BaseException, + already_refreshed: bool, + capability_kind: str, + ) -> bool: + if already_refreshed or not _is_session_terminated(exc) or self._reconnect is None: + return False + logger.warning( + "MCP {} '{}' session terminated; reconnecting server '{}' before retry", + capability_kind, + self._name, + self._server_name, + ) + refreshed_tool = await self._reconnect(self._server_name, self._name, self) + refreshed_session = getattr(refreshed_tool, "_session", None) + if refreshed_session is None: + logger.warning( + "MCP {} '{}' could not refresh session for server '{}'", + capability_kind, + self._name, + self._server_name, + ) + return False + self._session = refreshed_session + return True + + +class MCPToolWrapper(_MCPWrapperBase): """Wraps a single MCP server tool as a nanobot Tool.""" _plugin_discoverable = False def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30): - self._session = session + self._set_mcp_connection(session, server_name) self._original_name = tool_def.name self._name = _sanitize_name(f"mcp_{server_name}_{tool_def.name}") self._description = tool_def.description or tool_def.name @@ -203,7 +260,9 @@ class MCPToolWrapper(Tool): async def execute(self, **kwargs: Any) -> str: from mcp import types - for attempt in range(2): # At most 1 retry + retried_transient = False + refreshed_session = False + while True: try: result = await asyncio.wait_for( self._session.call_tool(self._original_name, arguments=kwargs), @@ -223,8 +282,16 @@ class MCPToolWrapper(Tool): logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name) return "(MCP tool call was cancelled)" except Exception as exc: + if await self._refresh_session_after_termination( + exc, + refreshed_session, + "tool", + ): + refreshed_session = True + continue if _is_transient(exc): - if attempt == 0: + if not retried_transient: + retried_transient = True logger.warning( "MCP tool '{}' hit transient error ({}), retrying once...", self._name, @@ -259,13 +326,13 @@ class MCPToolWrapper(Tool): return "(MCP tool call failed)" # Unreachable, but satisfies type checkers -class MCPResourceWrapper(Tool): +class MCPResourceWrapper(_MCPWrapperBase): """Wraps an MCP resource URI as a read-only nanobot Tool.""" _plugin_discoverable = False def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30): - self._session = session + self._set_mcp_connection(session, server_name) self._uri = resource_def.uri self._name = _sanitize_name(f"mcp_{server_name}_resource_{resource_def.name}") desc = resource_def.description or resource_def.name @@ -296,7 +363,9 @@ class MCPResourceWrapper(Tool): async def execute(self, **kwargs: Any) -> str: from mcp import types - for attempt in range(2): + retried_transient = False + refreshed_session = False + while True: try: result = await asyncio.wait_for( self._session.read_resource(self._uri), @@ -314,8 +383,16 @@ class MCPResourceWrapper(Tool): logger.warning("MCP resource '{}' was cancelled by server/SDK", self._name) return "(MCP resource read was cancelled)" except Exception as exc: + if await self._refresh_session_after_termination( + exc, + refreshed_session, + "resource", + ): + refreshed_session = True + continue if _is_transient(exc): - if attempt == 0: + if not retried_transient: + retried_transient = True logger.warning( "MCP resource '{}' hit transient error ({}), retrying once...", self._name, @@ -350,13 +427,13 @@ class MCPResourceWrapper(Tool): return "(MCP resource read failed)" # Unreachable -class MCPPromptWrapper(Tool): +class MCPPromptWrapper(_MCPWrapperBase): """Wraps an MCP prompt as a read-only nanobot Tool.""" _plugin_discoverable = False def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30): - self._session = session + self._set_mcp_connection(session, server_name) self._prompt_name = prompt_def.name self._name = _sanitize_name(f"mcp_{server_name}_prompt_{prompt_def.name}") desc = prompt_def.description or prompt_def.name @@ -402,7 +479,9 @@ class MCPPromptWrapper(Tool): from mcp import types from mcp.shared.exceptions import McpError - for attempt in range(2): + retried_transient = False + refreshed_session = False + while True: try: result = await asyncio.wait_for( self._session.get_prompt(self._prompt_name, arguments=kwargs), @@ -420,6 +499,13 @@ class MCPPromptWrapper(Tool): logger.warning("MCP prompt '{}' was cancelled by server/SDK", self._name) return "(MCP prompt call was cancelled)" except McpError as exc: + if await self._refresh_session_after_termination( + exc, + refreshed_session, + "prompt", + ): + refreshed_session = True + continue logger.exception( "MCP prompt '{}' failed: code={} message={}", self._name, @@ -429,7 +515,8 @@ class MCPPromptWrapper(Tool): return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])" except Exception as exc: if _is_transient(exc): - if attempt == 0: + if not retried_transient: + retried_transient = True logger.warning( "MCP prompt '{}' hit transient error ({}), retrying once...", self._name, @@ -747,6 +834,7 @@ async def connect_missing_servers(state: Any, registry: ToolRegistry) -> None: try: connected = await connect_mcp_servers(missing_servers, registry) state._mcp_stacks.update(connected) + _attach_reconnect_handlers(state, registry, connected) state._mcp_connected = bool(state._mcp_stacks) if connected: logger.info("MCP connected servers: {}", sorted(connected)) @@ -766,8 +854,7 @@ async def reload_servers(state: Any, registry: ToolRegistry) -> dict[str, Any]: """Reconcile live MCP connections with the current config file.""" async with _reload_lock(state): try: - from nanobot.config.loader import (load_config, - resolve_config_env_vars) + from nanobot.config.loader import load_config, resolve_config_env_vars config = resolve_config_env_vars(load_config()) next_servers = dict(config.tools.mcp_servers) @@ -808,6 +895,7 @@ async def reload_servers(state: Any, registry: ToolRegistry) -> dict[str, Any]: if to_connect: connected = await connect_mcp_servers(to_connect, registry) state._mcp_stacks.update(connected) + _attach_reconnect_handlers(state, registry, connected) state._mcp_connected = bool(state._mcp_stacks) failed = sorted(set(to_connect) - set(connected)) @@ -909,6 +997,68 @@ def _reload_lock(state: Any) -> asyncio.Lock: return lock +def _attach_reconnect_handlers( + state: Any, + registry: ToolRegistry, + server_names: Mapping[str, Any] | set[str] | list[str] | tuple[str, ...], +) -> None: + async def reconnect(server_name: str, tool_name: str, stale_tool: Tool) -> Tool | None: + return await _refresh_terminated_server( + state, + registry, + server_name, + tool_name, + stale_tool, + ) + + for server_name in server_names: + prefix = _tool_prefix(server_name) + for tool_name in list(registry.tool_names): + if not tool_name.startswith(prefix): + continue + tool = registry.get(tool_name) + if isinstance(tool, _MCPWrapperBase): + tool.set_reconnect_handler(reconnect) + + +async def _refresh_terminated_server( + state: Any, + registry: ToolRegistry, + server_name: str, + tool_name: str, + stale_tool: Tool, +) -> Tool | None: + async with _reload_lock(state): + cfg = state._mcp_servers.get(server_name) + if cfg is None: + logger.warning( + "MCP server '{}' session terminated but is no longer configured", + server_name, + ) + return None + + current_tool = registry.get(tool_name) + if ( + current_tool is not None + and current_tool is not stale_tool + and server_name in state._mcp_stacks + ): + return current_tool + + logger.warning("MCP server '{}' session terminated; refreshing connection", server_name) + _unregister_server_tools(state, registry, server_name) + await _close_server(state, server_name) + + connected = await connect_mcp_servers({server_name: cfg}, registry) + state._mcp_stacks.update(connected) + _attach_reconnect_handlers(state, registry, connected) + state._mcp_connected = bool(state._mcp_stacks) + if server_name not in connected: + logger.warning("MCP server '{}' reconnect failed after session termination", server_name) + return None + return registry.get(tool_name) + + def _server_signature(cfg: Any) -> Any: if hasattr(cfg, "model_dump"): return cfg.model_dump(mode="json") diff --git a/tests/agent/test_mcp_connection.py b/tests/agent/test_mcp_connection.py index 18d118c25..78dde4187 100644 --- a/tests/agent/test_mcp_connection.py +++ b/tests/agent/test_mcp_connection.py @@ -4,14 +4,19 @@ from __future__ import annotations import asyncio from contextlib import AsyncExitStack +from types import SimpleNamespace from typing import Any from unittest.mock import MagicMock import pytest +from mcp import types as mcp_types +from mcp.shared.exceptions import McpError +from mcp.types import ErrorData from nanobot.agent.loop import AgentLoop from nanobot.agent.tools import mcp as mcp_runtime from nanobot.agent.tools.base import Tool +from nanobot.agent.tools.mcp import MCPResourceWrapper, MCPToolWrapper from nanobot.bus.queue import MessageBus from nanobot.config.loader import load_config, save_config from nanobot.config.schema import MCPServerConfig @@ -218,3 +223,128 @@ async def test_reload_mcp_servers_retries_configured_server_without_live_stack( assert result["retried"] == ["browserbase"] assert loop.tools.has("mcp_browserbase_navigate") await loop.close_mcp() + + +@pytest.mark.asyncio +async def test_mcp_tool_reconnects_after_session_terminated( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +): + loop = _make_loop(tmp_path, mcp_servers={"remote": object()}) + closed: list[str] = [] + sessions: list[Any] = [] + connect_count = 0 + + async def _mark_closed(name: str) -> None: + closed.append(name) + + class _FakeSession: + def __init__(self, index: int) -> None: + self.index = index + self.call_count = 0 + + async def call_tool(self, _name: str, arguments: dict[str, Any]) -> Any: + self.call_count += 1 + assert arguments == {"symbol": "AAPL"} + if self.index == 1: + raise McpError(ErrorData(code=-32000, message="Session terminated")) + return SimpleNamespace( + content=[mcp_types.TextContent(type="text", text="recovered")] + ) + + async def _fake_connect(servers, registry): + nonlocal connect_count + stacks = {} + for name in servers: + connect_count += 1 + session = _FakeSession(connect_count) + sessions.append(session) + tool_def = SimpleNamespace( + name="quote", + description="quote tool", + inputSchema={"type": "object", "properties": {}}, + ) + registry.register(MCPToolWrapper(session, name, tool_def, tool_timeout=5)) + stack = AsyncExitStack() + await stack.__aenter__() + stack.push_async_callback(_mark_closed, name) + stacks[name] = stack + return stacks + + monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect) + + await loop._connect_mcp() + old_tool = loop.tools.get("mcp_remote_quote") + assert isinstance(old_tool, MCPToolWrapper) + + output = await old_tool.execute(symbol="AAPL") + + assert output == "recovered" + assert connect_count == 2 + assert closed == ["remote"] + assert sessions[0].call_count == 1 + assert sessions[1].call_count == 1 + assert "remote" in loop._mcp_stacks + assert loop.tools.get("mcp_remote_quote") is not old_tool + + +@pytest.mark.asyncio +async def test_concurrent_mcp_reconnect_reuses_fresh_session( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +): + loop = _make_loop(tmp_path, mcp_servers={"remote": object()}) + closed: list[str] = [] + connect_count = 0 + + async def _mark_closed(name: str) -> None: + closed.append(name) + + class _DeadSession: + async def read_resource(self, _uri: str) -> Any: + raise McpError(ErrorData(code=-32000, message="Session terminated")) + + class _LiveSession: + async def read_resource(self, uri: str) -> Any: + await asyncio.sleep(0) + return SimpleNamespace( + contents=[ + mcp_types.TextResourceContents( + uri=uri, + text=f"fresh:{uri.rsplit('/', maxsplit=1)[-1]}", + ) + ] + ) + + async def _fake_connect(servers, registry): + nonlocal connect_count + stacks = {} + for name in servers: + connect_count += 1 + session = _DeadSession() if connect_count == 1 else _LiveSession() + for resource_name in ("alpha", "beta"): + resource_def = SimpleNamespace( + name=resource_name, + uri=f"file:///{resource_name}", + description=f"{resource_name} resource", + ) + registry.register(MCPResourceWrapper(session, name, resource_def)) + stack = AsyncExitStack() + await stack.__aenter__() + stack.push_async_callback(_mark_closed, name) + stacks[name] = stack + return stacks + + monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect) + + await loop._connect_mcp() + old_alpha = loop.tools.get("mcp_remote_resource_alpha") + old_beta = loop.tools.get("mcp_remote_resource_beta") + assert isinstance(old_alpha, MCPResourceWrapper) + assert isinstance(old_beta, MCPResourceWrapper) + + outputs = await asyncio.gather(old_alpha.execute(), old_beta.execute()) + + assert outputs == ["fresh:alpha", "fresh:beta"] + assert connect_count == 2 + assert closed == ["remote"] diff --git a/tests/agent/test_mcp_transient_retry.py b/tests/agent/test_mcp_transient_retry.py index c7ef0c4ab..573c43d08 100644 --- a/tests/agent/test_mcp_transient_retry.py +++ b/tests/agent/test_mcp_transient_retry.py @@ -13,6 +13,7 @@ from nanobot.agent.tools.mcp import ( MCPPromptWrapper, MCPResourceWrapper, MCPToolWrapper, + _is_session_terminated, _is_transient, ) @@ -35,6 +36,14 @@ class _FakeEndOfStreamError(Exception): _FakeEndOfStreamError.__name__ = "EndOfStream" +def _session_terminated_error() -> McpError: + return McpError(ErrorData(code=-32000, message="Session terminated")) + + +def _connection_closed_error() -> McpError: + return McpError(ErrorData(code=-32000, message="Connection closed")) + + def test_is_transient_recognizes_closed_resource(): assert _is_transient(_FakeClosedResourceError("gone")) @@ -67,6 +76,14 @@ def test_is_transient_rejects_timeout(): assert not _is_transient(TimeoutError("timeout")) +def test_is_session_terminated_recognizes_mcp_error(): + assert _is_session_terminated(_session_terminated_error()) + + +def test_is_session_terminated_recognizes_connection_closed_mcp_error(): + assert _is_session_terminated(_connection_closed_error()) + + # --------------------------------------------------------------------------- # MCPToolWrapper retry behaviour # --------------------------------------------------------------------------- @@ -219,6 +236,35 @@ async def test_tool_retry_on_end_of_stream(): assert session.call_tool.call_count == 2 +@pytest.mark.asyncio +async def test_tool_reconnects_when_transient_retry_reveals_terminated_session(): + """Tool should reconnect if a stale session reports termination after transient retry.""" + old_session = AsyncMock() + old_session.call_tool = AsyncMock( + side_effect=[_FakeClosedResourceError("closed"), _session_terminated_error()] + ) + new_session = AsyncMock() + new_session.call_tool = AsyncMock(return_value=_make_tool_result("fresh")) + + wrapper = MCPToolWrapper(old_session, "test_server", _make_tool_def(), tool_timeout=5) + replacement = MCPToolWrapper(new_session, "test_server", _make_tool_def(), tool_timeout=5) + + async def reconnect(server_name: str, tool_name: str, stale_tool): + assert server_name == "test_server" + assert tool_name == "mcp_test_server_test_tool" + assert stale_tool is wrapper + return replacement + + wrapper.set_reconnect_handler(reconnect) + + with patch("nanobot.agent.tools.mcp.asyncio.sleep", new_callable=AsyncMock): + output = await wrapper.execute(foo="bar") + + assert output == "fresh" + assert old_session.call_tool.call_count == 2 + assert new_session.call_tool.call_count == 1 + + # --------------------------------------------------------------------------- # MCPResourceWrapper retry behaviour # --------------------------------------------------------------------------- @@ -284,6 +330,32 @@ async def test_resource_no_retry_on_non_transient(): assert session.read_resource.call_count == 1 +@pytest.mark.asyncio +async def test_resource_reconnects_on_session_terminated(): + """Resource should reconnect once when the MCP SDK reports a dead session.""" + old_session = AsyncMock() + old_session.read_resource = AsyncMock(side_effect=_session_terminated_error()) + new_session = AsyncMock() + new_session.read_resource = AsyncMock(return_value=_make_resource_result("fresh")) + + wrapper = MCPResourceWrapper(old_session, "test_server", _make_resource_def()) + replacement = MCPResourceWrapper(new_session, "test_server", _make_resource_def()) + + async def reconnect(server_name: str, tool_name: str, stale_tool): + assert server_name == "test_server" + assert tool_name == "mcp_test_server_resource_test_resource" + assert stale_tool is wrapper + return replacement + + wrapper.set_reconnect_handler(reconnect) + + output = await wrapper.execute() + + assert output == "fresh" + assert old_session.read_resource.call_count == 1 + assert new_session.read_resource.call_count == 1 + + # --------------------------------------------------------------------------- # MCPPromptWrapper retry behaviour # --------------------------------------------------------------------------- @@ -366,3 +438,29 @@ async def test_prompt_no_retry_on_non_transient(): assert "RuntimeError" in output assert session.get_prompt.call_count == 1 + + +@pytest.mark.asyncio +async def test_prompt_reconnects_on_session_terminated(): + """Prompt should reconnect once before falling back to McpError handling.""" + old_session = AsyncMock() + old_session.get_prompt = AsyncMock(side_effect=_session_terminated_error()) + new_session = AsyncMock() + new_session.get_prompt = AsyncMock(return_value=_make_prompt_result("fresh prompt")) + + wrapper = MCPPromptWrapper(old_session, "test_server", _make_prompt_def()) + replacement = MCPPromptWrapper(new_session, "test_server", _make_prompt_def()) + + async def reconnect(server_name: str, tool_name: str, stale_tool): + assert server_name == "test_server" + assert tool_name == "mcp_test_server_prompt_test_prompt" + assert stale_tool is wrapper + return replacement + + wrapper.set_reconnect_handler(reconnect) + + output = await wrapper.execute() + + assert output == "fresh prompt" + assert old_session.get_prompt.call_count == 1 + assert new_session.get_prompt.call_count == 1