mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
Improve tool call validation strictness (#4190)
* Improve tool call validation strictness Reject near-miss tool names without executing suggested tools. Require object-shaped tool parameters while preserving only lossless JSON wire-shape normalization. * Tighten tool call argument validation * Simplify tool argument validation tests * Improve tool name suggestions * Simplify tool suggestion helpers * Limit tool suggestions to canonical matches * Allow repair only for tool history replay * Clarify non-object tool argument errors * Inline replay tool argument normalization * Track only successful tool executions * Reject JSON null tool arguments
This commit is contained in:
parent
f3eb2aa08b
commit
0a396aa6e2
@ -399,7 +399,6 @@ class AgentRunner:
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
messages.append(assistant_message)
|
||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
@ -421,6 +420,11 @@ class AgentRunner:
|
||||
workspace_violation_counts,
|
||||
)
|
||||
tool_events.extend(new_events)
|
||||
tools_used.extend(
|
||||
tool_call.name
|
||||
for tool_call, event in zip(response.tool_calls, new_events)
|
||||
if event.get("status") == "ok"
|
||||
)
|
||||
context.tool_results = list(results)
|
||||
context.tool_events = list(new_events)
|
||||
completed_tool_results: list[dict[str, Any]] = []
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Tool registry for dynamic tool management."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
@ -30,6 +31,24 @@ class ToolRegistry:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
@staticmethod
|
||||
def _lookup_key(name: str) -> str:
|
||||
"""Normalize names for suggestions only; never for execution."""
|
||||
return "".join(ch.lower() for ch in name if ch.isalnum())
|
||||
|
||||
def _suggest_name(self, name: str) -> str | None:
|
||||
key = self._lookup_key(str(name or ""))
|
||||
if not key:
|
||||
return None
|
||||
matches = [
|
||||
registered
|
||||
for registered in self._tools
|
||||
if self._lookup_key(registered) == key
|
||||
]
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
return None
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
"""Check if a tool is registered."""
|
||||
return name in self._tools
|
||||
@ -73,20 +92,23 @@ class ToolRegistry:
|
||||
def prepare_call(
|
||||
self,
|
||||
name: str,
|
||||
params: dict[str, Any],
|
||||
) -> tuple[Tool | None, dict[str, Any], str | None]:
|
||||
params: Any,
|
||||
) -> tuple[Tool | None, Any, str | None]:
|
||||
"""Resolve, cast, and validate one tool call."""
|
||||
# Guard against invalid parameter types (e.g., list instead of dict)
|
||||
if not isinstance(params, dict) and name in ('write_file', 'read_file'):
|
||||
return None, params, (
|
||||
f"Error: Tool '{name}' parameters must be a JSON object, got {type(params).__name__}. "
|
||||
"Use named parameters: tool_name(param1=\"value1\", param2=\"value2\")"
|
||||
)
|
||||
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
suggestion = self._suggest_name(str(name))
|
||||
hint = f" Did you mean '{suggestion}'? Tool names must match exactly." if suggestion else ""
|
||||
return None, params, (
|
||||
f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||
f"Error: Tool '{name}' not found.{hint} Available: {', '.join(self.tool_names)}"
|
||||
)
|
||||
|
||||
params = self._coerce_params(tool, params)
|
||||
if not isinstance(params, dict):
|
||||
return tool, params, (
|
||||
f"Error: Tool '{name}' parameters must be a JSON object, got "
|
||||
f"{type(params).__name__}. Use named parameters like "
|
||||
'tool_name(param1="value1", param2="value2") matching the tool schema.'
|
||||
)
|
||||
|
||||
cast_params = tool.cast_params(params)
|
||||
@ -97,21 +119,56 @@ class ToolRegistry:
|
||||
)
|
||||
return tool, cast_params, None
|
||||
|
||||
async def execute(self, name: str, params: dict[str, Any]) -> Any:
|
||||
@classmethod
|
||||
def _coerce_argument_value(cls, value: Any) -> Any:
|
||||
if value is None:
|
||||
return {}
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return {}
|
||||
|
||||
if not stripped.startswith(("{", "[")):
|
||||
return value
|
||||
|
||||
try:
|
||||
parsed = json.loads(stripped)
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
return parsed
|
||||
|
||||
@classmethod
|
||||
def _coerce_params(cls, tool: Tool, params: Any) -> Any:
|
||||
params = cls._coerce_argument_value(params)
|
||||
return cls._unwrap_arguments_payload(tool, params)
|
||||
|
||||
@classmethod
|
||||
def _unwrap_arguments_payload(cls, tool: Tool, params: Any) -> Any:
|
||||
if not isinstance(params, dict) or set(params) != {"arguments"}:
|
||||
return params
|
||||
properties = (tool.parameters or {}).get("properties", {})
|
||||
if isinstance(properties, dict) and "arguments" in properties:
|
||||
return params
|
||||
return cls._coerce_argument_value(params.get("arguments"))
|
||||
|
||||
async def execute(self, name: str, params: Any) -> Any:
|
||||
"""Execute a tool by name with given parameters."""
|
||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||
hint = "\n\n[Analyze the error above and try a different approach.]"
|
||||
tool, params, error = self.prepare_call(name, params)
|
||||
if error:
|
||||
return error + _HINT
|
||||
return error + hint
|
||||
|
||||
try:
|
||||
assert tool is not None # guarded by prepare_call()
|
||||
result = await tool.execute(**params)
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
return result + _HINT
|
||||
return result + hint
|
||||
return result
|
||||
except Exception as e:
|
||||
return f"Error executing {name}: {str(e)}" + _HINT
|
||||
return f"Error executing {name}: {str(e)}" + hint
|
||||
|
||||
@property
|
||||
def tool_names(self) -> list[str]:
|
||||
|
||||
@ -10,9 +10,12 @@ import string
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.base import (
|
||||
LLMProvider,
|
||||
LLMResponse,
|
||||
ToolCallRequest,
|
||||
tool_arguments_object_for_replay,
|
||||
)
|
||||
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
@ -207,13 +210,11 @@ class AnthropicProvider(LLMProvider):
|
||||
continue
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
blocks.append({
|
||||
"type": "tool_use",
|
||||
"id": tc.get("id") or _gen_tool_id(),
|
||||
"name": func.get("name", ""),
|
||||
"input": args,
|
||||
"input": tool_arguments_object_for_replay(args),
|
||||
})
|
||||
|
||||
return blocks or [{"type": "text", "text": ""}]
|
||||
@ -509,7 +510,7 @@ class AnthropicProvider(LLMProvider):
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
arguments=block.input if isinstance(block.input, dict) else {},
|
||||
arguments=block.input,
|
||||
))
|
||||
elif block.type == "thinking":
|
||||
thinking_blocks.append({
|
||||
|
||||
@ -11,6 +11,7 @@ from datetime import datetime, timezone
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.utils.helpers import image_placeholder_text
|
||||
@ -21,19 +22,24 @@ class ToolCallRequest:
|
||||
"""A tool call request from the LLM."""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
arguments: Any
|
||||
extra_content: dict[str, Any] | None = None
|
||||
provider_specific_fields: dict[str, Any] | None = None
|
||||
function_provider_specific_fields: dict[str, Any] | None = None
|
||||
|
||||
def to_openai_tool_call(self) -> dict[str, Any]:
|
||||
"""Serialize to an OpenAI-style tool_call payload."""
|
||||
arguments = (
|
||||
self.arguments
|
||||
if isinstance(self.arguments, str)
|
||||
else json.dumps(self.arguments, ensure_ascii=False)
|
||||
)
|
||||
tool_call = {
|
||||
"id": self.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||
"arguments": arguments,
|
||||
},
|
||||
}
|
||||
if self.extra_content:
|
||||
@ -45,6 +51,62 @@ class ToolCallRequest:
|
||||
return tool_call
|
||||
|
||||
|
||||
def parse_tool_arguments(arguments: Any) -> Any:
|
||||
"""Parse provider tool arguments without guessing executable parameters.
|
||||
|
||||
Valid JSON object strings become dicts. Empty strings become no-arg calls.
|
||||
Malformed JSON and JSON array/scalar values are preserved so ToolRegistry
|
||||
can reject them before execution.
|
||||
"""
|
||||
if arguments is None:
|
||||
return {}
|
||||
if not isinstance(arguments, str):
|
||||
return arguments
|
||||
|
||||
stripped = arguments.strip()
|
||||
if not stripped:
|
||||
return {}
|
||||
|
||||
try:
|
||||
parsed = json.loads(stripped)
|
||||
except Exception:
|
||||
return arguments
|
||||
return arguments if parsed is None else parsed
|
||||
|
||||
|
||||
def tool_arguments_object_for_replay(arguments: Any) -> dict[str, Any]:
|
||||
"""Return object-shaped arguments for provider history replay only.
|
||||
|
||||
This compatibility path may repair malformed JSON because it only shapes
|
||||
existing conversation history for provider protocols. Do not use it for
|
||||
newly generated tool calls that are about to execute.
|
||||
"""
|
||||
if arguments is None:
|
||||
return {}
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
if not isinstance(arguments, str):
|
||||
return {}
|
||||
|
||||
stripped = arguments.strip()
|
||||
if not stripped:
|
||||
return {}
|
||||
|
||||
try:
|
||||
parsed = json.loads(stripped)
|
||||
except Exception:
|
||||
try:
|
||||
parsed = json_repair.loads(stripped)
|
||||
except Exception:
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
|
||||
def tool_arguments_json_for_replay(arguments: Any) -> str:
|
||||
"""Return JSON object string arguments for provider history replay only."""
|
||||
return json.dumps(tool_arguments_object_for_replay(arguments), ensure_ascii=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM provider."""
|
||||
|
||||
@ -10,9 +10,13 @@ import re
|
||||
from collections.abc import Awaitable, Callable, Iterator
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.base import (
|
||||
LLMProvider,
|
||||
LLMResponse,
|
||||
ToolCallRequest,
|
||||
parse_tool_arguments,
|
||||
tool_arguments_object_for_replay,
|
||||
)
|
||||
|
||||
_IMAGE_DATA_URL = re.compile(r"^data:image/([a-zA-Z0-9.+-]+);base64,(.*)$", re.DOTALL)
|
||||
_TEXT_BLOCK_TYPES = {"text", "input_text", "output_text"}
|
||||
@ -176,14 +180,7 @@ class BedrockProvider(LLMProvider):
|
||||
function = tool_call.get("function")
|
||||
if not isinstance(function, dict):
|
||||
return None
|
||||
args = function.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json_repair.loads(args) if args.strip() else {}
|
||||
except Exception:
|
||||
args = {}
|
||||
if not isinstance(args, dict):
|
||||
args = {}
|
||||
args = tool_arguments_object_for_replay(function.get("arguments", {}))
|
||||
return {
|
||||
"toolUse": {
|
||||
"toolUseId": str(tool_call.get("id") or ""),
|
||||
@ -491,7 +488,7 @@ class BedrockProvider(LLMProvider):
|
||||
content_parts.append(block["text"])
|
||||
tool_use = block.get("toolUse")
|
||||
if isinstance(tool_use, dict):
|
||||
arguments = tool_use.get("input") if isinstance(tool_use.get("input"), dict) else {}
|
||||
arguments = tool_use.get("input", {})
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=str(tool_use.get("toolUseId") or ""),
|
||||
name=str(tool_use.get("name") or ""),
|
||||
@ -616,14 +613,11 @@ class BedrockProvider(LLMProvider):
|
||||
for buf in tool_buffers.values():
|
||||
args: Any = {}
|
||||
if buf.get("input"):
|
||||
try:
|
||||
args = json_repair.loads(buf["input"])
|
||||
except Exception:
|
||||
args = {}
|
||||
args = parse_tool_arguments(buf["input"])
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=buf.get("id") or "",
|
||||
name=buf.get("name") or "",
|
||||
arguments=args if isinstance(args, dict) else {},
|
||||
arguments=args,
|
||||
))
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
|
||||
@ -17,10 +17,15 @@ from ipaddress import ip_address
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import json_repair
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.base import (
|
||||
LLMProvider,
|
||||
LLMResponse,
|
||||
ToolCallRequest,
|
||||
parse_tool_arguments,
|
||||
tool_arguments_json_for_replay,
|
||||
)
|
||||
from nanobot.providers.openai_responses import (
|
||||
consume_sdk_stream,
|
||||
convert_messages,
|
||||
@ -478,24 +483,6 @@ class OpenAICompatProvider(LLMProvider):
|
||||
"""Return True for providers that reject normal OpenAI tool call IDs."""
|
||||
return bool(self._spec and self._spec.name == "mistral")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_call_arguments(arguments: Any) -> str:
|
||||
"""Force function.arguments into a valid JSON object string."""
|
||||
if isinstance(arguments, str):
|
||||
stripped = arguments.strip()
|
||||
if not stripped:
|
||||
return "{}"
|
||||
try:
|
||||
parsed = json_repair.loads(stripped)
|
||||
except Exception:
|
||||
return "{}"
|
||||
if isinstance(parsed, dict):
|
||||
return json.dumps(parsed, ensure_ascii=False)
|
||||
return "{}"
|
||||
if isinstance(arguments, dict):
|
||||
return json.dumps(arguments, ensure_ascii=False)
|
||||
return "{}"
|
||||
|
||||
@staticmethod
|
||||
def _coerce_content_to_string(content: Any) -> str | None:
|
||||
"""Coerce block/list content into plain text for strict string-only APIs."""
|
||||
@ -572,7 +559,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if isinstance(function, dict):
|
||||
function_clean = dict(function)
|
||||
if "arguments" in function_clean:
|
||||
function_clean["arguments"] = self._normalize_tool_call_arguments(
|
||||
function_clean["arguments"] = tool_arguments_json_for_replay(
|
||||
function_clean.get("arguments")
|
||||
)
|
||||
else:
|
||||
@ -1021,14 +1008,12 @@ class OpenAICompatProvider(LLMProvider):
|
||||
for tc in raw_tool_calls:
|
||||
tc_map = self._maybe_mapping(tc) or {}
|
||||
fn = self._maybe_mapping(tc_map.get("function")) or {}
|
||||
args = fn.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
args = parse_tool_arguments(fn.get("arguments", {}))
|
||||
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||
parsed_tool_calls.append(ToolCallRequest(
|
||||
id=str(tc_map.get("id") or _short_tool_id()),
|
||||
name=str(fn.get("name") or ""),
|
||||
arguments=args if isinstance(args, dict) else {},
|
||||
arguments=args,
|
||||
extra_content=ec,
|
||||
provider_specific_fields=prov,
|
||||
function_provider_specific_fields=fn_prov,
|
||||
@ -1064,9 +1049,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
|
||||
tool_calls = []
|
||||
for tc in raw_tool_calls:
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
args = parse_tool_arguments(tc.function.arguments)
|
||||
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=str(getattr(tc, "id", None) or _short_tool_id()),
|
||||
@ -1207,7 +1190,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
ToolCallRequest(
|
||||
id=b["id"] or _short_tool_id(),
|
||||
name=b["name"],
|
||||
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
|
||||
arguments=parse_tool_arguments(b["arguments"]),
|
||||
extra_content=b.get("extra_content"),
|
||||
provider_specific_fields=b.get("prov"),
|
||||
function_provider_specific_fields=b.get("fn_prov"),
|
||||
|
||||
@ -5,6 +5,8 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from nanobot.providers.base import tool_arguments_json_for_replay
|
||||
|
||||
|
||||
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""Convert Chat Completions messages to Responses API input items.
|
||||
@ -46,7 +48,7 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str
|
||||
"id": response_item_id,
|
||||
"call_id": call_id or f"call_{idx}",
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
"arguments": tool_arguments_json_for_replay(fn.get("arguments")),
|
||||
})
|
||||
continue
|
||||
|
||||
|
||||
@ -7,10 +7,9 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest, parse_tool_arguments
|
||||
|
||||
FINISH_REASON_MAP = {
|
||||
"completed": "stop",
|
||||
@ -44,6 +43,27 @@ def _usage_from_response_obj(response: Any) -> dict[str, int]:
|
||||
}
|
||||
|
||||
|
||||
def _parse_tool_call_arguments(args_raw: Any, name: str | None) -> Any:
|
||||
parsed = parse_tool_arguments(args_raw)
|
||||
if parsed == args_raw and isinstance(args_raw, str) and args_raw.strip():
|
||||
logger.warning(
|
||||
"Failed to parse tool call arguments for '{}': {}",
|
||||
name,
|
||||
args_raw[:200],
|
||||
)
|
||||
return parsed
|
||||
|
||||
|
||||
def _tool_arguments_source(*values: Any) -> Any:
|
||||
for value in values:
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, str) and not value.strip():
|
||||
continue
|
||||
return value
|
||||
return "{}"
|
||||
|
||||
|
||||
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""Yield parsed JSON events from a Responses API SSE stream."""
|
||||
buffer: list[str] = []
|
||||
@ -116,10 +136,11 @@ async def consume_sse_with_reasoning(
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
arguments = item.get("arguments")
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": item.get("id") or "fc_0",
|
||||
"name": item.get("name"),
|
||||
"arguments": item.get("arguments") or "",
|
||||
"arguments": "" if arguments is None else arguments,
|
||||
}
|
||||
if on_tool_call_delta:
|
||||
await on_tool_call_delta({
|
||||
@ -156,7 +177,10 @@ async def consume_sse_with_reasoning(
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
delta = event.get("delta") or ""
|
||||
tool_call_buffers[call_id]["arguments"] += delta
|
||||
current = tool_call_buffers[call_id].get("arguments")
|
||||
if not isinstance(current, str):
|
||||
current = ""
|
||||
tool_call_buffers[call_id]["arguments"] = current + delta
|
||||
if on_tool_call_delta and delta:
|
||||
await on_tool_call_delta({
|
||||
"call_id": str(call_id),
|
||||
@ -166,14 +190,14 @@ async def consume_sse_with_reasoning(
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
arguments = event.get("arguments") or ""
|
||||
arguments = event.get("arguments")
|
||||
tool_call_buffers[call_id]["arguments"] = arguments
|
||||
if on_tool_call_delta:
|
||||
tool_call_args_emitted.add(str(call_id))
|
||||
await on_tool_call_delta({
|
||||
"call_id": str(call_id),
|
||||
"name": str(tool_call_buffers[call_id].get("name") or ""),
|
||||
"arguments": str(arguments),
|
||||
"arguments": "" if arguments is None else str(arguments),
|
||||
})
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
@ -182,7 +206,7 @@ async def consume_sse_with_reasoning(
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
||||
args_raw = _tool_arguments_source(buf.get("arguments"), item.get("arguments"))
|
||||
if on_tool_call_delta and str(call_id) not in tool_call_args_emitted:
|
||||
tool_call_args_emitted.add(str(call_id))
|
||||
await on_tool_call_delta({
|
||||
@ -190,17 +214,10 @@ async def consume_sse_with_reasoning(
|
||||
"name": str(buf.get("name") or item.get("name") or ""),
|
||||
"arguments": str(args_raw),
|
||||
})
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse tool call arguments for '{}': {}",
|
||||
buf.get("name") or item.get("name"),
|
||||
args_raw[:200],
|
||||
)
|
||||
args = json_repair.loads(args_raw)
|
||||
if not isinstance(args, dict):
|
||||
args = {"raw": args_raw}
|
||||
args = _parse_tool_call_arguments(
|
||||
args_raw,
|
||||
buf.get("name") or item.get("name"),
|
||||
)
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||
@ -283,22 +300,12 @@ def parse_response_output(response: Any) -> LLMResponse:
|
||||
elif item_type == "function_call":
|
||||
call_id = item.get("call_id") or ""
|
||||
item_id = item.get("id") or "fc_0"
|
||||
args_raw = item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse tool call arguments for '{}': {}",
|
||||
item.get("name"),
|
||||
str(args_raw)[:200],
|
||||
)
|
||||
args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
if not isinstance(args, dict):
|
||||
args = {"raw": args_raw}
|
||||
args_raw = _tool_arguments_source(item.get("arguments"))
|
||||
args = _parse_tool_call_arguments(args_raw, item.get("name"))
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=f"{call_id}|{item_id}",
|
||||
name=item.get("name") or "",
|
||||
arguments=args if isinstance(args, dict) else {},
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
usage = _usage_from_response_obj(response)
|
||||
@ -337,10 +344,11 @@ async def consume_sdk_stream(
|
||||
call_id = getattr(item, "call_id", None)
|
||||
if not call_id:
|
||||
continue
|
||||
arguments = getattr(item, "arguments", None)
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": getattr(item, "id", None) or "fc_0",
|
||||
"name": getattr(item, "name", None),
|
||||
"arguments": getattr(item, "arguments", None) or "",
|
||||
"arguments": "" if arguments is None else arguments,
|
||||
}
|
||||
if on_tool_call_delta:
|
||||
await on_tool_call_delta({
|
||||
@ -357,7 +365,10 @@ async def consume_sdk_stream(
|
||||
call_id = getattr(event, "call_id", None)
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
delta = getattr(event, "delta", "") or ""
|
||||
tool_call_buffers[call_id]["arguments"] += delta
|
||||
current = tool_call_buffers[call_id].get("arguments")
|
||||
if not isinstance(current, str):
|
||||
current = ""
|
||||
tool_call_buffers[call_id]["arguments"] = current + delta
|
||||
if on_tool_call_delta and delta:
|
||||
await on_tool_call_delta({
|
||||
"call_id": str(call_id),
|
||||
@ -367,14 +378,14 @@ async def consume_sdk_stream(
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = getattr(event, "call_id", None)
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
arguments = getattr(event, "arguments", "") or ""
|
||||
arguments = getattr(event, "arguments", None)
|
||||
tool_call_buffers[call_id]["arguments"] = arguments
|
||||
if on_tool_call_delta:
|
||||
tool_call_args_emitted.add(str(call_id))
|
||||
await on_tool_call_delta({
|
||||
"call_id": str(call_id),
|
||||
"name": str(tool_call_buffers[call_id].get("name") or ""),
|
||||
"arguments": str(arguments),
|
||||
"arguments": "" if arguments is None else str(arguments),
|
||||
})
|
||||
elif event_type == "response.output_item.done":
|
||||
item = getattr(event, "item", None)
|
||||
@ -383,7 +394,10 @@ async def consume_sdk_stream(
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
|
||||
args_raw = _tool_arguments_source(
|
||||
buf.get("arguments"),
|
||||
getattr(item, "arguments", None),
|
||||
)
|
||||
if on_tool_call_delta and str(call_id) not in tool_call_args_emitted:
|
||||
tool_call_args_emitted.add(str(call_id))
|
||||
await on_tool_call_delta({
|
||||
@ -391,17 +405,10 @@ async def consume_sdk_stream(
|
||||
"name": str(buf.get("name") or getattr(item, "name", None) or ""),
|
||||
"arguments": str(args_raw),
|
||||
})
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse tool call arguments for '{}': {}",
|
||||
buf.get("name") or getattr(item, "name", None),
|
||||
str(args_raw)[:200],
|
||||
)
|
||||
args = json_repair.loads(args_raw)
|
||||
if not isinstance(args, dict):
|
||||
args = {"raw": args_raw}
|
||||
args = _parse_tool_call_arguments(
|
||||
args_raw,
|
||||
buf.get("name") or getattr(item, "name", None),
|
||||
)
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
|
||||
|
||||
@ -49,13 +49,18 @@ async def invoke_file_edit_progress(
|
||||
await on_progress("", file_edit_events=file_edit_events)
|
||||
|
||||
|
||||
def _tool_event_arguments(tool_call: Any) -> dict[str, Any]:
|
||||
arguments = getattr(tool_call, "arguments", {}) or {}
|
||||
return arguments if isinstance(arguments, dict) else {}
|
||||
|
||||
|
||||
def build_tool_event_start_payload(tool_call: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"version": 1,
|
||||
"phase": "start",
|
||||
"call_id": str(getattr(tool_call, "id", "") or ""),
|
||||
"name": getattr(tool_call, "name", ""),
|
||||
"arguments": getattr(tool_call, "arguments", {}) or {},
|
||||
"arguments": _tool_event_arguments(tool_call),
|
||||
"result": None,
|
||||
"error": None,
|
||||
"files": [],
|
||||
@ -86,7 +91,7 @@ def build_tool_event_finish_payloads(context: AgentHookContext) -> list[dict[str
|
||||
"phase": phase,
|
||||
"call_id": str(getattr(tool_call, "id", "") or ""),
|
||||
"name": getattr(tool_call, "name", ""),
|
||||
"arguments": getattr(tool_call, "arguments", {}) or {},
|
||||
"arguments": _tool_event_arguments(tool_call),
|
||||
"result": result if phase == "end" else None,
|
||||
"error": None,
|
||||
"files": files,
|
||||
|
||||
@ -75,8 +75,10 @@ def build_goal_continue_message(custom: str | None = None) -> dict[str, str]:
|
||||
return {"role": "user", "content": custom or SUSTAINED_GOAL_CONTINUE_PROMPT}
|
||||
|
||||
|
||||
def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str | None:
|
||||
def external_lookup_signature(tool_name: str, arguments: Any) -> str | None:
|
||||
"""Stable signature for repeated external lookups we want to throttle."""
|
||||
if not isinstance(arguments, dict):
|
||||
return None
|
||||
if tool_name == "web_fetch":
|
||||
url = str(arguments.get("url") or "").strip()
|
||||
if url:
|
||||
@ -90,7 +92,7 @@ def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str
|
||||
|
||||
def repeated_external_lookup_error(
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
arguments: Any,
|
||||
seen_counts: dict[str, int],
|
||||
) -> str | None:
|
||||
"""Block repeated external lookups after a small retry budget."""
|
||||
@ -119,9 +121,11 @@ _OUTSIDE_PATH_PATTERN = re.compile(r"(?:^|[\s|>'\"])((?:/[^\s\"'>;|<]+)|(?:~[^\s
|
||||
|
||||
def workspace_violation_signature(
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
arguments: Any,
|
||||
) -> str | None:
|
||||
"""Return a stable cross-tool signature for the outside-workspace target."""
|
||||
if not isinstance(arguments, dict):
|
||||
return None
|
||||
for key in ("path", "file_path", "target", "source", "destination"):
|
||||
val = arguments.get(key)
|
||||
if isinstance(val, str) and val.strip():
|
||||
@ -151,7 +155,7 @@ def _normalize_violation_target(raw: str) -> str:
|
||||
|
||||
def repeated_workspace_violation_error(
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
arguments: Any,
|
||||
seen_counts: dict[str, int],
|
||||
) -> str | None:
|
||||
"""Return an escalated error after repeated bypass attempts."""
|
||||
|
||||
@ -3,17 +3,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.openai_responses.parsing import parse_response_output
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
class _DelayTool(Tool):
|
||||
def __init__(
|
||||
self,
|
||||
@ -57,10 +61,45 @@ class _DelayTool(Tool):
|
||||
return self._name
|
||||
|
||||
|
||||
async def _run_optional_tool_response(response: LLMResponse):
|
||||
provider = MagicMock()
|
||||
calls = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
return response
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = ToolRegistry()
|
||||
shared_events: list[str] = []
|
||||
tools.register(_DelayTool(
|
||||
"optional_tool",
|
||||
delay=0,
|
||||
read_only=True,
|
||||
shared_events=shared_events,
|
||||
))
|
||||
|
||||
result = await AgentRunner(provider).run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "try optional"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
return result, shared_events
|
||||
|
||||
|
||||
def _tool_message(result, tool_call_id: str) -> dict:
|
||||
return [
|
||||
msg for msg in result.messages
|
||||
if msg.get("role") == "tool" and msg.get("tool_call_id") == tool_call_id
|
||||
][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_batches_read_only_tools_before_exclusive_work():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
tools = ToolRegistry()
|
||||
shared_events: list[str] = []
|
||||
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
|
||||
@ -98,8 +137,6 @@ async def test_runner_batches_read_only_tools_before_exclusive_work():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_does_not_batch_exclusive_read_only_tools():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
tools = ToolRegistry()
|
||||
shared_events: list[str] = []
|
||||
read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events)
|
||||
@ -140,9 +177,151 @@ async def test_runner_does_not_batch_exclusive_read_only_tools():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_blocks_repeated_external_fetches():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
async def test_runner_rejects_near_miss_tool_name_without_executing():
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
captured_second_call: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="readFile",
|
||||
arguments={"path": "notes.txt"},
|
||||
)
|
||||
],
|
||||
finish_reason="tool_calls",
|
||||
usage={},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = ToolRegistry()
|
||||
shared_events: list[str] = []
|
||||
tools.register(_DelayTool(
|
||||
"read_file",
|
||||
delay=0,
|
||||
read_only=True,
|
||||
shared_events=shared_events,
|
||||
))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "read notes"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert result.tools_used == []
|
||||
assert shared_events == []
|
||||
assistant_message = [
|
||||
msg for msg in result.messages
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
][0]
|
||||
assert assistant_message["tool_calls"][0]["function"]["name"] == "readFile"
|
||||
tool_message = [
|
||||
msg for msg in result.messages
|
||||
if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_1"
|
||||
][0]
|
||||
assert tool_message["name"] == "readFile"
|
||||
assert "Tool 'readFile' not found" in tool_message["content"]
|
||||
assert "Did you mean 'read_file'?" in tool_message["content"]
|
||||
replayed_assistant = [
|
||||
msg for msg in captured_second_call
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
][0]
|
||||
assert replayed_assistant["tool_calls"][0]["function"]["name"] == "readFile"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("arguments", ['{path:"notes.txt"}', "null"])
|
||||
async def test_runner_rejects_openai_compat_invalid_arguments_without_executing(arguments):
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
parsed = OpenAICompatProvider()._parse({
|
||||
"choices": [{
|
||||
"message": {
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "optional_tool",
|
||||
"arguments": arguments,
|
||||
},
|
||||
}],
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
}],
|
||||
"usage": {},
|
||||
})
|
||||
|
||||
result, shared_events = await _run_optional_tool_response(parsed)
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert parsed.tool_calls[0].arguments == arguments
|
||||
assert result.tools_used == []
|
||||
assert shared_events == []
|
||||
tool_message = _tool_message(result, "call_1")
|
||||
assert "parameters must be a JSON object" in tool_message["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_rejects_openai_responses_malformed_arguments_without_executing():
|
||||
parsed = parse_response_output({
|
||||
"output": [{
|
||||
"type": "function_call",
|
||||
"call_id": "call_1",
|
||||
"id": "fc_1",
|
||||
"name": "optional_tool",
|
||||
"arguments": "{bad",
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {},
|
||||
})
|
||||
|
||||
result, shared_events = await _run_optional_tool_response(parsed)
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert parsed.tool_calls[0].arguments == "{bad"
|
||||
assert result.tools_used == []
|
||||
assert shared_events == []
|
||||
tool_message = _tool_message(result, "call_1|fc_1")
|
||||
assert "parameters must be a JSON object" in tool_message["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_rejects_openai_responses_array_arguments_without_executing():
|
||||
parsed = parse_response_output({
|
||||
"output": [{
|
||||
"type": "function_call",
|
||||
"call_id": "call_1",
|
||||
"id": "fc_1",
|
||||
"name": "optional_tool",
|
||||
"arguments": [],
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {},
|
||||
})
|
||||
|
||||
result, shared_events = await _run_optional_tool_response(parsed)
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert parsed.tool_calls[0].arguments == []
|
||||
assert result.tools_used == []
|
||||
assert shared_events == []
|
||||
tool_message = _tool_message(result, "call_1|fc_1")
|
||||
assert "parameters must be a JSON object" in tool_message["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_blocks_repeated_external_fetches():
|
||||
provider = MagicMock()
|
||||
captured_final_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
@ -80,3 +80,17 @@ def test_convert_user_content_coerces_mixed_typeless():
|
||||
])
|
||||
assert result[0] == {"type": "text", "text": "42"}
|
||||
assert result[1] == {"type": "text", "text": str({"key": "val"})}
|
||||
|
||||
|
||||
def test_convert_assistant_message_repairs_history_tool_arguments():
|
||||
blocks = AnthropicProvider._assistant_blocks({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "toolu_1",
|
||||
"function": {"name": "read_file", "arguments": '{path:"foo.txt"}'},
|
||||
}],
|
||||
})
|
||||
|
||||
assert blocks[0]["type"] == "tool_use"
|
||||
assert blocks[0]["input"] == {"path": "foo.txt"}
|
||||
|
||||
@ -161,6 +161,16 @@ def test_build_kwargs_converts_messages_tools_and_tool_results() -> None:
|
||||
assert kwargs["toolConfig"]["toolChoice"] == {"any": {}}
|
||||
|
||||
|
||||
def test_tool_use_block_repairs_history_tool_arguments() -> None:
|
||||
block = BedrockProvider._tool_use_block({
|
||||
"id": "toolu_1",
|
||||
"function": {"name": "read_file", "arguments": '{path:"foo.txt"}'},
|
||||
})
|
||||
|
||||
assert block is not None
|
||||
assert block["toolUse"]["input"] == {"path": "foo.txt"}
|
||||
|
||||
|
||||
def test_build_kwargs_keeps_tool_config_for_historical_tool_blocks_without_tools() -> None:
|
||||
provider = BedrockProvider(region="us-east-1", client=FakeClient())
|
||||
messages = [
|
||||
|
||||
@ -54,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace:
|
||||
return SimpleNamespace(choices=[choice], usage=usage)
|
||||
|
||||
|
||||
def _fake_tool_call_response_with_arguments(arguments) -> SimpleNamespace:
|
||||
"""Build a minimal chat response with caller-supplied tool arguments."""
|
||||
function = SimpleNamespace(name="optional_tool", arguments=arguments)
|
||||
tool_call = SimpleNamespace(id="call_123", type="function", function=function)
|
||||
message = SimpleNamespace(content=None, tool_calls=[tool_call], reasoning_content=None)
|
||||
choice = SimpleNamespace(message=message, finish_reason="tool_calls")
|
||||
return SimpleNamespace(choices=[choice], usage=SimpleNamespace())
|
||||
|
||||
|
||||
def _fake_responses_response(content: str = "ok") -> MagicMock:
|
||||
"""Build a minimal Responses API response object."""
|
||||
resp = MagicMock()
|
||||
@ -611,6 +620,24 @@ async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None:
|
||||
assert serialized["function"]["provider_specific_fields"] == {"inner": "value"}
|
||||
|
||||
|
||||
def test_openai_compat_parse_preserves_malformed_tool_arguments() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
result = provider._parse(_fake_tool_call_response_with_arguments('{path:"foo.txt"}'))
|
||||
|
||||
assert result.tool_calls[0].arguments == '{path:"foo.txt"}'
|
||||
|
||||
|
||||
def test_openai_compat_parse_preserves_array_tool_arguments() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
result = provider._parse(_fake_tool_call_response_with_arguments('["foo.txt"]'))
|
||||
|
||||
assert result.tool_calls[0].arguments == ["foo.txt"]
|
||||
|
||||
|
||||
def test_openai_model_passthrough() -> None:
|
||||
"""OpenAI models pass through unchanged."""
|
||||
spec = find_by_name("openai")
|
||||
@ -1110,7 +1137,7 @@ def test_openai_compat_stringifies_dict_tool_arguments() -> None:
|
||||
assert sanitized[1]["tool_calls"][0]["function"]["arguments"] == '{"cmd": "ls -la"}'
|
||||
|
||||
|
||||
def test_openai_compat_repairs_non_json_tool_arguments_string() -> None:
|
||||
def test_openai_compat_repairs_object_like_history_tool_arguments_string() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
|
||||
@ -155,6 +155,19 @@ class TestConvertMessages:
|
||||
assert items[0]["call_id"] == "call_abc"
|
||||
assert items[0]["id"] == "fc_1"
|
||||
assert items[0]["name"] == "get_weather"
|
||||
assert items[0]["arguments"] == '{"city": "SF"}'
|
||||
|
||||
def test_assistant_tool_call_history_repairs_malformed_arguments(self):
|
||||
_, items = convert_messages([{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_abc|fc_1",
|
||||
"function": {"name": "read_file", "arguments": '{path:"foo.txt"}'},
|
||||
}],
|
||||
}])
|
||||
|
||||
assert json.loads(items[0]["arguments"]) == {"path": "foo.txt"}
|
||||
|
||||
def test_duplicate_response_item_ids_are_made_unique(self):
|
||||
"""Codex rejects replayed Responses input items with duplicate ids."""
|
||||
@ -367,7 +380,7 @@ class TestParseResponseOutput:
|
||||
assert result.tool_calls[0].id == "call_1|fc_1"
|
||||
|
||||
def test_malformed_tool_arguments_logged(self):
|
||||
"""Malformed JSON arguments should log a warning and fallback."""
|
||||
"""Malformed JSON arguments should log a warning and remain non-object."""
|
||||
resp = {
|
||||
"output": [{
|
||||
"type": "function_call",
|
||||
@ -378,10 +391,29 @@ class TestParseResponseOutput:
|
||||
}
|
||||
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
|
||||
result = parse_response_output(resp)
|
||||
assert result.tool_calls[0].arguments == {"raw": "{bad json"}
|
||||
assert result.tool_calls[0].arguments == "{bad json"
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
|
||||
|
||||
@pytest.mark.parametrize("arguments", [[], False, 0])
|
||||
def test_falsy_non_object_tool_arguments_preserved(self, arguments):
|
||||
resp = {
|
||||
"output": [{
|
||||
"type": "function_call",
|
||||
"call_id": "c1",
|
||||
"id": "fc1",
|
||||
"name": "f",
|
||||
"arguments": arguments,
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {},
|
||||
}
|
||||
|
||||
result = parse_response_output(resp)
|
||||
|
||||
assert result.tool_calls[0].arguments == arguments
|
||||
assert type(result.tool_calls[0].arguments) is type(arguments)
|
||||
|
||||
def test_reasoning_content_extracted(self):
|
||||
resp = {
|
||||
"output": [
|
||||
@ -611,6 +643,38 @@ class TestConsumeSse:
|
||||
},
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("arguments", [[], False, 0])
|
||||
async def test_falsy_non_object_tool_arguments_preserved(self, arguments):
|
||||
response = _SseResponse([
|
||||
{
|
||||
"type": "response.output_item.added",
|
||||
"item": {
|
||||
"type": "function_call",
|
||||
"call_id": "c1",
|
||||
"id": "fc1",
|
||||
"name": "f",
|
||||
"arguments": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "function_call",
|
||||
"call_id": "c1",
|
||||
"id": "fc1",
|
||||
"name": "f",
|
||||
"arguments": arguments,
|
||||
},
|
||||
},
|
||||
{"type": "response.completed", "response": {"status": "completed"}},
|
||||
])
|
||||
|
||||
_, tool_calls, _, _, _ = await consume_sse_with_reasoning(response)
|
||||
|
||||
assert tool_calls[0].arguments == arguments
|
||||
assert type(tool_calls[0].arguments) is type(arguments)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# parsing - consume_sdk_stream
|
||||
@ -764,6 +828,28 @@ class TestConsumeSdkStream:
|
||||
},
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("arguments", [[], False, 0])
|
||||
async def test_falsy_non_object_tool_arguments_preserved(self, arguments):
|
||||
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
|
||||
item_added.name = "f"
|
||||
ev1 = MagicMock(type="response.output_item.added", item=item_added)
|
||||
item_done = MagicMock(type="function_call", call_id="c1", id="fc1")
|
||||
item_done.name = "f"
|
||||
item_done.arguments = arguments
|
||||
ev2 = MagicMock(type="response.output_item.done", item=item_done)
|
||||
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||
ev3 = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
async def stream():
|
||||
for e in [ev1, ev2, ev3]:
|
||||
yield e
|
||||
|
||||
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
|
||||
|
||||
assert tool_calls[0].arguments == arguments
|
||||
assert type(tool_calls[0].arguments) is type(arguments)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_extracted(self):
|
||||
usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15)
|
||||
@ -811,7 +897,7 @@ class TestConsumeSdkStream:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_tool_args_logged(self):
|
||||
"""Malformed JSON in streaming tool args should log a warning."""
|
||||
"""Malformed JSON in streaming tool args should log a warning and remain non-object."""
|
||||
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
|
||||
item_added.name = "f"
|
||||
ev1 = MagicMock(type="response.output_item.added", item=item_added)
|
||||
@ -828,6 +914,6 @@ class TestConsumeSdkStream:
|
||||
|
||||
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
|
||||
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
|
||||
assert tool_calls[0].arguments == {"raw": "{bad"}
|
||||
assert tool_calls[0].arguments == "{bad"
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
|
||||
|
||||
30
tests/providers/test_provider_tool_arguments.py
Normal file
30
tests/providers/test_provider_tool_arguments.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Shared tool-argument parsing policy tests."""
|
||||
|
||||
from nanobot.providers.base import (
|
||||
parse_tool_arguments,
|
||||
tool_arguments_json_for_replay,
|
||||
tool_arguments_object_for_replay,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_tool_arguments_preserves_malformed_executable_arguments() -> None:
|
||||
assert parse_tool_arguments('{path:"foo.txt"}') == '{path:"foo.txt"}'
|
||||
|
||||
|
||||
def test_parse_tool_arguments_preserves_non_object_executable_arguments() -> None:
|
||||
assert parse_tool_arguments('["foo.txt"]') == ["foo.txt"]
|
||||
assert parse_tool_arguments("false") is False
|
||||
assert parse_tool_arguments("null") == "null"
|
||||
|
||||
|
||||
def test_tool_arguments_object_for_replay_repairs_object_like_history_arguments() -> None:
|
||||
assert tool_arguments_object_for_replay('{path:"foo.txt"}') == {"path": "foo.txt"}
|
||||
|
||||
|
||||
def test_tool_arguments_object_for_replay_keeps_history_object_shaped() -> None:
|
||||
for arguments in ['["foo.txt"]', "false", "null", "0", ["foo.txt"], False, None, 0]:
|
||||
assert tool_arguments_object_for_replay(arguments) == {}
|
||||
|
||||
|
||||
def test_tool_arguments_json_for_replay_returns_object_string() -> None:
|
||||
assert tool_arguments_json_for_replay('{path:"foo.txt"}') == '{"path": "foo.txt"}'
|
||||
@ -7,8 +7,9 @@ from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
class _FakeTool(Tool):
|
||||
def __init__(self, name: str):
|
||||
def __init__(self, name: str, schema: dict[str, Any] | None = None):
|
||||
self._name = name
|
||||
self._schema = schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -20,7 +21,7 @@ class _FakeTool(Tool):
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {"type": "object", "properties": {}}
|
||||
return self._schema or {"type": "object", "properties": {}}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> Any:
|
||||
return kwargs
|
||||
@ -34,6 +35,13 @@ def _tool_names(definitions: list[dict[str, Any]]) -> list[str]:
|
||||
return names
|
||||
|
||||
|
||||
def _registry_with_names(names: list[str]) -> ToolRegistry:
|
||||
registry = ToolRegistry()
|
||||
for name in names:
|
||||
registry.register(_FakeTool(name))
|
||||
return registry
|
||||
|
||||
|
||||
def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool("mcp_git_status"))
|
||||
@ -49,17 +57,167 @@ def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
|
||||
]
|
||||
|
||||
|
||||
def test_prepare_call_rejects_near_miss_tool_name_with_suggestion() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool("read_file"))
|
||||
|
||||
tool, params, error = registry.prepare_call("readFile", {"path": "foo.txt"})
|
||||
|
||||
assert tool is None
|
||||
assert params == {"path": "foo.txt"}
|
||||
assert error is not None
|
||||
assert "Tool 'readFile' not found" in error
|
||||
assert "Did you mean 'read_file'?" in error
|
||||
assert "must match exactly" in error
|
||||
|
||||
|
||||
def test_suggest_name_handles_canonical_tool_name_variants() -> None:
|
||||
registry = _registry_with_names(["read_file"])
|
||||
expected = {
|
||||
"readFile": "read_file",
|
||||
"read-file": "read_file",
|
||||
"READ_FILE": "read_file",
|
||||
"read file": "read_file",
|
||||
"readfile": "read_file",
|
||||
}
|
||||
|
||||
assert {name: registry._suggest_name(name) for name in expected} == expected
|
||||
|
||||
|
||||
def test_suggest_name_suppresses_low_confidence_and_non_unique_matches() -> None:
|
||||
registry = _registry_with_names(["read_file", "write_file"])
|
||||
|
||||
for name in ["", "foo", "read", "file", "readfil", "read_file_tool"]:
|
||||
assert registry._suggest_name(name) is None
|
||||
|
||||
ambiguous = _registry_with_names(["read_file", "readFile"])
|
||||
assert ambiguous._suggest_name("readfile") is None
|
||||
|
||||
|
||||
def test_suggest_name_updates_after_register_and_unregister() -> None:
|
||||
registry = _registry_with_names(["read_file"])
|
||||
|
||||
assert registry._suggest_name("readFile") == "read_file"
|
||||
|
||||
registry.register(_FakeTool("readFile"))
|
||||
assert registry._suggest_name("read-file") is None
|
||||
|
||||
registry.unregister("read_file")
|
||||
assert registry._suggest_name("read-file") == "readFile"
|
||||
|
||||
|
||||
def test_prepare_call_read_file_rejects_non_object_params_with_actionable_hint() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool("read_file"))
|
||||
|
||||
tool, params, error = registry.prepare_call("read_file", ["foo.txt"])
|
||||
|
||||
assert tool is None
|
||||
assert tool is not None
|
||||
assert params == ["foo.txt"]
|
||||
assert error is not None
|
||||
assert "must be a JSON object" in error
|
||||
assert "Use named parameters" in error
|
||||
assert 'tool_name(param1="value1", param2="value2")' in error
|
||||
assert "matching the tool schema" in error
|
||||
|
||||
|
||||
def test_prepare_call_parses_json_string_arguments() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool("read_file"))
|
||||
|
||||
tool, params, error = registry.prepare_call("read_file", '{"path":"foo.txt"}')
|
||||
|
||||
assert tool is not None
|
||||
assert params == {"path": "foo.txt"}
|
||||
assert error is None
|
||||
|
||||
|
||||
def test_prepare_call_rejects_malformed_json_string_arguments() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool("read_file"))
|
||||
|
||||
tool, params, error = registry.prepare_call("read_file", '{path:"foo.txt"}')
|
||||
|
||||
assert tool is not None
|
||||
assert params == '{path:"foo.txt"}'
|
||||
assert error is not None
|
||||
assert "parameters must be a JSON object" in error
|
||||
|
||||
|
||||
def test_prepare_call_rejects_scalar_for_single_required_parameter() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool(
|
||||
"web_fetch",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"url": {"type": "string"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
))
|
||||
|
||||
tool, params, error = registry.prepare_call("web_fetch", "https://example.com")
|
||||
|
||||
assert tool is not None
|
||||
assert params == "https://example.com"
|
||||
assert error is not None
|
||||
assert "parameters must be a JSON object" in error
|
||||
|
||||
|
||||
def test_prepare_call_rejects_unquoted_scalar_strings_before_schema_cast() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool(
|
||||
"message",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"content": {"type": "string"}},
|
||||
"required": ["content"],
|
||||
},
|
||||
))
|
||||
|
||||
tool, params, error = registry.prepare_call("message", "true")
|
||||
|
||||
assert tool is not None
|
||||
assert params == "true"
|
||||
assert error is not None
|
||||
assert "parameters must be a JSON object" in error
|
||||
|
||||
|
||||
def test_prepare_call_unwraps_arguments_payload() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool(
|
||||
"read_file",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"path": {"type": "string"}},
|
||||
"required": ["path"],
|
||||
},
|
||||
))
|
||||
|
||||
tool, params, error = registry.prepare_call(
|
||||
"read_file",
|
||||
{"arguments": '{"path":"foo.txt"}'},
|
||||
)
|
||||
|
||||
assert tool is not None
|
||||
assert params == {"path": "foo.txt"}
|
||||
assert error is None
|
||||
|
||||
|
||||
def test_prepare_call_treats_none_arguments_as_empty_object() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool("list_exec_sessions"))
|
||||
|
||||
tool, params, error = registry.prepare_call("list_exec_sessions", None)
|
||||
|
||||
assert tool is not None
|
||||
assert params == {}
|
||||
assert error is None
|
||||
|
||||
tool, params, error = registry.prepare_call("list_exec_sessions", "null")
|
||||
|
||||
assert tool is not None
|
||||
assert params == "null"
|
||||
assert error is not None
|
||||
assert "parameters must be a JSON object" in error
|
||||
|
||||
|
||||
def test_prepare_call_other_tools_keep_generic_object_validation() -> None:
|
||||
@ -70,7 +228,11 @@ def test_prepare_call_other_tools_keep_generic_object_validation() -> None:
|
||||
|
||||
assert tool is not None
|
||||
assert params == ["TODO"]
|
||||
assert error == "Error: Invalid parameters for tool 'grep': parameters must be an object, got list"
|
||||
assert error == (
|
||||
"Error: Tool 'grep' parameters must be a JSON object, got list. "
|
||||
'Use named parameters like tool_name(param1="value1", param2="value2") '
|
||||
"matching the tool schema."
|
||||
)
|
||||
|
||||
|
||||
def test_get_definitions_returns_cached_result() -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user