mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
fix(mcp): reconnect terminated sessions
This commit is contained in:
parent
7c3808327f
commit
e9145b7acd
@ -5,6 +5,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from contextlib import AsyncExitStack, suppress
|
from contextlib import AsyncExitStack, suppress
|
||||||
from typing import Any, Mapping
|
from typing import Any, Mapping
|
||||||
from weakref import WeakKeyDictionary
|
from weakref import WeakKeyDictionary
|
||||||
@ -41,6 +42,7 @@ _WINDOWS_SHELL_LAUNCHERS: frozenset[str] = frozenset(("npx", "npm", "pnpm", "yar
|
|||||||
# Replace anything outside [a-zA-Z0-9_-] with underscore and collapse runs.
|
# Replace anything outside [a-zA-Z0-9_-] with underscore and collapse runs.
|
||||||
_SANITIZE_RE = re.compile(r"_+")
|
_SANITIZE_RE = re.compile(r"_+")
|
||||||
_RELOAD_LOCKS: WeakKeyDictionary[Any, asyncio.Lock] = WeakKeyDictionary()
|
_RELOAD_LOCKS: WeakKeyDictionary[Any, asyncio.Lock] = WeakKeyDictionary()
|
||||||
|
_ReconnectCallback = Callable[[str, str, Tool], Awaitable[Tool | None]]
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_name(name: str) -> str:
|
def _sanitize_name(name: str) -> str:
|
||||||
@ -53,6 +55,19 @@ def _is_transient(exc: BaseException) -> bool:
|
|||||||
return type(exc).__name__ in _TRANSIENT_EXC_NAMES
|
return type(exc).__name__ in _TRANSIENT_EXC_NAMES
|
||||||
|
|
||||||
|
|
||||||
|
def _is_session_terminated(exc: BaseException) -> bool:
|
||||||
|
"""Return True when the MCP SDK reports a dead client session."""
|
||||||
|
messages = [str(exc)]
|
||||||
|
error = getattr(exc, "error", None)
|
||||||
|
if error is not None:
|
||||||
|
messages.append(str(getattr(error, "message", "")))
|
||||||
|
return any(
|
||||||
|
marker in message.lower()
|
||||||
|
for marker in ("session terminated", "connection closed")
|
||||||
|
for message in messages
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _probe_http_url(url: str, timeout: float = 3.0) -> bool:
|
async def _probe_http_url(url: str, timeout: float = 3.0) -> bool:
|
||||||
"""Quick TCP probe to check if an HTTP MCP server is reachable.
|
"""Quick TCP probe to check if an HTTP MCP server is reachable.
|
||||||
|
|
||||||
@ -68,7 +83,8 @@ async def _probe_http_url(url: str, timeout: float = 3.0) -> bool:
|
|||||||
port = 443 if parsed.scheme == "https" else 80
|
port = 443 if parsed.scheme == "https" else 80
|
||||||
try:
|
try:
|
||||||
reader, writer = await asyncio.wait_for(
|
reader, writer = await asyncio.wait_for(
|
||||||
asyncio.open_connection(host, port), timeout=timeout,
|
asyncio.open_connection(host, port),
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
writer.close()
|
writer.close()
|
||||||
await writer.wait_closed()
|
await writer.wait_closed()
|
||||||
@ -174,13 +190,54 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
|
|||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
class MCPToolWrapper(Tool):
|
class _MCPWrapperBase(Tool):
|
||||||
|
"""Common reconnect handling for wrappers bound to one MCP server session."""
|
||||||
|
|
||||||
|
_plugin_discoverable = False
|
||||||
|
|
||||||
|
def _set_mcp_connection(self, session: Any, server_name: str) -> None:
|
||||||
|
self._session = session
|
||||||
|
self._server_name = server_name
|
||||||
|
self._reconnect: _ReconnectCallback | None = None
|
||||||
|
|
||||||
|
def set_reconnect_handler(self, reconnect: _ReconnectCallback) -> None:
|
||||||
|
self._reconnect = reconnect
|
||||||
|
|
||||||
|
async def _refresh_session_after_termination(
|
||||||
|
self,
|
||||||
|
exc: BaseException,
|
||||||
|
already_refreshed: bool,
|
||||||
|
capability_kind: str,
|
||||||
|
) -> bool:
|
||||||
|
if already_refreshed or not _is_session_terminated(exc) or self._reconnect is None:
|
||||||
|
return False
|
||||||
|
logger.warning(
|
||||||
|
"MCP {} '{}' session terminated; reconnecting server '{}' before retry",
|
||||||
|
capability_kind,
|
||||||
|
self._name,
|
||||||
|
self._server_name,
|
||||||
|
)
|
||||||
|
refreshed_tool = await self._reconnect(self._server_name, self._name, self)
|
||||||
|
refreshed_session = getattr(refreshed_tool, "_session", None)
|
||||||
|
if refreshed_session is None:
|
||||||
|
logger.warning(
|
||||||
|
"MCP {} '{}' could not refresh session for server '{}'",
|
||||||
|
capability_kind,
|
||||||
|
self._name,
|
||||||
|
self._server_name,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
self._session = refreshed_session
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolWrapper(_MCPWrapperBase):
|
||||||
"""Wraps a single MCP server tool as a nanobot Tool."""
|
"""Wraps a single MCP server tool as a nanobot Tool."""
|
||||||
|
|
||||||
_plugin_discoverable = False
|
_plugin_discoverable = False
|
||||||
|
|
||||||
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
|
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
|
||||||
self._session = session
|
self._set_mcp_connection(session, server_name)
|
||||||
self._original_name = tool_def.name
|
self._original_name = tool_def.name
|
||||||
self._name = _sanitize_name(f"mcp_{server_name}_{tool_def.name}")
|
self._name = _sanitize_name(f"mcp_{server_name}_{tool_def.name}")
|
||||||
self._description = tool_def.description or tool_def.name
|
self._description = tool_def.description or tool_def.name
|
||||||
@ -203,7 +260,9 @@ class MCPToolWrapper(Tool):
|
|||||||
async def execute(self, **kwargs: Any) -> str:
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
from mcp import types
|
from mcp import types
|
||||||
|
|
||||||
for attempt in range(2): # At most 1 retry
|
retried_transient = False
|
||||||
|
refreshed_session = False
|
||||||
|
while True:
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self._session.call_tool(self._original_name, arguments=kwargs),
|
self._session.call_tool(self._original_name, arguments=kwargs),
|
||||||
@ -223,8 +282,16 @@ class MCPToolWrapper(Tool):
|
|||||||
logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name)
|
logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name)
|
||||||
return "(MCP tool call was cancelled)"
|
return "(MCP tool call was cancelled)"
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
if await self._refresh_session_after_termination(
|
||||||
|
exc,
|
||||||
|
refreshed_session,
|
||||||
|
"tool",
|
||||||
|
):
|
||||||
|
refreshed_session = True
|
||||||
|
continue
|
||||||
if _is_transient(exc):
|
if _is_transient(exc):
|
||||||
if attempt == 0:
|
if not retried_transient:
|
||||||
|
retried_transient = True
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"MCP tool '{}' hit transient error ({}), retrying once...",
|
"MCP tool '{}' hit transient error ({}), retrying once...",
|
||||||
self._name,
|
self._name,
|
||||||
@ -259,13 +326,13 @@ class MCPToolWrapper(Tool):
|
|||||||
return "(MCP tool call failed)" # Unreachable, but satisfies type checkers
|
return "(MCP tool call failed)" # Unreachable, but satisfies type checkers
|
||||||
|
|
||||||
|
|
||||||
class MCPResourceWrapper(Tool):
|
class MCPResourceWrapper(_MCPWrapperBase):
|
||||||
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
||||||
|
|
||||||
_plugin_discoverable = False
|
_plugin_discoverable = False
|
||||||
|
|
||||||
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
|
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
|
||||||
self._session = session
|
self._set_mcp_connection(session, server_name)
|
||||||
self._uri = resource_def.uri
|
self._uri = resource_def.uri
|
||||||
self._name = _sanitize_name(f"mcp_{server_name}_resource_{resource_def.name}")
|
self._name = _sanitize_name(f"mcp_{server_name}_resource_{resource_def.name}")
|
||||||
desc = resource_def.description or resource_def.name
|
desc = resource_def.description or resource_def.name
|
||||||
@ -296,7 +363,9 @@ class MCPResourceWrapper(Tool):
|
|||||||
async def execute(self, **kwargs: Any) -> str:
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
from mcp import types
|
from mcp import types
|
||||||
|
|
||||||
for attempt in range(2):
|
retried_transient = False
|
||||||
|
refreshed_session = False
|
||||||
|
while True:
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self._session.read_resource(self._uri),
|
self._session.read_resource(self._uri),
|
||||||
@ -314,8 +383,16 @@ class MCPResourceWrapper(Tool):
|
|||||||
logger.warning("MCP resource '{}' was cancelled by server/SDK", self._name)
|
logger.warning("MCP resource '{}' was cancelled by server/SDK", self._name)
|
||||||
return "(MCP resource read was cancelled)"
|
return "(MCP resource read was cancelled)"
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
if await self._refresh_session_after_termination(
|
||||||
|
exc,
|
||||||
|
refreshed_session,
|
||||||
|
"resource",
|
||||||
|
):
|
||||||
|
refreshed_session = True
|
||||||
|
continue
|
||||||
if _is_transient(exc):
|
if _is_transient(exc):
|
||||||
if attempt == 0:
|
if not retried_transient:
|
||||||
|
retried_transient = True
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"MCP resource '{}' hit transient error ({}), retrying once...",
|
"MCP resource '{}' hit transient error ({}), retrying once...",
|
||||||
self._name,
|
self._name,
|
||||||
@ -350,13 +427,13 @@ class MCPResourceWrapper(Tool):
|
|||||||
return "(MCP resource read failed)" # Unreachable
|
return "(MCP resource read failed)" # Unreachable
|
||||||
|
|
||||||
|
|
||||||
class MCPPromptWrapper(Tool):
|
class MCPPromptWrapper(_MCPWrapperBase):
|
||||||
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
||||||
|
|
||||||
_plugin_discoverable = False
|
_plugin_discoverable = False
|
||||||
|
|
||||||
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
|
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
|
||||||
self._session = session
|
self._set_mcp_connection(session, server_name)
|
||||||
self._prompt_name = prompt_def.name
|
self._prompt_name = prompt_def.name
|
||||||
self._name = _sanitize_name(f"mcp_{server_name}_prompt_{prompt_def.name}")
|
self._name = _sanitize_name(f"mcp_{server_name}_prompt_{prompt_def.name}")
|
||||||
desc = prompt_def.description or prompt_def.name
|
desc = prompt_def.description or prompt_def.name
|
||||||
@ -402,7 +479,9 @@ class MCPPromptWrapper(Tool):
|
|||||||
from mcp import types
|
from mcp import types
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
|
|
||||||
for attempt in range(2):
|
retried_transient = False
|
||||||
|
refreshed_session = False
|
||||||
|
while True:
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self._session.get_prompt(self._prompt_name, arguments=kwargs),
|
self._session.get_prompt(self._prompt_name, arguments=kwargs),
|
||||||
@ -420,6 +499,13 @@ class MCPPromptWrapper(Tool):
|
|||||||
logger.warning("MCP prompt '{}' was cancelled by server/SDK", self._name)
|
logger.warning("MCP prompt '{}' was cancelled by server/SDK", self._name)
|
||||||
return "(MCP prompt call was cancelled)"
|
return "(MCP prompt call was cancelled)"
|
||||||
except McpError as exc:
|
except McpError as exc:
|
||||||
|
if await self._refresh_session_after_termination(
|
||||||
|
exc,
|
||||||
|
refreshed_session,
|
||||||
|
"prompt",
|
||||||
|
):
|
||||||
|
refreshed_session = True
|
||||||
|
continue
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"MCP prompt '{}' failed: code={} message={}",
|
"MCP prompt '{}' failed: code={} message={}",
|
||||||
self._name,
|
self._name,
|
||||||
@ -429,7 +515,8 @@ class MCPPromptWrapper(Tool):
|
|||||||
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if _is_transient(exc):
|
if _is_transient(exc):
|
||||||
if attempt == 0:
|
if not retried_transient:
|
||||||
|
retried_transient = True
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"MCP prompt '{}' hit transient error ({}), retrying once...",
|
"MCP prompt '{}' hit transient error ({}), retrying once...",
|
||||||
self._name,
|
self._name,
|
||||||
@ -747,6 +834,7 @@ async def connect_missing_servers(state: Any, registry: ToolRegistry) -> None:
|
|||||||
try:
|
try:
|
||||||
connected = await connect_mcp_servers(missing_servers, registry)
|
connected = await connect_mcp_servers(missing_servers, registry)
|
||||||
state._mcp_stacks.update(connected)
|
state._mcp_stacks.update(connected)
|
||||||
|
_attach_reconnect_handlers(state, registry, connected)
|
||||||
state._mcp_connected = bool(state._mcp_stacks)
|
state._mcp_connected = bool(state._mcp_stacks)
|
||||||
if connected:
|
if connected:
|
||||||
logger.info("MCP connected servers: {}", sorted(connected))
|
logger.info("MCP connected servers: {}", sorted(connected))
|
||||||
@ -766,8 +854,7 @@ async def reload_servers(state: Any, registry: ToolRegistry) -> dict[str, Any]:
|
|||||||
"""Reconcile live MCP connections with the current config file."""
|
"""Reconcile live MCP connections with the current config file."""
|
||||||
async with _reload_lock(state):
|
async with _reload_lock(state):
|
||||||
try:
|
try:
|
||||||
from nanobot.config.loader import (load_config,
|
from nanobot.config.loader import load_config, resolve_config_env_vars
|
||||||
resolve_config_env_vars)
|
|
||||||
|
|
||||||
config = resolve_config_env_vars(load_config())
|
config = resolve_config_env_vars(load_config())
|
||||||
next_servers = dict(config.tools.mcp_servers)
|
next_servers = dict(config.tools.mcp_servers)
|
||||||
@ -808,6 +895,7 @@ async def reload_servers(state: Any, registry: ToolRegistry) -> dict[str, Any]:
|
|||||||
if to_connect:
|
if to_connect:
|
||||||
connected = await connect_mcp_servers(to_connect, registry)
|
connected = await connect_mcp_servers(to_connect, registry)
|
||||||
state._mcp_stacks.update(connected)
|
state._mcp_stacks.update(connected)
|
||||||
|
_attach_reconnect_handlers(state, registry, connected)
|
||||||
|
|
||||||
state._mcp_connected = bool(state._mcp_stacks)
|
state._mcp_connected = bool(state._mcp_stacks)
|
||||||
failed = sorted(set(to_connect) - set(connected))
|
failed = sorted(set(to_connect) - set(connected))
|
||||||
@ -909,6 +997,68 @@ def _reload_lock(state: Any) -> asyncio.Lock:
|
|||||||
return lock
|
return lock
|
||||||
|
|
||||||
|
|
||||||
|
def _attach_reconnect_handlers(
|
||||||
|
state: Any,
|
||||||
|
registry: ToolRegistry,
|
||||||
|
server_names: Mapping[str, Any] | set[str] | list[str] | tuple[str, ...],
|
||||||
|
) -> None:
|
||||||
|
async def reconnect(server_name: str, tool_name: str, stale_tool: Tool) -> Tool | None:
|
||||||
|
return await _refresh_terminated_server(
|
||||||
|
state,
|
||||||
|
registry,
|
||||||
|
server_name,
|
||||||
|
tool_name,
|
||||||
|
stale_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
for server_name in server_names:
|
||||||
|
prefix = _tool_prefix(server_name)
|
||||||
|
for tool_name in list(registry.tool_names):
|
||||||
|
if not tool_name.startswith(prefix):
|
||||||
|
continue
|
||||||
|
tool = registry.get(tool_name)
|
||||||
|
if isinstance(tool, _MCPWrapperBase):
|
||||||
|
tool.set_reconnect_handler(reconnect)
|
||||||
|
|
||||||
|
|
||||||
|
async def _refresh_terminated_server(
|
||||||
|
state: Any,
|
||||||
|
registry: ToolRegistry,
|
||||||
|
server_name: str,
|
||||||
|
tool_name: str,
|
||||||
|
stale_tool: Tool,
|
||||||
|
) -> Tool | None:
|
||||||
|
async with _reload_lock(state):
|
||||||
|
cfg = state._mcp_servers.get(server_name)
|
||||||
|
if cfg is None:
|
||||||
|
logger.warning(
|
||||||
|
"MCP server '{}' session terminated but is no longer configured",
|
||||||
|
server_name,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
current_tool = registry.get(tool_name)
|
||||||
|
if (
|
||||||
|
current_tool is not None
|
||||||
|
and current_tool is not stale_tool
|
||||||
|
and server_name in state._mcp_stacks
|
||||||
|
):
|
||||||
|
return current_tool
|
||||||
|
|
||||||
|
logger.warning("MCP server '{}' session terminated; refreshing connection", server_name)
|
||||||
|
_unregister_server_tools(state, registry, server_name)
|
||||||
|
await _close_server(state, server_name)
|
||||||
|
|
||||||
|
connected = await connect_mcp_servers({server_name: cfg}, registry)
|
||||||
|
state._mcp_stacks.update(connected)
|
||||||
|
_attach_reconnect_handlers(state, registry, connected)
|
||||||
|
state._mcp_connected = bool(state._mcp_stacks)
|
||||||
|
if server_name not in connected:
|
||||||
|
logger.warning("MCP server '{}' reconnect failed after session termination", server_name)
|
||||||
|
return None
|
||||||
|
return registry.get(tool_name)
|
||||||
|
|
||||||
|
|
||||||
def _server_signature(cfg: Any) -> Any:
|
def _server_signature(cfg: Any) -> Any:
|
||||||
if hasattr(cfg, "model_dump"):
|
if hasattr(cfg, "model_dump"):
|
||||||
return cfg.model_dump(mode="json")
|
return cfg.model_dump(mode="json")
|
||||||
|
|||||||
@ -4,14 +4,19 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from mcp import types as mcp_types
|
||||||
|
from mcp.shared.exceptions import McpError
|
||||||
|
from mcp.types import ErrorData
|
||||||
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.agent.tools import mcp as mcp_runtime
|
from nanobot.agent.tools import mcp as mcp_runtime
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
|
from nanobot.agent.tools.mcp import MCPResourceWrapper, MCPToolWrapper
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.config.loader import load_config, save_config
|
from nanobot.config.loader import load_config, save_config
|
||||||
from nanobot.config.schema import MCPServerConfig
|
from nanobot.config.schema import MCPServerConfig
|
||||||
@ -218,3 +223,128 @@ async def test_reload_mcp_servers_retries_configured_server_without_live_stack(
|
|||||||
assert result["retried"] == ["browserbase"]
|
assert result["retried"] == ["browserbase"]
|
||||||
assert loop.tools.has("mcp_browserbase_navigate")
|
assert loop.tools.has("mcp_browserbase_navigate")
|
||||||
await loop.close_mcp()
|
await loop.close_mcp()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_tool_reconnects_after_session_terminated(
|
||||||
|
tmp_path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
):
|
||||||
|
loop = _make_loop(tmp_path, mcp_servers={"remote": object()})
|
||||||
|
closed: list[str] = []
|
||||||
|
sessions: list[Any] = []
|
||||||
|
connect_count = 0
|
||||||
|
|
||||||
|
async def _mark_closed(name: str) -> None:
|
||||||
|
closed.append(name)
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self, index: int) -> None:
|
||||||
|
self.index = index
|
||||||
|
self.call_count = 0
|
||||||
|
|
||||||
|
async def call_tool(self, _name: str, arguments: dict[str, Any]) -> Any:
|
||||||
|
self.call_count += 1
|
||||||
|
assert arguments == {"symbol": "AAPL"}
|
||||||
|
if self.index == 1:
|
||||||
|
raise McpError(ErrorData(code=-32000, message="Session terminated"))
|
||||||
|
return SimpleNamespace(
|
||||||
|
content=[mcp_types.TextContent(type="text", text="recovered")]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fake_connect(servers, registry):
|
||||||
|
nonlocal connect_count
|
||||||
|
stacks = {}
|
||||||
|
for name in servers:
|
||||||
|
connect_count += 1
|
||||||
|
session = _FakeSession(connect_count)
|
||||||
|
sessions.append(session)
|
||||||
|
tool_def = SimpleNamespace(
|
||||||
|
name="quote",
|
||||||
|
description="quote tool",
|
||||||
|
inputSchema={"type": "object", "properties": {}},
|
||||||
|
)
|
||||||
|
registry.register(MCPToolWrapper(session, name, tool_def, tool_timeout=5))
|
||||||
|
stack = AsyncExitStack()
|
||||||
|
await stack.__aenter__()
|
||||||
|
stack.push_async_callback(_mark_closed, name)
|
||||||
|
stacks[name] = stack
|
||||||
|
return stacks
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
||||||
|
|
||||||
|
await loop._connect_mcp()
|
||||||
|
old_tool = loop.tools.get("mcp_remote_quote")
|
||||||
|
assert isinstance(old_tool, MCPToolWrapper)
|
||||||
|
|
||||||
|
output = await old_tool.execute(symbol="AAPL")
|
||||||
|
|
||||||
|
assert output == "recovered"
|
||||||
|
assert connect_count == 2
|
||||||
|
assert closed == ["remote"]
|
||||||
|
assert sessions[0].call_count == 1
|
||||||
|
assert sessions[1].call_count == 1
|
||||||
|
assert "remote" in loop._mcp_stacks
|
||||||
|
assert loop.tools.get("mcp_remote_quote") is not old_tool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_mcp_reconnect_reuses_fresh_session(
|
||||||
|
tmp_path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
):
|
||||||
|
loop = _make_loop(tmp_path, mcp_servers={"remote": object()})
|
||||||
|
closed: list[str] = []
|
||||||
|
connect_count = 0
|
||||||
|
|
||||||
|
async def _mark_closed(name: str) -> None:
|
||||||
|
closed.append(name)
|
||||||
|
|
||||||
|
class _DeadSession:
|
||||||
|
async def read_resource(self, _uri: str) -> Any:
|
||||||
|
raise McpError(ErrorData(code=-32000, message="Session terminated"))
|
||||||
|
|
||||||
|
class _LiveSession:
|
||||||
|
async def read_resource(self, uri: str) -> Any:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
return SimpleNamespace(
|
||||||
|
contents=[
|
||||||
|
mcp_types.TextResourceContents(
|
||||||
|
uri=uri,
|
||||||
|
text=f"fresh:{uri.rsplit('/', maxsplit=1)[-1]}",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fake_connect(servers, registry):
|
||||||
|
nonlocal connect_count
|
||||||
|
stacks = {}
|
||||||
|
for name in servers:
|
||||||
|
connect_count += 1
|
||||||
|
session = _DeadSession() if connect_count == 1 else _LiveSession()
|
||||||
|
for resource_name in ("alpha", "beta"):
|
||||||
|
resource_def = SimpleNamespace(
|
||||||
|
name=resource_name,
|
||||||
|
uri=f"file:///{resource_name}",
|
||||||
|
description=f"{resource_name} resource",
|
||||||
|
)
|
||||||
|
registry.register(MCPResourceWrapper(session, name, resource_def))
|
||||||
|
stack = AsyncExitStack()
|
||||||
|
await stack.__aenter__()
|
||||||
|
stack.push_async_callback(_mark_closed, name)
|
||||||
|
stacks[name] = stack
|
||||||
|
return stacks
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
||||||
|
|
||||||
|
await loop._connect_mcp()
|
||||||
|
old_alpha = loop.tools.get("mcp_remote_resource_alpha")
|
||||||
|
old_beta = loop.tools.get("mcp_remote_resource_beta")
|
||||||
|
assert isinstance(old_alpha, MCPResourceWrapper)
|
||||||
|
assert isinstance(old_beta, MCPResourceWrapper)
|
||||||
|
|
||||||
|
outputs = await asyncio.gather(old_alpha.execute(), old_beta.execute())
|
||||||
|
|
||||||
|
assert outputs == ["fresh:alpha", "fresh:beta"]
|
||||||
|
assert connect_count == 2
|
||||||
|
assert closed == ["remote"]
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from nanobot.agent.tools.mcp import (
|
|||||||
MCPPromptWrapper,
|
MCPPromptWrapper,
|
||||||
MCPResourceWrapper,
|
MCPResourceWrapper,
|
||||||
MCPToolWrapper,
|
MCPToolWrapper,
|
||||||
|
_is_session_terminated,
|
||||||
_is_transient,
|
_is_transient,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -35,6 +36,14 @@ class _FakeEndOfStreamError(Exception):
|
|||||||
_FakeEndOfStreamError.__name__ = "EndOfStream"
|
_FakeEndOfStreamError.__name__ = "EndOfStream"
|
||||||
|
|
||||||
|
|
||||||
|
def _session_terminated_error() -> McpError:
|
||||||
|
return McpError(ErrorData(code=-32000, message="Session terminated"))
|
||||||
|
|
||||||
|
|
||||||
|
def _connection_closed_error() -> McpError:
|
||||||
|
return McpError(ErrorData(code=-32000, message="Connection closed"))
|
||||||
|
|
||||||
|
|
||||||
def test_is_transient_recognizes_closed_resource():
|
def test_is_transient_recognizes_closed_resource():
|
||||||
assert _is_transient(_FakeClosedResourceError("gone"))
|
assert _is_transient(_FakeClosedResourceError("gone"))
|
||||||
|
|
||||||
@ -67,6 +76,14 @@ def test_is_transient_rejects_timeout():
|
|||||||
assert not _is_transient(TimeoutError("timeout"))
|
assert not _is_transient(TimeoutError("timeout"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_session_terminated_recognizes_mcp_error():
|
||||||
|
assert _is_session_terminated(_session_terminated_error())
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_session_terminated_recognizes_connection_closed_mcp_error():
|
||||||
|
assert _is_session_terminated(_connection_closed_error())
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# MCPToolWrapper retry behaviour
|
# MCPToolWrapper retry behaviour
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -219,6 +236,35 @@ async def test_tool_retry_on_end_of_stream():
|
|||||||
assert session.call_tool.call_count == 2
|
assert session.call_tool.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_reconnects_when_transient_retry_reveals_terminated_session():
|
||||||
|
"""Tool should reconnect if a stale session reports termination after transient retry."""
|
||||||
|
old_session = AsyncMock()
|
||||||
|
old_session.call_tool = AsyncMock(
|
||||||
|
side_effect=[_FakeClosedResourceError("closed"), _session_terminated_error()]
|
||||||
|
)
|
||||||
|
new_session = AsyncMock()
|
||||||
|
new_session.call_tool = AsyncMock(return_value=_make_tool_result("fresh"))
|
||||||
|
|
||||||
|
wrapper = MCPToolWrapper(old_session, "test_server", _make_tool_def(), tool_timeout=5)
|
||||||
|
replacement = MCPToolWrapper(new_session, "test_server", _make_tool_def(), tool_timeout=5)
|
||||||
|
|
||||||
|
async def reconnect(server_name: str, tool_name: str, stale_tool):
|
||||||
|
assert server_name == "test_server"
|
||||||
|
assert tool_name == "mcp_test_server_test_tool"
|
||||||
|
assert stale_tool is wrapper
|
||||||
|
return replacement
|
||||||
|
|
||||||
|
wrapper.set_reconnect_handler(reconnect)
|
||||||
|
|
||||||
|
with patch("nanobot.agent.tools.mcp.asyncio.sleep", new_callable=AsyncMock):
|
||||||
|
output = await wrapper.execute(foo="bar")
|
||||||
|
|
||||||
|
assert output == "fresh"
|
||||||
|
assert old_session.call_tool.call_count == 2
|
||||||
|
assert new_session.call_tool.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# MCPResourceWrapper retry behaviour
|
# MCPResourceWrapper retry behaviour
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -284,6 +330,32 @@ async def test_resource_no_retry_on_non_transient():
|
|||||||
assert session.read_resource.call_count == 1
|
assert session.read_resource.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resource_reconnects_on_session_terminated():
|
||||||
|
"""Resource should reconnect once when the MCP SDK reports a dead session."""
|
||||||
|
old_session = AsyncMock()
|
||||||
|
old_session.read_resource = AsyncMock(side_effect=_session_terminated_error())
|
||||||
|
new_session = AsyncMock()
|
||||||
|
new_session.read_resource = AsyncMock(return_value=_make_resource_result("fresh"))
|
||||||
|
|
||||||
|
wrapper = MCPResourceWrapper(old_session, "test_server", _make_resource_def())
|
||||||
|
replacement = MCPResourceWrapper(new_session, "test_server", _make_resource_def())
|
||||||
|
|
||||||
|
async def reconnect(server_name: str, tool_name: str, stale_tool):
|
||||||
|
assert server_name == "test_server"
|
||||||
|
assert tool_name == "mcp_test_server_resource_test_resource"
|
||||||
|
assert stale_tool is wrapper
|
||||||
|
return replacement
|
||||||
|
|
||||||
|
wrapper.set_reconnect_handler(reconnect)
|
||||||
|
|
||||||
|
output = await wrapper.execute()
|
||||||
|
|
||||||
|
assert output == "fresh"
|
||||||
|
assert old_session.read_resource.call_count == 1
|
||||||
|
assert new_session.read_resource.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# MCPPromptWrapper retry behaviour
|
# MCPPromptWrapper retry behaviour
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -366,3 +438,29 @@ async def test_prompt_no_retry_on_non_transient():
|
|||||||
|
|
||||||
assert "RuntimeError" in output
|
assert "RuntimeError" in output
|
||||||
assert session.get_prompt.call_count == 1
|
assert session.get_prompt.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_reconnects_on_session_terminated():
|
||||||
|
"""Prompt should reconnect once before falling back to McpError handling."""
|
||||||
|
old_session = AsyncMock()
|
||||||
|
old_session.get_prompt = AsyncMock(side_effect=_session_terminated_error())
|
||||||
|
new_session = AsyncMock()
|
||||||
|
new_session.get_prompt = AsyncMock(return_value=_make_prompt_result("fresh prompt"))
|
||||||
|
|
||||||
|
wrapper = MCPPromptWrapper(old_session, "test_server", _make_prompt_def())
|
||||||
|
replacement = MCPPromptWrapper(new_session, "test_server", _make_prompt_def())
|
||||||
|
|
||||||
|
async def reconnect(server_name: str, tool_name: str, stale_tool):
|
||||||
|
assert server_name == "test_server"
|
||||||
|
assert tool_name == "mcp_test_server_prompt_test_prompt"
|
||||||
|
assert stale_tool is wrapper
|
||||||
|
return replacement
|
||||||
|
|
||||||
|
wrapper.set_reconnect_handler(reconnect)
|
||||||
|
|
||||||
|
output = await wrapper.execute()
|
||||||
|
|
||||||
|
assert output == "fresh prompt"
|
||||||
|
assert old_session.get_prompt.call_count == 1
|
||||||
|
assert new_session.get_prompt.call_count == 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user