Improve tool call validation strictness (#4190)

* Improve tool call validation strictness

Reject near-miss tool names without executing suggested tools. Require object-shaped tool parameters while preserving only lossless JSON wire-shape normalization.

* Tighten tool call argument validation

* Simplify tool argument validation tests

* Improve tool name suggestions

* Simplify tool suggestion helpers

* Limit tool suggestions to canonical matches

* Allow repair only for tool history replay

* Clarify non-object tool argument errors

* Inline replay tool argument normalization

* Track only successful tool executions

* Reject JSON null tool arguments
This commit is contained in:
chengyongru 2026-06-09 14:50:40 +08:00 committed by GitHub
parent f3eb2aa08b
commit 0a396aa6e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 769 additions and 142 deletions

View File

@ -399,7 +399,6 @@ class AgentRunner:
thinking_blocks=response.thinking_blocks,
)
messages.append(assistant_message)
tools_used.extend(tc.name for tc in response.tool_calls)
await self._emit_checkpoint(
spec,
{
@ -421,6 +420,11 @@ class AgentRunner:
workspace_violation_counts,
)
tool_events.extend(new_events)
tools_used.extend(
tool_call.name
for tool_call, event in zip(response.tool_calls, new_events)
if event.get("status") == "ok"
)
context.tool_results = list(results)
context.tool_events = list(new_events)
completed_tool_results: list[dict[str, Any]] = []

View File

@ -1,5 +1,6 @@
"""Tool registry for dynamic tool management."""
import json
from typing import Any
from nanobot.agent.tools.base import Tool
@ -30,6 +31,24 @@ class ToolRegistry:
"""Get a tool by name."""
return self._tools.get(name)
@staticmethod
def _lookup_key(name: str) -> str:
"""Normalize names for suggestions only; never for execution."""
return "".join(ch.lower() for ch in name if ch.isalnum())
def _suggest_name(self, name: str) -> str | None:
key = self._lookup_key(str(name or ""))
if not key:
return None
matches = [
registered
for registered in self._tools
if self._lookup_key(registered) == key
]
if len(matches) == 1:
return matches[0]
return None
def has(self, name: str) -> bool:
"""Check if a tool is registered."""
return name in self._tools
@ -73,20 +92,23 @@ class ToolRegistry:
def prepare_call(
self,
name: str,
params: dict[str, Any],
) -> tuple[Tool | None, dict[str, Any], str | None]:
params: Any,
) -> tuple[Tool | None, Any, str | None]:
"""Resolve, cast, and validate one tool call."""
# Guard against invalid parameter types (e.g., list instead of dict)
if not isinstance(params, dict) and name in ('write_file', 'read_file'):
return None, params, (
f"Error: Tool '{name}' parameters must be a JSON object, got {type(params).__name__}. "
"Use named parameters: tool_name(param1=\"value1\", param2=\"value2\")"
)
tool = self._tools.get(name)
if not tool:
suggestion = self._suggest_name(str(name))
hint = f" Did you mean '{suggestion}'? Tool names must match exactly." if suggestion else ""
return None, params, (
f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
f"Error: Tool '{name}' not found.{hint} Available: {', '.join(self.tool_names)}"
)
params = self._coerce_params(tool, params)
if not isinstance(params, dict):
return tool, params, (
f"Error: Tool '{name}' parameters must be a JSON object, got "
f"{type(params).__name__}. Use named parameters like "
'tool_name(param1="value1", param2="value2") matching the tool schema.'
)
cast_params = tool.cast_params(params)
@ -97,21 +119,56 @@ class ToolRegistry:
)
return tool, cast_params, None
async def execute(self, name: str, params: dict[str, Any]) -> Any:
@classmethod
def _coerce_argument_value(cls, value: Any) -> Any:
if value is None:
return {}
if not isinstance(value, str):
return value
stripped = value.strip()
if not stripped:
return {}
if not stripped.startswith(("{", "[")):
return value
try:
parsed = json.loads(stripped)
except Exception:
return value
return parsed
@classmethod
def _coerce_params(cls, tool: Tool, params: Any) -> Any:
params = cls._coerce_argument_value(params)
return cls._unwrap_arguments_payload(tool, params)
@classmethod
def _unwrap_arguments_payload(cls, tool: Tool, params: Any) -> Any:
if not isinstance(params, dict) or set(params) != {"arguments"}:
return params
properties = (tool.parameters or {}).get("properties", {})
if isinstance(properties, dict) and "arguments" in properties:
return params
return cls._coerce_argument_value(params.get("arguments"))
async def execute(self, name: str, params: Any) -> Any:
"""Execute a tool by name with given parameters."""
_HINT = "\n\n[Analyze the error above and try a different approach.]"
hint = "\n\n[Analyze the error above and try a different approach.]"
tool, params, error = self.prepare_call(name, params)
if error:
return error + _HINT
return error + hint
try:
assert tool is not None # guarded by prepare_call()
result = await tool.execute(**params)
if isinstance(result, str) and result.startswith("Error"):
return result + _HINT
return result + hint
return result
except Exception as e:
return f"Error executing {name}: {str(e)}" + _HINT
return f"Error executing {name}: {str(e)}" + hint
@property
def tool_names(self) -> list[str]:

View File

@ -10,9 +10,12 @@ import string
from collections.abc import Awaitable, Callable
from typing import Any
import json_repair
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.base import (
LLMProvider,
LLMResponse,
ToolCallRequest,
tool_arguments_object_for_replay,
)
_ALNUM = string.ascii_letters + string.digits
@ -207,13 +210,11 @@ class AnthropicProvider(LLMProvider):
continue
func = tc.get("function", {})
args = func.get("arguments", "{}")
if isinstance(args, str):
args = json_repair.loads(args)
blocks.append({
"type": "tool_use",
"id": tc.get("id") or _gen_tool_id(),
"name": func.get("name", ""),
"input": args,
"input": tool_arguments_object_for_replay(args),
})
return blocks or [{"type": "text", "text": ""}]
@ -509,7 +510,7 @@ class AnthropicProvider(LLMProvider):
tool_calls.append(ToolCallRequest(
id=block.id,
name=block.name,
arguments=block.input if isinstance(block.input, dict) else {},
arguments=block.input,
))
elif block.type == "thinking":
thinking_blocks.append({

View File

@ -11,6 +11,7 @@ from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from typing import Any
import json_repair
from loguru import logger
from nanobot.utils.helpers import image_placeholder_text
@ -21,19 +22,24 @@ class ToolCallRequest:
"""A tool call request from the LLM."""
id: str
name: str
arguments: dict[str, Any]
arguments: Any
extra_content: dict[str, Any] | None = None
provider_specific_fields: dict[str, Any] | None = None
function_provider_specific_fields: dict[str, Any] | None = None
def to_openai_tool_call(self) -> dict[str, Any]:
"""Serialize to an OpenAI-style tool_call payload."""
arguments = (
self.arguments
if isinstance(self.arguments, str)
else json.dumps(self.arguments, ensure_ascii=False)
)
tool_call = {
"id": self.id,
"type": "function",
"function": {
"name": self.name,
"arguments": json.dumps(self.arguments, ensure_ascii=False),
"arguments": arguments,
},
}
if self.extra_content:
@ -45,6 +51,62 @@ class ToolCallRequest:
return tool_call
def parse_tool_arguments(arguments: Any) -> Any:
"""Parse provider tool arguments without guessing executable parameters.
Valid JSON object strings become dicts. Empty strings become no-arg calls.
Malformed JSON and JSON array/scalar values are preserved so ToolRegistry
can reject them before execution.
"""
if arguments is None:
return {}
if not isinstance(arguments, str):
return arguments
stripped = arguments.strip()
if not stripped:
return {}
try:
parsed = json.loads(stripped)
except Exception:
return arguments
return arguments if parsed is None else parsed
def tool_arguments_object_for_replay(arguments: Any) -> dict[str, Any]:
"""Return object-shaped arguments for provider history replay only.
This compatibility path may repair malformed JSON because it only shapes
existing conversation history for provider protocols. Do not use it for
newly generated tool calls that are about to execute.
"""
if arguments is None:
return {}
if isinstance(arguments, dict):
return arguments
if not isinstance(arguments, str):
return {}
stripped = arguments.strip()
if not stripped:
return {}
try:
parsed = json.loads(stripped)
except Exception:
try:
parsed = json_repair.loads(stripped)
except Exception:
return {}
return parsed if isinstance(parsed, dict) else {}
def tool_arguments_json_for_replay(arguments: Any) -> str:
"""Return JSON object string arguments for provider history replay only."""
return json.dumps(tool_arguments_object_for_replay(arguments), ensure_ascii=False)
@dataclass
class LLMResponse:
"""Response from an LLM provider."""

View File

@ -10,9 +10,13 @@ import re
from collections.abc import Awaitable, Callable, Iterator
from typing import Any
import json_repair
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.base import (
LLMProvider,
LLMResponse,
ToolCallRequest,
parse_tool_arguments,
tool_arguments_object_for_replay,
)
_IMAGE_DATA_URL = re.compile(r"^data:image/([a-zA-Z0-9.+-]+);base64,(.*)$", re.DOTALL)
_TEXT_BLOCK_TYPES = {"text", "input_text", "output_text"}
@ -176,14 +180,7 @@ class BedrockProvider(LLMProvider):
function = tool_call.get("function")
if not isinstance(function, dict):
return None
args = function.get("arguments", {})
if isinstance(args, str):
try:
args = json_repair.loads(args) if args.strip() else {}
except Exception:
args = {}
if not isinstance(args, dict):
args = {}
args = tool_arguments_object_for_replay(function.get("arguments", {}))
return {
"toolUse": {
"toolUseId": str(tool_call.get("id") or ""),
@ -491,7 +488,7 @@ class BedrockProvider(LLMProvider):
content_parts.append(block["text"])
tool_use = block.get("toolUse")
if isinstance(tool_use, dict):
arguments = tool_use.get("input") if isinstance(tool_use.get("input"), dict) else {}
arguments = tool_use.get("input", {})
tool_calls.append(ToolCallRequest(
id=str(tool_use.get("toolUseId") or ""),
name=str(tool_use.get("name") or ""),
@ -616,14 +613,11 @@ class BedrockProvider(LLMProvider):
for buf in tool_buffers.values():
args: Any = {}
if buf.get("input"):
try:
args = json_repair.loads(buf["input"])
except Exception:
args = {}
args = parse_tool_arguments(buf["input"])
tool_calls.append(ToolCallRequest(
id=buf.get("id") or "",
name=buf.get("name") or "",
arguments=args if isinstance(args, dict) else {},
arguments=args,
))
return LLMResponse(
content="".join(content_parts) or None,

View File

@ -17,10 +17,15 @@ from ipaddress import ip_address
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
import json_repair
from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.base import (
LLMProvider,
LLMResponse,
ToolCallRequest,
parse_tool_arguments,
tool_arguments_json_for_replay,
)
from nanobot.providers.openai_responses import (
consume_sdk_stream,
convert_messages,
@ -478,24 +483,6 @@ class OpenAICompatProvider(LLMProvider):
"""Return True for providers that reject normal OpenAI tool call IDs."""
return bool(self._spec and self._spec.name == "mistral")
@staticmethod
def _normalize_tool_call_arguments(arguments: Any) -> str:
"""Force function.arguments into a valid JSON object string."""
if isinstance(arguments, str):
stripped = arguments.strip()
if not stripped:
return "{}"
try:
parsed = json_repair.loads(stripped)
except Exception:
return "{}"
if isinstance(parsed, dict):
return json.dumps(parsed, ensure_ascii=False)
return "{}"
if isinstance(arguments, dict):
return json.dumps(arguments, ensure_ascii=False)
return "{}"
@staticmethod
def _coerce_content_to_string(content: Any) -> str | None:
"""Coerce block/list content into plain text for strict string-only APIs."""
@ -572,7 +559,7 @@ class OpenAICompatProvider(LLMProvider):
if isinstance(function, dict):
function_clean = dict(function)
if "arguments" in function_clean:
function_clean["arguments"] = self._normalize_tool_call_arguments(
function_clean["arguments"] = tool_arguments_json_for_replay(
function_clean.get("arguments")
)
else:
@ -1021,14 +1008,12 @@ class OpenAICompatProvider(LLMProvider):
for tc in raw_tool_calls:
tc_map = self._maybe_mapping(tc) or {}
fn = self._maybe_mapping(tc_map.get("function")) or {}
args = fn.get("arguments", {})
if isinstance(args, str):
args = json_repair.loads(args)
args = parse_tool_arguments(fn.get("arguments", {}))
ec, prov, fn_prov = _extract_tc_extras(tc)
parsed_tool_calls.append(ToolCallRequest(
id=str(tc_map.get("id") or _short_tool_id()),
name=str(fn.get("name") or ""),
arguments=args if isinstance(args, dict) else {},
arguments=args,
extra_content=ec,
provider_specific_fields=prov,
function_provider_specific_fields=fn_prov,
@ -1064,9 +1049,7 @@ class OpenAICompatProvider(LLMProvider):
tool_calls = []
for tc in raw_tool_calls:
args = tc.function.arguments
if isinstance(args, str):
args = json_repair.loads(args)
args = parse_tool_arguments(tc.function.arguments)
ec, prov, fn_prov = _extract_tc_extras(tc)
tool_calls.append(ToolCallRequest(
id=str(getattr(tc, "id", None) or _short_tool_id()),
@ -1207,7 +1190,7 @@ class OpenAICompatProvider(LLMProvider):
ToolCallRequest(
id=b["id"] or _short_tool_id(),
name=b["name"],
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
arguments=parse_tool_arguments(b["arguments"]),
extra_content=b.get("extra_content"),
provider_specific_fields=b.get("prov"),
function_provider_specific_fields=b.get("fn_prov"),

View File

@ -5,6 +5,8 @@ from __future__ import annotations
import json
from typing import Any
from nanobot.providers.base import tool_arguments_json_for_replay
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
"""Convert Chat Completions messages to Responses API input items.
@ -46,7 +48,7 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str
"id": response_item_id,
"call_id": call_id or f"call_{idx}",
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
"arguments": tool_arguments_json_for_replay(fn.get("arguments")),
})
continue

View File

@ -7,10 +7,9 @@ from collections.abc import Awaitable, Callable
from typing import Any, AsyncGenerator
import httpx
import json_repair
from loguru import logger
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.providers.base import LLMResponse, ToolCallRequest, parse_tool_arguments
FINISH_REASON_MAP = {
"completed": "stop",
@ -44,6 +43,27 @@ def _usage_from_response_obj(response: Any) -> dict[str, int]:
}
def _parse_tool_call_arguments(args_raw: Any, name: str | None) -> Any:
parsed = parse_tool_arguments(args_raw)
if parsed == args_raw and isinstance(args_raw, str) and args_raw.strip():
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
name,
args_raw[:200],
)
return parsed
def _tool_arguments_source(*values: Any) -> Any:
for value in values:
if value is None:
continue
if isinstance(value, str) and not value.strip():
continue
return value
return "{}"
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
"""Yield parsed JSON events from a Responses API SSE stream."""
buffer: list[str] = []
@ -116,10 +136,11 @@ async def consume_sse_with_reasoning(
call_id = item.get("call_id")
if not call_id:
continue
arguments = item.get("arguments")
tool_call_buffers[call_id] = {
"id": item.get("id") or "fc_0",
"name": item.get("name"),
"arguments": item.get("arguments") or "",
"arguments": "" if arguments is None else arguments,
}
if on_tool_call_delta:
await on_tool_call_delta({
@ -156,7 +177,10 @@ async def consume_sse_with_reasoning(
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
delta = event.get("delta") or ""
tool_call_buffers[call_id]["arguments"] += delta
current = tool_call_buffers[call_id].get("arguments")
if not isinstance(current, str):
current = ""
tool_call_buffers[call_id]["arguments"] = current + delta
if on_tool_call_delta and delta:
await on_tool_call_delta({
"call_id": str(call_id),
@ -166,14 +190,14 @@ async def consume_sse_with_reasoning(
elif event_type == "response.function_call_arguments.done":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
arguments = event.get("arguments") or ""
arguments = event.get("arguments")
tool_call_buffers[call_id]["arguments"] = arguments
if on_tool_call_delta:
tool_call_args_emitted.add(str(call_id))
await on_tool_call_delta({
"call_id": str(call_id),
"name": str(tool_call_buffers[call_id].get("name") or ""),
"arguments": str(arguments),
"arguments": "" if arguments is None else str(arguments),
})
elif event_type == "response.output_item.done":
item = event.get("item") or {}
@ -182,7 +206,7 @@ async def consume_sse_with_reasoning(
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
args_raw = _tool_arguments_source(buf.get("arguments"), item.get("arguments"))
if on_tool_call_delta and str(call_id) not in tool_call_args_emitted:
tool_call_args_emitted.add(str(call_id))
await on_tool_call_delta({
@ -190,17 +214,10 @@ async def consume_sse_with_reasoning(
"name": str(buf.get("name") or item.get("name") or ""),
"arguments": str(args_raw),
})
try:
args = json.loads(args_raw)
except Exception:
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
buf.get("name") or item.get("name"),
args_raw[:200],
)
args = json_repair.loads(args_raw)
if not isinstance(args, dict):
args = {"raw": args_raw}
args = _parse_tool_call_arguments(
args_raw,
buf.get("name") or item.get("name"),
)
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
@ -283,22 +300,12 @@ def parse_response_output(response: Any) -> LLMResponse:
elif item_type == "function_call":
call_id = item.get("call_id") or ""
item_id = item.get("id") or "fc_0"
args_raw = item.get("arguments") or "{}"
try:
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
except Exception:
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
item.get("name"),
str(args_raw)[:200],
)
args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw
if not isinstance(args, dict):
args = {"raw": args_raw}
args_raw = _tool_arguments_source(item.get("arguments"))
args = _parse_tool_call_arguments(args_raw, item.get("name"))
tool_calls.append(ToolCallRequest(
id=f"{call_id}|{item_id}",
name=item.get("name") or "",
arguments=args if isinstance(args, dict) else {},
arguments=args,
))
usage = _usage_from_response_obj(response)
@ -337,10 +344,11 @@ async def consume_sdk_stream(
call_id = getattr(item, "call_id", None)
if not call_id:
continue
arguments = getattr(item, "arguments", None)
tool_call_buffers[call_id] = {
"id": getattr(item, "id", None) or "fc_0",
"name": getattr(item, "name", None),
"arguments": getattr(item, "arguments", None) or "",
"arguments": "" if arguments is None else arguments,
}
if on_tool_call_delta:
await on_tool_call_delta({
@ -357,7 +365,10 @@ async def consume_sdk_stream(
call_id = getattr(event, "call_id", None)
if call_id and call_id in tool_call_buffers:
delta = getattr(event, "delta", "") or ""
tool_call_buffers[call_id]["arguments"] += delta
current = tool_call_buffers[call_id].get("arguments")
if not isinstance(current, str):
current = ""
tool_call_buffers[call_id]["arguments"] = current + delta
if on_tool_call_delta and delta:
await on_tool_call_delta({
"call_id": str(call_id),
@ -367,14 +378,14 @@ async def consume_sdk_stream(
elif event_type == "response.function_call_arguments.done":
call_id = getattr(event, "call_id", None)
if call_id and call_id in tool_call_buffers:
arguments = getattr(event, "arguments", "") or ""
arguments = getattr(event, "arguments", None)
tool_call_buffers[call_id]["arguments"] = arguments
if on_tool_call_delta:
tool_call_args_emitted.add(str(call_id))
await on_tool_call_delta({
"call_id": str(call_id),
"name": str(tool_call_buffers[call_id].get("name") or ""),
"arguments": str(arguments),
"arguments": "" if arguments is None else str(arguments),
})
elif event_type == "response.output_item.done":
item = getattr(event, "item", None)
@ -383,7 +394,10 @@ async def consume_sdk_stream(
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
args_raw = _tool_arguments_source(
buf.get("arguments"),
getattr(item, "arguments", None),
)
if on_tool_call_delta and str(call_id) not in tool_call_args_emitted:
tool_call_args_emitted.add(str(call_id))
await on_tool_call_delta({
@ -391,17 +405,10 @@ async def consume_sdk_stream(
"name": str(buf.get("name") or getattr(item, "name", None) or ""),
"arguments": str(args_raw),
})
try:
args = json.loads(args_raw)
except Exception:
logger.warning(
"Failed to parse tool call arguments for '{}': {}",
buf.get("name") or getattr(item, "name", None),
str(args_raw)[:200],
)
args = json_repair.loads(args_raw)
if not isinstance(args, dict):
args = {"raw": args_raw}
args = _parse_tool_call_arguments(
args_raw,
buf.get("name") or getattr(item, "name", None),
)
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",

View File

@ -49,13 +49,18 @@ async def invoke_file_edit_progress(
await on_progress("", file_edit_events=file_edit_events)
def _tool_event_arguments(tool_call: Any) -> dict[str, Any]:
arguments = getattr(tool_call, "arguments", {}) or {}
return arguments if isinstance(arguments, dict) else {}
def build_tool_event_start_payload(tool_call: Any) -> dict[str, Any]:
return {
"version": 1,
"phase": "start",
"call_id": str(getattr(tool_call, "id", "") or ""),
"name": getattr(tool_call, "name", ""),
"arguments": getattr(tool_call, "arguments", {}) or {},
"arguments": _tool_event_arguments(tool_call),
"result": None,
"error": None,
"files": [],
@ -86,7 +91,7 @@ def build_tool_event_finish_payloads(context: AgentHookContext) -> list[dict[str
"phase": phase,
"call_id": str(getattr(tool_call, "id", "") or ""),
"name": getattr(tool_call, "name", ""),
"arguments": getattr(tool_call, "arguments", {}) or {},
"arguments": _tool_event_arguments(tool_call),
"result": result if phase == "end" else None,
"error": None,
"files": files,

View File

@ -75,8 +75,10 @@ def build_goal_continue_message(custom: str | None = None) -> dict[str, str]:
return {"role": "user", "content": custom or SUSTAINED_GOAL_CONTINUE_PROMPT}
def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str | None:
def external_lookup_signature(tool_name: str, arguments: Any) -> str | None:
"""Stable signature for repeated external lookups we want to throttle."""
if not isinstance(arguments, dict):
return None
if tool_name == "web_fetch":
url = str(arguments.get("url") or "").strip()
if url:
@ -90,7 +92,7 @@ def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str
def repeated_external_lookup_error(
tool_name: str,
arguments: dict[str, Any],
arguments: Any,
seen_counts: dict[str, int],
) -> str | None:
"""Block repeated external lookups after a small retry budget."""
@ -119,9 +121,11 @@ _OUTSIDE_PATH_PATTERN = re.compile(r"(?:^|[\s|>'\"])((?:/[^\s\"'>;|<]+)|(?:~[^\s
def workspace_violation_signature(
tool_name: str,
arguments: dict[str, Any],
arguments: Any,
) -> str | None:
"""Return a stable cross-tool signature for the outside-workspace target."""
if not isinstance(arguments, dict):
return None
for key in ("path", "file_path", "target", "source", "destination"):
val = arguments.get(key)
if isinstance(val, str) and val.strip():
@ -151,7 +155,7 @@ def _normalize_violation_target(raw: str) -> str:
def repeated_workspace_violation_error(
tool_name: str,
arguments: dict[str, Any],
arguments: Any,
seen_counts: dict[str, int],
) -> str | None:
"""Return an escalated error after repeated bypass attempts."""

View File

@ -3,17 +3,21 @@
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.agent.runner import AgentRunner, AgentRunSpec
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.config.schema import AgentDefaults
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
from nanobot.providers.openai_responses.parsing import parse_response_output
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
class _DelayTool(Tool):
def __init__(
self,
@ -57,10 +61,45 @@ class _DelayTool(Tool):
return self._name
async def _run_optional_tool_response(response: LLMResponse):
provider = MagicMock()
calls = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
calls["n"] += 1
if calls["n"] == 1:
return response
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = ToolRegistry()
shared_events: list[str] = []
tools.register(_DelayTool(
"optional_tool",
delay=0,
read_only=True,
shared_events=shared_events,
))
result = await AgentRunner(provider).run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "try optional"}],
tools=tools,
model="test-model",
max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
return result, shared_events
def _tool_message(result, tool_call_id: str) -> dict:
return [
msg for msg in result.messages
if msg.get("role") == "tool" and msg.get("tool_call_id") == tool_call_id
][0]
@pytest.mark.asyncio
async def test_runner_batches_read_only_tools_before_exclusive_work():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
tools = ToolRegistry()
shared_events: list[str] = []
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
@ -98,8 +137,6 @@ async def test_runner_batches_read_only_tools_before_exclusive_work():
@pytest.mark.asyncio
async def test_runner_does_not_batch_exclusive_read_only_tools():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
tools = ToolRegistry()
shared_events: list[str] = []
read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events)
@ -140,9 +177,151 @@ async def test_runner_does_not_batch_exclusive_read_only_tools():
@pytest.mark.asyncio
async def test_runner_blocks_repeated_external_fetches():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
async def test_runner_rejects_near_miss_tool_name_without_executing():
provider = MagicMock()
call_count = {"n": 0}
captured_second_call: list[dict] = []
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="call_1",
name="readFile",
arguments={"path": "notes.txt"},
)
],
finish_reason="tool_calls",
usage={},
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = ToolRegistry()
shared_events: list[str] = []
tools.register(_DelayTool(
"read_file",
delay=0,
read_only=True,
shared_events=shared_events,
))
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "read notes"}],
tools=tools,
model="test-model",
max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "done"
assert result.tools_used == []
assert shared_events == []
assistant_message = [
msg for msg in result.messages
if msg.get("role") == "assistant" and msg.get("tool_calls")
][0]
assert assistant_message["tool_calls"][0]["function"]["name"] == "readFile"
tool_message = [
msg for msg in result.messages
if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_1"
][0]
assert tool_message["name"] == "readFile"
assert "Tool 'readFile' not found" in tool_message["content"]
assert "Did you mean 'read_file'?" in tool_message["content"]
replayed_assistant = [
msg for msg in captured_second_call
if msg.get("role") == "assistant" and msg.get("tool_calls")
][0]
assert replayed_assistant["tool_calls"][0]["function"]["name"] == "readFile"
@pytest.mark.asyncio
@pytest.mark.parametrize("arguments", ['{path:"notes.txt"}', "null"])
async def test_runner_rejects_openai_compat_invalid_arguments_without_executing(arguments):
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
parsed = OpenAICompatProvider()._parse({
"choices": [{
"message": {
"tool_calls": [{
"id": "call_1",
"type": "function",
"function": {
"name": "optional_tool",
"arguments": arguments,
},
}],
},
"finish_reason": "tool_calls",
}],
"usage": {},
})
result, shared_events = await _run_optional_tool_response(parsed)
assert result.final_content == "done"
assert parsed.tool_calls[0].arguments == arguments
assert result.tools_used == []
assert shared_events == []
tool_message = _tool_message(result, "call_1")
assert "parameters must be a JSON object" in tool_message["content"]
@pytest.mark.asyncio
async def test_runner_rejects_openai_responses_malformed_arguments_without_executing():
parsed = parse_response_output({
"output": [{
"type": "function_call",
"call_id": "call_1",
"id": "fc_1",
"name": "optional_tool",
"arguments": "{bad",
}],
"status": "completed",
"usage": {},
})
result, shared_events = await _run_optional_tool_response(parsed)
assert result.final_content == "done"
assert parsed.tool_calls[0].arguments == "{bad"
assert result.tools_used == []
assert shared_events == []
tool_message = _tool_message(result, "call_1|fc_1")
assert "parameters must be a JSON object" in tool_message["content"]
@pytest.mark.asyncio
async def test_runner_rejects_openai_responses_array_arguments_without_executing():
parsed = parse_response_output({
"output": [{
"type": "function_call",
"call_id": "call_1",
"id": "fc_1",
"name": "optional_tool",
"arguments": [],
}],
"status": "completed",
"usage": {},
})
result, shared_events = await _run_optional_tool_response(parsed)
assert result.final_content == "done"
assert parsed.tool_calls[0].arguments == []
assert result.tools_used == []
assert shared_events == []
tool_message = _tool_message(result, "call_1|fc_1")
assert "parameters must be a JSON object" in tool_message["content"]
@pytest.mark.asyncio
async def test_runner_blocks_repeated_external_fetches():
provider = MagicMock()
captured_final_call: list[dict] = []
call_count = {"n": 0}

View File

@ -80,3 +80,17 @@ def test_convert_user_content_coerces_mixed_typeless():
])
assert result[0] == {"type": "text", "text": "42"}
assert result[1] == {"type": "text", "text": str({"key": "val"})}
def test_convert_assistant_message_repairs_history_tool_arguments():
blocks = AnthropicProvider._assistant_blocks({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "toolu_1",
"function": {"name": "read_file", "arguments": '{path:"foo.txt"}'},
}],
})
assert blocks[0]["type"] == "tool_use"
assert blocks[0]["input"] == {"path": "foo.txt"}

View File

@ -161,6 +161,16 @@ def test_build_kwargs_converts_messages_tools_and_tool_results() -> None:
assert kwargs["toolConfig"]["toolChoice"] == {"any": {}}
def test_tool_use_block_repairs_history_tool_arguments() -> None:
block = BedrockProvider._tool_use_block({
"id": "toolu_1",
"function": {"name": "read_file", "arguments": '{path:"foo.txt"}'},
})
assert block is not None
assert block["toolUse"]["input"] == {"path": "foo.txt"}
def test_build_kwargs_keeps_tool_config_for_historical_tool_blocks_without_tools() -> None:
provider = BedrockProvider(region="us-east-1", client=FakeClient())
messages = [

View File

@ -54,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace:
return SimpleNamespace(choices=[choice], usage=usage)
def _fake_tool_call_response_with_arguments(arguments) -> SimpleNamespace:
"""Build a minimal chat response with caller-supplied tool arguments."""
function = SimpleNamespace(name="optional_tool", arguments=arguments)
tool_call = SimpleNamespace(id="call_123", type="function", function=function)
message = SimpleNamespace(content=None, tool_calls=[tool_call], reasoning_content=None)
choice = SimpleNamespace(message=message, finish_reason="tool_calls")
return SimpleNamespace(choices=[choice], usage=SimpleNamespace())
def _fake_responses_response(content: str = "ok") -> MagicMock:
"""Build a minimal Responses API response object."""
resp = MagicMock()
@ -611,6 +620,24 @@ async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None:
assert serialized["function"]["provider_specific_fields"] == {"inner": "value"}
def test_openai_compat_parse_preserves_malformed_tool_arguments() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
result = provider._parse(_fake_tool_call_response_with_arguments('{path:"foo.txt"}'))
assert result.tool_calls[0].arguments == '{path:"foo.txt"}'
def test_openai_compat_parse_preserves_array_tool_arguments() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
result = provider._parse(_fake_tool_call_response_with_arguments('["foo.txt"]'))
assert result.tool_calls[0].arguments == ["foo.txt"]
def test_openai_model_passthrough() -> None:
"""OpenAI models pass through unchanged."""
spec = find_by_name("openai")
@ -1110,7 +1137,7 @@ def test_openai_compat_stringifies_dict_tool_arguments() -> None:
assert sanitized[1]["tool_calls"][0]["function"]["arguments"] == '{"cmd": "ls -la"}'
def test_openai_compat_repairs_non_json_tool_arguments_string() -> None:
def test_openai_compat_repairs_object_like_history_tool_arguments_string() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()

View File

@ -155,6 +155,19 @@ class TestConvertMessages:
assert items[0]["call_id"] == "call_abc"
assert items[0]["id"] == "fc_1"
assert items[0]["name"] == "get_weather"
assert items[0]["arguments"] == '{"city": "SF"}'
def test_assistant_tool_call_history_repairs_malformed_arguments(self):
_, items = convert_messages([{
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_abc|fc_1",
"function": {"name": "read_file", "arguments": '{path:"foo.txt"}'},
}],
}])
assert json.loads(items[0]["arguments"]) == {"path": "foo.txt"}
def test_duplicate_response_item_ids_are_made_unique(self):
"""Codex rejects replayed Responses input items with duplicate ids."""
@ -367,7 +380,7 @@ class TestParseResponseOutput:
assert result.tool_calls[0].id == "call_1|fc_1"
def test_malformed_tool_arguments_logged(self):
"""Malformed JSON arguments should log a warning and fallback."""
"""Malformed JSON arguments should log a warning and remain non-object."""
resp = {
"output": [{
"type": "function_call",
@ -378,10 +391,29 @@ class TestParseResponseOutput:
}
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
result = parse_response_output(resp)
assert result.tool_calls[0].arguments == {"raw": "{bad json"}
assert result.tool_calls[0].arguments == "{bad json"
mock_logger.warning.assert_called_once()
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
@pytest.mark.parametrize("arguments", [[], False, 0])
def test_falsy_non_object_tool_arguments_preserved(self, arguments):
resp = {
"output": [{
"type": "function_call",
"call_id": "c1",
"id": "fc1",
"name": "f",
"arguments": arguments,
}],
"status": "completed",
"usage": {},
}
result = parse_response_output(resp)
assert result.tool_calls[0].arguments == arguments
assert type(result.tool_calls[0].arguments) is type(arguments)
def test_reasoning_content_extracted(self):
resp = {
"output": [
@ -611,6 +643,38 @@ class TestConsumeSse:
},
]
@pytest.mark.asyncio
@pytest.mark.parametrize("arguments", [[], False, 0])
async def test_falsy_non_object_tool_arguments_preserved(self, arguments):
response = _SseResponse([
{
"type": "response.output_item.added",
"item": {
"type": "function_call",
"call_id": "c1",
"id": "fc1",
"name": "f",
"arguments": "",
},
},
{
"type": "response.output_item.done",
"item": {
"type": "function_call",
"call_id": "c1",
"id": "fc1",
"name": "f",
"arguments": arguments,
},
},
{"type": "response.completed", "response": {"status": "completed"}},
])
_, tool_calls, _, _, _ = await consume_sse_with_reasoning(response)
assert tool_calls[0].arguments == arguments
assert type(tool_calls[0].arguments) is type(arguments)
# ======================================================================
# parsing - consume_sdk_stream
@ -764,6 +828,28 @@ class TestConsumeSdkStream:
},
]
@pytest.mark.asyncio
@pytest.mark.parametrize("arguments", [[], False, 0])
async def test_falsy_non_object_tool_arguments_preserved(self, arguments):
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
item_added.name = "f"
ev1 = MagicMock(type="response.output_item.added", item=item_added)
item_done = MagicMock(type="function_call", call_id="c1", id="fc1")
item_done.name = "f"
item_done.arguments = arguments
ev2 = MagicMock(type="response.output_item.done", item=item_done)
resp_obj = MagicMock(status="completed", usage=None, output=[])
ev3 = MagicMock(type="response.completed", response=resp_obj)
async def stream():
for e in [ev1, ev2, ev3]:
yield e
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
assert tool_calls[0].arguments == arguments
assert type(tool_calls[0].arguments) is type(arguments)
@pytest.mark.asyncio
async def test_usage_extracted(self):
usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15)
@ -811,7 +897,7 @@ class TestConsumeSdkStream:
@pytest.mark.asyncio
async def test_malformed_tool_args_logged(self):
"""Malformed JSON in streaming tool args should log a warning."""
"""Malformed JSON in streaming tool args should log a warning and remain non-object."""
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
item_added.name = "f"
ev1 = MagicMock(type="response.output_item.added", item=item_added)
@ -828,6 +914,6 @@ class TestConsumeSdkStream:
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
assert tool_calls[0].arguments == {"raw": "{bad"}
assert tool_calls[0].arguments == "{bad"
mock_logger.warning.assert_called_once()
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)

View File

@ -0,0 +1,30 @@
"""Shared tool-argument parsing policy tests."""
from nanobot.providers.base import (
parse_tool_arguments,
tool_arguments_json_for_replay,
tool_arguments_object_for_replay,
)
def test_parse_tool_arguments_preserves_malformed_executable_arguments() -> None:
assert parse_tool_arguments('{path:"foo.txt"}') == '{path:"foo.txt"}'
def test_parse_tool_arguments_preserves_non_object_executable_arguments() -> None:
assert parse_tool_arguments('["foo.txt"]') == ["foo.txt"]
assert parse_tool_arguments("false") is False
assert parse_tool_arguments("null") == "null"
def test_tool_arguments_object_for_replay_repairs_object_like_history_arguments() -> None:
assert tool_arguments_object_for_replay('{path:"foo.txt"}') == {"path": "foo.txt"}
def test_tool_arguments_object_for_replay_keeps_history_object_shaped() -> None:
for arguments in ['["foo.txt"]', "false", "null", "0", ["foo.txt"], False, None, 0]:
assert tool_arguments_object_for_replay(arguments) == {}
def test_tool_arguments_json_for_replay_returns_object_string() -> None:
assert tool_arguments_json_for_replay('{path:"foo.txt"}') == '{"path": "foo.txt"}'

View File

@ -7,8 +7,9 @@ from nanobot.agent.tools.registry import ToolRegistry
class _FakeTool(Tool):
def __init__(self, name: str):
def __init__(self, name: str, schema: dict[str, Any] | None = None):
self._name = name
self._schema = schema
@property
def name(self) -> str:
@ -20,7 +21,7 @@ class _FakeTool(Tool):
@property
def parameters(self) -> dict[str, Any]:
return {"type": "object", "properties": {}}
return self._schema or {"type": "object", "properties": {}}
async def execute(self, **kwargs: Any) -> Any:
return kwargs
@ -34,6 +35,13 @@ def _tool_names(definitions: list[dict[str, Any]]) -> list[str]:
return names
def _registry_with_names(names: list[str]) -> ToolRegistry:
registry = ToolRegistry()
for name in names:
registry.register(_FakeTool(name))
return registry
def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
registry = ToolRegistry()
registry.register(_FakeTool("mcp_git_status"))
@ -49,17 +57,167 @@ def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
]
def test_prepare_call_rejects_near_miss_tool_name_with_suggestion() -> None:
registry = ToolRegistry()
registry.register(_FakeTool("read_file"))
tool, params, error = registry.prepare_call("readFile", {"path": "foo.txt"})
assert tool is None
assert params == {"path": "foo.txt"}
assert error is not None
assert "Tool 'readFile' not found" in error
assert "Did you mean 'read_file'?" in error
assert "must match exactly" in error
def test_suggest_name_handles_canonical_tool_name_variants() -> None:
registry = _registry_with_names(["read_file"])
expected = {
"readFile": "read_file",
"read-file": "read_file",
"READ_FILE": "read_file",
"read file": "read_file",
"readfile": "read_file",
}
assert {name: registry._suggest_name(name) for name in expected} == expected
def test_suggest_name_suppresses_low_confidence_and_non_unique_matches() -> None:
registry = _registry_with_names(["read_file", "write_file"])
for name in ["", "foo", "read", "file", "readfil", "read_file_tool"]:
assert registry._suggest_name(name) is None
ambiguous = _registry_with_names(["read_file", "readFile"])
assert ambiguous._suggest_name("readfile") is None
def test_suggest_name_updates_after_register_and_unregister() -> None:
registry = _registry_with_names(["read_file"])
assert registry._suggest_name("readFile") == "read_file"
registry.register(_FakeTool("readFile"))
assert registry._suggest_name("read-file") is None
registry.unregister("read_file")
assert registry._suggest_name("read-file") == "readFile"
def test_prepare_call_read_file_rejects_non_object_params_with_actionable_hint() -> None:
registry = ToolRegistry()
registry.register(_FakeTool("read_file"))
tool, params, error = registry.prepare_call("read_file", ["foo.txt"])
assert tool is None
assert tool is not None
assert params == ["foo.txt"]
assert error is not None
assert "must be a JSON object" in error
assert "Use named parameters" in error
assert 'tool_name(param1="value1", param2="value2")' in error
assert "matching the tool schema" in error
def test_prepare_call_parses_json_string_arguments() -> None:
registry = ToolRegistry()
registry.register(_FakeTool("read_file"))
tool, params, error = registry.prepare_call("read_file", '{"path":"foo.txt"}')
assert tool is not None
assert params == {"path": "foo.txt"}
assert error is None
def test_prepare_call_rejects_malformed_json_string_arguments() -> None:
registry = ToolRegistry()
registry.register(_FakeTool("read_file"))
tool, params, error = registry.prepare_call("read_file", '{path:"foo.txt"}')
assert tool is not None
assert params == '{path:"foo.txt"}'
assert error is not None
assert "parameters must be a JSON object" in error
def test_prepare_call_rejects_scalar_for_single_required_parameter() -> None:
registry = ToolRegistry()
registry.register(_FakeTool(
"web_fetch",
{
"type": "object",
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
))
tool, params, error = registry.prepare_call("web_fetch", "https://example.com")
assert tool is not None
assert params == "https://example.com"
assert error is not None
assert "parameters must be a JSON object" in error
def test_prepare_call_rejects_unquoted_scalar_strings_before_schema_cast() -> None:
registry = ToolRegistry()
registry.register(_FakeTool(
"message",
{
"type": "object",
"properties": {"content": {"type": "string"}},
"required": ["content"],
},
))
tool, params, error = registry.prepare_call("message", "true")
assert tool is not None
assert params == "true"
assert error is not None
assert "parameters must be a JSON object" in error
def test_prepare_call_unwraps_arguments_payload() -> None:
registry = ToolRegistry()
registry.register(_FakeTool(
"read_file",
{
"type": "object",
"properties": {"path": {"type": "string"}},
"required": ["path"],
},
))
tool, params, error = registry.prepare_call(
"read_file",
{"arguments": '{"path":"foo.txt"}'},
)
assert tool is not None
assert params == {"path": "foo.txt"}
assert error is None
def test_prepare_call_treats_none_arguments_as_empty_object() -> None:
registry = ToolRegistry()
registry.register(_FakeTool("list_exec_sessions"))
tool, params, error = registry.prepare_call("list_exec_sessions", None)
assert tool is not None
assert params == {}
assert error is None
tool, params, error = registry.prepare_call("list_exec_sessions", "null")
assert tool is not None
assert params == "null"
assert error is not None
assert "parameters must be a JSON object" in error
def test_prepare_call_other_tools_keep_generic_object_validation() -> None:
@ -70,7 +228,11 @@ def test_prepare_call_other_tools_keep_generic_object_validation() -> None:
assert tool is not None
assert params == ["TODO"]
assert error == "Error: Invalid parameters for tool 'grep': parameters must be an object, got list"
assert error == (
"Error: Tool 'grep' parameters must be a JSON object, got list. "
'Use named parameters like tool_name(param1="value1", param2="value2") '
"matching the tool schema."
)
def test_get_definitions_returns_cached_result() -> None: