fix(web): validate redirect targets before fetching

This commit is contained in:
hinotoi-agent 2026-05-20 16:44:01 +08:00 committed by Xubin Ren
parent b1140f6aee
commit ff173045fe
2 changed files with 132 additions and 50 deletions

View File

@ -8,7 +8,7 @@ import json
import os import os
import re import re
from typing import Any, Callable from typing import Any, Callable
from urllib.parse import quote, urlparse from urllib.parse import quote, urljoin, urlparse
import httpx import httpx
from loguru import logger from loguru import logger
@ -78,9 +78,41 @@ def _validate_url(url: str) -> tuple[bool, str]:
def _validate_url_safe(url: str) -> tuple[bool, str]: def _validate_url_safe(url: str) -> tuple[bool, str]:
"""Validate URL with SSRF protection: scheme, domain, and resolved IP check.""" """Validate URL with SSRF protection: scheme, domain, and resolved IP check."""
from nanobot.security.network import validate_url_target from nanobot.security.network import validate_url_target
return validate_url_target(url) return validate_url_target(url)
async def _get_with_safe_redirects(
client: httpx.AsyncClient,
url: str,
headers: dict[str, str] | None = None,
) -> tuple[httpx.Response | None, str | None]:
"""GET a URL while validating every redirect target before requesting it."""
current_url = url
for _ in range(MAX_REDIRECTS + 1):
is_valid, error_msg = _validate_url_safe(current_url)
if not is_valid:
return None, f"Redirect blocked: {error_msg}"
response = await client.get(current_url, headers=headers, follow_redirects=False)
if not response.is_redirect:
return response, None
location = response.headers.get("location")
if not location:
return response, None
next_url = urljoin(str(response.url), location)
is_valid, error_msg = _validate_url_safe(next_url)
if not is_valid:
await response.aclose()
return None, f"Redirect blocked: {error_msg}"
await response.aclose()
current_url = next_url
return None, f"Too many redirects: exceeded limit of {MAX_REDIRECTS}"
def _format_results(query: str, items: list[dict[str, Any]], n: int) -> str: def _format_results(query: str, items: list[dict[str, Any]], n: int) -> str:
"""Format provider results into shared plaintext output.""" """Format provider results into shared plaintext output."""
if not items: if not items:
@ -488,18 +520,21 @@ class WebFetchTool(Tool):
# Detect and fetch images directly to avoid Jina's textual image captioning # Detect and fetch images directly to avoid Jina's textual image captioning
try: try:
async with httpx.AsyncClient(proxy=self.proxy, follow_redirects=True, max_redirects=MAX_REDIRECTS, timeout=15.0) as client: async with httpx.AsyncClient(proxy=self.proxy, timeout=15.0) as client:
async with client.stream("GET", url, headers={"User-Agent": self.user_agent}) as r: r, redirect_error = await _get_with_safe_redirects(
from nanobot.security.network import validate_resolved_url client,
url,
redir_ok, redir_err = validate_resolved_url(str(r.url)) headers={"User-Agent": self.user_agent},
if not redir_ok: )
return json.dumps({"error": f"Redirect blocked: {redir_err}", "url": url}, ensure_ascii=False) if redirect_error:
return json.dumps({"error": redirect_error, "url": url}, ensure_ascii=False)
if r is None:
return json.dumps({"error": "Fetch failed", "url": url}, ensure_ascii=False)
ctype = r.headers.get("content-type", "") ctype = r.headers.get("content-type", "")
if ctype.startswith("image/"): if ctype.startswith("image/"):
r.raise_for_status() r.raise_for_status()
raw = await r.aread() raw = r.content
return build_image_content_blocks(raw, ctype, url, f"(Image fetched from: {url})") return build_image_content_blocks(raw, ctype, url, f"(Image fetched from: {url})")
except Exception as e: except Exception as e:
logger.debug("Pre-fetch image detection failed for {}: {}", url, e) logger.debug("Pre-fetch image detection failed for {}: {}", url, e)
@ -553,19 +588,20 @@ class WebFetchTool(Tool):
try: try:
async with httpx.AsyncClient( async with httpx.AsyncClient(
follow_redirects=True,
max_redirects=MAX_REDIRECTS,
timeout=30.0, timeout=30.0,
proxy=self.proxy, proxy=self.proxy,
) as client: ) as client:
r = await client.get(url, headers={"User-Agent": self.user_agent}) r, redirect_error = await _get_with_safe_redirects(
client,
url,
headers={"User-Agent": self.user_agent},
)
if redirect_error:
return json.dumps({"error": redirect_error, "url": url}, ensure_ascii=False)
if r is None:
return json.dumps({"error": "Fetch failed", "url": url}, ensure_ascii=False)
r.raise_for_status() r.raise_for_status()
from nanobot.security.network import validate_resolved_url
redir_ok, redir_err = validate_resolved_url(str(r.url))
if not redir_ok:
return json.dumps({"error": f"Redirect blocked: {redir_err}", "url": url}, ensure_ascii=False)
ctype = r.headers.get("content-type", "") ctype = r.headers.get("content-type", "")
if ctype.startswith("image/"): if ctype.startswith("image/"):
return build_image_content_blocks(r.content, ctype, url, f"(Image fetched from: {url})") return build_image_content_blocks(r.content, ctype, url, f"(Image fetched from: {url})")

View File

@ -6,11 +6,15 @@ import json
import socket import socket
from unittest.mock import patch from unittest.mock import patch
import httpx
import pytest import pytest
from nanobot.agent.tools import web as web_module
from nanobot.agent.tools.web import WebFetchTool from nanobot.agent.tools.web import WebFetchTool
from nanobot.config.schema import WebFetchConfig from nanobot.config.schema import WebFetchConfig
_REAL_GETADDRINFO = socket.getaddrinfo
def _fake_resolve_private(hostname, port, family=0, type_=0): def _fake_resolve_private(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))] return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))]
@ -54,6 +58,7 @@ async def test_web_fetch_result_contains_untrusted_flag():
url = "https://example.com/page" url = "https://example.com/page"
text = fake_html text = fake_html
headers = {"content-type": "text/html"} headers = {"content-type": "text/html"}
is_redirect = False
def raise_for_status(self): pass def raise_for_status(self): pass
def json(self): return {} def json(self): return {}
@ -95,6 +100,7 @@ async def test_web_fetch_can_skip_jina_and_use_custom_user_agent(monkeypatch):
url = "https://example.com/page" url = "https://example.com/page"
text = "<html><head><title>Test</title></head><body><p>Hello world</p></body></html>" text = "<html><head><title>Test</title></head><body><p>Hello world</p></body></html>"
headers = {"content-type": "text/html"} headers = {"content-type": "text/html"}
is_redirect = False
def raise_for_status(self): def raise_for_status(self):
return None return None
@ -113,7 +119,7 @@ async def test_web_fetch_can_skip_jina_and_use_custom_user_agent(monkeypatch):
seen_headers.append(headers or {}) seen_headers.append(headers or {})
return FakeStreamResponse() return FakeStreamResponse()
async def get(self, url, headers=None): async def get(self, url, headers=None, **kwargs):
seen_headers.append(headers or {}) seen_headers.append(headers or {})
return FakeResponse() return FakeResponse()
@ -133,43 +139,83 @@ async def test_web_fetch_can_skip_jina_and_use_custom_user_agent(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_web_fetch_blocks_private_redirect_before_returning_image(monkeypatch): async def test_web_fetch_blocks_private_redirect_before_returning_image(monkeypatch):
tool = WebFetchTool() tool = WebFetchTool(config=WebFetchConfig(use_jina_reader=False))
class FakeStreamResponse: def handler(request: httpx.Request) -> httpx.Response:
headers = {"content-type": "image/png"} if str(request.url) == "https://example.com/image.png":
url = "http://127.0.0.1/secret.png" return httpx.Response(
content = b"\x89PNG\r\n\x1a\n" 302,
headers={"Location": "http://127.0.0.1/secret.png"},
request=request,
)
if str(request.url) == "http://127.0.0.1/secret.png":
return httpx.Response(
200,
headers={"content-type": "image/png"},
content=b"\x89PNG\r\n\x1a\n",
request=request,
)
return httpx.Response(404, request=request)
async def __aenter__(self): transport = httpx.MockTransport(handler)
return self real_async_client = httpx.AsyncClient
async def __aexit__(self, exc_type, exc, tb): class TransportAsyncClient(real_async_client):
return False
async def aread(self):
return self.content
def raise_for_status(self):
return None
class FakeClient:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass kwargs.pop("proxy", None)
super().__init__(*args, transport=transport, **kwargs)
async def __aenter__(self): monkeypatch.setattr("nanobot.agent.tools.web.httpx.AsyncClient", TransportAsyncClient)
return self
async def __aexit__(self, exc_type, exc, tb): def resolve_public_start_only(hostname, port, family=0, type_=0):
return False if hostname == "example.com":
return _fake_resolve_public(hostname, port, family, type_)
return _REAL_GETADDRINFO(hostname, port, family, type_)
def stream(self, method, url, headers=None): with patch("nanobot.security.network.socket.getaddrinfo", resolve_public_start_only):
return FakeStreamResponse()
monkeypatch.setattr("nanobot.agent.tools.web.httpx.AsyncClient", FakeClient)
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public):
result = await tool.execute(url="https://example.com/image.png") result = await tool.execute(url="https://example.com/image.png")
data = json.loads(result) data = json.loads(result)
assert "error" in data assert "error" in data
assert "redirect blocked" in data["error"].lower() assert "redirect blocked" in data["error"].lower()
@pytest.mark.asyncio
async def test_web_fetch_does_not_request_private_redirect_target(monkeypatch):
tool = WebFetchTool(config=WebFetchConfig(use_jina_reader=False))
requested: list[str] = []
def handler(request: httpx.Request) -> httpx.Response:
requested.append(str(request.url))
if str(request.url) == "https://attacker.example/start":
return httpx.Response(
302,
headers={"Location": "http://127.0.0.1:8765/metadata"},
request=request,
)
if str(request.url) == "http://127.0.0.1:8765/metadata":
return httpx.Response(200, content=b"internal secret", request=request)
return httpx.Response(404, request=request)
transport = httpx.MockTransport(handler)
real_async_client = httpx.AsyncClient
class TransportAsyncClient(real_async_client):
def __init__(self, *args, **kwargs):
kwargs["transport"] = transport
super().__init__(*args, **kwargs)
monkeypatch.setattr(web_module.httpx, "AsyncClient", TransportAsyncClient)
def resolve_public_start_only(hostname, port, family=0, type_=0):
if hostname == "attacker.example":
return _fake_resolve_public(hostname, port, family, type_)
return _REAL_GETADDRINFO(hostname, port, family, type_)
with patch("nanobot.security.network.socket.getaddrinfo", resolve_public_start_only):
result = await tool.execute(url="https://attacker.example/start")
data = json.loads(result)
assert "error" in data
assert "redirect blocked" in data["error"].lower()
assert requested == ["https://attacker.example/start"]