mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 14:23:58 +00:00
fix(mcp): reconnect terminated sessions
This commit is contained in:
parent
7c3808327f
commit
e9145b7acd
@ -5,6 +5,7 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
import urllib.parse
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import AsyncExitStack, suppress
|
||||
from typing import Any, Mapping
|
||||
from weakref import WeakKeyDictionary
|
||||
@ -41,6 +42,7 @@ _WINDOWS_SHELL_LAUNCHERS: frozenset[str] = frozenset(("npx", "npm", "pnpm", "yar
|
||||
# Replace anything outside [a-zA-Z0-9_-] with underscore and collapse runs.
|
||||
_SANITIZE_RE = re.compile(r"_+")
|
||||
_RELOAD_LOCKS: WeakKeyDictionary[Any, asyncio.Lock] = WeakKeyDictionary()
|
||||
_ReconnectCallback = Callable[[str, str, Tool], Awaitable[Tool | None]]
|
||||
|
||||
|
||||
def _sanitize_name(name: str) -> str:
|
||||
@ -53,6 +55,19 @@ def _is_transient(exc: BaseException) -> bool:
|
||||
return type(exc).__name__ in _TRANSIENT_EXC_NAMES
|
||||
|
||||
|
||||
def _is_session_terminated(exc: BaseException) -> bool:
|
||||
"""Return True when the MCP SDK reports a dead client session."""
|
||||
messages = [str(exc)]
|
||||
error = getattr(exc, "error", None)
|
||||
if error is not None:
|
||||
messages.append(str(getattr(error, "message", "")))
|
||||
return any(
|
||||
marker in message.lower()
|
||||
for marker in ("session terminated", "connection closed")
|
||||
for message in messages
|
||||
)
|
||||
|
||||
|
||||
async def _probe_http_url(url: str, timeout: float = 3.0) -> bool:
|
||||
"""Quick TCP probe to check if an HTTP MCP server is reachable.
|
||||
|
||||
@ -68,7 +83,8 @@ async def _probe_http_url(url: str, timeout: float = 3.0) -> bool:
|
||||
port = 443 if parsed.scheme == "https" else 80
|
||||
try:
|
||||
reader, writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(host, port), timeout=timeout,
|
||||
asyncio.open_connection(host, port),
|
||||
timeout=timeout,
|
||||
)
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
@ -174,13 +190,54 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
|
||||
return normalized
|
||||
|
||||
|
||||
class MCPToolWrapper(Tool):
|
||||
class _MCPWrapperBase(Tool):
|
||||
"""Common reconnect handling for wrappers bound to one MCP server session."""
|
||||
|
||||
_plugin_discoverable = False
|
||||
|
||||
def _set_mcp_connection(self, session: Any, server_name: str) -> None:
|
||||
self._session = session
|
||||
self._server_name = server_name
|
||||
self._reconnect: _ReconnectCallback | None = None
|
||||
|
||||
def set_reconnect_handler(self, reconnect: _ReconnectCallback) -> None:
|
||||
self._reconnect = reconnect
|
||||
|
||||
async def _refresh_session_after_termination(
|
||||
self,
|
||||
exc: BaseException,
|
||||
already_refreshed: bool,
|
||||
capability_kind: str,
|
||||
) -> bool:
|
||||
if already_refreshed or not _is_session_terminated(exc) or self._reconnect is None:
|
||||
return False
|
||||
logger.warning(
|
||||
"MCP {} '{}' session terminated; reconnecting server '{}' before retry",
|
||||
capability_kind,
|
||||
self._name,
|
||||
self._server_name,
|
||||
)
|
||||
refreshed_tool = await self._reconnect(self._server_name, self._name, self)
|
||||
refreshed_session = getattr(refreshed_tool, "_session", None)
|
||||
if refreshed_session is None:
|
||||
logger.warning(
|
||||
"MCP {} '{}' could not refresh session for server '{}'",
|
||||
capability_kind,
|
||||
self._name,
|
||||
self._server_name,
|
||||
)
|
||||
return False
|
||||
self._session = refreshed_session
|
||||
return True
|
||||
|
||||
|
||||
class MCPToolWrapper(_MCPWrapperBase):
|
||||
"""Wraps a single MCP server tool as a nanobot Tool."""
|
||||
|
||||
_plugin_discoverable = False
|
||||
|
||||
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
|
||||
self._session = session
|
||||
self._set_mcp_connection(session, server_name)
|
||||
self._original_name = tool_def.name
|
||||
self._name = _sanitize_name(f"mcp_{server_name}_{tool_def.name}")
|
||||
self._description = tool_def.description or tool_def.name
|
||||
@ -203,7 +260,9 @@ class MCPToolWrapper(Tool):
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
from mcp import types
|
||||
|
||||
for attempt in range(2): # At most 1 retry
|
||||
retried_transient = False
|
||||
refreshed_session = False
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._session.call_tool(self._original_name, arguments=kwargs),
|
||||
@ -223,8 +282,16 @@ class MCPToolWrapper(Tool):
|
||||
logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name)
|
||||
return "(MCP tool call was cancelled)"
|
||||
except Exception as exc:
|
||||
if await self._refresh_session_after_termination(
|
||||
exc,
|
||||
refreshed_session,
|
||||
"tool",
|
||||
):
|
||||
refreshed_session = True
|
||||
continue
|
||||
if _is_transient(exc):
|
||||
if attempt == 0:
|
||||
if not retried_transient:
|
||||
retried_transient = True
|
||||
logger.warning(
|
||||
"MCP tool '{}' hit transient error ({}), retrying once...",
|
||||
self._name,
|
||||
@ -259,13 +326,13 @@ class MCPToolWrapper(Tool):
|
||||
return "(MCP tool call failed)" # Unreachable, but satisfies type checkers
|
||||
|
||||
|
||||
class MCPResourceWrapper(Tool):
|
||||
class MCPResourceWrapper(_MCPWrapperBase):
|
||||
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
||||
|
||||
_plugin_discoverable = False
|
||||
|
||||
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
|
||||
self._session = session
|
||||
self._set_mcp_connection(session, server_name)
|
||||
self._uri = resource_def.uri
|
||||
self._name = _sanitize_name(f"mcp_{server_name}_resource_{resource_def.name}")
|
||||
desc = resource_def.description or resource_def.name
|
||||
@ -296,7 +363,9 @@ class MCPResourceWrapper(Tool):
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
from mcp import types
|
||||
|
||||
for attempt in range(2):
|
||||
retried_transient = False
|
||||
refreshed_session = False
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._session.read_resource(self._uri),
|
||||
@ -314,8 +383,16 @@ class MCPResourceWrapper(Tool):
|
||||
logger.warning("MCP resource '{}' was cancelled by server/SDK", self._name)
|
||||
return "(MCP resource read was cancelled)"
|
||||
except Exception as exc:
|
||||
if await self._refresh_session_after_termination(
|
||||
exc,
|
||||
refreshed_session,
|
||||
"resource",
|
||||
):
|
||||
refreshed_session = True
|
||||
continue
|
||||
if _is_transient(exc):
|
||||
if attempt == 0:
|
||||
if not retried_transient:
|
||||
retried_transient = True
|
||||
logger.warning(
|
||||
"MCP resource '{}' hit transient error ({}), retrying once...",
|
||||
self._name,
|
||||
@ -350,13 +427,13 @@ class MCPResourceWrapper(Tool):
|
||||
return "(MCP resource read failed)" # Unreachable
|
||||
|
||||
|
||||
class MCPPromptWrapper(Tool):
|
||||
class MCPPromptWrapper(_MCPWrapperBase):
|
||||
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
||||
|
||||
_plugin_discoverable = False
|
||||
|
||||
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
|
||||
self._session = session
|
||||
self._set_mcp_connection(session, server_name)
|
||||
self._prompt_name = prompt_def.name
|
||||
self._name = _sanitize_name(f"mcp_{server_name}_prompt_{prompt_def.name}")
|
||||
desc = prompt_def.description or prompt_def.name
|
||||
@ -402,7 +479,9 @@ class MCPPromptWrapper(Tool):
|
||||
from mcp import types
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
for attempt in range(2):
|
||||
retried_transient = False
|
||||
refreshed_session = False
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._session.get_prompt(self._prompt_name, arguments=kwargs),
|
||||
@ -420,6 +499,13 @@ class MCPPromptWrapper(Tool):
|
||||
logger.warning("MCP prompt '{}' was cancelled by server/SDK", self._name)
|
||||
return "(MCP prompt call was cancelled)"
|
||||
except McpError as exc:
|
||||
if await self._refresh_session_after_termination(
|
||||
exc,
|
||||
refreshed_session,
|
||||
"prompt",
|
||||
):
|
||||
refreshed_session = True
|
||||
continue
|
||||
logger.exception(
|
||||
"MCP prompt '{}' failed: code={} message={}",
|
||||
self._name,
|
||||
@ -429,7 +515,8 @@ class MCPPromptWrapper(Tool):
|
||||
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
||||
except Exception as exc:
|
||||
if _is_transient(exc):
|
||||
if attempt == 0:
|
||||
if not retried_transient:
|
||||
retried_transient = True
|
||||
logger.warning(
|
||||
"MCP prompt '{}' hit transient error ({}), retrying once...",
|
||||
self._name,
|
||||
@ -747,6 +834,7 @@ async def connect_missing_servers(state: Any, registry: ToolRegistry) -> None:
|
||||
try:
|
||||
connected = await connect_mcp_servers(missing_servers, registry)
|
||||
state._mcp_stacks.update(connected)
|
||||
_attach_reconnect_handlers(state, registry, connected)
|
||||
state._mcp_connected = bool(state._mcp_stacks)
|
||||
if connected:
|
||||
logger.info("MCP connected servers: {}", sorted(connected))
|
||||
@ -766,8 +854,7 @@ async def reload_servers(state: Any, registry: ToolRegistry) -> dict[str, Any]:
|
||||
"""Reconcile live MCP connections with the current config file."""
|
||||
async with _reload_lock(state):
|
||||
try:
|
||||
from nanobot.config.loader import (load_config,
|
||||
resolve_config_env_vars)
|
||||
from nanobot.config.loader import load_config, resolve_config_env_vars
|
||||
|
||||
config = resolve_config_env_vars(load_config())
|
||||
next_servers = dict(config.tools.mcp_servers)
|
||||
@ -808,6 +895,7 @@ async def reload_servers(state: Any, registry: ToolRegistry) -> dict[str, Any]:
|
||||
if to_connect:
|
||||
connected = await connect_mcp_servers(to_connect, registry)
|
||||
state._mcp_stacks.update(connected)
|
||||
_attach_reconnect_handlers(state, registry, connected)
|
||||
|
||||
state._mcp_connected = bool(state._mcp_stacks)
|
||||
failed = sorted(set(to_connect) - set(connected))
|
||||
@ -909,6 +997,68 @@ def _reload_lock(state: Any) -> asyncio.Lock:
|
||||
return lock
|
||||
|
||||
|
||||
def _attach_reconnect_handlers(
|
||||
state: Any,
|
||||
registry: ToolRegistry,
|
||||
server_names: Mapping[str, Any] | set[str] | list[str] | tuple[str, ...],
|
||||
) -> None:
|
||||
async def reconnect(server_name: str, tool_name: str, stale_tool: Tool) -> Tool | None:
|
||||
return await _refresh_terminated_server(
|
||||
state,
|
||||
registry,
|
||||
server_name,
|
||||
tool_name,
|
||||
stale_tool,
|
||||
)
|
||||
|
||||
for server_name in server_names:
|
||||
prefix = _tool_prefix(server_name)
|
||||
for tool_name in list(registry.tool_names):
|
||||
if not tool_name.startswith(prefix):
|
||||
continue
|
||||
tool = registry.get(tool_name)
|
||||
if isinstance(tool, _MCPWrapperBase):
|
||||
tool.set_reconnect_handler(reconnect)
|
||||
|
||||
|
||||
async def _refresh_terminated_server(
|
||||
state: Any,
|
||||
registry: ToolRegistry,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
stale_tool: Tool,
|
||||
) -> Tool | None:
|
||||
async with _reload_lock(state):
|
||||
cfg = state._mcp_servers.get(server_name)
|
||||
if cfg is None:
|
||||
logger.warning(
|
||||
"MCP server '{}' session terminated but is no longer configured",
|
||||
server_name,
|
||||
)
|
||||
return None
|
||||
|
||||
current_tool = registry.get(tool_name)
|
||||
if (
|
||||
current_tool is not None
|
||||
and current_tool is not stale_tool
|
||||
and server_name in state._mcp_stacks
|
||||
):
|
||||
return current_tool
|
||||
|
||||
logger.warning("MCP server '{}' session terminated; refreshing connection", server_name)
|
||||
_unregister_server_tools(state, registry, server_name)
|
||||
await _close_server(state, server_name)
|
||||
|
||||
connected = await connect_mcp_servers({server_name: cfg}, registry)
|
||||
state._mcp_stacks.update(connected)
|
||||
_attach_reconnect_handlers(state, registry, connected)
|
||||
state._mcp_connected = bool(state._mcp_stacks)
|
||||
if server_name not in connected:
|
||||
logger.warning("MCP server '{}' reconnect failed after session termination", server_name)
|
||||
return None
|
||||
return registry.get(tool_name)
|
||||
|
||||
|
||||
def _server_signature(cfg: Any) -> Any:
|
||||
if hasattr(cfg, "model_dump"):
|
||||
return cfg.model_dump(mode="json")
|
||||
|
||||
@ -4,14 +4,19 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from mcp import types as mcp_types
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import ErrorData
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.tools import mcp as mcp_runtime
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.mcp import MCPResourceWrapper, MCPToolWrapper
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.loader import load_config, save_config
|
||||
from nanobot.config.schema import MCPServerConfig
|
||||
@ -218,3 +223,128 @@ async def test_reload_mcp_servers_retries_configured_server_without_live_stack(
|
||||
assert result["retried"] == ["browserbase"]
|
||||
assert loop.tools.has("mcp_browserbase_navigate")
|
||||
await loop.close_mcp()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_reconnects_after_session_terminated(
|
||||
tmp_path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
loop = _make_loop(tmp_path, mcp_servers={"remote": object()})
|
||||
closed: list[str] = []
|
||||
sessions: list[Any] = []
|
||||
connect_count = 0
|
||||
|
||||
async def _mark_closed(name: str) -> None:
|
||||
closed.append(name)
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, index: int) -> None:
|
||||
self.index = index
|
||||
self.call_count = 0
|
||||
|
||||
async def call_tool(self, _name: str, arguments: dict[str, Any]) -> Any:
|
||||
self.call_count += 1
|
||||
assert arguments == {"symbol": "AAPL"}
|
||||
if self.index == 1:
|
||||
raise McpError(ErrorData(code=-32000, message="Session terminated"))
|
||||
return SimpleNamespace(
|
||||
content=[mcp_types.TextContent(type="text", text="recovered")]
|
||||
)
|
||||
|
||||
async def _fake_connect(servers, registry):
|
||||
nonlocal connect_count
|
||||
stacks = {}
|
||||
for name in servers:
|
||||
connect_count += 1
|
||||
session = _FakeSession(connect_count)
|
||||
sessions.append(session)
|
||||
tool_def = SimpleNamespace(
|
||||
name="quote",
|
||||
description="quote tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
registry.register(MCPToolWrapper(session, name, tool_def, tool_timeout=5))
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
stack.push_async_callback(_mark_closed, name)
|
||||
stacks[name] = stack
|
||||
return stacks
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
||||
|
||||
await loop._connect_mcp()
|
||||
old_tool = loop.tools.get("mcp_remote_quote")
|
||||
assert isinstance(old_tool, MCPToolWrapper)
|
||||
|
||||
output = await old_tool.execute(symbol="AAPL")
|
||||
|
||||
assert output == "recovered"
|
||||
assert connect_count == 2
|
||||
assert closed == ["remote"]
|
||||
assert sessions[0].call_count == 1
|
||||
assert sessions[1].call_count == 1
|
||||
assert "remote" in loop._mcp_stacks
|
||||
assert loop.tools.get("mcp_remote_quote") is not old_tool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_mcp_reconnect_reuses_fresh_session(
|
||||
tmp_path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
loop = _make_loop(tmp_path, mcp_servers={"remote": object()})
|
||||
closed: list[str] = []
|
||||
connect_count = 0
|
||||
|
||||
async def _mark_closed(name: str) -> None:
|
||||
closed.append(name)
|
||||
|
||||
class _DeadSession:
|
||||
async def read_resource(self, _uri: str) -> Any:
|
||||
raise McpError(ErrorData(code=-32000, message="Session terminated"))
|
||||
|
||||
class _LiveSession:
|
||||
async def read_resource(self, uri: str) -> Any:
|
||||
await asyncio.sleep(0)
|
||||
return SimpleNamespace(
|
||||
contents=[
|
||||
mcp_types.TextResourceContents(
|
||||
uri=uri,
|
||||
text=f"fresh:{uri.rsplit('/', maxsplit=1)[-1]}",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def _fake_connect(servers, registry):
|
||||
nonlocal connect_count
|
||||
stacks = {}
|
||||
for name in servers:
|
||||
connect_count += 1
|
||||
session = _DeadSession() if connect_count == 1 else _LiveSession()
|
||||
for resource_name in ("alpha", "beta"):
|
||||
resource_def = SimpleNamespace(
|
||||
name=resource_name,
|
||||
uri=f"file:///{resource_name}",
|
||||
description=f"{resource_name} resource",
|
||||
)
|
||||
registry.register(MCPResourceWrapper(session, name, resource_def))
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
stack.push_async_callback(_mark_closed, name)
|
||||
stacks[name] = stack
|
||||
return stacks
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
||||
|
||||
await loop._connect_mcp()
|
||||
old_alpha = loop.tools.get("mcp_remote_resource_alpha")
|
||||
old_beta = loop.tools.get("mcp_remote_resource_beta")
|
||||
assert isinstance(old_alpha, MCPResourceWrapper)
|
||||
assert isinstance(old_beta, MCPResourceWrapper)
|
||||
|
||||
outputs = await asyncio.gather(old_alpha.execute(), old_beta.execute())
|
||||
|
||||
assert outputs == ["fresh:alpha", "fresh:beta"]
|
||||
assert connect_count == 2
|
||||
assert closed == ["remote"]
|
||||
|
||||
@ -13,6 +13,7 @@ from nanobot.agent.tools.mcp import (
|
||||
MCPPromptWrapper,
|
||||
MCPResourceWrapper,
|
||||
MCPToolWrapper,
|
||||
_is_session_terminated,
|
||||
_is_transient,
|
||||
)
|
||||
|
||||
@ -35,6 +36,14 @@ class _FakeEndOfStreamError(Exception):
|
||||
_FakeEndOfStreamError.__name__ = "EndOfStream"
|
||||
|
||||
|
||||
def _session_terminated_error() -> McpError:
|
||||
return McpError(ErrorData(code=-32000, message="Session terminated"))
|
||||
|
||||
|
||||
def _connection_closed_error() -> McpError:
|
||||
return McpError(ErrorData(code=-32000, message="Connection closed"))
|
||||
|
||||
|
||||
def test_is_transient_recognizes_closed_resource():
|
||||
assert _is_transient(_FakeClosedResourceError("gone"))
|
||||
|
||||
@ -67,6 +76,14 @@ def test_is_transient_rejects_timeout():
|
||||
assert not _is_transient(TimeoutError("timeout"))
|
||||
|
||||
|
||||
def test_is_session_terminated_recognizes_mcp_error():
|
||||
assert _is_session_terminated(_session_terminated_error())
|
||||
|
||||
|
||||
def test_is_session_terminated_recognizes_connection_closed_mcp_error():
|
||||
assert _is_session_terminated(_connection_closed_error())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCPToolWrapper retry behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -219,6 +236,35 @@ async def test_tool_retry_on_end_of_stream():
|
||||
assert session.call_tool.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_reconnects_when_transient_retry_reveals_terminated_session():
|
||||
"""Tool should reconnect if a stale session reports termination after transient retry."""
|
||||
old_session = AsyncMock()
|
||||
old_session.call_tool = AsyncMock(
|
||||
side_effect=[_FakeClosedResourceError("closed"), _session_terminated_error()]
|
||||
)
|
||||
new_session = AsyncMock()
|
||||
new_session.call_tool = AsyncMock(return_value=_make_tool_result("fresh"))
|
||||
|
||||
wrapper = MCPToolWrapper(old_session, "test_server", _make_tool_def(), tool_timeout=5)
|
||||
replacement = MCPToolWrapper(new_session, "test_server", _make_tool_def(), tool_timeout=5)
|
||||
|
||||
async def reconnect(server_name: str, tool_name: str, stale_tool):
|
||||
assert server_name == "test_server"
|
||||
assert tool_name == "mcp_test_server_test_tool"
|
||||
assert stale_tool is wrapper
|
||||
return replacement
|
||||
|
||||
wrapper.set_reconnect_handler(reconnect)
|
||||
|
||||
with patch("nanobot.agent.tools.mcp.asyncio.sleep", new_callable=AsyncMock):
|
||||
output = await wrapper.execute(foo="bar")
|
||||
|
||||
assert output == "fresh"
|
||||
assert old_session.call_tool.call_count == 2
|
||||
assert new_session.call_tool.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCPResourceWrapper retry behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -284,6 +330,32 @@ async def test_resource_no_retry_on_non_transient():
|
||||
assert session.read_resource.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resource_reconnects_on_session_terminated():
|
||||
"""Resource should reconnect once when the MCP SDK reports a dead session."""
|
||||
old_session = AsyncMock()
|
||||
old_session.read_resource = AsyncMock(side_effect=_session_terminated_error())
|
||||
new_session = AsyncMock()
|
||||
new_session.read_resource = AsyncMock(return_value=_make_resource_result("fresh"))
|
||||
|
||||
wrapper = MCPResourceWrapper(old_session, "test_server", _make_resource_def())
|
||||
replacement = MCPResourceWrapper(new_session, "test_server", _make_resource_def())
|
||||
|
||||
async def reconnect(server_name: str, tool_name: str, stale_tool):
|
||||
assert server_name == "test_server"
|
||||
assert tool_name == "mcp_test_server_resource_test_resource"
|
||||
assert stale_tool is wrapper
|
||||
return replacement
|
||||
|
||||
wrapper.set_reconnect_handler(reconnect)
|
||||
|
||||
output = await wrapper.execute()
|
||||
|
||||
assert output == "fresh"
|
||||
assert old_session.read_resource.call_count == 1
|
||||
assert new_session.read_resource.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCPPromptWrapper retry behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -366,3 +438,29 @@ async def test_prompt_no_retry_on_non_transient():
|
||||
|
||||
assert "RuntimeError" in output
|
||||
assert session.get_prompt.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_reconnects_on_session_terminated():
|
||||
"""Prompt should reconnect once before falling back to McpError handling."""
|
||||
old_session = AsyncMock()
|
||||
old_session.get_prompt = AsyncMock(side_effect=_session_terminated_error())
|
||||
new_session = AsyncMock()
|
||||
new_session.get_prompt = AsyncMock(return_value=_make_prompt_result("fresh prompt"))
|
||||
|
||||
wrapper = MCPPromptWrapper(old_session, "test_server", _make_prompt_def())
|
||||
replacement = MCPPromptWrapper(new_session, "test_server", _make_prompt_def())
|
||||
|
||||
async def reconnect(server_name: str, tool_name: str, stale_tool):
|
||||
assert server_name == "test_server"
|
||||
assert tool_name == "mcp_test_server_prompt_test_prompt"
|
||||
assert stale_tool is wrapper
|
||||
return replacement
|
||||
|
||||
wrapper.set_reconnect_handler(reconnect)
|
||||
|
||||
output = await wrapper.execute()
|
||||
|
||||
assert output == "fresh prompt"
|
||||
assert old_session.get_prompt.call_count == 1
|
||||
assert new_session.get_prompt.call_count == 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user