mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-07 02:05:51 +00:00
fix(agent): soften SSRF guard recovery
Keep private URL access blocked at the tool boundary, but return a clear non-retryable hint so the agent can recover conversationally instead of aborting the turn. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
d97e177981
commit
db14685a69
@ -776,8 +776,6 @@ class AgentRunner:
|
|||||||
handled = self._classify_violation(
|
handled = self._classify_violation(
|
||||||
raw_text=prep_error,
|
raw_text=prep_error,
|
||||||
soft_payload=prep_error + hint,
|
soft_payload=prep_error + hint,
|
||||||
ssrf_payload=prep_error,
|
|
||||||
ssrf_error=RuntimeError(prep_error),
|
|
||||||
event=event,
|
event=event,
|
||||||
tool_call=tool_call,
|
tool_call=tool_call,
|
||||||
workspace_violation_counts=workspace_violation_counts,
|
workspace_violation_counts=workspace_violation_counts,
|
||||||
@ -808,8 +806,6 @@ class AgentRunner:
|
|||||||
raw_text=str(exc),
|
raw_text=str(exc),
|
||||||
# Preserve legacy exception payloads without the retry hint.
|
# Preserve legacy exception payloads without the retry hint.
|
||||||
soft_payload=payload,
|
soft_payload=payload,
|
||||||
ssrf_payload=payload,
|
|
||||||
ssrf_error=exc,
|
|
||||||
event=event,
|
event=event,
|
||||||
tool_call=tool_call,
|
tool_call=tool_call,
|
||||||
workspace_violation_counts=workspace_violation_counts,
|
workspace_violation_counts=workspace_violation_counts,
|
||||||
@ -829,8 +825,6 @@ class AgentRunner:
|
|||||||
handled = self._classify_violation(
|
handled = self._classify_violation(
|
||||||
raw_text=result,
|
raw_text=result,
|
||||||
soft_payload=result + hint,
|
soft_payload=result + hint,
|
||||||
ssrf_payload=result,
|
|
||||||
ssrf_error=RuntimeError(result),
|
|
||||||
event=event,
|
event=event,
|
||||||
tool_call=tool_call,
|
tool_call=tool_call,
|
||||||
workspace_violation_counts=workspace_violation_counts,
|
workspace_violation_counts=workspace_violation_counts,
|
||||||
@ -849,8 +843,21 @@ class AgentRunner:
|
|||||||
detail = detail[:120] + "..."
|
detail = detail[:120] + "..."
|
||||||
return result, {"name": tool_call.name, "status": "ok", "detail": detail}, None
|
return result, {"name": tool_call.name, "status": "ok", "detail": detail}, None
|
||||||
|
|
||||||
# SSRF remains fatal; workspace path boundaries are soft + throttled.
|
# SSRF is a hard security block at the tool boundary, but the agent turn
|
||||||
_SSRF_MARKER: str = "internal/private url detected"
|
# should recover conversationally instead of aborting the runtime.
|
||||||
|
_SSRF_MARKERS: tuple[str, ...] = (
|
||||||
|
"internal/private url detected",
|
||||||
|
"private/internal address",
|
||||||
|
"private address",
|
||||||
|
)
|
||||||
|
_SSRF_BOUNDARY_NOTE: str = (
|
||||||
|
"This is a non-bypassable security boundary. Stop trying to access "
|
||||||
|
"private/internal URLs. Do not retry with curl, wget, encoded IPs, "
|
||||||
|
"alternate DNS, redirects, proxies, or another tool. Ask the user for "
|
||||||
|
"local files, logs, screenshots, or an explicit safe public URL instead. "
|
||||||
|
"If the user explicitly trusts this private URL, ask them to whitelist "
|
||||||
|
"the exact IP/CIDR via tools.ssrfWhitelist."
|
||||||
|
)
|
||||||
|
|
||||||
# Non-SSRF boundary markers returned to the LLM as recoverable tool errors.
|
# Non-SSRF boundary markers returned to the LLM as recoverable tool errors.
|
||||||
_WORKSPACE_VIOLATION_MARKERS: tuple[str, ...] = (
|
_WORKSPACE_VIOLATION_MARKERS: tuple[str, ...] = (
|
||||||
@ -864,7 +871,10 @@ class AgentRunner:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _is_ssrf_violation(cls, text: str) -> bool:
|
def _is_ssrf_violation(cls, text: str) -> bool:
|
||||||
return bool(text) and cls._SSRF_MARKER in text.lower()
|
if not text:
|
||||||
|
return False
|
||||||
|
lowered = text.lower()
|
||||||
|
return any(marker in lowered for marker in cls._SSRF_MARKERS)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _is_workspace_violation(cls, text: str) -> bool:
|
def _is_workspace_violation(cls, text: str) -> bool:
|
||||||
@ -872,7 +882,7 @@ class AgentRunner:
|
|||||||
if not text:
|
if not text:
|
||||||
return False
|
return False
|
||||||
lowered = text.lower()
|
lowered = text.lower()
|
||||||
if cls._SSRF_MARKER in lowered:
|
if cls._is_ssrf_violation(lowered):
|
||||||
return True
|
return True
|
||||||
return any(marker in lowered for marker in cls._WORKSPACE_VIOLATION_MARKERS)
|
return any(marker in lowered for marker in cls._WORKSPACE_VIOLATION_MARKERS)
|
||||||
|
|
||||||
@ -881,8 +891,6 @@ class AgentRunner:
|
|||||||
*,
|
*,
|
||||||
raw_text: str,
|
raw_text: str,
|
||||||
soft_payload: str,
|
soft_payload: str,
|
||||||
ssrf_payload: str,
|
|
||||||
ssrf_error: BaseException,
|
|
||||||
event: dict[str, str],
|
event: dict[str, str],
|
||||||
tool_call: ToolCallRequest,
|
tool_call: ToolCallRequest,
|
||||||
workspace_violation_counts: dict[str, int],
|
workspace_violation_counts: dict[str, int],
|
||||||
@ -890,12 +898,12 @@ class AgentRunner:
|
|||||||
"""Classify safety-boundary failures, or return ``None`` to pass through."""
|
"""Classify safety-boundary failures, or return ``None`` to pass through."""
|
||||||
if self._is_ssrf_violation(raw_text):
|
if self._is_ssrf_violation(raw_text):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Tool {} blocked by SSRF guard; aborting turn: {}",
|
"Tool {} blocked by SSRF guard; returning non-retryable tool error: {}",
|
||||||
tool_call.name,
|
tool_call.name,
|
||||||
raw_text.replace("\n", " ").strip()[:200],
|
raw_text.replace("\n", " ").strip()[:200],
|
||||||
)
|
)
|
||||||
event["detail"] = self._event_detail("workspace_violation: ", raw_text)
|
event["detail"] = self._event_detail("ssrf_violation: ", raw_text)
|
||||||
return ssrf_payload, event, ssrf_error
|
return self._ssrf_soft_payload(raw_text), event, None
|
||||||
|
|
||||||
if self._is_workspace_violation(raw_text):
|
if self._is_workspace_violation(raw_text):
|
||||||
escalation = repeated_workspace_violation_error(
|
escalation = repeated_workspace_violation_error(
|
||||||
@ -918,6 +926,11 @@ class AgentRunner:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _ssrf_soft_payload(cls, raw_text: str) -> str:
|
||||||
|
text = raw_text.strip() or "Error: request blocked by SSRF guard"
|
||||||
|
return f"{text}\n\n{cls._SSRF_BOUNDARY_NOTE}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _event_detail(prefix: str, text: str, limit: int = 160) -> str:
|
def _event_detail(prefix: str, text: str, limit: int = 160) -> str:
|
||||||
return (prefix + text.replace("\n", " ").strip())[:limit]
|
return (prefix + text.replace("\n", " ").strip())[:limit]
|
||||||
|
|||||||
@ -321,7 +321,7 @@ class ExecTool(Tool):
|
|||||||
|
|
||||||
from nanobot.security.network import contains_internal_url
|
from nanobot.security.network import contains_internal_url
|
||||||
if contains_internal_url(cmd):
|
if contains_internal_url(cmd):
|
||||||
# SSRF stays fatal in the runner, so keep this marker direct.
|
# The runner turns this marker into a non-retryable security hint.
|
||||||
return "Error: Command blocked by safety guard (internal/private URL detected)"
|
return "Error: Command blocked by safety guard (internal/private URL detected)"
|
||||||
|
|
||||||
if self.restrict_to_workspace:
|
if self.restrict_to_workspace:
|
||||||
|
|||||||
@ -364,17 +364,15 @@ async def test_runner_does_not_abort_on_workspace_violation_anymore():
|
|||||||
assert "workspace_violation" in result.tool_events[0]["detail"]
|
assert "workspace_violation" in result.tool_events[0]["detail"]
|
||||||
|
|
||||||
|
|
||||||
def test_is_ssrf_violation_remains_fatal():
|
def test_is_ssrf_violation_recognizes_private_url_blocks():
|
||||||
"""SSRF rejections are the only marker that stays turn-fatal.
|
"""SSRF rejections are classified separately from workspace boundaries."""
|
||||||
|
|
||||||
A single successful internal-URL fetch can leak cloud metadata, so we
|
|
||||||
never let the LLM "retry" with a different URL phrasing -- contrast
|
|
||||||
this with workspace-bound rejections which are soft + throttled in v2.
|
|
||||||
"""
|
|
||||||
from nanobot.agent.runner import AgentRunner
|
from nanobot.agent.runner import AgentRunner
|
||||||
|
|
||||||
ssrf_msg = "Error: Command blocked by safety guard (internal/private URL detected)"
|
ssrf_msg = "Error: Command blocked by safety guard (internal/private URL detected)"
|
||||||
assert AgentRunner._is_ssrf_violation(ssrf_msg) is True
|
assert AgentRunner._is_ssrf_violation(ssrf_msg) is True
|
||||||
|
assert AgentRunner._is_ssrf_violation(
|
||||||
|
"URL validation failed: Blocked: host resolves to private/internal address 192.168.1.2"
|
||||||
|
) is True
|
||||||
|
|
||||||
# Workspace-bound markers are NOT classified as SSRF.
|
# Workspace-bound markers are NOT classified as SSRF.
|
||||||
assert AgentRunner._is_ssrf_violation(
|
assert AgentRunner._is_ssrf_violation(
|
||||||
@ -390,8 +388,8 @@ def test_is_ssrf_violation_remains_fatal():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_aborts_on_ssrf_violation():
|
async def test_runner_returns_non_retryable_hint_on_ssrf_violation():
|
||||||
"""SSRF still fatal-aborts the turn even though workspace ones are soft."""
|
"""SSRF stays blocked, but the runtime gives the LLM a final chance to recover."""
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
@ -404,7 +402,10 @@ async def test_runner_aborts_on_ssrf_violation():
|
|||||||
arguments={"command": "curl http://169.254.169.254"},
|
arguments={"command": "curl http://169.254.169.254"},
|
||||||
)],
|
)],
|
||||||
),
|
),
|
||||||
LLMResponse(content="should NOT be reached", tool_calls=[]),
|
LLMResponse(
|
||||||
|
content="I cannot access that private URL. Please share local files.",
|
||||||
|
tool_calls=[],
|
||||||
|
),
|
||||||
])
|
])
|
||||||
tools = MagicMock()
|
tools = MagicMock()
|
||||||
tools.get_definitions.return_value = []
|
tools.get_definitions.return_value = []
|
||||||
@ -421,9 +422,16 @@ async def test_runner_aborts_on_ssrf_violation():
|
|||||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
))
|
))
|
||||||
|
|
||||||
assert provider.chat_with_retry.await_count == 1, "SSRF must abort immediately"
|
assert provider.chat_with_retry.await_count == 2
|
||||||
assert result.stop_reason == "tool_error"
|
assert result.stop_reason == "completed"
|
||||||
assert "internal/private url detected" in (result.error or "").lower()
|
assert result.error is None
|
||||||
|
assert result.final_content == "I cannot access that private URL. Please share local files."
|
||||||
|
assert result.tool_events and result.tool_events[0]["detail"].startswith("ssrf_violation:")
|
||||||
|
tool_messages = [m for m in result.messages if m.get("role") == "tool"]
|
||||||
|
assert tool_messages
|
||||||
|
assert "non-bypassable security boundary" in tool_messages[0]["content"]
|
||||||
|
assert "Do not retry" in tool_messages[0]["content"]
|
||||||
|
assert "tools.ssrfWhitelist" in tool_messages[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -1290,7 +1298,7 @@ async def test_streamed_flag_not_set_on_llm_error(tmp_path):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_streamed_flag_not_set_on_tool_error(tmp_path):
|
async def test_ssrf_soft_block_can_finalize_after_streamed_tool_call(tmp_path):
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@ -1307,7 +1315,14 @@ async def test_streamed_flag_not_set_on_tool_error(tmp_path):
|
|||||||
)],
|
)],
|
||||||
usage={},
|
usage={},
|
||||||
)
|
)
|
||||||
provider.chat_stream_with_retry = AsyncMock(return_value=tool_call_resp)
|
provider.chat_stream_with_retry = AsyncMock(side_effect=[
|
||||||
|
tool_call_resp,
|
||||||
|
LLMResponse(
|
||||||
|
content="I cannot access private URLs. Please share the local file.",
|
||||||
|
tool_calls=[],
|
||||||
|
usage={},
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
@ -1323,9 +1338,8 @@ async def test_streamed_flag_not_set_on_tool_error(tmp_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "internal/private URL detected" in result.content
|
assert result.content == "I cannot access private URLs. Please share the local file."
|
||||||
assert not result.metadata.get("_streamed"), \
|
assert result.metadata.get("_streamed") is True
|
||||||
"_streamed must not be set when stop_reason is tool_error"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user