mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
test: cover MCP redirect guard wiring
Maintainer edit: make the unsafe redirect regression go through connect_mcp_servers so both SSE and streamable HTTP prove that the request hook is attached to the MCP clients before redirects are followed.
This commit is contained in:
parent
a73924f77e
commit
06d454a225
@ -522,11 +522,24 @@ async def test_connect_mcp_servers_rejects_unsafe_http_urls_before_probe(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_request_hook_rejects_unsafe_redirect_targets(
|
||||
@pytest.mark.parametrize(
|
||||
("config", "expected_transport"),
|
||||
[
|
||||
(MCPServerConfig(type="sse", url="https://mcp.example.com/sse"), "sse"),
|
||||
(
|
||||
MCPServerConfig(type="streamableHttp", url="https://mcp.example.com/mcp"),
|
||||
"streamableHttp",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_connect_mcp_servers_http_clients_reject_unsafe_redirect_targets(
|
||||
config: MCPServerConfig,
|
||||
expected_transport: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
checked_urls: list[str] = []
|
||||
sent_urls: list[str] = []
|
||||
used_transports: list[str] = []
|
||||
|
||||
def _validate(url: str) -> tuple[bool, str]:
|
||||
checked_urls.append(url)
|
||||
@ -534,6 +547,9 @@ async def test_mcp_http_request_hook_rejects_unsafe_redirect_targets(
|
||||
return False, "loopback blocked"
|
||||
return True, ""
|
||||
|
||||
async def _reachable(_url: str) -> bool:
|
||||
return True
|
||||
|
||||
def _handler(request: httpx.Request) -> httpx.Response:
|
||||
sent_urls.append(str(request.url))
|
||||
if str(request.url) == "https://example.com/start":
|
||||
@ -544,17 +560,45 @@ async def test_mcp_http_request_hook_rejects_unsafe_redirect_targets(
|
||||
)
|
||||
raise AssertionError("unsafe redirect target should be blocked before transport")
|
||||
|
||||
monkeypatch.setattr(mcp_mod, "validate_url_target", _validate)
|
||||
original_async_client = httpx.AsyncClient
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
event_hooks={"request": [mcp_mod._validate_mcp_request_url]},
|
||||
follow_redirects=True,
|
||||
transport=httpx.MockTransport(_handler),
|
||||
) as client:
|
||||
with pytest.raises(httpx.RequestError, match="loopback blocked"):
|
||||
def _async_client_with_mock_transport(*args: object, **kwargs: object) -> httpx.AsyncClient:
|
||||
kwargs.setdefault("transport", httpx.MockTransport(_handler))
|
||||
return original_async_client(*args, **kwargs)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _fake_sse_client(_url: str, httpx_client_factory=None):
|
||||
assert httpx_client_factory is not None
|
||||
used_transports.append("sse")
|
||||
async with httpx_client_factory() as client:
|
||||
await client.get("https://example.com/start")
|
||||
yield object(), object()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _fake_streamable_http_client(_url: str, http_client=None):
|
||||
assert http_client is not None
|
||||
used_transports.append("streamableHttp")
|
||||
await http_client.get("https://example.com/start")
|
||||
yield object(), object(), object()
|
||||
|
||||
monkeypatch.setattr(mcp_mod, "validate_url_target", _validate)
|
||||
monkeypatch.setattr(mcp_mod, "_probe_http_url", _reachable)
|
||||
monkeypatch.setattr(mcp_mod.httpx, "AsyncClient", _async_client_with_mock_transport)
|
||||
monkeypatch.setattr(sys.modules["mcp.client.sse"], "sse_client", _fake_sse_client)
|
||||
monkeypatch.setattr(
|
||||
sys.modules["mcp.client.streamable_http"],
|
||||
"streamable_http_client",
|
||||
_fake_streamable_http_client,
|
||||
)
|
||||
|
||||
registry = ToolRegistry()
|
||||
stacks = await connect_mcp_servers({"remote": config}, registry)
|
||||
|
||||
assert stacks == {}
|
||||
assert registry.tool_names == []
|
||||
assert used_transports == [expected_transport]
|
||||
assert checked_urls == [
|
||||
config.url,
|
||||
"https://example.com/start",
|
||||
"http://127.0.0.1/private",
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user