mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 00:22:31 +00:00
fix(mcp): probe HTTP port before connecting to prevent event-loop crash
When an MCP server configured as streamableHttp or SSE is unreachable, streamable_http_client's anyio task group cleanup raises RuntimeError / ExceptionGroup that escapes the caller's try/except and crashes the event loop with "Unhandled exception in event loop". Fix: add a lightweight TCP probe (_probe_http_url) before entering the MCP SDK transport. If the port is closed, the server is skipped with a warning instead of crashing. stdio transport is not probed (local process). Closes #3739
This commit is contained in:
parent
921fe259f4
commit
6a4ed255de
@ -4,6 +4,7 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import urllib.parse
|
||||||
from contextlib import AsyncExitStack, suppress
|
from contextlib import AsyncExitStack, suppress
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -44,6 +45,30 @@ def _is_transient(exc: BaseException) -> bool:
|
|||||||
return type(exc).__name__ in _TRANSIENT_EXC_NAMES
|
return type(exc).__name__ in _TRANSIENT_EXC_NAMES
|
||||||
|
|
||||||
|
|
||||||
|
async def _probe_http_url(url: str, timeout: float = 3.0) -> bool:
|
||||||
|
"""Quick TCP probe to check if an HTTP MCP server is reachable.
|
||||||
|
|
||||||
|
Avoids entering ``streamable_http_client`` / ``sse_client`` when the port is
|
||||||
|
closed — those transports use anyio task groups whose cleanup can raise
|
||||||
|
``RuntimeError`` / ``ExceptionGroup`` that escape the caller's try/except
|
||||||
|
and crash the event loop.
|
||||||
|
"""
|
||||||
|
parsed = urllib.parse.urlparse(url)
|
||||||
|
host = parsed.hostname or "127.0.0.1"
|
||||||
|
port = parsed.port
|
||||||
|
if not port:
|
||||||
|
port = 443 if parsed.scheme == "https" else 80
|
||||||
|
try:
|
||||||
|
reader, writer = await asyncio.wait_for(
|
||||||
|
asyncio.open_connection(host, port), timeout=timeout,
|
||||||
|
)
|
||||||
|
writer.close()
|
||||||
|
await writer.wait_closed()
|
||||||
|
return True
|
||||||
|
except (OSError, asyncio.TimeoutError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _windows_command_basename(command: str) -> str:
|
def _windows_command_basename(command: str) -> str:
|
||||||
"""Return the lowercase basename for a Windows command or path."""
|
"""Return the lowercase basename for a Windows command or path."""
|
||||||
return command.replace("\\", "/").rsplit("/", maxsplit=1)[-1].lower()
|
return command.replace("\\", "/").rsplit("/", maxsplit=1)[-1].lower()
|
||||||
@ -481,6 +506,10 @@ async def connect_mcp_servers(
|
|||||||
)
|
)
|
||||||
read, write = await server_stack.enter_async_context(stdio_client(params))
|
read, write = await server_stack.enter_async_context(stdio_client(params))
|
||||||
elif transport_type == "sse":
|
elif transport_type == "sse":
|
||||||
|
if not await _probe_http_url(cfg.url):
|
||||||
|
logger.warning("MCP server '{}': {} unreachable, skipping", name, cfg.url)
|
||||||
|
await server_stack.aclose()
|
||||||
|
return name, None
|
||||||
|
|
||||||
def httpx_client_factory(
|
def httpx_client_factory(
|
||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
@ -503,6 +532,11 @@ async def connect_mcp_servers(
|
|||||||
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
||||||
)
|
)
|
||||||
elif transport_type == "streamableHttp":
|
elif transport_type == "streamableHttp":
|
||||||
|
if not await _probe_http_url(cfg.url):
|
||||||
|
logger.warning("MCP server '{}': {} unreachable, skipping", name, cfg.url)
|
||||||
|
await server_stack.aclose()
|
||||||
|
return name, None
|
||||||
|
|
||||||
http_client = await server_stack.enter_async_context(
|
http_client = await server_stack.enter_async_context(
|
||||||
httpx.AsyncClient(
|
httpx.AsyncClient(
|
||||||
headers=cfg.headers or None,
|
headers=cfg.headers or None,
|
||||||
|
|||||||
106
tests/tools/test_mcp_probe.py
Normal file
106
tests/tools/test_mcp_probe.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
"""Tests for MCP HTTP probe guard (prevents event-loop crash on unreachable servers)."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.tools.mcp import _probe_http_url, connect_mcp_servers
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _probe_http_url unit tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_returns_true_for_open_port(tmp_path):
|
||||||
|
"""Start a trivial TCP server, probe should return True."""
|
||||||
|
server = await asyncio.start_server(
|
||||||
|
lambda r, w: None, "127.0.0.1", 0,
|
||||||
|
)
|
||||||
|
port = server.sockets[0].getsockname()[1]
|
||||||
|
try:
|
||||||
|
assert await _probe_http_url(f"http://127.0.0.1:{port}/mcp") is True
|
||||||
|
finally:
|
||||||
|
server.close()
|
||||||
|
await server.wait_closed()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_returns_false_for_closed_port():
|
||||||
|
"""Port 19999 is almost certainly not listening."""
|
||||||
|
assert await _probe_http_url("http://127.0.0.1:19999/mcp") is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_uses_default_port_for_http():
|
||||||
|
"""When no port in URL, should default to 80 (will fail -> False)."""
|
||||||
|
assert await _probe_http_url("http://unreachable-host.test/mcp") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# connect_mcp_servers skips unreachable HTTP servers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_http_cfg(url: str, transport: str = "streamableHttp"):
|
||||||
|
cfg = MagicMock()
|
||||||
|
cfg.type = transport
|
||||||
|
cfg.url = url
|
||||||
|
cfg.command = None
|
||||||
|
cfg.args = []
|
||||||
|
cfg.env = {}
|
||||||
|
cfg.headers = None
|
||||||
|
cfg.tool_timeout = 30
|
||||||
|
cfg.enabled_tools = ["*"]
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_skips_unreachable_streamable_http():
|
||||||
|
"""Unreachable streamableHttp server should be skipped with a warning, no crash."""
|
||||||
|
registry = ToolRegistry()
|
||||||
|
servers = {"dead": _make_http_cfg("http://127.0.0.1:19999/mcp")}
|
||||||
|
stacks = await connect_mcp_servers(servers, registry)
|
||||||
|
assert stacks == {}
|
||||||
|
assert len(registry._tools) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_skips_unreachable_sse():
|
||||||
|
"""Unreachable SSE server should be skipped with a warning, no crash."""
|
||||||
|
registry = ToolRegistry()
|
||||||
|
servers = {"dead": _make_http_cfg("http://127.0.0.1:19999/sse", transport="sse")}
|
||||||
|
stacks = await connect_mcp_servers(servers, registry)
|
||||||
|
assert stacks == {}
|
||||||
|
assert len(registry._tools) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_not_called_for_stdio():
|
||||||
|
"""stdio transport should not be probed — it spawns a local process."""
|
||||||
|
called = False
|
||||||
|
original_probe = _probe_http_url
|
||||||
|
|
||||||
|
async def _spy_probe(url, **kw):
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
return await original_probe(url, **kw)
|
||||||
|
|
||||||
|
with patch("nanobot.agent.tools.mcp._probe_http_url", _spy_probe):
|
||||||
|
cfg = MagicMock()
|
||||||
|
cfg.type = "stdio"
|
||||||
|
cfg.url = None
|
||||||
|
cfg.command = "nonexistent-command-xyz"
|
||||||
|
cfg.args = []
|
||||||
|
cfg.env = None
|
||||||
|
cfg.headers = None
|
||||||
|
cfg.tool_timeout = 30
|
||||||
|
cfg.enabled_tools = ["*"]
|
||||||
|
registry = ToolRegistry()
|
||||||
|
await connect_mcp_servers({"s": cfg}, registry)
|
||||||
|
|
||||||
|
assert not called, "probe should not be called for stdio transport"
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
Loading…
x
Reference in New Issue
Block a user