mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 15:24:06 +00:00
Improve tool call validation strictness (#4190)
* Improve tool call validation strictness Reject near-miss tool names without executing suggested tools. Require object-shaped tool parameters while preserving only lossless JSON wire-shape normalization. * Tighten tool call argument validation * Simplify tool argument validation tests * Improve tool name suggestions * Simplify tool suggestion helpers * Limit tool suggestions to canonical matches * Allow repair only for tool history replay * Clarify non-object tool argument errors * Inline replay tool argument normalization * Track only successful tool executions * Reject JSON null tool arguments
This commit is contained in:
parent
f3eb2aa08b
commit
0a396aa6e2
@ -399,7 +399,6 @@ class AgentRunner:
|
|||||||
thinking_blocks=response.thinking_blocks,
|
thinking_blocks=response.thinking_blocks,
|
||||||
)
|
)
|
||||||
messages.append(assistant_message)
|
messages.append(assistant_message)
|
||||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
|
||||||
await self._emit_checkpoint(
|
await self._emit_checkpoint(
|
||||||
spec,
|
spec,
|
||||||
{
|
{
|
||||||
@ -421,6 +420,11 @@ class AgentRunner:
|
|||||||
workspace_violation_counts,
|
workspace_violation_counts,
|
||||||
)
|
)
|
||||||
tool_events.extend(new_events)
|
tool_events.extend(new_events)
|
||||||
|
tools_used.extend(
|
||||||
|
tool_call.name
|
||||||
|
for tool_call, event in zip(response.tool_calls, new_events)
|
||||||
|
if event.get("status") == "ok"
|
||||||
|
)
|
||||||
context.tool_results = list(results)
|
context.tool_results = list(results)
|
||||||
context.tool_events = list(new_events)
|
context.tool_events = list(new_events)
|
||||||
completed_tool_results: list[dict[str, Any]] = []
|
completed_tool_results: list[dict[str, Any]] = []
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
"""Tool registry for dynamic tool management."""
|
"""Tool registry for dynamic tool management."""
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
@ -30,6 +31,24 @@ class ToolRegistry:
|
|||||||
"""Get a tool by name."""
|
"""Get a tool by name."""
|
||||||
return self._tools.get(name)
|
return self._tools.get(name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _lookup_key(name: str) -> str:
|
||||||
|
"""Normalize names for suggestions only; never for execution."""
|
||||||
|
return "".join(ch.lower() for ch in name if ch.isalnum())
|
||||||
|
|
||||||
|
def _suggest_name(self, name: str) -> str | None:
|
||||||
|
key = self._lookup_key(str(name or ""))
|
||||||
|
if not key:
|
||||||
|
return None
|
||||||
|
matches = [
|
||||||
|
registered
|
||||||
|
for registered in self._tools
|
||||||
|
if self._lookup_key(registered) == key
|
||||||
|
]
|
||||||
|
if len(matches) == 1:
|
||||||
|
return matches[0]
|
||||||
|
return None
|
||||||
|
|
||||||
def has(self, name: str) -> bool:
|
def has(self, name: str) -> bool:
|
||||||
"""Check if a tool is registered."""
|
"""Check if a tool is registered."""
|
||||||
return name in self._tools
|
return name in self._tools
|
||||||
@ -73,20 +92,23 @@ class ToolRegistry:
|
|||||||
def prepare_call(
|
def prepare_call(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
params: dict[str, Any],
|
params: Any,
|
||||||
) -> tuple[Tool | None, dict[str, Any], str | None]:
|
) -> tuple[Tool | None, Any, str | None]:
|
||||||
"""Resolve, cast, and validate one tool call."""
|
"""Resolve, cast, and validate one tool call."""
|
||||||
# Guard against invalid parameter types (e.g., list instead of dict)
|
|
||||||
if not isinstance(params, dict) and name in ('write_file', 'read_file'):
|
|
||||||
return None, params, (
|
|
||||||
f"Error: Tool '{name}' parameters must be a JSON object, got {type(params).__name__}. "
|
|
||||||
"Use named parameters: tool_name(param1=\"value1\", param2=\"value2\")"
|
|
||||||
)
|
|
||||||
|
|
||||||
tool = self._tools.get(name)
|
tool = self._tools.get(name)
|
||||||
if not tool:
|
if not tool:
|
||||||
|
suggestion = self._suggest_name(str(name))
|
||||||
|
hint = f" Did you mean '{suggestion}'? Tool names must match exactly." if suggestion else ""
|
||||||
return None, params, (
|
return None, params, (
|
||||||
f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
f"Error: Tool '{name}' not found.{hint} Available: {', '.join(self.tool_names)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
params = self._coerce_params(tool, params)
|
||||||
|
if not isinstance(params, dict):
|
||||||
|
return tool, params, (
|
||||||
|
f"Error: Tool '{name}' parameters must be a JSON object, got "
|
||||||
|
f"{type(params).__name__}. Use named parameters like "
|
||||||
|
'tool_name(param1="value1", param2="value2") matching the tool schema.'
|
||||||
)
|
)
|
||||||
|
|
||||||
cast_params = tool.cast_params(params)
|
cast_params = tool.cast_params(params)
|
||||||
@ -97,21 +119,56 @@ class ToolRegistry:
|
|||||||
)
|
)
|
||||||
return tool, cast_params, None
|
return tool, cast_params, None
|
||||||
|
|
||||||
async def execute(self, name: str, params: dict[str, Any]) -> Any:
|
@classmethod
|
||||||
|
def _coerce_argument_value(cls, value: Any) -> Any:
|
||||||
|
if value is None:
|
||||||
|
return {}
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return value
|
||||||
|
|
||||||
|
stripped = value.strip()
|
||||||
|
if not stripped:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if not stripped.startswith(("{", "[")):
|
||||||
|
return value
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = json.loads(stripped)
|
||||||
|
except Exception:
|
||||||
|
return value
|
||||||
|
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _coerce_params(cls, tool: Tool, params: Any) -> Any:
|
||||||
|
params = cls._coerce_argument_value(params)
|
||||||
|
return cls._unwrap_arguments_payload(tool, params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _unwrap_arguments_payload(cls, tool: Tool, params: Any) -> Any:
|
||||||
|
if not isinstance(params, dict) or set(params) != {"arguments"}:
|
||||||
|
return params
|
||||||
|
properties = (tool.parameters or {}).get("properties", {})
|
||||||
|
if isinstance(properties, dict) and "arguments" in properties:
|
||||||
|
return params
|
||||||
|
return cls._coerce_argument_value(params.get("arguments"))
|
||||||
|
|
||||||
|
async def execute(self, name: str, params: Any) -> Any:
|
||||||
"""Execute a tool by name with given parameters."""
|
"""Execute a tool by name with given parameters."""
|
||||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
hint = "\n\n[Analyze the error above and try a different approach.]"
|
||||||
tool, params, error = self.prepare_call(name, params)
|
tool, params, error = self.prepare_call(name, params)
|
||||||
if error:
|
if error:
|
||||||
return error + _HINT
|
return error + hint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert tool is not None # guarded by prepare_call()
|
assert tool is not None # guarded by prepare_call()
|
||||||
result = await tool.execute(**params)
|
result = await tool.execute(**params)
|
||||||
if isinstance(result, str) and result.startswith("Error"):
|
if isinstance(result, str) and result.startswith("Error"):
|
||||||
return result + _HINT
|
return result + hint
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error executing {name}: {str(e)}" + _HINT
|
return f"Error executing {name}: {str(e)}" + hint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tool_names(self) -> list[str]:
|
def tool_names(self) -> list[str]:
|
||||||
|
|||||||
@ -10,9 +10,12 @@ import string
|
|||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import json_repair
|
from nanobot.providers.base import (
|
||||||
|
LLMProvider,
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
LLMResponse,
|
||||||
|
ToolCallRequest,
|
||||||
|
tool_arguments_object_for_replay,
|
||||||
|
)
|
||||||
|
|
||||||
_ALNUM = string.ascii_letters + string.digits
|
_ALNUM = string.ascii_letters + string.digits
|
||||||
|
|
||||||
@ -207,13 +210,11 @@ class AnthropicProvider(LLMProvider):
|
|||||||
continue
|
continue
|
||||||
func = tc.get("function", {})
|
func = tc.get("function", {})
|
||||||
args = func.get("arguments", "{}")
|
args = func.get("arguments", "{}")
|
||||||
if isinstance(args, str):
|
|
||||||
args = json_repair.loads(args)
|
|
||||||
blocks.append({
|
blocks.append({
|
||||||
"type": "tool_use",
|
"type": "tool_use",
|
||||||
"id": tc.get("id") or _gen_tool_id(),
|
"id": tc.get("id") or _gen_tool_id(),
|
||||||
"name": func.get("name", ""),
|
"name": func.get("name", ""),
|
||||||
"input": args,
|
"input": tool_arguments_object_for_replay(args),
|
||||||
})
|
})
|
||||||
|
|
||||||
return blocks or [{"type": "text", "text": ""}]
|
return blocks or [{"type": "text", "text": ""}]
|
||||||
@ -509,7 +510,7 @@ class AnthropicProvider(LLMProvider):
|
|||||||
tool_calls.append(ToolCallRequest(
|
tool_calls.append(ToolCallRequest(
|
||||||
id=block.id,
|
id=block.id,
|
||||||
name=block.name,
|
name=block.name,
|
||||||
arguments=block.input if isinstance(block.input, dict) else {},
|
arguments=block.input,
|
||||||
))
|
))
|
||||||
elif block.type == "thinking":
|
elif block.type == "thinking":
|
||||||
thinking_blocks.append({
|
thinking_blocks.append({
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from datetime import datetime, timezone
|
|||||||
from email.utils import parsedate_to_datetime
|
from email.utils import parsedate_to_datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import json_repair
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.utils.helpers import image_placeholder_text
|
from nanobot.utils.helpers import image_placeholder_text
|
||||||
@ -21,19 +22,24 @@ class ToolCallRequest:
|
|||||||
"""A tool call request from the LLM."""
|
"""A tool call request from the LLM."""
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
arguments: dict[str, Any]
|
arguments: Any
|
||||||
extra_content: dict[str, Any] | None = None
|
extra_content: dict[str, Any] | None = None
|
||||||
provider_specific_fields: dict[str, Any] | None = None
|
provider_specific_fields: dict[str, Any] | None = None
|
||||||
function_provider_specific_fields: dict[str, Any] | None = None
|
function_provider_specific_fields: dict[str, Any] | None = None
|
||||||
|
|
||||||
def to_openai_tool_call(self) -> dict[str, Any]:
|
def to_openai_tool_call(self) -> dict[str, Any]:
|
||||||
"""Serialize to an OpenAI-style tool_call payload."""
|
"""Serialize to an OpenAI-style tool_call payload."""
|
||||||
|
arguments = (
|
||||||
|
self.arguments
|
||||||
|
if isinstance(self.arguments, str)
|
||||||
|
else json.dumps(self.arguments, ensure_ascii=False)
|
||||||
|
)
|
||||||
tool_call = {
|
tool_call = {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
"arguments": arguments,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if self.extra_content:
|
if self.extra_content:
|
||||||
@ -45,6 +51,62 @@ class ToolCallRequest:
|
|||||||
return tool_call
|
return tool_call
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_arguments(arguments: Any) -> Any:
|
||||||
|
"""Parse provider tool arguments without guessing executable parameters.
|
||||||
|
|
||||||
|
Valid JSON object strings become dicts. Empty strings become no-arg calls.
|
||||||
|
Malformed JSON and JSON array/scalar values are preserved so ToolRegistry
|
||||||
|
can reject them before execution.
|
||||||
|
"""
|
||||||
|
if arguments is None:
|
||||||
|
return {}
|
||||||
|
if not isinstance(arguments, str):
|
||||||
|
return arguments
|
||||||
|
|
||||||
|
stripped = arguments.strip()
|
||||||
|
if not stripped:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = json.loads(stripped)
|
||||||
|
except Exception:
|
||||||
|
return arguments
|
||||||
|
return arguments if parsed is None else parsed
|
||||||
|
|
||||||
|
|
||||||
|
def tool_arguments_object_for_replay(arguments: Any) -> dict[str, Any]:
|
||||||
|
"""Return object-shaped arguments for provider history replay only.
|
||||||
|
|
||||||
|
This compatibility path may repair malformed JSON because it only shapes
|
||||||
|
existing conversation history for provider protocols. Do not use it for
|
||||||
|
newly generated tool calls that are about to execute.
|
||||||
|
"""
|
||||||
|
if arguments is None:
|
||||||
|
return {}
|
||||||
|
if isinstance(arguments, dict):
|
||||||
|
return arguments
|
||||||
|
if not isinstance(arguments, str):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
stripped = arguments.strip()
|
||||||
|
if not stripped:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = json.loads(stripped)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
parsed = json_repair.loads(stripped)
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
return parsed if isinstance(parsed, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def tool_arguments_json_for_replay(arguments: Any) -> str:
|
||||||
|
"""Return JSON object string arguments for provider history replay only."""
|
||||||
|
return json.dumps(tool_arguments_object_for_replay(arguments), ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMResponse:
|
class LLMResponse:
|
||||||
"""Response from an LLM provider."""
|
"""Response from an LLM provider."""
|
||||||
|
|||||||
@ -10,9 +10,13 @@ import re
|
|||||||
from collections.abc import Awaitable, Callable, Iterator
|
from collections.abc import Awaitable, Callable, Iterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import json_repair
|
from nanobot.providers.base import (
|
||||||
|
LLMProvider,
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
LLMResponse,
|
||||||
|
ToolCallRequest,
|
||||||
|
parse_tool_arguments,
|
||||||
|
tool_arguments_object_for_replay,
|
||||||
|
)
|
||||||
|
|
||||||
_IMAGE_DATA_URL = re.compile(r"^data:image/([a-zA-Z0-9.+-]+);base64,(.*)$", re.DOTALL)
|
_IMAGE_DATA_URL = re.compile(r"^data:image/([a-zA-Z0-9.+-]+);base64,(.*)$", re.DOTALL)
|
||||||
_TEXT_BLOCK_TYPES = {"text", "input_text", "output_text"}
|
_TEXT_BLOCK_TYPES = {"text", "input_text", "output_text"}
|
||||||
@ -176,14 +180,7 @@ class BedrockProvider(LLMProvider):
|
|||||||
function = tool_call.get("function")
|
function = tool_call.get("function")
|
||||||
if not isinstance(function, dict):
|
if not isinstance(function, dict):
|
||||||
return None
|
return None
|
||||||
args = function.get("arguments", {})
|
args = tool_arguments_object_for_replay(function.get("arguments", {}))
|
||||||
if isinstance(args, str):
|
|
||||||
try:
|
|
||||||
args = json_repair.loads(args) if args.strip() else {}
|
|
||||||
except Exception:
|
|
||||||
args = {}
|
|
||||||
if not isinstance(args, dict):
|
|
||||||
args = {}
|
|
||||||
return {
|
return {
|
||||||
"toolUse": {
|
"toolUse": {
|
||||||
"toolUseId": str(tool_call.get("id") or ""),
|
"toolUseId": str(tool_call.get("id") or ""),
|
||||||
@ -491,7 +488,7 @@ class BedrockProvider(LLMProvider):
|
|||||||
content_parts.append(block["text"])
|
content_parts.append(block["text"])
|
||||||
tool_use = block.get("toolUse")
|
tool_use = block.get("toolUse")
|
||||||
if isinstance(tool_use, dict):
|
if isinstance(tool_use, dict):
|
||||||
arguments = tool_use.get("input") if isinstance(tool_use.get("input"), dict) else {}
|
arguments = tool_use.get("input", {})
|
||||||
tool_calls.append(ToolCallRequest(
|
tool_calls.append(ToolCallRequest(
|
||||||
id=str(tool_use.get("toolUseId") or ""),
|
id=str(tool_use.get("toolUseId") or ""),
|
||||||
name=str(tool_use.get("name") or ""),
|
name=str(tool_use.get("name") or ""),
|
||||||
@ -616,14 +613,11 @@ class BedrockProvider(LLMProvider):
|
|||||||
for buf in tool_buffers.values():
|
for buf in tool_buffers.values():
|
||||||
args: Any = {}
|
args: Any = {}
|
||||||
if buf.get("input"):
|
if buf.get("input"):
|
||||||
try:
|
args = parse_tool_arguments(buf["input"])
|
||||||
args = json_repair.loads(buf["input"])
|
|
||||||
except Exception:
|
|
||||||
args = {}
|
|
||||||
tool_calls.append(ToolCallRequest(
|
tool_calls.append(ToolCallRequest(
|
||||||
id=buf.get("id") or "",
|
id=buf.get("id") or "",
|
||||||
name=buf.get("name") or "",
|
name=buf.get("name") or "",
|
||||||
arguments=args if isinstance(args, dict) else {},
|
arguments=args,
|
||||||
))
|
))
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content="".join(content_parts) or None,
|
content="".join(content_parts) or None,
|
||||||
|
|||||||
@ -17,10 +17,15 @@ from ipaddress import ip_address
|
|||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import json_repair
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import (
|
||||||
|
LLMProvider,
|
||||||
|
LLMResponse,
|
||||||
|
ToolCallRequest,
|
||||||
|
parse_tool_arguments,
|
||||||
|
tool_arguments_json_for_replay,
|
||||||
|
)
|
||||||
from nanobot.providers.openai_responses import (
|
from nanobot.providers.openai_responses import (
|
||||||
consume_sdk_stream,
|
consume_sdk_stream,
|
||||||
convert_messages,
|
convert_messages,
|
||||||
@ -478,24 +483,6 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
"""Return True for providers that reject normal OpenAI tool call IDs."""
|
"""Return True for providers that reject normal OpenAI tool call IDs."""
|
||||||
return bool(self._spec and self._spec.name == "mistral")
|
return bool(self._spec and self._spec.name == "mistral")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _normalize_tool_call_arguments(arguments: Any) -> str:
|
|
||||||
"""Force function.arguments into a valid JSON object string."""
|
|
||||||
if isinstance(arguments, str):
|
|
||||||
stripped = arguments.strip()
|
|
||||||
if not stripped:
|
|
||||||
return "{}"
|
|
||||||
try:
|
|
||||||
parsed = json_repair.loads(stripped)
|
|
||||||
except Exception:
|
|
||||||
return "{}"
|
|
||||||
if isinstance(parsed, dict):
|
|
||||||
return json.dumps(parsed, ensure_ascii=False)
|
|
||||||
return "{}"
|
|
||||||
if isinstance(arguments, dict):
|
|
||||||
return json.dumps(arguments, ensure_ascii=False)
|
|
||||||
return "{}"
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _coerce_content_to_string(content: Any) -> str | None:
|
def _coerce_content_to_string(content: Any) -> str | None:
|
||||||
"""Coerce block/list content into plain text for strict string-only APIs."""
|
"""Coerce block/list content into plain text for strict string-only APIs."""
|
||||||
@ -572,7 +559,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
if isinstance(function, dict):
|
if isinstance(function, dict):
|
||||||
function_clean = dict(function)
|
function_clean = dict(function)
|
||||||
if "arguments" in function_clean:
|
if "arguments" in function_clean:
|
||||||
function_clean["arguments"] = self._normalize_tool_call_arguments(
|
function_clean["arguments"] = tool_arguments_json_for_replay(
|
||||||
function_clean.get("arguments")
|
function_clean.get("arguments")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -1021,14 +1008,12 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
for tc in raw_tool_calls:
|
for tc in raw_tool_calls:
|
||||||
tc_map = self._maybe_mapping(tc) or {}
|
tc_map = self._maybe_mapping(tc) or {}
|
||||||
fn = self._maybe_mapping(tc_map.get("function")) or {}
|
fn = self._maybe_mapping(tc_map.get("function")) or {}
|
||||||
args = fn.get("arguments", {})
|
args = parse_tool_arguments(fn.get("arguments", {}))
|
||||||
if isinstance(args, str):
|
|
||||||
args = json_repair.loads(args)
|
|
||||||
ec, prov, fn_prov = _extract_tc_extras(tc)
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||||
parsed_tool_calls.append(ToolCallRequest(
|
parsed_tool_calls.append(ToolCallRequest(
|
||||||
id=str(tc_map.get("id") or _short_tool_id()),
|
id=str(tc_map.get("id") or _short_tool_id()),
|
||||||
name=str(fn.get("name") or ""),
|
name=str(fn.get("name") or ""),
|
||||||
arguments=args if isinstance(args, dict) else {},
|
arguments=args,
|
||||||
extra_content=ec,
|
extra_content=ec,
|
||||||
provider_specific_fields=prov,
|
provider_specific_fields=prov,
|
||||||
function_provider_specific_fields=fn_prov,
|
function_provider_specific_fields=fn_prov,
|
||||||
@ -1064,9 +1049,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for tc in raw_tool_calls:
|
for tc in raw_tool_calls:
|
||||||
args = tc.function.arguments
|
args = parse_tool_arguments(tc.function.arguments)
|
||||||
if isinstance(args, str):
|
|
||||||
args = json_repair.loads(args)
|
|
||||||
ec, prov, fn_prov = _extract_tc_extras(tc)
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||||
tool_calls.append(ToolCallRequest(
|
tool_calls.append(ToolCallRequest(
|
||||||
id=str(getattr(tc, "id", None) or _short_tool_id()),
|
id=str(getattr(tc, "id", None) or _short_tool_id()),
|
||||||
@ -1207,7 +1190,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
ToolCallRequest(
|
ToolCallRequest(
|
||||||
id=b["id"] or _short_tool_id(),
|
id=b["id"] or _short_tool_id(),
|
||||||
name=b["name"],
|
name=b["name"],
|
||||||
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
|
arguments=parse_tool_arguments(b["arguments"]),
|
||||||
extra_content=b.get("extra_content"),
|
extra_content=b.get("extra_content"),
|
||||||
provider_specific_fields=b.get("prov"),
|
provider_specific_fields=b.get("prov"),
|
||||||
function_provider_specific_fields=b.get("fn_prov"),
|
function_provider_specific_fields=b.get("fn_prov"),
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from nanobot.providers.base import tool_arguments_json_for_replay
|
||||||
|
|
||||||
|
|
||||||
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||||
"""Convert Chat Completions messages to Responses API input items.
|
"""Convert Chat Completions messages to Responses API input items.
|
||||||
@ -46,7 +48,7 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str
|
|||||||
"id": response_item_id,
|
"id": response_item_id,
|
||||||
"call_id": call_id or f"call_{idx}",
|
"call_id": call_id or f"call_{idx}",
|
||||||
"name": fn.get("name"),
|
"name": fn.get("name"),
|
||||||
"arguments": fn.get("arguments") or "{}",
|
"arguments": tool_arguments_json_for_replay(fn.get("arguments")),
|
||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -7,10 +7,9 @@ from collections.abc import Awaitable, Callable
|
|||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import json_repair
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMResponse, ToolCallRequest, parse_tool_arguments
|
||||||
|
|
||||||
FINISH_REASON_MAP = {
|
FINISH_REASON_MAP = {
|
||||||
"completed": "stop",
|
"completed": "stop",
|
||||||
@ -44,6 +43,27 @@ def _usage_from_response_obj(response: Any) -> dict[str, int]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_tool_call_arguments(args_raw: Any, name: str | None) -> Any:
|
||||||
|
parsed = parse_tool_arguments(args_raw)
|
||||||
|
if parsed == args_raw and isinstance(args_raw, str) and args_raw.strip():
|
||||||
|
logger.warning(
|
||||||
|
"Failed to parse tool call arguments for '{}': {}",
|
||||||
|
name,
|
||||||
|
args_raw[:200],
|
||||||
|
)
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_arguments_source(*values: Any) -> Any:
|
||||||
|
for value in values:
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
if isinstance(value, str) and not value.strip():
|
||||||
|
continue
|
||||||
|
return value
|
||||||
|
return "{}"
|
||||||
|
|
||||||
|
|
||||||
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
"""Yield parsed JSON events from a Responses API SSE stream."""
|
"""Yield parsed JSON events from a Responses API SSE stream."""
|
||||||
buffer: list[str] = []
|
buffer: list[str] = []
|
||||||
@ -116,10 +136,11 @@ async def consume_sse_with_reasoning(
|
|||||||
call_id = item.get("call_id")
|
call_id = item.get("call_id")
|
||||||
if not call_id:
|
if not call_id:
|
||||||
continue
|
continue
|
||||||
|
arguments = item.get("arguments")
|
||||||
tool_call_buffers[call_id] = {
|
tool_call_buffers[call_id] = {
|
||||||
"id": item.get("id") or "fc_0",
|
"id": item.get("id") or "fc_0",
|
||||||
"name": item.get("name"),
|
"name": item.get("name"),
|
||||||
"arguments": item.get("arguments") or "",
|
"arguments": "" if arguments is None else arguments,
|
||||||
}
|
}
|
||||||
if on_tool_call_delta:
|
if on_tool_call_delta:
|
||||||
await on_tool_call_delta({
|
await on_tool_call_delta({
|
||||||
@ -156,7 +177,10 @@ async def consume_sse_with_reasoning(
|
|||||||
call_id = event.get("call_id")
|
call_id = event.get("call_id")
|
||||||
if call_id and call_id in tool_call_buffers:
|
if call_id and call_id in tool_call_buffers:
|
||||||
delta = event.get("delta") or ""
|
delta = event.get("delta") or ""
|
||||||
tool_call_buffers[call_id]["arguments"] += delta
|
current = tool_call_buffers[call_id].get("arguments")
|
||||||
|
if not isinstance(current, str):
|
||||||
|
current = ""
|
||||||
|
tool_call_buffers[call_id]["arguments"] = current + delta
|
||||||
if on_tool_call_delta and delta:
|
if on_tool_call_delta and delta:
|
||||||
await on_tool_call_delta({
|
await on_tool_call_delta({
|
||||||
"call_id": str(call_id),
|
"call_id": str(call_id),
|
||||||
@ -166,14 +190,14 @@ async def consume_sse_with_reasoning(
|
|||||||
elif event_type == "response.function_call_arguments.done":
|
elif event_type == "response.function_call_arguments.done":
|
||||||
call_id = event.get("call_id")
|
call_id = event.get("call_id")
|
||||||
if call_id and call_id in tool_call_buffers:
|
if call_id and call_id in tool_call_buffers:
|
||||||
arguments = event.get("arguments") or ""
|
arguments = event.get("arguments")
|
||||||
tool_call_buffers[call_id]["arguments"] = arguments
|
tool_call_buffers[call_id]["arguments"] = arguments
|
||||||
if on_tool_call_delta:
|
if on_tool_call_delta:
|
||||||
tool_call_args_emitted.add(str(call_id))
|
tool_call_args_emitted.add(str(call_id))
|
||||||
await on_tool_call_delta({
|
await on_tool_call_delta({
|
||||||
"call_id": str(call_id),
|
"call_id": str(call_id),
|
||||||
"name": str(tool_call_buffers[call_id].get("name") or ""),
|
"name": str(tool_call_buffers[call_id].get("name") or ""),
|
||||||
"arguments": str(arguments),
|
"arguments": "" if arguments is None else str(arguments),
|
||||||
})
|
})
|
||||||
elif event_type == "response.output_item.done":
|
elif event_type == "response.output_item.done":
|
||||||
item = event.get("item") or {}
|
item = event.get("item") or {}
|
||||||
@ -182,7 +206,7 @@ async def consume_sse_with_reasoning(
|
|||||||
if not call_id:
|
if not call_id:
|
||||||
continue
|
continue
|
||||||
buf = tool_call_buffers.get(call_id) or {}
|
buf = tool_call_buffers.get(call_id) or {}
|
||||||
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
args_raw = _tool_arguments_source(buf.get("arguments"), item.get("arguments"))
|
||||||
if on_tool_call_delta and str(call_id) not in tool_call_args_emitted:
|
if on_tool_call_delta and str(call_id) not in tool_call_args_emitted:
|
||||||
tool_call_args_emitted.add(str(call_id))
|
tool_call_args_emitted.add(str(call_id))
|
||||||
await on_tool_call_delta({
|
await on_tool_call_delta({
|
||||||
@ -190,17 +214,10 @@ async def consume_sse_with_reasoning(
|
|||||||
"name": str(buf.get("name") or item.get("name") or ""),
|
"name": str(buf.get("name") or item.get("name") or ""),
|
||||||
"arguments": str(args_raw),
|
"arguments": str(args_raw),
|
||||||
})
|
})
|
||||||
try:
|
args = _parse_tool_call_arguments(
|
||||||
args = json.loads(args_raw)
|
args_raw,
|
||||||
except Exception:
|
buf.get("name") or item.get("name"),
|
||||||
logger.warning(
|
)
|
||||||
"Failed to parse tool call arguments for '{}': {}",
|
|
||||||
buf.get("name") or item.get("name"),
|
|
||||||
args_raw[:200],
|
|
||||||
)
|
|
||||||
args = json_repair.loads(args_raw)
|
|
||||||
if not isinstance(args, dict):
|
|
||||||
args = {"raw": args_raw}
|
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCallRequest(
|
ToolCallRequest(
|
||||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||||
@ -283,22 +300,12 @@ def parse_response_output(response: Any) -> LLMResponse:
|
|||||||
elif item_type == "function_call":
|
elif item_type == "function_call":
|
||||||
call_id = item.get("call_id") or ""
|
call_id = item.get("call_id") or ""
|
||||||
item_id = item.get("id") or "fc_0"
|
item_id = item.get("id") or "fc_0"
|
||||||
args_raw = item.get("arguments") or "{}"
|
args_raw = _tool_arguments_source(item.get("arguments"))
|
||||||
try:
|
args = _parse_tool_call_arguments(args_raw, item.get("name"))
|
||||||
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to parse tool call arguments for '{}': {}",
|
|
||||||
item.get("name"),
|
|
||||||
str(args_raw)[:200],
|
|
||||||
)
|
|
||||||
args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
|
||||||
if not isinstance(args, dict):
|
|
||||||
args = {"raw": args_raw}
|
|
||||||
tool_calls.append(ToolCallRequest(
|
tool_calls.append(ToolCallRequest(
|
||||||
id=f"{call_id}|{item_id}",
|
id=f"{call_id}|{item_id}",
|
||||||
name=item.get("name") or "",
|
name=item.get("name") or "",
|
||||||
arguments=args if isinstance(args, dict) else {},
|
arguments=args,
|
||||||
))
|
))
|
||||||
|
|
||||||
usage = _usage_from_response_obj(response)
|
usage = _usage_from_response_obj(response)
|
||||||
@ -337,10 +344,11 @@ async def consume_sdk_stream(
|
|||||||
call_id = getattr(item, "call_id", None)
|
call_id = getattr(item, "call_id", None)
|
||||||
if not call_id:
|
if not call_id:
|
||||||
continue
|
continue
|
||||||
|
arguments = getattr(item, "arguments", None)
|
||||||
tool_call_buffers[call_id] = {
|
tool_call_buffers[call_id] = {
|
||||||
"id": getattr(item, "id", None) or "fc_0",
|
"id": getattr(item, "id", None) or "fc_0",
|
||||||
"name": getattr(item, "name", None),
|
"name": getattr(item, "name", None),
|
||||||
"arguments": getattr(item, "arguments", None) or "",
|
"arguments": "" if arguments is None else arguments,
|
||||||
}
|
}
|
||||||
if on_tool_call_delta:
|
if on_tool_call_delta:
|
||||||
await on_tool_call_delta({
|
await on_tool_call_delta({
|
||||||
@ -357,7 +365,10 @@ async def consume_sdk_stream(
|
|||||||
call_id = getattr(event, "call_id", None)
|
call_id = getattr(event, "call_id", None)
|
||||||
if call_id and call_id in tool_call_buffers:
|
if call_id and call_id in tool_call_buffers:
|
||||||
delta = getattr(event, "delta", "") or ""
|
delta = getattr(event, "delta", "") or ""
|
||||||
tool_call_buffers[call_id]["arguments"] += delta
|
current = tool_call_buffers[call_id].get("arguments")
|
||||||
|
if not isinstance(current, str):
|
||||||
|
current = ""
|
||||||
|
tool_call_buffers[call_id]["arguments"] = current + delta
|
||||||
if on_tool_call_delta and delta:
|
if on_tool_call_delta and delta:
|
||||||
await on_tool_call_delta({
|
await on_tool_call_delta({
|
||||||
"call_id": str(call_id),
|
"call_id": str(call_id),
|
||||||
@ -367,14 +378,14 @@ async def consume_sdk_stream(
|
|||||||
elif event_type == "response.function_call_arguments.done":
|
elif event_type == "response.function_call_arguments.done":
|
||||||
call_id = getattr(event, "call_id", None)
|
call_id = getattr(event, "call_id", None)
|
||||||
if call_id and call_id in tool_call_buffers:
|
if call_id and call_id in tool_call_buffers:
|
||||||
arguments = getattr(event, "arguments", "") or ""
|
arguments = getattr(event, "arguments", None)
|
||||||
tool_call_buffers[call_id]["arguments"] = arguments
|
tool_call_buffers[call_id]["arguments"] = arguments
|
||||||
if on_tool_call_delta:
|
if on_tool_call_delta:
|
||||||
tool_call_args_emitted.add(str(call_id))
|
tool_call_args_emitted.add(str(call_id))
|
||||||
await on_tool_call_delta({
|
await on_tool_call_delta({
|
||||||
"call_id": str(call_id),
|
"call_id": str(call_id),
|
||||||
"name": str(tool_call_buffers[call_id].get("name") or ""),
|
"name": str(tool_call_buffers[call_id].get("name") or ""),
|
||||||
"arguments": str(arguments),
|
"arguments": "" if arguments is None else str(arguments),
|
||||||
})
|
})
|
||||||
elif event_type == "response.output_item.done":
|
elif event_type == "response.output_item.done":
|
||||||
item = getattr(event, "item", None)
|
item = getattr(event, "item", None)
|
||||||
@ -383,7 +394,10 @@ async def consume_sdk_stream(
|
|||||||
if not call_id:
|
if not call_id:
|
||||||
continue
|
continue
|
||||||
buf = tool_call_buffers.get(call_id) or {}
|
buf = tool_call_buffers.get(call_id) or {}
|
||||||
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
|
args_raw = _tool_arguments_source(
|
||||||
|
buf.get("arguments"),
|
||||||
|
getattr(item, "arguments", None),
|
||||||
|
)
|
||||||
if on_tool_call_delta and str(call_id) not in tool_call_args_emitted:
|
if on_tool_call_delta and str(call_id) not in tool_call_args_emitted:
|
||||||
tool_call_args_emitted.add(str(call_id))
|
tool_call_args_emitted.add(str(call_id))
|
||||||
await on_tool_call_delta({
|
await on_tool_call_delta({
|
||||||
@ -391,17 +405,10 @@ async def consume_sdk_stream(
|
|||||||
"name": str(buf.get("name") or getattr(item, "name", None) or ""),
|
"name": str(buf.get("name") or getattr(item, "name", None) or ""),
|
||||||
"arguments": str(args_raw),
|
"arguments": str(args_raw),
|
||||||
})
|
})
|
||||||
try:
|
args = _parse_tool_call_arguments(
|
||||||
args = json.loads(args_raw)
|
args_raw,
|
||||||
except Exception:
|
buf.get("name") or getattr(item, "name", None),
|
||||||
logger.warning(
|
)
|
||||||
"Failed to parse tool call arguments for '{}': {}",
|
|
||||||
buf.get("name") or getattr(item, "name", None),
|
|
||||||
str(args_raw)[:200],
|
|
||||||
)
|
|
||||||
args = json_repair.loads(args_raw)
|
|
||||||
if not isinstance(args, dict):
|
|
||||||
args = {"raw": args_raw}
|
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCallRequest(
|
ToolCallRequest(
|
||||||
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
|
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
|
||||||
|
|||||||
@ -49,13 +49,18 @@ async def invoke_file_edit_progress(
|
|||||||
await on_progress("", file_edit_events=file_edit_events)
|
await on_progress("", file_edit_events=file_edit_events)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_event_arguments(tool_call: Any) -> dict[str, Any]:
|
||||||
|
arguments = getattr(tool_call, "arguments", {}) or {}
|
||||||
|
return arguments if isinstance(arguments, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
def build_tool_event_start_payload(tool_call: Any) -> dict[str, Any]:
|
def build_tool_event_start_payload(tool_call: Any) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"version": 1,
|
"version": 1,
|
||||||
"phase": "start",
|
"phase": "start",
|
||||||
"call_id": str(getattr(tool_call, "id", "") or ""),
|
"call_id": str(getattr(tool_call, "id", "") or ""),
|
||||||
"name": getattr(tool_call, "name", ""),
|
"name": getattr(tool_call, "name", ""),
|
||||||
"arguments": getattr(tool_call, "arguments", {}) or {},
|
"arguments": _tool_event_arguments(tool_call),
|
||||||
"result": None,
|
"result": None,
|
||||||
"error": None,
|
"error": None,
|
||||||
"files": [],
|
"files": [],
|
||||||
@ -86,7 +91,7 @@ def build_tool_event_finish_payloads(context: AgentHookContext) -> list[dict[str
|
|||||||
"phase": phase,
|
"phase": phase,
|
||||||
"call_id": str(getattr(tool_call, "id", "") or ""),
|
"call_id": str(getattr(tool_call, "id", "") or ""),
|
||||||
"name": getattr(tool_call, "name", ""),
|
"name": getattr(tool_call, "name", ""),
|
||||||
"arguments": getattr(tool_call, "arguments", {}) or {},
|
"arguments": _tool_event_arguments(tool_call),
|
||||||
"result": result if phase == "end" else None,
|
"result": result if phase == "end" else None,
|
||||||
"error": None,
|
"error": None,
|
||||||
"files": files,
|
"files": files,
|
||||||
|
|||||||
@ -75,8 +75,10 @@ def build_goal_continue_message(custom: str | None = None) -> dict[str, str]:
|
|||||||
return {"role": "user", "content": custom or SUSTAINED_GOAL_CONTINUE_PROMPT}
|
return {"role": "user", "content": custom or SUSTAINED_GOAL_CONTINUE_PROMPT}
|
||||||
|
|
||||||
|
|
||||||
def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str | None:
|
def external_lookup_signature(tool_name: str, arguments: Any) -> str | None:
|
||||||
"""Stable signature for repeated external lookups we want to throttle."""
|
"""Stable signature for repeated external lookups we want to throttle."""
|
||||||
|
if not isinstance(arguments, dict):
|
||||||
|
return None
|
||||||
if tool_name == "web_fetch":
|
if tool_name == "web_fetch":
|
||||||
url = str(arguments.get("url") or "").strip()
|
url = str(arguments.get("url") or "").strip()
|
||||||
if url:
|
if url:
|
||||||
@ -90,7 +92,7 @@ def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str
|
|||||||
|
|
||||||
def repeated_external_lookup_error(
|
def repeated_external_lookup_error(
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
arguments: dict[str, Any],
|
arguments: Any,
|
||||||
seen_counts: dict[str, int],
|
seen_counts: dict[str, int],
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Block repeated external lookups after a small retry budget."""
|
"""Block repeated external lookups after a small retry budget."""
|
||||||
@ -119,9 +121,11 @@ _OUTSIDE_PATH_PATTERN = re.compile(r"(?:^|[\s|>'\"])((?:/[^\s\"'>;|<]+)|(?:~[^\s
|
|||||||
|
|
||||||
def workspace_violation_signature(
|
def workspace_violation_signature(
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
arguments: dict[str, Any],
|
arguments: Any,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Return a stable cross-tool signature for the outside-workspace target."""
|
"""Return a stable cross-tool signature for the outside-workspace target."""
|
||||||
|
if not isinstance(arguments, dict):
|
||||||
|
return None
|
||||||
for key in ("path", "file_path", "target", "source", "destination"):
|
for key in ("path", "file_path", "target", "source", "destination"):
|
||||||
val = arguments.get(key)
|
val = arguments.get(key)
|
||||||
if isinstance(val, str) and val.strip():
|
if isinstance(val, str) and val.strip():
|
||||||
@ -151,7 +155,7 @@ def _normalize_violation_target(raw: str) -> str:
|
|||||||
|
|
||||||
def repeated_workspace_violation_error(
|
def repeated_workspace_violation_error(
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
arguments: dict[str, Any],
|
arguments: Any,
|
||||||
seen_counts: dict[str, int],
|
seen_counts: dict[str, int],
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Return an escalated error after repeated bypass attempts."""
|
"""Return an escalated error after repeated bypass attempts."""
|
||||||
|
|||||||
@ -3,17 +3,21 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.config.schema import AgentDefaults
|
from nanobot.config.schema import AgentDefaults
|
||||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
from nanobot.providers.openai_responses.parsing import parse_response_output
|
||||||
|
|
||||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
class _DelayTool(Tool):
|
class _DelayTool(Tool):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -57,10 +61,45 @@ class _DelayTool(Tool):
|
|||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_optional_tool_response(response: LLMResponse):
|
||||||
|
provider = MagicMock()
|
||||||
|
calls = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
calls["n"] += 1
|
||||||
|
if calls["n"] == 1:
|
||||||
|
return response
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = ToolRegistry()
|
||||||
|
shared_events: list[str] = []
|
||||||
|
tools.register(_DelayTool(
|
||||||
|
"optional_tool",
|
||||||
|
delay=0,
|
||||||
|
read_only=True,
|
||||||
|
shared_events=shared_events,
|
||||||
|
))
|
||||||
|
|
||||||
|
result = await AgentRunner(provider).run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "try optional"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
return result, shared_events
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_message(result, tool_call_id: str) -> dict:
|
||||||
|
return [
|
||||||
|
msg for msg in result.messages
|
||||||
|
if msg.get("role") == "tool" and msg.get("tool_call_id") == tool_call_id
|
||||||
|
][0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_batches_read_only_tools_before_exclusive_work():
|
async def test_runner_batches_read_only_tools_before_exclusive_work():
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
|
||||||
|
|
||||||
tools = ToolRegistry()
|
tools = ToolRegistry()
|
||||||
shared_events: list[str] = []
|
shared_events: list[str] = []
|
||||||
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
|
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
|
||||||
@ -98,8 +137,6 @@ async def test_runner_batches_read_only_tools_before_exclusive_work():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_does_not_batch_exclusive_read_only_tools():
|
async def test_runner_does_not_batch_exclusive_read_only_tools():
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
|
||||||
|
|
||||||
tools = ToolRegistry()
|
tools = ToolRegistry()
|
||||||
shared_events: list[str] = []
|
shared_events: list[str] = []
|
||||||
read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events)
|
read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events)
|
||||||
@ -140,9 +177,151 @@ async def test_runner_does_not_batch_exclusive_read_only_tools():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_blocks_repeated_external_fetches():
|
async def test_runner_rejects_near_miss_tool_name_without_executing():
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
provider = MagicMock()
|
||||||
|
call_count = {"n": 0}
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="readFile",
|
||||||
|
arguments={"path": "notes.txt"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = ToolRegistry()
|
||||||
|
shared_events: list[str] = []
|
||||||
|
tools.register(_DelayTool(
|
||||||
|
"read_file",
|
||||||
|
delay=0,
|
||||||
|
read_only=True,
|
||||||
|
shared_events=shared_events,
|
||||||
|
))
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "read notes"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
assert result.tools_used == []
|
||||||
|
assert shared_events == []
|
||||||
|
assistant_message = [
|
||||||
|
msg for msg in result.messages
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||||
|
][0]
|
||||||
|
assert assistant_message["tool_calls"][0]["function"]["name"] == "readFile"
|
||||||
|
tool_message = [
|
||||||
|
msg for msg in result.messages
|
||||||
|
if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_1"
|
||||||
|
][0]
|
||||||
|
assert tool_message["name"] == "readFile"
|
||||||
|
assert "Tool 'readFile' not found" in tool_message["content"]
|
||||||
|
assert "Did you mean 'read_file'?" in tool_message["content"]
|
||||||
|
replayed_assistant = [
|
||||||
|
msg for msg in captured_second_call
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||||
|
][0]
|
||||||
|
assert replayed_assistant["tool_calls"][0]["function"]["name"] == "readFile"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("arguments", ['{path:"notes.txt"}', "null"])
|
||||||
|
async def test_runner_rejects_openai_compat_invalid_arguments_without_executing(arguments):
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
parsed = OpenAICompatProvider()._parse({
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "optional_tool",
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
}],
|
||||||
|
"usage": {},
|
||||||
|
})
|
||||||
|
|
||||||
|
result, shared_events = await _run_optional_tool_response(parsed)
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
assert parsed.tool_calls[0].arguments == arguments
|
||||||
|
assert result.tools_used == []
|
||||||
|
assert shared_events == []
|
||||||
|
tool_message = _tool_message(result, "call_1")
|
||||||
|
assert "parameters must be a JSON object" in tool_message["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_rejects_openai_responses_malformed_arguments_without_executing():
|
||||||
|
parsed = parse_response_output({
|
||||||
|
"output": [{
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "call_1",
|
||||||
|
"id": "fc_1",
|
||||||
|
"name": "optional_tool",
|
||||||
|
"arguments": "{bad",
|
||||||
|
}],
|
||||||
|
"status": "completed",
|
||||||
|
"usage": {},
|
||||||
|
})
|
||||||
|
|
||||||
|
result, shared_events = await _run_optional_tool_response(parsed)
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
assert parsed.tool_calls[0].arguments == "{bad"
|
||||||
|
assert result.tools_used == []
|
||||||
|
assert shared_events == []
|
||||||
|
tool_message = _tool_message(result, "call_1|fc_1")
|
||||||
|
assert "parameters must be a JSON object" in tool_message["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_rejects_openai_responses_array_arguments_without_executing():
|
||||||
|
parsed = parse_response_output({
|
||||||
|
"output": [{
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "call_1",
|
||||||
|
"id": "fc_1",
|
||||||
|
"name": "optional_tool",
|
||||||
|
"arguments": [],
|
||||||
|
}],
|
||||||
|
"status": "completed",
|
||||||
|
"usage": {},
|
||||||
|
})
|
||||||
|
|
||||||
|
result, shared_events = await _run_optional_tool_response(parsed)
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
assert parsed.tool_calls[0].arguments == []
|
||||||
|
assert result.tools_used == []
|
||||||
|
assert shared_events == []
|
||||||
|
tool_message = _tool_message(result, "call_1|fc_1")
|
||||||
|
assert "parameters must be a JSON object" in tool_message["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_blocks_repeated_external_fetches():
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
captured_final_call: list[dict] = []
|
captured_final_call: list[dict] = []
|
||||||
call_count = {"n": 0}
|
call_count = {"n": 0}
|
||||||
|
|||||||
@ -80,3 +80,17 @@ def test_convert_user_content_coerces_mixed_typeless():
|
|||||||
])
|
])
|
||||||
assert result[0] == {"type": "text", "text": "42"}
|
assert result[0] == {"type": "text", "text": "42"}
|
||||||
assert result[1] == {"type": "text", "text": str({"key": "val"})}
|
assert result[1] == {"type": "text", "text": str({"key": "val"})}
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_assistant_message_repairs_history_tool_arguments():
|
||||||
|
blocks = AnthropicProvider._assistant_blocks({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "toolu_1",
|
||||||
|
"function": {"name": "read_file", "arguments": '{path:"foo.txt"}'},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
|
||||||
|
assert blocks[0]["type"] == "tool_use"
|
||||||
|
assert blocks[0]["input"] == {"path": "foo.txt"}
|
||||||
|
|||||||
@ -161,6 +161,16 @@ def test_build_kwargs_converts_messages_tools_and_tool_results() -> None:
|
|||||||
assert kwargs["toolConfig"]["toolChoice"] == {"any": {}}
|
assert kwargs["toolConfig"]["toolChoice"] == {"any": {}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_use_block_repairs_history_tool_arguments() -> None:
|
||||||
|
block = BedrockProvider._tool_use_block({
|
||||||
|
"id": "toolu_1",
|
||||||
|
"function": {"name": "read_file", "arguments": '{path:"foo.txt"}'},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert block is not None
|
||||||
|
assert block["toolUse"]["input"] == {"path": "foo.txt"}
|
||||||
|
|
||||||
|
|
||||||
def test_build_kwargs_keeps_tool_config_for_historical_tool_blocks_without_tools() -> None:
|
def test_build_kwargs_keeps_tool_config_for_historical_tool_blocks_without_tools() -> None:
|
||||||
provider = BedrockProvider(region="us-east-1", client=FakeClient())
|
provider = BedrockProvider(region="us-east-1", client=FakeClient())
|
||||||
messages = [
|
messages = [
|
||||||
|
|||||||
@ -54,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace:
|
|||||||
return SimpleNamespace(choices=[choice], usage=usage)
|
return SimpleNamespace(choices=[choice], usage=usage)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_tool_call_response_with_arguments(arguments) -> SimpleNamespace:
|
||||||
|
"""Build a minimal chat response with caller-supplied tool arguments."""
|
||||||
|
function = SimpleNamespace(name="optional_tool", arguments=arguments)
|
||||||
|
tool_call = SimpleNamespace(id="call_123", type="function", function=function)
|
||||||
|
message = SimpleNamespace(content=None, tool_calls=[tool_call], reasoning_content=None)
|
||||||
|
choice = SimpleNamespace(message=message, finish_reason="tool_calls")
|
||||||
|
return SimpleNamespace(choices=[choice], usage=SimpleNamespace())
|
||||||
|
|
||||||
|
|
||||||
def _fake_responses_response(content: str = "ok") -> MagicMock:
|
def _fake_responses_response(content: str = "ok") -> MagicMock:
|
||||||
"""Build a minimal Responses API response object."""
|
"""Build a minimal Responses API response object."""
|
||||||
resp = MagicMock()
|
resp = MagicMock()
|
||||||
@ -611,6 +620,24 @@ async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None:
|
|||||||
assert serialized["function"]["provider_specific_fields"] == {"inner": "value"}
|
assert serialized["function"]["provider_specific_fields"] == {"inner": "value"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_parse_preserves_malformed_tool_arguments() -> None:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
result = provider._parse(_fake_tool_call_response_with_arguments('{path:"foo.txt"}'))
|
||||||
|
|
||||||
|
assert result.tool_calls[0].arguments == '{path:"foo.txt"}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_parse_preserves_array_tool_arguments() -> None:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
result = provider._parse(_fake_tool_call_response_with_arguments('["foo.txt"]'))
|
||||||
|
|
||||||
|
assert result.tool_calls[0].arguments == ["foo.txt"]
|
||||||
|
|
||||||
|
|
||||||
def test_openai_model_passthrough() -> None:
|
def test_openai_model_passthrough() -> None:
|
||||||
"""OpenAI models pass through unchanged."""
|
"""OpenAI models pass through unchanged."""
|
||||||
spec = find_by_name("openai")
|
spec = find_by_name("openai")
|
||||||
@ -1110,7 +1137,7 @@ def test_openai_compat_stringifies_dict_tool_arguments() -> None:
|
|||||||
assert sanitized[1]["tool_calls"][0]["function"]["arguments"] == '{"cmd": "ls -la"}'
|
assert sanitized[1]["tool_calls"][0]["function"]["arguments"] == '{"cmd": "ls -la"}'
|
||||||
|
|
||||||
|
|
||||||
def test_openai_compat_repairs_non_json_tool_arguments_string() -> None:
|
def test_openai_compat_repairs_object_like_history_tool_arguments_string() -> None:
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
provider = OpenAICompatProvider()
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
|||||||
@ -155,6 +155,19 @@ class TestConvertMessages:
|
|||||||
assert items[0]["call_id"] == "call_abc"
|
assert items[0]["call_id"] == "call_abc"
|
||||||
assert items[0]["id"] == "fc_1"
|
assert items[0]["id"] == "fc_1"
|
||||||
assert items[0]["name"] == "get_weather"
|
assert items[0]["name"] == "get_weather"
|
||||||
|
assert items[0]["arguments"] == '{"city": "SF"}'
|
||||||
|
|
||||||
|
def test_assistant_tool_call_history_repairs_malformed_arguments(self):
|
||||||
|
_, items = convert_messages([{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_abc|fc_1",
|
||||||
|
"function": {"name": "read_file", "arguments": '{path:"foo.txt"}'},
|
||||||
|
}],
|
||||||
|
}])
|
||||||
|
|
||||||
|
assert json.loads(items[0]["arguments"]) == {"path": "foo.txt"}
|
||||||
|
|
||||||
def test_duplicate_response_item_ids_are_made_unique(self):
|
def test_duplicate_response_item_ids_are_made_unique(self):
|
||||||
"""Codex rejects replayed Responses input items with duplicate ids."""
|
"""Codex rejects replayed Responses input items with duplicate ids."""
|
||||||
@ -367,7 +380,7 @@ class TestParseResponseOutput:
|
|||||||
assert result.tool_calls[0].id == "call_1|fc_1"
|
assert result.tool_calls[0].id == "call_1|fc_1"
|
||||||
|
|
||||||
def test_malformed_tool_arguments_logged(self):
|
def test_malformed_tool_arguments_logged(self):
|
||||||
"""Malformed JSON arguments should log a warning and fallback."""
|
"""Malformed JSON arguments should log a warning and remain non-object."""
|
||||||
resp = {
|
resp = {
|
||||||
"output": [{
|
"output": [{
|
||||||
"type": "function_call",
|
"type": "function_call",
|
||||||
@ -378,10 +391,29 @@ class TestParseResponseOutput:
|
|||||||
}
|
}
|
||||||
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
|
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
|
||||||
result = parse_response_output(resp)
|
result = parse_response_output(resp)
|
||||||
assert result.tool_calls[0].arguments == {"raw": "{bad json"}
|
assert result.tool_calls[0].arguments == "{bad json"
|
||||||
mock_logger.warning.assert_called_once()
|
mock_logger.warning.assert_called_once()
|
||||||
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
|
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("arguments", [[], False, 0])
|
||||||
|
def test_falsy_non_object_tool_arguments_preserved(self, arguments):
|
||||||
|
resp = {
|
||||||
|
"output": [{
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "c1",
|
||||||
|
"id": "fc1",
|
||||||
|
"name": "f",
|
||||||
|
"arguments": arguments,
|
||||||
|
}],
|
||||||
|
"status": "completed",
|
||||||
|
"usage": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = parse_response_output(resp)
|
||||||
|
|
||||||
|
assert result.tool_calls[0].arguments == arguments
|
||||||
|
assert type(result.tool_calls[0].arguments) is type(arguments)
|
||||||
|
|
||||||
def test_reasoning_content_extracted(self):
|
def test_reasoning_content_extracted(self):
|
||||||
resp = {
|
resp = {
|
||||||
"output": [
|
"output": [
|
||||||
@ -611,6 +643,38 @@ class TestConsumeSse:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("arguments", [[], False, 0])
|
||||||
|
async def test_falsy_non_object_tool_arguments_preserved(self, arguments):
|
||||||
|
response = _SseResponse([
|
||||||
|
{
|
||||||
|
"type": "response.output_item.added",
|
||||||
|
"item": {
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "c1",
|
||||||
|
"id": "fc1",
|
||||||
|
"name": "f",
|
||||||
|
"arguments": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "response.output_item.done",
|
||||||
|
"item": {
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "c1",
|
||||||
|
"id": "fc1",
|
||||||
|
"name": "f",
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "response.completed", "response": {"status": "completed"}},
|
||||||
|
])
|
||||||
|
|
||||||
|
_, tool_calls, _, _, _ = await consume_sse_with_reasoning(response)
|
||||||
|
|
||||||
|
assert tool_calls[0].arguments == arguments
|
||||||
|
assert type(tool_calls[0].arguments) is type(arguments)
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================
|
# ======================================================================
|
||||||
# parsing - consume_sdk_stream
|
# parsing - consume_sdk_stream
|
||||||
@ -764,6 +828,28 @@ class TestConsumeSdkStream:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("arguments", [[], False, 0])
|
||||||
|
async def test_falsy_non_object_tool_arguments_preserved(self, arguments):
|
||||||
|
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
|
||||||
|
item_added.name = "f"
|
||||||
|
ev1 = MagicMock(type="response.output_item.added", item=item_added)
|
||||||
|
item_done = MagicMock(type="function_call", call_id="c1", id="fc1")
|
||||||
|
item_done.name = "f"
|
||||||
|
item_done.arguments = arguments
|
||||||
|
ev2 = MagicMock(type="response.output_item.done", item=item_done)
|
||||||
|
resp_obj = MagicMock(status="completed", usage=None, output=[])
|
||||||
|
ev3 = MagicMock(type="response.completed", response=resp_obj)
|
||||||
|
|
||||||
|
async def stream():
|
||||||
|
for e in [ev1, ev2, ev3]:
|
||||||
|
yield e
|
||||||
|
|
||||||
|
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
|
||||||
|
|
||||||
|
assert tool_calls[0].arguments == arguments
|
||||||
|
assert type(tool_calls[0].arguments) is type(arguments)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_usage_extracted(self):
|
async def test_usage_extracted(self):
|
||||||
usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15)
|
usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15)
|
||||||
@ -811,7 +897,7 @@ class TestConsumeSdkStream:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_malformed_tool_args_logged(self):
|
async def test_malformed_tool_args_logged(self):
|
||||||
"""Malformed JSON in streaming tool args should log a warning."""
|
"""Malformed JSON in streaming tool args should log a warning and remain non-object."""
|
||||||
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
|
item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="")
|
||||||
item_added.name = "f"
|
item_added.name = "f"
|
||||||
ev1 = MagicMock(type="response.output_item.added", item=item_added)
|
ev1 = MagicMock(type="response.output_item.added", item=item_added)
|
||||||
@ -828,6 +914,6 @@ class TestConsumeSdkStream:
|
|||||||
|
|
||||||
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
|
with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger:
|
||||||
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
|
_, tool_calls, _, _, _ = await consume_sdk_stream(stream())
|
||||||
assert tool_calls[0].arguments == {"raw": "{bad"}
|
assert tool_calls[0].arguments == "{bad"
|
||||||
mock_logger.warning.assert_called_once()
|
mock_logger.warning.assert_called_once()
|
||||||
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
|
assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args)
|
||||||
|
|||||||
30
tests/providers/test_provider_tool_arguments.py
Normal file
30
tests/providers/test_provider_tool_arguments.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""Shared tool-argument parsing policy tests."""
|
||||||
|
|
||||||
|
from nanobot.providers.base import (
|
||||||
|
parse_tool_arguments,
|
||||||
|
tool_arguments_json_for_replay,
|
||||||
|
tool_arguments_object_for_replay,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_tool_arguments_preserves_malformed_executable_arguments() -> None:
|
||||||
|
assert parse_tool_arguments('{path:"foo.txt"}') == '{path:"foo.txt"}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_tool_arguments_preserves_non_object_executable_arguments() -> None:
|
||||||
|
assert parse_tool_arguments('["foo.txt"]') == ["foo.txt"]
|
||||||
|
assert parse_tool_arguments("false") is False
|
||||||
|
assert parse_tool_arguments("null") == "null"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_arguments_object_for_replay_repairs_object_like_history_arguments() -> None:
|
||||||
|
assert tool_arguments_object_for_replay('{path:"foo.txt"}') == {"path": "foo.txt"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_arguments_object_for_replay_keeps_history_object_shaped() -> None:
|
||||||
|
for arguments in ['["foo.txt"]', "false", "null", "0", ["foo.txt"], False, None, 0]:
|
||||||
|
assert tool_arguments_object_for_replay(arguments) == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_arguments_json_for_replay_returns_object_string() -> None:
|
||||||
|
assert tool_arguments_json_for_replay('{path:"foo.txt"}') == '{"path": "foo.txt"}'
|
||||||
@ -7,8 +7,9 @@ from nanobot.agent.tools.registry import ToolRegistry
|
|||||||
|
|
||||||
|
|
||||||
class _FakeTool(Tool):
|
class _FakeTool(Tool):
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str, schema: dict[str, Any] | None = None):
|
||||||
self._name = name
|
self._name = name
|
||||||
|
self._schema = schema
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@ -20,7 +21,7 @@ class _FakeTool(Tool):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {"type": "object", "properties": {}}
|
return self._schema or {"type": "object", "properties": {}}
|
||||||
|
|
||||||
async def execute(self, **kwargs: Any) -> Any:
|
async def execute(self, **kwargs: Any) -> Any:
|
||||||
return kwargs
|
return kwargs
|
||||||
@ -34,6 +35,13 @@ def _tool_names(definitions: list[dict[str, Any]]) -> list[str]:
|
|||||||
return names
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def _registry_with_names(names: list[str]) -> ToolRegistry:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
for name in names:
|
||||||
|
registry.register(_FakeTool(name))
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
|
def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
registry.register(_FakeTool("mcp_git_status"))
|
registry.register(_FakeTool("mcp_git_status"))
|
||||||
@ -49,17 +57,167 @@ def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_call_rejects_near_miss_tool_name_with_suggestion() -> None:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(_FakeTool("read_file"))
|
||||||
|
|
||||||
|
tool, params, error = registry.prepare_call("readFile", {"path": "foo.txt"})
|
||||||
|
|
||||||
|
assert tool is None
|
||||||
|
assert params == {"path": "foo.txt"}
|
||||||
|
assert error is not None
|
||||||
|
assert "Tool 'readFile' not found" in error
|
||||||
|
assert "Did you mean 'read_file'?" in error
|
||||||
|
assert "must match exactly" in error
|
||||||
|
|
||||||
|
|
||||||
|
def test_suggest_name_handles_canonical_tool_name_variants() -> None:
|
||||||
|
registry = _registry_with_names(["read_file"])
|
||||||
|
expected = {
|
||||||
|
"readFile": "read_file",
|
||||||
|
"read-file": "read_file",
|
||||||
|
"READ_FILE": "read_file",
|
||||||
|
"read file": "read_file",
|
||||||
|
"readfile": "read_file",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert {name: registry._suggest_name(name) for name in expected} == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_suggest_name_suppresses_low_confidence_and_non_unique_matches() -> None:
|
||||||
|
registry = _registry_with_names(["read_file", "write_file"])
|
||||||
|
|
||||||
|
for name in ["", "foo", "read", "file", "readfil", "read_file_tool"]:
|
||||||
|
assert registry._suggest_name(name) is None
|
||||||
|
|
||||||
|
ambiguous = _registry_with_names(["read_file", "readFile"])
|
||||||
|
assert ambiguous._suggest_name("readfile") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_suggest_name_updates_after_register_and_unregister() -> None:
|
||||||
|
registry = _registry_with_names(["read_file"])
|
||||||
|
|
||||||
|
assert registry._suggest_name("readFile") == "read_file"
|
||||||
|
|
||||||
|
registry.register(_FakeTool("readFile"))
|
||||||
|
assert registry._suggest_name("read-file") is None
|
||||||
|
|
||||||
|
registry.unregister("read_file")
|
||||||
|
assert registry._suggest_name("read-file") == "readFile"
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_call_read_file_rejects_non_object_params_with_actionable_hint() -> None:
|
def test_prepare_call_read_file_rejects_non_object_params_with_actionable_hint() -> None:
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
registry.register(_FakeTool("read_file"))
|
registry.register(_FakeTool("read_file"))
|
||||||
|
|
||||||
tool, params, error = registry.prepare_call("read_file", ["foo.txt"])
|
tool, params, error = registry.prepare_call("read_file", ["foo.txt"])
|
||||||
|
|
||||||
assert tool is None
|
assert tool is not None
|
||||||
assert params == ["foo.txt"]
|
assert params == ["foo.txt"]
|
||||||
assert error is not None
|
assert error is not None
|
||||||
assert "must be a JSON object" in error
|
assert "must be a JSON object" in error
|
||||||
assert "Use named parameters" in error
|
assert 'tool_name(param1="value1", param2="value2")' in error
|
||||||
|
assert "matching the tool schema" in error
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_call_parses_json_string_arguments() -> None:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(_FakeTool("read_file"))
|
||||||
|
|
||||||
|
tool, params, error = registry.prepare_call("read_file", '{"path":"foo.txt"}')
|
||||||
|
|
||||||
|
assert tool is not None
|
||||||
|
assert params == {"path": "foo.txt"}
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_call_rejects_malformed_json_string_arguments() -> None:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(_FakeTool("read_file"))
|
||||||
|
|
||||||
|
tool, params, error = registry.prepare_call("read_file", '{path:"foo.txt"}')
|
||||||
|
|
||||||
|
assert tool is not None
|
||||||
|
assert params == '{path:"foo.txt"}'
|
||||||
|
assert error is not None
|
||||||
|
assert "parameters must be a JSON object" in error
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_call_rejects_scalar_for_single_required_parameter() -> None:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(_FakeTool(
|
||||||
|
"web_fetch",
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"url": {"type": "string"}},
|
||||||
|
"required": ["url"],
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
tool, params, error = registry.prepare_call("web_fetch", "https://example.com")
|
||||||
|
|
||||||
|
assert tool is not None
|
||||||
|
assert params == "https://example.com"
|
||||||
|
assert error is not None
|
||||||
|
assert "parameters must be a JSON object" in error
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_call_rejects_unquoted_scalar_strings_before_schema_cast() -> None:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(_FakeTool(
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"content": {"type": "string"}},
|
||||||
|
"required": ["content"],
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
tool, params, error = registry.prepare_call("message", "true")
|
||||||
|
|
||||||
|
assert tool is not None
|
||||||
|
assert params == "true"
|
||||||
|
assert error is not None
|
||||||
|
assert "parameters must be a JSON object" in error
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_call_unwraps_arguments_payload() -> None:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(_FakeTool(
|
||||||
|
"read_file",
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"path": {"type": "string"}},
|
||||||
|
"required": ["path"],
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
tool, params, error = registry.prepare_call(
|
||||||
|
"read_file",
|
||||||
|
{"arguments": '{"path":"foo.txt"}'},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool is not None
|
||||||
|
assert params == {"path": "foo.txt"}
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_call_treats_none_arguments_as_empty_object() -> None:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(_FakeTool("list_exec_sessions"))
|
||||||
|
|
||||||
|
tool, params, error = registry.prepare_call("list_exec_sessions", None)
|
||||||
|
|
||||||
|
assert tool is not None
|
||||||
|
assert params == {}
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
tool, params, error = registry.prepare_call("list_exec_sessions", "null")
|
||||||
|
|
||||||
|
assert tool is not None
|
||||||
|
assert params == "null"
|
||||||
|
assert error is not None
|
||||||
|
assert "parameters must be a JSON object" in error
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_call_other_tools_keep_generic_object_validation() -> None:
|
def test_prepare_call_other_tools_keep_generic_object_validation() -> None:
|
||||||
@ -70,7 +228,11 @@ def test_prepare_call_other_tools_keep_generic_object_validation() -> None:
|
|||||||
|
|
||||||
assert tool is not None
|
assert tool is not None
|
||||||
assert params == ["TODO"]
|
assert params == ["TODO"]
|
||||||
assert error == "Error: Invalid parameters for tool 'grep': parameters must be an object, got list"
|
assert error == (
|
||||||
|
"Error: Tool 'grep' parameters must be a JSON object, got list. "
|
||||||
|
'Use named parameters like tool_name(param1="value1", param2="value2") '
|
||||||
|
"matching the tool schema."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_definitions_returns_cached_result() -> None:
|
def test_get_definitions_returns_cached_result() -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user