fix(mcp): reconnect terminated sessions

This commit is contained in:
chengyongru 2026-06-03 13:26:28 +08:00 committed by Xubin Ren
parent 7c3808327f
commit e9145b7acd
3 changed files with 393 additions and 15 deletions

View File

@ -5,6 +5,7 @@ import os
import re import re
import shutil import shutil
import urllib.parse import urllib.parse
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack, suppress from contextlib import AsyncExitStack, suppress
from typing import Any, Mapping from typing import Any, Mapping
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
@ -41,6 +42,7 @@ _WINDOWS_SHELL_LAUNCHERS: frozenset[str] = frozenset(("npx", "npm", "pnpm", "yar
# Replace anything outside [a-zA-Z0-9_-] with underscore and collapse runs. # Replace anything outside [a-zA-Z0-9_-] with underscore and collapse runs.
_SANITIZE_RE = re.compile(r"_+") _SANITIZE_RE = re.compile(r"_+")
_RELOAD_LOCKS: WeakKeyDictionary[Any, asyncio.Lock] = WeakKeyDictionary() _RELOAD_LOCKS: WeakKeyDictionary[Any, asyncio.Lock] = WeakKeyDictionary()
_ReconnectCallback = Callable[[str, str, Tool], Awaitable[Tool | None]]
def _sanitize_name(name: str) -> str: def _sanitize_name(name: str) -> str:
@ -53,6 +55,19 @@ def _is_transient(exc: BaseException) -> bool:
return type(exc).__name__ in _TRANSIENT_EXC_NAMES return type(exc).__name__ in _TRANSIENT_EXC_NAMES
def _is_session_terminated(exc: BaseException) -> bool:
"""Return True when the MCP SDK reports a dead client session."""
messages = [str(exc)]
error = getattr(exc, "error", None)
if error is not None:
messages.append(str(getattr(error, "message", "")))
return any(
marker in message.lower()
for marker in ("session terminated", "connection closed")
for message in messages
)
async def _probe_http_url(url: str, timeout: float = 3.0) -> bool: async def _probe_http_url(url: str, timeout: float = 3.0) -> bool:
"""Quick TCP probe to check if an HTTP MCP server is reachable. """Quick TCP probe to check if an HTTP MCP server is reachable.
@ -68,7 +83,8 @@ async def _probe_http_url(url: str, timeout: float = 3.0) -> bool:
port = 443 if parsed.scheme == "https" else 80 port = 443 if parsed.scheme == "https" else 80
try: try:
reader, writer = await asyncio.wait_for( reader, writer = await asyncio.wait_for(
asyncio.open_connection(host, port), timeout=timeout, asyncio.open_connection(host, port),
timeout=timeout,
) )
writer.close() writer.close()
await writer.wait_closed() await writer.wait_closed()
@ -174,13 +190,54 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
return normalized return normalized
class MCPToolWrapper(Tool): class _MCPWrapperBase(Tool):
"""Common reconnect handling for wrappers bound to one MCP server session."""
_plugin_discoverable = False
def _set_mcp_connection(self, session: Any, server_name: str) -> None:
self._session = session
self._server_name = server_name
self._reconnect: _ReconnectCallback | None = None
def set_reconnect_handler(self, reconnect: _ReconnectCallback) -> None:
self._reconnect = reconnect
async def _refresh_session_after_termination(
self,
exc: BaseException,
already_refreshed: bool,
capability_kind: str,
) -> bool:
if already_refreshed or not _is_session_terminated(exc) or self._reconnect is None:
return False
logger.warning(
"MCP {} '{}' session terminated; reconnecting server '{}' before retry",
capability_kind,
self._name,
self._server_name,
)
refreshed_tool = await self._reconnect(self._server_name, self._name, self)
refreshed_session = getattr(refreshed_tool, "_session", None)
if refreshed_session is None:
logger.warning(
"MCP {} '{}' could not refresh session for server '{}'",
capability_kind,
self._name,
self._server_name,
)
return False
self._session = refreshed_session
return True
class MCPToolWrapper(_MCPWrapperBase):
"""Wraps a single MCP server tool as a nanobot Tool.""" """Wraps a single MCP server tool as a nanobot Tool."""
_plugin_discoverable = False _plugin_discoverable = False
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30): def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
self._session = session self._set_mcp_connection(session, server_name)
self._original_name = tool_def.name self._original_name = tool_def.name
self._name = _sanitize_name(f"mcp_{server_name}_{tool_def.name}") self._name = _sanitize_name(f"mcp_{server_name}_{tool_def.name}")
self._description = tool_def.description or tool_def.name self._description = tool_def.description or tool_def.name
@ -203,7 +260,9 @@ class MCPToolWrapper(Tool):
async def execute(self, **kwargs: Any) -> str: async def execute(self, **kwargs: Any) -> str:
from mcp import types from mcp import types
for attempt in range(2): # At most 1 retry retried_transient = False
refreshed_session = False
while True:
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(
self._session.call_tool(self._original_name, arguments=kwargs), self._session.call_tool(self._original_name, arguments=kwargs),
@ -223,8 +282,16 @@ class MCPToolWrapper(Tool):
logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name) logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name)
return "(MCP tool call was cancelled)" return "(MCP tool call was cancelled)"
except Exception as exc: except Exception as exc:
if await self._refresh_session_after_termination(
exc,
refreshed_session,
"tool",
):
refreshed_session = True
continue
if _is_transient(exc): if _is_transient(exc):
if attempt == 0: if not retried_transient:
retried_transient = True
logger.warning( logger.warning(
"MCP tool '{}' hit transient error ({}), retrying once...", "MCP tool '{}' hit transient error ({}), retrying once...",
self._name, self._name,
@ -259,13 +326,13 @@ class MCPToolWrapper(Tool):
return "(MCP tool call failed)" # Unreachable, but satisfies type checkers return "(MCP tool call failed)" # Unreachable, but satisfies type checkers
class MCPResourceWrapper(Tool): class MCPResourceWrapper(_MCPWrapperBase):
"""Wraps an MCP resource URI as a read-only nanobot Tool.""" """Wraps an MCP resource URI as a read-only nanobot Tool."""
_plugin_discoverable = False _plugin_discoverable = False
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30): def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
self._session = session self._set_mcp_connection(session, server_name)
self._uri = resource_def.uri self._uri = resource_def.uri
self._name = _sanitize_name(f"mcp_{server_name}_resource_{resource_def.name}") self._name = _sanitize_name(f"mcp_{server_name}_resource_{resource_def.name}")
desc = resource_def.description or resource_def.name desc = resource_def.description or resource_def.name
@ -296,7 +363,9 @@ class MCPResourceWrapper(Tool):
async def execute(self, **kwargs: Any) -> str: async def execute(self, **kwargs: Any) -> str:
from mcp import types from mcp import types
for attempt in range(2): retried_transient = False
refreshed_session = False
while True:
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(
self._session.read_resource(self._uri), self._session.read_resource(self._uri),
@ -314,8 +383,16 @@ class MCPResourceWrapper(Tool):
logger.warning("MCP resource '{}' was cancelled by server/SDK", self._name) logger.warning("MCP resource '{}' was cancelled by server/SDK", self._name)
return "(MCP resource read was cancelled)" return "(MCP resource read was cancelled)"
except Exception as exc: except Exception as exc:
if await self._refresh_session_after_termination(
exc,
refreshed_session,
"resource",
):
refreshed_session = True
continue
if _is_transient(exc): if _is_transient(exc):
if attempt == 0: if not retried_transient:
retried_transient = True
logger.warning( logger.warning(
"MCP resource '{}' hit transient error ({}), retrying once...", "MCP resource '{}' hit transient error ({}), retrying once...",
self._name, self._name,
@ -350,13 +427,13 @@ class MCPResourceWrapper(Tool):
return "(MCP resource read failed)" # Unreachable return "(MCP resource read failed)" # Unreachable
class MCPPromptWrapper(Tool): class MCPPromptWrapper(_MCPWrapperBase):
"""Wraps an MCP prompt as a read-only nanobot Tool.""" """Wraps an MCP prompt as a read-only nanobot Tool."""
_plugin_discoverable = False _plugin_discoverable = False
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30): def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
self._session = session self._set_mcp_connection(session, server_name)
self._prompt_name = prompt_def.name self._prompt_name = prompt_def.name
self._name = _sanitize_name(f"mcp_{server_name}_prompt_{prompt_def.name}") self._name = _sanitize_name(f"mcp_{server_name}_prompt_{prompt_def.name}")
desc = prompt_def.description or prompt_def.name desc = prompt_def.description or prompt_def.name
@ -402,7 +479,9 @@ class MCPPromptWrapper(Tool):
from mcp import types from mcp import types
from mcp.shared.exceptions import McpError from mcp.shared.exceptions import McpError
for attempt in range(2): retried_transient = False
refreshed_session = False
while True:
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(
self._session.get_prompt(self._prompt_name, arguments=kwargs), self._session.get_prompt(self._prompt_name, arguments=kwargs),
@ -420,6 +499,13 @@ class MCPPromptWrapper(Tool):
logger.warning("MCP prompt '{}' was cancelled by server/SDK", self._name) logger.warning("MCP prompt '{}' was cancelled by server/SDK", self._name)
return "(MCP prompt call was cancelled)" return "(MCP prompt call was cancelled)"
except McpError as exc: except McpError as exc:
if await self._refresh_session_after_termination(
exc,
refreshed_session,
"prompt",
):
refreshed_session = True
continue
logger.exception( logger.exception(
"MCP prompt '{}' failed: code={} message={}", "MCP prompt '{}' failed: code={} message={}",
self._name, self._name,
@ -429,7 +515,8 @@ class MCPPromptWrapper(Tool):
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])" return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
except Exception as exc: except Exception as exc:
if _is_transient(exc): if _is_transient(exc):
if attempt == 0: if not retried_transient:
retried_transient = True
logger.warning( logger.warning(
"MCP prompt '{}' hit transient error ({}), retrying once...", "MCP prompt '{}' hit transient error ({}), retrying once...",
self._name, self._name,
@ -747,6 +834,7 @@ async def connect_missing_servers(state: Any, registry: ToolRegistry) -> None:
try: try:
connected = await connect_mcp_servers(missing_servers, registry) connected = await connect_mcp_servers(missing_servers, registry)
state._mcp_stacks.update(connected) state._mcp_stacks.update(connected)
_attach_reconnect_handlers(state, registry, connected)
state._mcp_connected = bool(state._mcp_stacks) state._mcp_connected = bool(state._mcp_stacks)
if connected: if connected:
logger.info("MCP connected servers: {}", sorted(connected)) logger.info("MCP connected servers: {}", sorted(connected))
@ -766,8 +854,7 @@ async def reload_servers(state: Any, registry: ToolRegistry) -> dict[str, Any]:
"""Reconcile live MCP connections with the current config file.""" """Reconcile live MCP connections with the current config file."""
async with _reload_lock(state): async with _reload_lock(state):
try: try:
from nanobot.config.loader import (load_config, from nanobot.config.loader import load_config, resolve_config_env_vars
resolve_config_env_vars)
config = resolve_config_env_vars(load_config()) config = resolve_config_env_vars(load_config())
next_servers = dict(config.tools.mcp_servers) next_servers = dict(config.tools.mcp_servers)
@ -808,6 +895,7 @@ async def reload_servers(state: Any, registry: ToolRegistry) -> dict[str, Any]:
if to_connect: if to_connect:
connected = await connect_mcp_servers(to_connect, registry) connected = await connect_mcp_servers(to_connect, registry)
state._mcp_stacks.update(connected) state._mcp_stacks.update(connected)
_attach_reconnect_handlers(state, registry, connected)
state._mcp_connected = bool(state._mcp_stacks) state._mcp_connected = bool(state._mcp_stacks)
failed = sorted(set(to_connect) - set(connected)) failed = sorted(set(to_connect) - set(connected))
@ -909,6 +997,68 @@ def _reload_lock(state: Any) -> asyncio.Lock:
return lock return lock
def _attach_reconnect_handlers(
state: Any,
registry: ToolRegistry,
server_names: Mapping[str, Any] | set[str] | list[str] | tuple[str, ...],
) -> None:
async def reconnect(server_name: str, tool_name: str, stale_tool: Tool) -> Tool | None:
return await _refresh_terminated_server(
state,
registry,
server_name,
tool_name,
stale_tool,
)
for server_name in server_names:
prefix = _tool_prefix(server_name)
for tool_name in list(registry.tool_names):
if not tool_name.startswith(prefix):
continue
tool = registry.get(tool_name)
if isinstance(tool, _MCPWrapperBase):
tool.set_reconnect_handler(reconnect)
async def _refresh_terminated_server(
state: Any,
registry: ToolRegistry,
server_name: str,
tool_name: str,
stale_tool: Tool,
) -> Tool | None:
async with _reload_lock(state):
cfg = state._mcp_servers.get(server_name)
if cfg is None:
logger.warning(
"MCP server '{}' session terminated but is no longer configured",
server_name,
)
return None
current_tool = registry.get(tool_name)
if (
current_tool is not None
and current_tool is not stale_tool
and server_name in state._mcp_stacks
):
return current_tool
logger.warning("MCP server '{}' session terminated; refreshing connection", server_name)
_unregister_server_tools(state, registry, server_name)
await _close_server(state, server_name)
connected = await connect_mcp_servers({server_name: cfg}, registry)
state._mcp_stacks.update(connected)
_attach_reconnect_handlers(state, registry, connected)
state._mcp_connected = bool(state._mcp_stacks)
if server_name not in connected:
logger.warning("MCP server '{}' reconnect failed after session termination", server_name)
return None
return registry.get(tool_name)
def _server_signature(cfg: Any) -> Any: def _server_signature(cfg: Any) -> Any:
if hasattr(cfg, "model_dump"): if hasattr(cfg, "model_dump"):
return cfg.model_dump(mode="json") return cfg.model_dump(mode="json")

View File

@ -4,14 +4,19 @@ from __future__ import annotations
import asyncio import asyncio
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from types import SimpleNamespace
from typing import Any from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from mcp import types as mcp_types
from mcp.shared.exceptions import McpError
from mcp.types import ErrorData
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.agent.tools import mcp as mcp_runtime from nanobot.agent.tools import mcp as mcp_runtime
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.mcp import MCPResourceWrapper, MCPToolWrapper
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.config.loader import load_config, save_config from nanobot.config.loader import load_config, save_config
from nanobot.config.schema import MCPServerConfig from nanobot.config.schema import MCPServerConfig
@ -218,3 +223,128 @@ async def test_reload_mcp_servers_retries_configured_server_without_live_stack(
assert result["retried"] == ["browserbase"] assert result["retried"] == ["browserbase"]
assert loop.tools.has("mcp_browserbase_navigate") assert loop.tools.has("mcp_browserbase_navigate")
await loop.close_mcp() await loop.close_mcp()
@pytest.mark.asyncio
async def test_mcp_tool_reconnects_after_session_terminated(
tmp_path,
monkeypatch: pytest.MonkeyPatch,
):
loop = _make_loop(tmp_path, mcp_servers={"remote": object()})
closed: list[str] = []
sessions: list[Any] = []
connect_count = 0
async def _mark_closed(name: str) -> None:
closed.append(name)
class _FakeSession:
def __init__(self, index: int) -> None:
self.index = index
self.call_count = 0
async def call_tool(self, _name: str, arguments: dict[str, Any]) -> Any:
self.call_count += 1
assert arguments == {"symbol": "AAPL"}
if self.index == 1:
raise McpError(ErrorData(code=-32000, message="Session terminated"))
return SimpleNamespace(
content=[mcp_types.TextContent(type="text", text="recovered")]
)
async def _fake_connect(servers, registry):
nonlocal connect_count
stacks = {}
for name in servers:
connect_count += 1
session = _FakeSession(connect_count)
sessions.append(session)
tool_def = SimpleNamespace(
name="quote",
description="quote tool",
inputSchema={"type": "object", "properties": {}},
)
registry.register(MCPToolWrapper(session, name, tool_def, tool_timeout=5))
stack = AsyncExitStack()
await stack.__aenter__()
stack.push_async_callback(_mark_closed, name)
stacks[name] = stack
return stacks
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
await loop._connect_mcp()
old_tool = loop.tools.get("mcp_remote_quote")
assert isinstance(old_tool, MCPToolWrapper)
output = await old_tool.execute(symbol="AAPL")
assert output == "recovered"
assert connect_count == 2
assert closed == ["remote"]
assert sessions[0].call_count == 1
assert sessions[1].call_count == 1
assert "remote" in loop._mcp_stacks
assert loop.tools.get("mcp_remote_quote") is not old_tool
@pytest.mark.asyncio
async def test_concurrent_mcp_reconnect_reuses_fresh_session(
tmp_path,
monkeypatch: pytest.MonkeyPatch,
):
loop = _make_loop(tmp_path, mcp_servers={"remote": object()})
closed: list[str] = []
connect_count = 0
async def _mark_closed(name: str) -> None:
closed.append(name)
class _DeadSession:
async def read_resource(self, _uri: str) -> Any:
raise McpError(ErrorData(code=-32000, message="Session terminated"))
class _LiveSession:
async def read_resource(self, uri: str) -> Any:
await asyncio.sleep(0)
return SimpleNamespace(
contents=[
mcp_types.TextResourceContents(
uri=uri,
text=f"fresh:{uri.rsplit('/', maxsplit=1)[-1]}",
)
]
)
async def _fake_connect(servers, registry):
nonlocal connect_count
stacks = {}
for name in servers:
connect_count += 1
session = _DeadSession() if connect_count == 1 else _LiveSession()
for resource_name in ("alpha", "beta"):
resource_def = SimpleNamespace(
name=resource_name,
uri=f"file:///{resource_name}",
description=f"{resource_name} resource",
)
registry.register(MCPResourceWrapper(session, name, resource_def))
stack = AsyncExitStack()
await stack.__aenter__()
stack.push_async_callback(_mark_closed, name)
stacks[name] = stack
return stacks
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
await loop._connect_mcp()
old_alpha = loop.tools.get("mcp_remote_resource_alpha")
old_beta = loop.tools.get("mcp_remote_resource_beta")
assert isinstance(old_alpha, MCPResourceWrapper)
assert isinstance(old_beta, MCPResourceWrapper)
outputs = await asyncio.gather(old_alpha.execute(), old_beta.execute())
assert outputs == ["fresh:alpha", "fresh:beta"]
assert connect_count == 2
assert closed == ["remote"]

View File

@ -13,6 +13,7 @@ from nanobot.agent.tools.mcp import (
MCPPromptWrapper, MCPPromptWrapper,
MCPResourceWrapper, MCPResourceWrapper,
MCPToolWrapper, MCPToolWrapper,
_is_session_terminated,
_is_transient, _is_transient,
) )
@ -35,6 +36,14 @@ class _FakeEndOfStreamError(Exception):
_FakeEndOfStreamError.__name__ = "EndOfStream" _FakeEndOfStreamError.__name__ = "EndOfStream"
def _session_terminated_error() -> McpError:
return McpError(ErrorData(code=-32000, message="Session terminated"))
def _connection_closed_error() -> McpError:
return McpError(ErrorData(code=-32000, message="Connection closed"))
def test_is_transient_recognizes_closed_resource(): def test_is_transient_recognizes_closed_resource():
assert _is_transient(_FakeClosedResourceError("gone")) assert _is_transient(_FakeClosedResourceError("gone"))
@ -67,6 +76,14 @@ def test_is_transient_rejects_timeout():
assert not _is_transient(TimeoutError("timeout")) assert not _is_transient(TimeoutError("timeout"))
def test_is_session_terminated_recognizes_mcp_error():
assert _is_session_terminated(_session_terminated_error())
def test_is_session_terminated_recognizes_connection_closed_mcp_error():
assert _is_session_terminated(_connection_closed_error())
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# MCPToolWrapper retry behaviour # MCPToolWrapper retry behaviour
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -219,6 +236,35 @@ async def test_tool_retry_on_end_of_stream():
assert session.call_tool.call_count == 2 assert session.call_tool.call_count == 2
@pytest.mark.asyncio
async def test_tool_reconnects_when_transient_retry_reveals_terminated_session():
"""Tool should reconnect if a stale session reports termination after transient retry."""
old_session = AsyncMock()
old_session.call_tool = AsyncMock(
side_effect=[_FakeClosedResourceError("closed"), _session_terminated_error()]
)
new_session = AsyncMock()
new_session.call_tool = AsyncMock(return_value=_make_tool_result("fresh"))
wrapper = MCPToolWrapper(old_session, "test_server", _make_tool_def(), tool_timeout=5)
replacement = MCPToolWrapper(new_session, "test_server", _make_tool_def(), tool_timeout=5)
async def reconnect(server_name: str, tool_name: str, stale_tool):
assert server_name == "test_server"
assert tool_name == "mcp_test_server_test_tool"
assert stale_tool is wrapper
return replacement
wrapper.set_reconnect_handler(reconnect)
with patch("nanobot.agent.tools.mcp.asyncio.sleep", new_callable=AsyncMock):
output = await wrapper.execute(foo="bar")
assert output == "fresh"
assert old_session.call_tool.call_count == 2
assert new_session.call_tool.call_count == 1
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# MCPResourceWrapper retry behaviour # MCPResourceWrapper retry behaviour
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -284,6 +330,32 @@ async def test_resource_no_retry_on_non_transient():
assert session.read_resource.call_count == 1 assert session.read_resource.call_count == 1
@pytest.mark.asyncio
async def test_resource_reconnects_on_session_terminated():
"""Resource should reconnect once when the MCP SDK reports a dead session."""
old_session = AsyncMock()
old_session.read_resource = AsyncMock(side_effect=_session_terminated_error())
new_session = AsyncMock()
new_session.read_resource = AsyncMock(return_value=_make_resource_result("fresh"))
wrapper = MCPResourceWrapper(old_session, "test_server", _make_resource_def())
replacement = MCPResourceWrapper(new_session, "test_server", _make_resource_def())
async def reconnect(server_name: str, tool_name: str, stale_tool):
assert server_name == "test_server"
assert tool_name == "mcp_test_server_resource_test_resource"
assert stale_tool is wrapper
return replacement
wrapper.set_reconnect_handler(reconnect)
output = await wrapper.execute()
assert output == "fresh"
assert old_session.read_resource.call_count == 1
assert new_session.read_resource.call_count == 1
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# MCPPromptWrapper retry behaviour # MCPPromptWrapper retry behaviour
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -366,3 +438,29 @@ async def test_prompt_no_retry_on_non_transient():
assert "RuntimeError" in output assert "RuntimeError" in output
assert session.get_prompt.call_count == 1 assert session.get_prompt.call_count == 1
@pytest.mark.asyncio
async def test_prompt_reconnects_on_session_terminated():
"""Prompt should reconnect once before falling back to McpError handling."""
old_session = AsyncMock()
old_session.get_prompt = AsyncMock(side_effect=_session_terminated_error())
new_session = AsyncMock()
new_session.get_prompt = AsyncMock(return_value=_make_prompt_result("fresh prompt"))
wrapper = MCPPromptWrapper(old_session, "test_server", _make_prompt_def())
replacement = MCPPromptWrapper(new_session, "test_server", _make_prompt_def())
async def reconnect(server_name: str, tool_name: str, stale_tool):
assert server_name == "test_server"
assert tool_name == "mcp_test_server_prompt_test_prompt"
assert stale_tool is wrapper
return replacement
wrapper.set_reconnect_handler(reconnect)
output = await wrapper.execute()
assert output == "fresh prompt"
assert old_session.get_prompt.call_count == 1
assert new_session.get_prompt.call_count == 1