fix(mcp): reject unsafe HTTP URLs before probe

This commit is contained in:
Stellar鱼 2026-06-07 13:38:02 +08:00 committed by Xubin Ren
parent 6e6470daa0
commit ed0aeb1ea9
3 changed files with 119 additions and 8 deletions

View File

@ -21,6 +21,7 @@ from nanobot.bus.events import (
RUNTIME_CONTROL_MCP_RELOAD, RUNTIME_CONTROL_MCP_RELOAD,
InboundMessage, InboundMessage,
) )
from nanobot.security.network import validate_url_target
# Transient connection errors that warrant a single retry. # Transient connection errors that warrant a single retry.
# These typically happen when an MCP server restarts or a network # These typically happen when an MCP server restarts or a network
@ -87,12 +88,23 @@ async def _probe_http_url(url: str, timeout: float = 3.0) -> bool:
timeout=timeout, timeout=timeout,
) )
writer.close() writer.close()
await writer.wait_closed() with suppress(OSError, asyncio.TimeoutError):
await asyncio.wait_for(writer.wait_closed(), timeout=0.2)
return True return True
except (OSError, asyncio.TimeoutError): except (OSError, asyncio.TimeoutError):
return False return False
async def _validate_mcp_request_url(request: httpx.Request) -> None:
"""Validate each outgoing MCP HTTP request, including redirect targets."""
ok, error = validate_url_target(str(request.url))
if not ok:
raise httpx.RequestError(
f"Blocked unsafe MCP URL {request.url} ({error})",
request=request,
)
def _windows_command_basename(command: str) -> str: def _windows_command_basename(command: str) -> str:
"""Return the lowercase basename for a Windows command or path.""" """Return the lowercase basename for a Windows command or path."""
return command.replace("\\", "/").rsplit("/", maxsplit=1)[-1].lower() return command.replace("\\", "/").rsplit("/", maxsplit=1)[-1].lower()
@ -595,6 +607,18 @@ async def connect_mcp_servers(
await server_stack.aclose() await server_stack.aclose()
return name, None return name, None
if transport_type in {"sse", "streamableHttp"}:
ok, error = validate_url_target(cfg.url)
if not ok:
logger.warning(
"MCP server '{}': blocked unsafe URL {} ({})",
name,
cfg.url,
error,
)
await server_stack.aclose()
return name, None
if transport_type == "stdio": if transport_type == "stdio":
command, args, env = _normalize_windows_stdio_command( command, args, env = _normalize_windows_stdio_command(
cfg.command, cfg.command,
@ -626,6 +650,7 @@ async def connect_mcp_servers(
} }
return httpx.AsyncClient( return httpx.AsyncClient(
headers=merged_headers or None, headers=merged_headers or None,
event_hooks={"request": [_validate_mcp_request_url]},
follow_redirects=True, follow_redirects=True,
timeout=timeout, timeout=timeout,
auth=auth, auth=auth,
@ -643,6 +668,7 @@ async def connect_mcp_servers(
http_client = await server_stack.enter_async_context( http_client = await server_stack.enter_async_context(
httpx.AsyncClient( httpx.AsyncClient(
headers=cfg.headers or None, headers=cfg.headers or None,
event_hooks={"request": [_validate_mcp_request_url]},
follow_redirects=True, follow_redirects=True,
timeout=None, timeout=None,
) )

View File

@ -16,9 +16,11 @@ from nanobot.agent.tools.registry import ToolRegistry
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_probe_returns_true_for_open_port(tmp_path): async def test_probe_returns_true_for_open_port(tmp_path):
"""Start a trivial TCP server, probe should return True.""" """Start a trivial TCP server, probe should return True."""
server = await asyncio.start_server( async def _close_connection(_reader, writer):
lambda r, w: None, "127.0.0.1", 0, writer.close()
) await writer.wait_closed()
server = await asyncio.start_server(_close_connection, "127.0.0.1", 0)
port = server.sockets[0].getsockname()[1] port = server.sockets[0].getsockname()[1]
try: try:
assert await _probe_http_url(f"http://127.0.0.1:{port}/mcp") is True assert await _probe_http_url(f"http://127.0.0.1:{port}/mcp") is True
@ -59,9 +61,13 @@ def _make_http_cfg(url: str, transport: str = "streamableHttp"):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connect_skips_unreachable_streamable_http(): async def test_connect_skips_unreachable_streamable_http():
"""Unreachable streamableHttp server should be skipped with a warning, no crash.""" """Unreachable streamableHttp server should be skipped with a warning, no crash."""
async def _unreachable(_url: str) -> bool:
return False
registry = ToolRegistry() registry = ToolRegistry()
servers = {"dead": _make_http_cfg("http://127.0.0.1:19999/mcp")} servers = {"dead": _make_http_cfg("http://93.184.216.34:19999/mcp")}
stacks = await connect_mcp_servers(servers, registry) with patch("nanobot.agent.tools.mcp._probe_http_url", _unreachable):
stacks = await connect_mcp_servers(servers, registry)
assert stacks == {} assert stacks == {}
assert len(registry._tools) == 0 assert len(registry._tools) == 0
@ -69,9 +75,13 @@ async def test_connect_skips_unreachable_streamable_http():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connect_skips_unreachable_sse(): async def test_connect_skips_unreachable_sse():
"""Unreachable SSE server should be skipped with a warning, no crash.""" """Unreachable SSE server should be skipped with a warning, no crash."""
async def _unreachable(_url: str) -> bool:
return False
registry = ToolRegistry() registry = ToolRegistry()
servers = {"dead": _make_http_cfg("http://127.0.0.1:19999/sse", transport="sse")} servers = {"dead": _make_http_cfg("http://93.184.216.34:19999/sse", transport="sse")}
stacks = await connect_mcp_servers(servers, registry) with patch("nanobot.agent.tools.mcp._probe_http_url", _unreachable):
stacks = await connect_mcp_servers(servers, registry)
assert stacks == {} assert stacks == {}
assert len(registry._tools) == 0 assert len(registry._tools) == 0

View File

@ -5,6 +5,7 @@ import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from types import ModuleType, SimpleNamespace from types import ModuleType, SimpleNamespace
import httpx
import pytest import pytest
import nanobot.agent.tools.mcp as mcp_mod import nanobot.agent.tools.mcp as mcp_mod
@ -486,6 +487,80 @@ async def test_connect_mcp_servers_logs_stdio_pollution_hint(
assert "stderr" in messages[-1] assert "stderr" in messages[-1]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"config",
[
MCPServerConfig(url="http://127.0.0.1:9/sse"),
MCPServerConfig(type="streamableHttp", url="http://127.0.0.1:9/mcp"),
],
)
async def test_connect_mcp_servers_rejects_unsafe_http_urls_before_probe(
config: MCPServerConfig,
monkeypatch: pytest.MonkeyPatch,
) -> None:
attempted_connections: list[tuple[object, ...]] = []
warnings: list[str] = []
async def _open_connection(*args: object, **_kwargs: object):
attempted_connections.append(args)
raise AssertionError("unsafe MCP URL should be rejected before TCP probe")
def _warning(message: str, *args: object) -> None:
warnings.append(message.format(*args))
monkeypatch.setattr(mcp_mod.asyncio, "open_connection", _open_connection)
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
registry = ToolRegistry()
stacks = await connect_mcp_servers({"local": config}, registry)
assert stacks == {}
assert registry.tool_names == []
assert attempted_connections == []
assert any("blocked unsafe URL" in warning for warning in warnings)
@pytest.mark.asyncio
async def test_mcp_http_request_hook_rejects_unsafe_redirect_targets(
monkeypatch: pytest.MonkeyPatch,
) -> None:
checked_urls: list[str] = []
sent_urls: list[str] = []
def _validate(url: str) -> tuple[bool, str]:
checked_urls.append(url)
if url == "http://127.0.0.1/private":
return False, "loopback blocked"
return True, ""
def _handler(request: httpx.Request) -> httpx.Response:
sent_urls.append(str(request.url))
if str(request.url) == "https://example.com/start":
return httpx.Response(
302,
headers={"Location": "http://127.0.0.1/private"},
request=request,
)
raise AssertionError("unsafe redirect target should be blocked before transport")
monkeypatch.setattr(mcp_mod, "validate_url_target", _validate)
async with httpx.AsyncClient(
event_hooks={"request": [mcp_mod._validate_mcp_request_url]},
follow_redirects=True,
transport=httpx.MockTransport(_handler),
) as client:
with pytest.raises(httpx.RequestError, match="loopback blocked"):
await client.get("https://example.com/start")
assert checked_urls == [
"https://example.com/start",
"http://127.0.0.1/private",
]
assert sent_urls == ["https://example.com/start"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connect_mcp_servers_one_failure_does_not_block_others( async def test_connect_mcp_servers_one_failure_does_not_block_others(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,