diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index 4cc5bdf55..73c0850d5 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -4,6 +4,7 @@ import asyncio import os import re import shutil +import urllib.parse from contextlib import AsyncExitStack, suppress from typing import Any @@ -44,6 +45,30 @@ def _is_transient(exc: BaseException) -> bool: return type(exc).__name__ in _TRANSIENT_EXC_NAMES +async def _probe_http_url(url: str, timeout: float = 3.0) -> bool: + """Quick TCP probe to check if an HTTP MCP server is reachable. + + Avoids entering ``streamable_http_client`` / ``sse_client`` when the port is + closed — those transports use anyio task groups whose cleanup can raise + ``RuntimeError`` / ``ExceptionGroup`` that escape the caller's try/except + and crash the event loop. + """ + parsed = urllib.parse.urlparse(url) + host = parsed.hostname or "127.0.0.1" + port = parsed.port + if not port: + port = 443 if parsed.scheme == "https" else 80 + try: + reader, writer = await asyncio.wait_for( + asyncio.open_connection(host, port), timeout=timeout, + ) + writer.close() + await writer.wait_closed() + return True + except (OSError, asyncio.TimeoutError): + return False + + def _windows_command_basename(command: str) -> str: """Return the lowercase basename for a Windows command or path.""" return command.replace("\\", "/").rsplit("/", maxsplit=1)[-1].lower() @@ -481,6 +506,10 @@ async def connect_mcp_servers( ) read, write = await server_stack.enter_async_context(stdio_client(params)) elif transport_type == "sse": + if not await _probe_http_url(cfg.url): + logger.warning("MCP server '{}': {} unreachable, skipping", name, cfg.url) + await server_stack.aclose() + return name, None def httpx_client_factory( headers: dict[str, str] | None = None, @@ -503,6 +532,11 @@ async def connect_mcp_servers( sse_client(cfg.url, httpx_client_factory=httpx_client_factory) ) elif transport_type == "streamableHttp": + if not await _probe_http_url(cfg.url): + logger.warning("MCP server '{}': {} unreachable, skipping", name, cfg.url) + await server_stack.aclose() + return name, None + http_client = await server_stack.enter_async_context( httpx.AsyncClient( headers=cfg.headers or None, diff --git a/tests/tools/test_mcp_probe.py b/tests/tools/test_mcp_probe.py new file mode 100644 index 000000000..f8fcea031 --- /dev/null +++ b/tests/tools/test_mcp_probe.py @@ -0,0 +1,106 @@ +"""Tests for MCP HTTP probe guard (prevents event-loop crash on unreachable servers).""" +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.agent.tools.mcp import _probe_http_url, connect_mcp_servers +from nanobot.agent.tools.registry import ToolRegistry + + +# --------------------------------------------------------------------------- +# _probe_http_url unit tests +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_probe_returns_true_for_open_port(tmp_path): + """Start a trivial TCP server, probe should return True.""" + server = await asyncio.start_server( + lambda r, w: None, "127.0.0.1", 0, + ) + port = server.sockets[0].getsockname()[1] + try: + assert await _probe_http_url(f"http://127.0.0.1:{port}/mcp") is True + finally: + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_probe_returns_false_for_closed_port(): + """Port 19999 is almost certainly not listening.""" + assert await _probe_http_url("http://127.0.0.1:19999/mcp") is False + + +@pytest.mark.asyncio +async def test_probe_uses_default_port_for_http(): + """When no port in URL, should default to 80 (will fail -> False).""" + assert await _probe_http_url("http://unreachable-host.test/mcp") is False + + +# --------------------------------------------------------------------------- +# connect_mcp_servers skips unreachable HTTP servers +# --------------------------------------------------------------------------- + +def _make_http_cfg(url: str, transport: str = "streamableHttp"): + cfg = MagicMock() + cfg.type = transport + cfg.url = url + cfg.command = None + cfg.args = [] + cfg.env = {} + cfg.headers = None + cfg.tool_timeout = 30 + cfg.enabled_tools = ["*"] + return cfg + + +@pytest.mark.asyncio +async def test_connect_skips_unreachable_streamable_http(): + """Unreachable streamableHttp server should be skipped with a warning, no crash.""" + registry = ToolRegistry() + servers = {"dead": _make_http_cfg("http://127.0.0.1:19999/mcp")} + stacks = await connect_mcp_servers(servers, registry) + assert stacks == {} + assert len(registry._tools) == 0 + + +@pytest.mark.asyncio +async def test_connect_skips_unreachable_sse(): + """Unreachable SSE server should be skipped with a warning, no crash.""" + registry = ToolRegistry() + servers = {"dead": _make_http_cfg("http://127.0.0.1:19999/sse", transport="sse")} + stacks = await connect_mcp_servers(servers, registry) + assert stacks == {} + assert len(registry._tools) == 0 + + +@pytest.mark.asyncio +async def test_probe_not_called_for_stdio(): + """stdio transport should not be probed — it spawns a local process.""" + called = False + original_probe = _probe_http_url + + async def _spy_probe(url, **kw): + nonlocal called + called = True + return await original_probe(url, **kw) + + with patch("nanobot.agent.tools.mcp._probe_http_url", _spy_probe): + cfg = MagicMock() + cfg.type = "stdio" + cfg.url = None + cfg.command = "nonexistent-command-xyz" + cfg.args = [] + cfg.env = None + cfg.headers = None + cfg.tool_timeout = 30 + cfg.enabled_tools = ["*"] + registry = ToolRegistry() + await connect_mcp_servers({"s": cfg}, registry) + + assert not called, "probe should not be called for stdio transport" + + +import asyncio