mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
fix(mcp): reject unsafe HTTP URLs before probe
This commit is contained in:
parent
6e6470daa0
commit
ed0aeb1ea9
@ -21,6 +21,7 @@ from nanobot.bus.events import (
|
|||||||
RUNTIME_CONTROL_MCP_RELOAD,
|
RUNTIME_CONTROL_MCP_RELOAD,
|
||||||
InboundMessage,
|
InboundMessage,
|
||||||
)
|
)
|
||||||
|
from nanobot.security.network import validate_url_target
|
||||||
|
|
||||||
# Transient connection errors that warrant a single retry.
|
# Transient connection errors that warrant a single retry.
|
||||||
# These typically happen when an MCP server restarts or a network
|
# These typically happen when an MCP server restarts or a network
|
||||||
@ -87,12 +88,23 @@ async def _probe_http_url(url: str, timeout: float = 3.0) -> bool:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
writer.close()
|
writer.close()
|
||||||
await writer.wait_closed()
|
with suppress(OSError, asyncio.TimeoutError):
|
||||||
|
await asyncio.wait_for(writer.wait_closed(), timeout=0.2)
|
||||||
return True
|
return True
|
||||||
except (OSError, asyncio.TimeoutError):
|
except (OSError, asyncio.TimeoutError):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _validate_mcp_request_url(request: httpx.Request) -> None:
|
||||||
|
"""Validate each outgoing MCP HTTP request, including redirect targets."""
|
||||||
|
ok, error = validate_url_target(str(request.url))
|
||||||
|
if not ok:
|
||||||
|
raise httpx.RequestError(
|
||||||
|
f"Blocked unsafe MCP URL {request.url} ({error})",
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _windows_command_basename(command: str) -> str:
|
def _windows_command_basename(command: str) -> str:
|
||||||
"""Return the lowercase basename for a Windows command or path."""
|
"""Return the lowercase basename for a Windows command or path."""
|
||||||
return command.replace("\\", "/").rsplit("/", maxsplit=1)[-1].lower()
|
return command.replace("\\", "/").rsplit("/", maxsplit=1)[-1].lower()
|
||||||
@ -595,6 +607,18 @@ async def connect_mcp_servers(
|
|||||||
await server_stack.aclose()
|
await server_stack.aclose()
|
||||||
return name, None
|
return name, None
|
||||||
|
|
||||||
|
if transport_type in {"sse", "streamableHttp"}:
|
||||||
|
ok, error = validate_url_target(cfg.url)
|
||||||
|
if not ok:
|
||||||
|
logger.warning(
|
||||||
|
"MCP server '{}': blocked unsafe URL {} ({})",
|
||||||
|
name,
|
||||||
|
cfg.url,
|
||||||
|
error,
|
||||||
|
)
|
||||||
|
await server_stack.aclose()
|
||||||
|
return name, None
|
||||||
|
|
||||||
if transport_type == "stdio":
|
if transport_type == "stdio":
|
||||||
command, args, env = _normalize_windows_stdio_command(
|
command, args, env = _normalize_windows_stdio_command(
|
||||||
cfg.command,
|
cfg.command,
|
||||||
@ -626,6 +650,7 @@ async def connect_mcp_servers(
|
|||||||
}
|
}
|
||||||
return httpx.AsyncClient(
|
return httpx.AsyncClient(
|
||||||
headers=merged_headers or None,
|
headers=merged_headers or None,
|
||||||
|
event_hooks={"request": [_validate_mcp_request_url]},
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
auth=auth,
|
auth=auth,
|
||||||
@ -643,6 +668,7 @@ async def connect_mcp_servers(
|
|||||||
http_client = await server_stack.enter_async_context(
|
http_client = await server_stack.enter_async_context(
|
||||||
httpx.AsyncClient(
|
httpx.AsyncClient(
|
||||||
headers=cfg.headers or None,
|
headers=cfg.headers or None,
|
||||||
|
event_hooks={"request": [_validate_mcp_request_url]},
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -16,9 +16,11 @@ from nanobot.agent.tools.registry import ToolRegistry
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_probe_returns_true_for_open_port(tmp_path):
|
async def test_probe_returns_true_for_open_port(tmp_path):
|
||||||
"""Start a trivial TCP server, probe should return True."""
|
"""Start a trivial TCP server, probe should return True."""
|
||||||
server = await asyncio.start_server(
|
async def _close_connection(_reader, writer):
|
||||||
lambda r, w: None, "127.0.0.1", 0,
|
writer.close()
|
||||||
)
|
await writer.wait_closed()
|
||||||
|
|
||||||
|
server = await asyncio.start_server(_close_connection, "127.0.0.1", 0)
|
||||||
port = server.sockets[0].getsockname()[1]
|
port = server.sockets[0].getsockname()[1]
|
||||||
try:
|
try:
|
||||||
assert await _probe_http_url(f"http://127.0.0.1:{port}/mcp") is True
|
assert await _probe_http_url(f"http://127.0.0.1:{port}/mcp") is True
|
||||||
@ -59,9 +61,13 @@ def _make_http_cfg(url: str, transport: str = "streamableHttp"):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_connect_skips_unreachable_streamable_http():
|
async def test_connect_skips_unreachable_streamable_http():
|
||||||
"""Unreachable streamableHttp server should be skipped with a warning, no crash."""
|
"""Unreachable streamableHttp server should be skipped with a warning, no crash."""
|
||||||
|
async def _unreachable(_url: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
servers = {"dead": _make_http_cfg("http://127.0.0.1:19999/mcp")}
|
servers = {"dead": _make_http_cfg("http://93.184.216.34:19999/mcp")}
|
||||||
stacks = await connect_mcp_servers(servers, registry)
|
with patch("nanobot.agent.tools.mcp._probe_http_url", _unreachable):
|
||||||
|
stacks = await connect_mcp_servers(servers, registry)
|
||||||
assert stacks == {}
|
assert stacks == {}
|
||||||
assert len(registry._tools) == 0
|
assert len(registry._tools) == 0
|
||||||
|
|
||||||
@ -69,9 +75,13 @@ async def test_connect_skips_unreachable_streamable_http():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_connect_skips_unreachable_sse():
|
async def test_connect_skips_unreachable_sse():
|
||||||
"""Unreachable SSE server should be skipped with a warning, no crash."""
|
"""Unreachable SSE server should be skipped with a warning, no crash."""
|
||||||
|
async def _unreachable(_url: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
servers = {"dead": _make_http_cfg("http://127.0.0.1:19999/sse", transport="sse")}
|
servers = {"dead": _make_http_cfg("http://93.184.216.34:19999/sse", transport="sse")}
|
||||||
stacks = await connect_mcp_servers(servers, registry)
|
with patch("nanobot.agent.tools.mcp._probe_http_url", _unreachable):
|
||||||
|
stacks = await connect_mcp_servers(servers, registry)
|
||||||
assert stacks == {}
|
assert stacks == {}
|
||||||
assert len(registry._tools) == 0
|
assert len(registry._tools) == 0
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import sys
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from types import ModuleType, SimpleNamespace
|
from types import ModuleType, SimpleNamespace
|
||||||
|
|
||||||
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import nanobot.agent.tools.mcp as mcp_mod
|
import nanobot.agent.tools.mcp as mcp_mod
|
||||||
@ -486,6 +487,80 @@ async def test_connect_mcp_servers_logs_stdio_pollution_hint(
|
|||||||
assert "stderr" in messages[-1]
|
assert "stderr" in messages[-1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config",
|
||||||
|
[
|
||||||
|
MCPServerConfig(url="http://127.0.0.1:9/sse"),
|
||||||
|
MCPServerConfig(type="streamableHttp", url="http://127.0.0.1:9/mcp"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_connect_mcp_servers_rejects_unsafe_http_urls_before_probe(
|
||||||
|
config: MCPServerConfig,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
attempted_connections: list[tuple[object, ...]] = []
|
||||||
|
warnings: list[str] = []
|
||||||
|
|
||||||
|
async def _open_connection(*args: object, **_kwargs: object):
|
||||||
|
attempted_connections.append(args)
|
||||||
|
raise AssertionError("unsafe MCP URL should be rejected before TCP probe")
|
||||||
|
|
||||||
|
def _warning(message: str, *args: object) -> None:
|
||||||
|
warnings.append(message.format(*args))
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_mod.asyncio, "open_connection", _open_connection)
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
||||||
|
|
||||||
|
registry = ToolRegistry()
|
||||||
|
stacks = await connect_mcp_servers({"local": config}, registry)
|
||||||
|
|
||||||
|
assert stacks == {}
|
||||||
|
assert registry.tool_names == []
|
||||||
|
assert attempted_connections == []
|
||||||
|
assert any("blocked unsafe URL" in warning for warning in warnings)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_http_request_hook_rejects_unsafe_redirect_targets(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
checked_urls: list[str] = []
|
||||||
|
sent_urls: list[str] = []
|
||||||
|
|
||||||
|
def _validate(url: str) -> tuple[bool, str]:
|
||||||
|
checked_urls.append(url)
|
||||||
|
if url == "http://127.0.0.1/private":
|
||||||
|
return False, "loopback blocked"
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
def _handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
sent_urls.append(str(request.url))
|
||||||
|
if str(request.url) == "https://example.com/start":
|
||||||
|
return httpx.Response(
|
||||||
|
302,
|
||||||
|
headers={"Location": "http://127.0.0.1/private"},
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
raise AssertionError("unsafe redirect target should be blocked before transport")
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_mod, "validate_url_target", _validate)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
event_hooks={"request": [mcp_mod._validate_mcp_request_url]},
|
||||||
|
follow_redirects=True,
|
||||||
|
transport=httpx.MockTransport(_handler),
|
||||||
|
) as client:
|
||||||
|
with pytest.raises(httpx.RequestError, match="loopback blocked"):
|
||||||
|
await client.get("https://example.com/start")
|
||||||
|
|
||||||
|
assert checked_urls == [
|
||||||
|
"https://example.com/start",
|
||||||
|
"http://127.0.0.1/private",
|
||||||
|
]
|
||||||
|
assert sent_urls == ["https://example.com/start"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_connect_mcp_servers_one_failure_does_not_block_others(
|
async def test_connect_mcp_servers_one_failure_does_not_block_others(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user