fix(mcp): reconnect terminated sessions

This commit is contained in:
chengyongru 2026-06-03 13:26:28 +08:00 committed by Xubin Ren
parent 7c3808327f
commit e9145b7acd
3 changed files with 393 additions and 15 deletions

View File

@ -5,6 +5,7 @@ import os
import re
import shutil
import urllib.parse
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack, suppress
from typing import Any, Mapping
from weakref import WeakKeyDictionary
@ -41,6 +42,7 @@ _WINDOWS_SHELL_LAUNCHERS: frozenset[str] = frozenset(("npx", "npm", "pnpm", "yar
# Replace anything outside [a-zA-Z0-9_-] with underscore and collapse runs.
_SANITIZE_RE = re.compile(r"_+")
_RELOAD_LOCKS: WeakKeyDictionary[Any, asyncio.Lock] = WeakKeyDictionary()
_ReconnectCallback = Callable[[str, str, Tool], Awaitable[Tool | None]]
def _sanitize_name(name: str) -> str:
@ -53,6 +55,19 @@ def _is_transient(exc: BaseException) -> bool:
return type(exc).__name__ in _TRANSIENT_EXC_NAMES
def _is_session_terminated(exc: BaseException) -> bool:
"""Return True when the MCP SDK reports a dead client session."""
messages = [str(exc)]
error = getattr(exc, "error", None)
if error is not None:
messages.append(str(getattr(error, "message", "")))
return any(
marker in message.lower()
for marker in ("session terminated", "connection closed")
for message in messages
)
async def _probe_http_url(url: str, timeout: float = 3.0) -> bool:
"""Quick TCP probe to check if an HTTP MCP server is reachable.
@ -68,7 +83,8 @@ async def _probe_http_url(url: str, timeout: float = 3.0) -> bool:
port = 443 if parsed.scheme == "https" else 80
try:
reader, writer = await asyncio.wait_for(
asyncio.open_connection(host, port), timeout=timeout,
asyncio.open_connection(host, port),
timeout=timeout,
)
writer.close()
await writer.wait_closed()
@ -174,13 +190,54 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
return normalized
class MCPToolWrapper(Tool):
class _MCPWrapperBase(Tool):
"""Common reconnect handling for wrappers bound to one MCP server session."""
_plugin_discoverable = False
def _set_mcp_connection(self, session: Any, server_name: str) -> None:
self._session = session
self._server_name = server_name
self._reconnect: _ReconnectCallback | None = None
def set_reconnect_handler(self, reconnect: _ReconnectCallback) -> None:
self._reconnect = reconnect
async def _refresh_session_after_termination(
self,
exc: BaseException,
already_refreshed: bool,
capability_kind: str,
) -> bool:
if already_refreshed or not _is_session_terminated(exc) or self._reconnect is None:
return False
logger.warning(
"MCP {} '{}' session terminated; reconnecting server '{}' before retry",
capability_kind,
self._name,
self._server_name,
)
refreshed_tool = await self._reconnect(self._server_name, self._name, self)
refreshed_session = getattr(refreshed_tool, "_session", None)
if refreshed_session is None:
logger.warning(
"MCP {} '{}' could not refresh session for server '{}'",
capability_kind,
self._name,
self._server_name,
)
return False
self._session = refreshed_session
return True
class MCPToolWrapper(_MCPWrapperBase):
"""Wraps a single MCP server tool as a nanobot Tool."""
_plugin_discoverable = False
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
self._session = session
self._set_mcp_connection(session, server_name)
self._original_name = tool_def.name
self._name = _sanitize_name(f"mcp_{server_name}_{tool_def.name}")
self._description = tool_def.description or tool_def.name
@ -203,7 +260,9 @@ class MCPToolWrapper(Tool):
async def execute(self, **kwargs: Any) -> str:
from mcp import types
for attempt in range(2): # At most 1 retry
retried_transient = False
refreshed_session = False
while True:
try:
result = await asyncio.wait_for(
self._session.call_tool(self._original_name, arguments=kwargs),
@ -223,8 +282,16 @@ class MCPToolWrapper(Tool):
logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name)
return "(MCP tool call was cancelled)"
except Exception as exc:
if await self._refresh_session_after_termination(
exc,
refreshed_session,
"tool",
):
refreshed_session = True
continue
if _is_transient(exc):
if attempt == 0:
if not retried_transient:
retried_transient = True
logger.warning(
"MCP tool '{}' hit transient error ({}), retrying once...",
self._name,
@ -259,13 +326,13 @@ class MCPToolWrapper(Tool):
return "(MCP tool call failed)" # Unreachable, but satisfies type checkers
class MCPResourceWrapper(Tool):
class MCPResourceWrapper(_MCPWrapperBase):
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
_plugin_discoverable = False
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
self._session = session
self._set_mcp_connection(session, server_name)
self._uri = resource_def.uri
self._name = _sanitize_name(f"mcp_{server_name}_resource_{resource_def.name}")
desc = resource_def.description or resource_def.name
@ -296,7 +363,9 @@ class MCPResourceWrapper(Tool):
async def execute(self, **kwargs: Any) -> str:
from mcp import types
for attempt in range(2):
retried_transient = False
refreshed_session = False
while True:
try:
result = await asyncio.wait_for(
self._session.read_resource(self._uri),
@ -314,8 +383,16 @@ class MCPResourceWrapper(Tool):
logger.warning("MCP resource '{}' was cancelled by server/SDK", self._name)
return "(MCP resource read was cancelled)"
except Exception as exc:
if await self._refresh_session_after_termination(
exc,
refreshed_session,
"resource",
):
refreshed_session = True
continue
if _is_transient(exc):
if attempt == 0:
if not retried_transient:
retried_transient = True
logger.warning(
"MCP resource '{}' hit transient error ({}), retrying once...",
self._name,
@ -350,13 +427,13 @@ class MCPResourceWrapper(Tool):
return "(MCP resource read failed)" # Unreachable
class MCPPromptWrapper(Tool):
class MCPPromptWrapper(_MCPWrapperBase):
"""Wraps an MCP prompt as a read-only nanobot Tool."""
_plugin_discoverable = False
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
self._session = session
self._set_mcp_connection(session, server_name)
self._prompt_name = prompt_def.name
self._name = _sanitize_name(f"mcp_{server_name}_prompt_{prompt_def.name}")
desc = prompt_def.description or prompt_def.name
@ -402,7 +479,9 @@ class MCPPromptWrapper(Tool):
from mcp import types
from mcp.shared.exceptions import McpError
for attempt in range(2):
retried_transient = False
refreshed_session = False
while True:
try:
result = await asyncio.wait_for(
self._session.get_prompt(self._prompt_name, arguments=kwargs),
@ -420,6 +499,13 @@ class MCPPromptWrapper(Tool):
logger.warning("MCP prompt '{}' was cancelled by server/SDK", self._name)
return "(MCP prompt call was cancelled)"
except McpError as exc:
if await self._refresh_session_after_termination(
exc,
refreshed_session,
"prompt",
):
refreshed_session = True
continue
logger.exception(
"MCP prompt '{}' failed: code={} message={}",
self._name,
@ -429,7 +515,8 @@ class MCPPromptWrapper(Tool):
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
except Exception as exc:
if _is_transient(exc):
if attempt == 0:
if not retried_transient:
retried_transient = True
logger.warning(
"MCP prompt '{}' hit transient error ({}), retrying once...",
self._name,
@ -747,6 +834,7 @@ async def connect_missing_servers(state: Any, registry: ToolRegistry) -> None:
try:
connected = await connect_mcp_servers(missing_servers, registry)
state._mcp_stacks.update(connected)
_attach_reconnect_handlers(state, registry, connected)
state._mcp_connected = bool(state._mcp_stacks)
if connected:
logger.info("MCP connected servers: {}", sorted(connected))
@ -766,8 +854,7 @@ async def reload_servers(state: Any, registry: ToolRegistry) -> dict[str, Any]:
"""Reconcile live MCP connections with the current config file."""
async with _reload_lock(state):
try:
from nanobot.config.loader import (load_config,
resolve_config_env_vars)
from nanobot.config.loader import load_config, resolve_config_env_vars
config = resolve_config_env_vars(load_config())
next_servers = dict(config.tools.mcp_servers)
@ -808,6 +895,7 @@ async def reload_servers(state: Any, registry: ToolRegistry) -> dict[str, Any]:
if to_connect:
connected = await connect_mcp_servers(to_connect, registry)
state._mcp_stacks.update(connected)
_attach_reconnect_handlers(state, registry, connected)
state._mcp_connected = bool(state._mcp_stacks)
failed = sorted(set(to_connect) - set(connected))
@ -909,6 +997,68 @@ def _reload_lock(state: Any) -> asyncio.Lock:
return lock
def _attach_reconnect_handlers(
state: Any,
registry: ToolRegistry,
server_names: Mapping[str, Any] | set[str] | list[str] | tuple[str, ...],
) -> None:
async def reconnect(server_name: str, tool_name: str, stale_tool: Tool) -> Tool | None:
return await _refresh_terminated_server(
state,
registry,
server_name,
tool_name,
stale_tool,
)
for server_name in server_names:
prefix = _tool_prefix(server_name)
for tool_name in list(registry.tool_names):
if not tool_name.startswith(prefix):
continue
tool = registry.get(tool_name)
if isinstance(tool, _MCPWrapperBase):
tool.set_reconnect_handler(reconnect)
async def _refresh_terminated_server(
state: Any,
registry: ToolRegistry,
server_name: str,
tool_name: str,
stale_tool: Tool,
) -> Tool | None:
async with _reload_lock(state):
cfg = state._mcp_servers.get(server_name)
if cfg is None:
logger.warning(
"MCP server '{}' session terminated but is no longer configured",
server_name,
)
return None
current_tool = registry.get(tool_name)
if (
current_tool is not None
and current_tool is not stale_tool
and server_name in state._mcp_stacks
):
return current_tool
logger.warning("MCP server '{}' session terminated; refreshing connection", server_name)
_unregister_server_tools(state, registry, server_name)
await _close_server(state, server_name)
connected = await connect_mcp_servers({server_name: cfg}, registry)
state._mcp_stacks.update(connected)
_attach_reconnect_handlers(state, registry, connected)
state._mcp_connected = bool(state._mcp_stacks)
if server_name not in connected:
logger.warning("MCP server '{}' reconnect failed after session termination", server_name)
return None
return registry.get(tool_name)
def _server_signature(cfg: Any) -> Any:
if hasattr(cfg, "model_dump"):
return cfg.model_dump(mode="json")

View File

@ -4,14 +4,19 @@ from __future__ import annotations
import asyncio
from contextlib import AsyncExitStack
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from mcp import types as mcp_types
from mcp.shared.exceptions import McpError
from mcp.types import ErrorData
from nanobot.agent.loop import AgentLoop
from nanobot.agent.tools import mcp as mcp_runtime
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.mcp import MCPResourceWrapper, MCPToolWrapper
from nanobot.bus.queue import MessageBus
from nanobot.config.loader import load_config, save_config
from nanobot.config.schema import MCPServerConfig
@ -218,3 +223,128 @@ async def test_reload_mcp_servers_retries_configured_server_without_live_stack(
assert result["retried"] == ["browserbase"]
assert loop.tools.has("mcp_browserbase_navigate")
await loop.close_mcp()
@pytest.mark.asyncio
async def test_mcp_tool_reconnects_after_session_terminated(
tmp_path,
monkeypatch: pytest.MonkeyPatch,
):
loop = _make_loop(tmp_path, mcp_servers={"remote": object()})
closed: list[str] = []
sessions: list[Any] = []
connect_count = 0
async def _mark_closed(name: str) -> None:
closed.append(name)
class _FakeSession:
def __init__(self, index: int) -> None:
self.index = index
self.call_count = 0
async def call_tool(self, _name: str, arguments: dict[str, Any]) -> Any:
self.call_count += 1
assert arguments == {"symbol": "AAPL"}
if self.index == 1:
raise McpError(ErrorData(code=-32000, message="Session terminated"))
return SimpleNamespace(
content=[mcp_types.TextContent(type="text", text="recovered")]
)
async def _fake_connect(servers, registry):
nonlocal connect_count
stacks = {}
for name in servers:
connect_count += 1
session = _FakeSession(connect_count)
sessions.append(session)
tool_def = SimpleNamespace(
name="quote",
description="quote tool",
inputSchema={"type": "object", "properties": {}},
)
registry.register(MCPToolWrapper(session, name, tool_def, tool_timeout=5))
stack = AsyncExitStack()
await stack.__aenter__()
stack.push_async_callback(_mark_closed, name)
stacks[name] = stack
return stacks
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
await loop._connect_mcp()
old_tool = loop.tools.get("mcp_remote_quote")
assert isinstance(old_tool, MCPToolWrapper)
output = await old_tool.execute(symbol="AAPL")
assert output == "recovered"
assert connect_count == 2
assert closed == ["remote"]
assert sessions[0].call_count == 1
assert sessions[1].call_count == 1
assert "remote" in loop._mcp_stacks
assert loop.tools.get("mcp_remote_quote") is not old_tool
@pytest.mark.asyncio
async def test_concurrent_mcp_reconnect_reuses_fresh_session(
tmp_path,
monkeypatch: pytest.MonkeyPatch,
):
loop = _make_loop(tmp_path, mcp_servers={"remote": object()})
closed: list[str] = []
connect_count = 0
async def _mark_closed(name: str) -> None:
closed.append(name)
class _DeadSession:
async def read_resource(self, _uri: str) -> Any:
raise McpError(ErrorData(code=-32000, message="Session terminated"))
class _LiveSession:
async def read_resource(self, uri: str) -> Any:
await asyncio.sleep(0)
return SimpleNamespace(
contents=[
mcp_types.TextResourceContents(
uri=uri,
text=f"fresh:{uri.rsplit('/', maxsplit=1)[-1]}",
)
]
)
async def _fake_connect(servers, registry):
nonlocal connect_count
stacks = {}
for name in servers:
connect_count += 1
session = _DeadSession() if connect_count == 1 else _LiveSession()
for resource_name in ("alpha", "beta"):
resource_def = SimpleNamespace(
name=resource_name,
uri=f"file:///{resource_name}",
description=f"{resource_name} resource",
)
registry.register(MCPResourceWrapper(session, name, resource_def))
stack = AsyncExitStack()
await stack.__aenter__()
stack.push_async_callback(_mark_closed, name)
stacks[name] = stack
return stacks
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
await loop._connect_mcp()
old_alpha = loop.tools.get("mcp_remote_resource_alpha")
old_beta = loop.tools.get("mcp_remote_resource_beta")
assert isinstance(old_alpha, MCPResourceWrapper)
assert isinstance(old_beta, MCPResourceWrapper)
outputs = await asyncio.gather(old_alpha.execute(), old_beta.execute())
assert outputs == ["fresh:alpha", "fresh:beta"]
assert connect_count == 2
assert closed == ["remote"]

View File

@ -13,6 +13,7 @@ from nanobot.agent.tools.mcp import (
MCPPromptWrapper,
MCPResourceWrapper,
MCPToolWrapper,
_is_session_terminated,
_is_transient,
)
@ -35,6 +36,14 @@ class _FakeEndOfStreamError(Exception):
_FakeEndOfStreamError.__name__ = "EndOfStream"
def _session_terminated_error() -> McpError:
return McpError(ErrorData(code=-32000, message="Session terminated"))
def _connection_closed_error() -> McpError:
return McpError(ErrorData(code=-32000, message="Connection closed"))
def test_is_transient_recognizes_closed_resource():
assert _is_transient(_FakeClosedResourceError("gone"))
@ -67,6 +76,14 @@ def test_is_transient_rejects_timeout():
assert not _is_transient(TimeoutError("timeout"))
def test_is_session_terminated_recognizes_mcp_error():
assert _is_session_terminated(_session_terminated_error())
def test_is_session_terminated_recognizes_connection_closed_mcp_error():
assert _is_session_terminated(_connection_closed_error())
# ---------------------------------------------------------------------------
# MCPToolWrapper retry behaviour
# ---------------------------------------------------------------------------
@ -219,6 +236,35 @@ async def test_tool_retry_on_end_of_stream():
assert session.call_tool.call_count == 2
@pytest.mark.asyncio
async def test_tool_reconnects_when_transient_retry_reveals_terminated_session():
"""Tool should reconnect if a stale session reports termination after transient retry."""
old_session = AsyncMock()
old_session.call_tool = AsyncMock(
side_effect=[_FakeClosedResourceError("closed"), _session_terminated_error()]
)
new_session = AsyncMock()
new_session.call_tool = AsyncMock(return_value=_make_tool_result("fresh"))
wrapper = MCPToolWrapper(old_session, "test_server", _make_tool_def(), tool_timeout=5)
replacement = MCPToolWrapper(new_session, "test_server", _make_tool_def(), tool_timeout=5)
async def reconnect(server_name: str, tool_name: str, stale_tool):
assert server_name == "test_server"
assert tool_name == "mcp_test_server_test_tool"
assert stale_tool is wrapper
return replacement
wrapper.set_reconnect_handler(reconnect)
with patch("nanobot.agent.tools.mcp.asyncio.sleep", new_callable=AsyncMock):
output = await wrapper.execute(foo="bar")
assert output == "fresh"
assert old_session.call_tool.call_count == 2
assert new_session.call_tool.call_count == 1
# ---------------------------------------------------------------------------
# MCPResourceWrapper retry behaviour
# ---------------------------------------------------------------------------
@ -284,6 +330,32 @@ async def test_resource_no_retry_on_non_transient():
assert session.read_resource.call_count == 1
@pytest.mark.asyncio
async def test_resource_reconnects_on_session_terminated():
"""Resource should reconnect once when the MCP SDK reports a dead session."""
old_session = AsyncMock()
old_session.read_resource = AsyncMock(side_effect=_session_terminated_error())
new_session = AsyncMock()
new_session.read_resource = AsyncMock(return_value=_make_resource_result("fresh"))
wrapper = MCPResourceWrapper(old_session, "test_server", _make_resource_def())
replacement = MCPResourceWrapper(new_session, "test_server", _make_resource_def())
async def reconnect(server_name: str, tool_name: str, stale_tool):
assert server_name == "test_server"
assert tool_name == "mcp_test_server_resource_test_resource"
assert stale_tool is wrapper
return replacement
wrapper.set_reconnect_handler(reconnect)
output = await wrapper.execute()
assert output == "fresh"
assert old_session.read_resource.call_count == 1
assert new_session.read_resource.call_count == 1
# ---------------------------------------------------------------------------
# MCPPromptWrapper retry behaviour
# ---------------------------------------------------------------------------
@ -366,3 +438,29 @@ async def test_prompt_no_retry_on_non_transient():
assert "RuntimeError" in output
assert session.get_prompt.call_count == 1
@pytest.mark.asyncio
async def test_prompt_reconnects_on_session_terminated():
"""Prompt should reconnect once before falling back to McpError handling."""
old_session = AsyncMock()
old_session.get_prompt = AsyncMock(side_effect=_session_terminated_error())
new_session = AsyncMock()
new_session.get_prompt = AsyncMock(return_value=_make_prompt_result("fresh prompt"))
wrapper = MCPPromptWrapper(old_session, "test_server", _make_prompt_def())
replacement = MCPPromptWrapper(new_session, "test_server", _make_prompt_def())
async def reconnect(server_name: str, tool_name: str, stale_tool):
assert server_name == "test_server"
assert tool_name == "mcp_test_server_prompt_test_prompt"
assert stale_tool is wrapper
return replacement
wrapper.set_reconnect_handler(reconnect)
output = await wrapper.execute()
assert output == "fresh prompt"
assert old_session.get_prompt.call_count == 1
assert new_session.get_prompt.call_count == 1