fix(web): keep safe fetch preflight streaming

This commit is contained in:
Xubin Ren 2026-05-22 23:02:15 +08:00
parent 25d00b1ea4
commit 545294c62c
2 changed files with 119 additions and 9 deletions

View File

@ -114,6 +114,46 @@ async def _get_with_safe_redirects(
return None, f"Too many redirects: exceeded limit of {MAX_REDIRECTS}"
async def _stream_with_safe_redirects(
client: httpx.AsyncClient,
url: str,
headers: dict[str, str] | None = None,
) -> tuple[httpx.Response | None, Any | None, str | None]:
"""Open a streamed response while validating every redirect target first."""
current_url = url
for _ in range(MAX_REDIRECTS + 1):
is_valid, error_msg = _validate_url_safe(current_url)
if not is_valid:
return None, None, f"Redirect blocked: {error_msg}"
stream = client.stream(
"GET",
current_url,
headers=headers,
follow_redirects=False,
)
response = await stream.__aenter__()
is_redirect = 300 <= response.status_code < 400
if not is_redirect:
return response, stream, None
location = response.headers.get("location")
if not location:
return response, stream, None
next_url = urljoin(str(response.url), location)
is_valid, error_msg = _validate_url_safe(next_url)
if not is_valid:
await stream.__aexit__(None, None, None)
return None, None, f"Redirect blocked: {error_msg}"
await stream.__aexit__(None, None, None)
current_url = next_url
return None, None, f"Too many redirects: exceeded limit of {MAX_REDIRECTS}"
def _format_results(query: str, items: list[dict[str, Any]], n: int) -> str:
"""Format provider results into shared plaintext output."""
if not items:
@ -522,7 +562,7 @@ class WebFetchTool(Tool):
# Detect and fetch images directly to avoid Jina's textual image captioning
try:
async with httpx.AsyncClient(proxy=self.proxy, timeout=15.0) as client:
r, redirect_error = await _get_with_safe_redirects(
r, stream, redirect_error = await _stream_with_safe_redirects(
client,
url,
headers={"User-Agent": self.user_agent},
@ -532,11 +572,15 @@ class WebFetchTool(Tool):
if r is None:
return json.dumps({"error": "Fetch failed", "url": url}, ensure_ascii=False)
ctype = r.headers.get("content-type", "")
if ctype.startswith("image/"):
r.raise_for_status()
raw = r.content
return build_image_content_blocks(raw, ctype, url, f"(Image fetched from: {url})")
try:
ctype = r.headers.get("content-type", "")
if ctype.startswith("image/"):
r.raise_for_status()
raw = await r.aread()
return build_image_content_blocks(raw, ctype, url, f"(Image fetched from: {url})")
finally:
if stream is not None:
await stream.__aexit__(None, None, None)
except Exception as e:
logger.debug("Pre-fetch image detection failed for {}: {}", url, e)
@ -585,8 +629,6 @@ class WebFetchTool(Tool):
async def _fetch_readability(self, url: str, extract_mode: str, max_chars: int) -> Any:
"""Local fallback using readability-lxml."""
from readability import Document
try:
async with httpx.AsyncClient(
timeout=30.0,
@ -610,6 +652,8 @@ class WebFetchTool(Tool):
if "application/json" in ctype:
text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")):
from readability import Document
doc = Document(r.text)
content = self._to_markdown(doc.summary()) if extract_mode == "markdown" else _strip_tags(doc.summary())
text = f"# {doc.title()}\n\n{content}" if doc.title() else content

View File

@ -86,6 +86,7 @@ async def test_web_fetch_can_skip_jina_and_use_custom_user_agent(monkeypatch):
raise AssertionError("Jina Reader should be skipped when disabled")
class FakeStreamResponse:
status_code = 200
headers = {"content-type": "text/html"}
url = "https://example.com/page"
@ -95,6 +96,9 @@ async def test_web_fetch_can_skip_jina_and_use_custom_user_agent(monkeypatch):
async def __aexit__(self, exc_type, exc, tb):
return False
async def aread(self):
raise AssertionError("non-image prefetch body should not be read")
class FakeResponse:
status_code = 200
url = "https://example.com/page"
@ -115,7 +119,7 @@ async def test_web_fetch_can_skip_jina_and_use_custom_user_agent(monkeypatch):
async def __aexit__(self, exc_type, exc, tb):
return False
def stream(self, method, url, headers=None):
def stream(self, method, url, headers=None, **kwargs):
seen_headers.append(headers or {})
return FakeStreamResponse()
@ -137,6 +141,68 @@ async def test_web_fetch_can_skip_jina_and_use_custom_user_agent(monkeypatch):
]
@pytest.mark.asyncio
async def test_web_fetch_blocks_private_redirect_before_readability_request(monkeypatch):
tool = WebFetchTool(config=WebFetchConfig(use_jina_reader=False))
requested: list[str] = []
class FakeStreamResponse:
status_code = 200
headers = {"content-type": "text/html"}
url = "https://attacker.example/start"
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def aread(self):
raise AssertionError("non-image prefetch body should not be read")
class FakeRedirectResponse:
status_code = 302
headers = {"location": "http://127.0.0.1:8765/metadata"}
url = "https://attacker.example/start"
async def aclose(self):
return None
class FakeClient:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
def stream(self, method, url, headers=None, **kwargs):
return FakeStreamResponse()
async def get(self, url, headers=None, **kwargs):
requested.append(url)
if url == "http://127.0.0.1:8765/metadata":
raise AssertionError("private redirect target should not be requested")
return FakeRedirectResponse()
monkeypatch.setattr(web_module.httpx, "AsyncClient", FakeClient)
def resolve_public_start_only(hostname, port, family=0, type_=0):
if hostname == "attacker.example":
return _fake_resolve_public(hostname, port, family, type_)
return _REAL_GETADDRINFO(hostname, port, family, type_)
with patch("nanobot.security.network.socket.getaddrinfo", resolve_public_start_only):
result = await tool.execute(url="https://attacker.example/start")
data = json.loads(result)
assert "error" in data
assert "redirect blocked" in data["error"].lower()
assert requested == ["https://attacker.example/start"]
@pytest.mark.asyncio
async def test_web_fetch_blocks_private_redirect_before_returning_image(monkeypatch):
tool = WebFetchTool(config=WebFetchConfig(use_jina_reader=False))