diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index 59a41127e..181c4e9f8 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -21,6 +21,7 @@ from nanobot.bus.events import ( RUNTIME_CONTROL_MCP_RELOAD, InboundMessage, ) +from nanobot.security.network import validate_url_target # Transient connection errors that warrant a single retry. # These typically happen when an MCP server restarts or a network @@ -87,12 +88,23 @@ async def _probe_http_url(url: str, timeout: float = 3.0) -> bool: timeout=timeout, ) writer.close() - await writer.wait_closed() + with suppress(OSError, asyncio.TimeoutError): + await asyncio.wait_for(writer.wait_closed(), timeout=0.2) return True except (OSError, asyncio.TimeoutError): return False +async def _validate_mcp_request_url(request: httpx.Request) -> None: + """Validate each outgoing MCP HTTP request, including redirect targets.""" + ok, error = validate_url_target(str(request.url)) + if not ok: + raise httpx.RequestError( + f"Blocked unsafe MCP URL {request.url} ({error})", + request=request, + ) + + def _windows_command_basename(command: str) -> str: """Return the lowercase basename for a Windows command or path.""" return command.replace("\\", "/").rsplit("/", maxsplit=1)[-1].lower() @@ -595,6 +607,18 @@ async def connect_mcp_servers( await server_stack.aclose() return name, None + if transport_type in {"sse", "streamableHttp"}: + ok, error = validate_url_target(cfg.url) + if not ok: + logger.warning( + "MCP server '{}': blocked unsafe URL {} ({})", + name, + cfg.url, + error, + ) + await server_stack.aclose() + return name, None + if transport_type == "stdio": command, args, env = _normalize_windows_stdio_command( cfg.command, @@ -626,6 +650,7 @@ async def connect_mcp_servers( } return httpx.AsyncClient( headers=merged_headers or None, + event_hooks={"request": [_validate_mcp_request_url]}, follow_redirects=True, timeout=timeout, auth=auth, @@ -643,6 +668,7 @@ async def connect_mcp_servers( http_client = await server_stack.enter_async_context( httpx.AsyncClient( headers=cfg.headers or None, + event_hooks={"request": [_validate_mcp_request_url]}, follow_redirects=True, timeout=None, ) diff --git a/tests/tools/test_mcp_probe.py b/tests/tools/test_mcp_probe.py index 38dc8fe7e..818895a75 100644 --- a/tests/tools/test_mcp_probe.py +++ b/tests/tools/test_mcp_probe.py @@ -16,9 +16,11 @@ from nanobot.agent.tools.registry import ToolRegistry @pytest.mark.asyncio async def test_probe_returns_true_for_open_port(tmp_path): """Start a trivial TCP server, probe should return True.""" - server = await asyncio.start_server( - lambda r, w: None, "127.0.0.1", 0, - ) + async def _close_connection(_reader, writer): + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(_close_connection, "127.0.0.1", 0) port = server.sockets[0].getsockname()[1] try: assert await _probe_http_url(f"http://127.0.0.1:{port}/mcp") is True @@ -59,9 +61,13 @@ def _make_http_cfg(url: str, transport: str = "streamableHttp"): @pytest.mark.asyncio async def test_connect_skips_unreachable_streamable_http(): """Unreachable streamableHttp server should be skipped with a warning, no crash.""" + async def _unreachable(_url: str) -> bool: + return False + registry = ToolRegistry() - servers = {"dead": _make_http_cfg("http://127.0.0.1:19999/mcp")} - stacks = await connect_mcp_servers(servers, registry) + servers = {"dead": _make_http_cfg("http://93.184.216.34:19999/mcp")} + with patch("nanobot.agent.tools.mcp._probe_http_url", _unreachable): + stacks = await connect_mcp_servers(servers, registry) assert stacks == {} assert len(registry._tools) == 0 @@ -69,9 +75,13 @@ async def test_connect_skips_unreachable_streamable_http(): @pytest.mark.asyncio async def test_connect_skips_unreachable_sse(): """Unreachable SSE server should be skipped with a warning, no crash.""" + async def _unreachable(_url: str) -> bool: + return False + registry = ToolRegistry() - servers = {"dead": _make_http_cfg("http://127.0.0.1:19999/sse", transport="sse")} - stacks = await connect_mcp_servers(servers, registry) + servers = {"dead": _make_http_cfg("http://93.184.216.34:19999/sse", transport="sse")} + with patch("nanobot.agent.tools.mcp._probe_http_url", _unreachable): + stacks = await connect_mcp_servers(servers, registry) assert stacks == {} assert len(registry._tools) == 0 diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 68fadce44..d69fc03bc 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -5,6 +5,7 @@ import sys from contextlib import asynccontextmanager from types import ModuleType, SimpleNamespace +import httpx import pytest import nanobot.agent.tools.mcp as mcp_mod @@ -486,6 +487,80 @@ async def test_connect_mcp_servers_logs_stdio_pollution_hint( assert "stderr" in messages[-1] +@pytest.mark.asyncio +@pytest.mark.parametrize( + "config", + [ + MCPServerConfig(url="http://127.0.0.1:9/sse"), + MCPServerConfig(type="streamableHttp", url="http://127.0.0.1:9/mcp"), + ], +) +async def test_connect_mcp_servers_rejects_unsafe_http_urls_before_probe( + config: MCPServerConfig, + monkeypatch: pytest.MonkeyPatch, +) -> None: + attempted_connections: list[tuple[object, ...]] = [] + warnings: list[str] = [] + + async def _open_connection(*args: object, **_kwargs: object): + attempted_connections.append(args) + raise AssertionError("unsafe MCP URL should be rejected before TCP probe") + + def _warning(message: str, *args: object) -> None: + warnings.append(message.format(*args)) + + monkeypatch.setattr(mcp_mod.asyncio, "open_connection", _open_connection) + monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning) + + registry = ToolRegistry() + stacks = await connect_mcp_servers({"local": config}, registry) + + assert stacks == {} + assert registry.tool_names == [] + assert attempted_connections == [] + assert any("blocked unsafe URL" in warning for warning in warnings) + + +@pytest.mark.asyncio +async def test_mcp_http_request_hook_rejects_unsafe_redirect_targets( + monkeypatch: pytest.MonkeyPatch, +) -> None: + checked_urls: list[str] = [] + sent_urls: list[str] = [] + + def _validate(url: str) -> tuple[bool, str]: + checked_urls.append(url) + if url == "http://127.0.0.1/private": + return False, "loopback blocked" + return True, "" + + def _handler(request: httpx.Request) -> httpx.Response: + sent_urls.append(str(request.url)) + if str(request.url) == "https://example.com/start": + return httpx.Response( + 302, + headers={"Location": "http://127.0.0.1/private"}, + request=request, + ) + raise AssertionError("unsafe redirect target should be blocked before transport") + + monkeypatch.setattr(mcp_mod, "validate_url_target", _validate) + + async with httpx.AsyncClient( + event_hooks={"request": [mcp_mod._validate_mcp_request_url]}, + follow_redirects=True, + transport=httpx.MockTransport(_handler), + ) as client: + with pytest.raises(httpx.RequestError, match="loopback blocked"): + await client.get("https://example.com/start") + + assert checked_urls == [ + "https://example.com/start", + "http://127.0.0.1/private", + ] + assert sent_urls == ["https://example.com/start"] + + @pytest.mark.asyncio async def test_connect_mcp_servers_one_failure_does_not_block_others( monkeypatch: pytest.MonkeyPatch,