From 06d454a225ca45af9081e1f70db0ce869a15bcca Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Sun, 7 Jun 2026 21:53:58 +0800 Subject: [PATCH] test: cover MCP redirect guard wiring Maintainer edit: make the unsafe redirect regression go through connect_mcp_servers so both SSE and streamable HTTP prove that the request hook is attached to the MCP clients before redirects are followed. --- tests/tools/test_mcp_tool.py | 60 +++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index d69fc03bc..949f4eec8 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -522,11 +522,24 @@ async def test_connect_mcp_servers_rejects_unsafe_http_urls_before_probe( @pytest.mark.asyncio -async def test_mcp_http_request_hook_rejects_unsafe_redirect_targets( +@pytest.mark.parametrize( + ("config", "expected_transport"), + [ + (MCPServerConfig(type="sse", url="https://mcp.example.com/sse"), "sse"), + ( + MCPServerConfig(type="streamableHttp", url="https://mcp.example.com/mcp"), + "streamableHttp", + ), + ], +) +async def test_connect_mcp_servers_http_clients_reject_unsafe_redirect_targets( + config: MCPServerConfig, + expected_transport: str, monkeypatch: pytest.MonkeyPatch, ) -> None: checked_urls: list[str] = [] sent_urls: list[str] = [] + used_transports: list[str] = [] def _validate(url: str) -> tuple[bool, str]: checked_urls.append(url) @@ -534,6 +547,9 @@ async def test_mcp_http_request_hook_rejects_unsafe_redirect_targets( return False, "loopback blocked" return True, "" + async def _reachable(_url: str) -> bool: + return True + def _handler(request: httpx.Request) -> httpx.Response: sent_urls.append(str(request.url)) if str(request.url) == "https://example.com/start": @@ -544,17 +560,45 @@ async def test_mcp_http_request_hook_rejects_unsafe_redirect_targets( ) raise AssertionError("unsafe redirect target should be blocked before transport") - monkeypatch.setattr(mcp_mod, "validate_url_target", _validate) + original_async_client = httpx.AsyncClient - async with httpx.AsyncClient( - event_hooks={"request": [mcp_mod._validate_mcp_request_url]}, - follow_redirects=True, - transport=httpx.MockTransport(_handler), - ) as client: - with pytest.raises(httpx.RequestError, match="loopback blocked"): + def _async_client_with_mock_transport(*args: object, **kwargs: object) -> httpx.AsyncClient: + kwargs.setdefault("transport", httpx.MockTransport(_handler)) + return original_async_client(*args, **kwargs) + + @asynccontextmanager + async def _fake_sse_client(_url: str, httpx_client_factory=None): + assert httpx_client_factory is not None + used_transports.append("sse") + async with httpx_client_factory() as client: await client.get("https://example.com/start") + yield object(), object() + @asynccontextmanager + async def _fake_streamable_http_client(_url: str, http_client=None): + assert http_client is not None + used_transports.append("streamableHttp") + await http_client.get("https://example.com/start") + yield object(), object(), object() + + monkeypatch.setattr(mcp_mod, "validate_url_target", _validate) + monkeypatch.setattr(mcp_mod, "_probe_http_url", _reachable) + monkeypatch.setattr(mcp_mod.httpx, "AsyncClient", _async_client_with_mock_transport) + monkeypatch.setattr(sys.modules["mcp.client.sse"], "sse_client", _fake_sse_client) + monkeypatch.setattr( + sys.modules["mcp.client.streamable_http"], + "streamable_http_client", + _fake_streamable_http_client, + ) + + registry = ToolRegistry() + stacks = await connect_mcp_servers({"remote": config}, registry) + + assert stacks == {} + assert registry.tool_names == [] + assert used_transports == [expected_transport] assert checked_urls == [ + config.url, "https://example.com/start", "http://127.0.0.1/private", ]