test: cover MCP redirect guard wiring

Maintainer edit: make the unsafe redirect regression go through connect_mcp_servers so both SSE and streamable HTTP prove that the request hook is attached to the MCP clients before redirects are followed.
This commit is contained in:
chengyongru 2026-06-07 21:53:58 +08:00 committed by Xubin Ren
parent a73924f77e
commit 06d454a225

View File

@ -522,11 +522,24 @@ async def test_connect_mcp_servers_rejects_unsafe_http_urls_before_probe(
@pytest.mark.asyncio
async def test_mcp_http_request_hook_rejects_unsafe_redirect_targets(
@pytest.mark.parametrize(
("config", "expected_transport"),
[
(MCPServerConfig(type="sse", url="https://mcp.example.com/sse"), "sse"),
(
MCPServerConfig(type="streamableHttp", url="https://mcp.example.com/mcp"),
"streamableHttp",
),
],
)
async def test_connect_mcp_servers_http_clients_reject_unsafe_redirect_targets(
config: MCPServerConfig,
expected_transport: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
checked_urls: list[str] = []
sent_urls: list[str] = []
used_transports: list[str] = []
def _validate(url: str) -> tuple[bool, str]:
checked_urls.append(url)
@ -534,6 +547,9 @@ async def test_mcp_http_request_hook_rejects_unsafe_redirect_targets(
return False, "loopback blocked"
return True, ""
async def _reachable(_url: str) -> bool:
return True
def _handler(request: httpx.Request) -> httpx.Response:
sent_urls.append(str(request.url))
if str(request.url) == "https://example.com/start":
@ -544,17 +560,45 @@ async def test_mcp_http_request_hook_rejects_unsafe_redirect_targets(
)
raise AssertionError("unsafe redirect target should be blocked before transport")
monkeypatch.setattr(mcp_mod, "validate_url_target", _validate)
original_async_client = httpx.AsyncClient
async with httpx.AsyncClient(
event_hooks={"request": [mcp_mod._validate_mcp_request_url]},
follow_redirects=True,
transport=httpx.MockTransport(_handler),
) as client:
with pytest.raises(httpx.RequestError, match="loopback blocked"):
def _async_client_with_mock_transport(*args: object, **kwargs: object) -> httpx.AsyncClient:
kwargs.setdefault("transport", httpx.MockTransport(_handler))
return original_async_client(*args, **kwargs)
@asynccontextmanager
async def _fake_sse_client(_url: str, httpx_client_factory=None):
assert httpx_client_factory is not None
used_transports.append("sse")
async with httpx_client_factory() as client:
await client.get("https://example.com/start")
yield object(), object()
@asynccontextmanager
async def _fake_streamable_http_client(_url: str, http_client=None):
assert http_client is not None
used_transports.append("streamableHttp")
await http_client.get("https://example.com/start")
yield object(), object(), object()
monkeypatch.setattr(mcp_mod, "validate_url_target", _validate)
monkeypatch.setattr(mcp_mod, "_probe_http_url", _reachable)
monkeypatch.setattr(mcp_mod.httpx, "AsyncClient", _async_client_with_mock_transport)
monkeypatch.setattr(sys.modules["mcp.client.sse"], "sse_client", _fake_sse_client)
monkeypatch.setattr(
sys.modules["mcp.client.streamable_http"],
"streamable_http_client",
_fake_streamable_http_client,
)
registry = ToolRegistry()
stacks = await connect_mcp_servers({"remote": config}, registry)
assert stacks == {}
assert registry.tool_names == []
assert used_transports == [expected_transport]
assert checked_urls == [
config.url,
"https://example.com/start",
"http://127.0.0.1/private",
]