diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 8746b5c27..8cffb3fdc 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -399,7 +399,6 @@ class AgentRunner: thinking_blocks=response.thinking_blocks, ) messages.append(assistant_message) - tools_used.extend(tc.name for tc in response.tool_calls) await self._emit_checkpoint( spec, { @@ -421,6 +420,11 @@ class AgentRunner: workspace_violation_counts, ) tool_events.extend(new_events) + tools_used.extend( + tool_call.name + for tool_call, event in zip(response.tool_calls, new_events) + if event.get("status") == "ok" + ) context.tool_results = list(results) context.tool_events = list(new_events) completed_tool_results: list[dict[str, Any]] = [] diff --git a/nanobot/agent/tools/registry.py b/nanobot/agent/tools/registry.py index 3d185d579..c60697adf 100644 --- a/nanobot/agent/tools/registry.py +++ b/nanobot/agent/tools/registry.py @@ -1,5 +1,6 @@ """Tool registry for dynamic tool management.""" +import json from typing import Any from nanobot.agent.tools.base import Tool @@ -30,6 +31,24 @@ class ToolRegistry: """Get a tool by name.""" return self._tools.get(name) + @staticmethod + def _lookup_key(name: str) -> str: + """Normalize names for suggestions only; never for execution.""" + return "".join(ch.lower() for ch in name if ch.isalnum()) + + def _suggest_name(self, name: str) -> str | None: + key = self._lookup_key(str(name or "")) + if not key: + return None + matches = [ + registered + for registered in self._tools + if self._lookup_key(registered) == key + ] + if len(matches) == 1: + return matches[0] + return None + def has(self, name: str) -> bool: """Check if a tool is registered.""" return name in self._tools @@ -73,20 +92,23 @@ class ToolRegistry: def prepare_call( self, name: str, - params: dict[str, Any], - ) -> tuple[Tool | None, dict[str, Any], str | None]: + params: Any, + ) -> tuple[Tool | None, Any, str | None]: """Resolve, cast, and validate one tool call.""" - # Guard against invalid parameter types (e.g., list instead of dict) - if not isinstance(params, dict) and name in ('write_file', 'read_file'): - return None, params, ( - f"Error: Tool '{name}' parameters must be a JSON object, got {type(params).__name__}. " - "Use named parameters: tool_name(param1=\"value1\", param2=\"value2\")" - ) - tool = self._tools.get(name) if not tool: + suggestion = self._suggest_name(str(name)) + hint = f" Did you mean '{suggestion}'? Tool names must match exactly." if suggestion else "" return None, params, ( - f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}" + f"Error: Tool '{name}' not found.{hint} Available: {', '.join(self.tool_names)}" + ) + + params = self._coerce_params(tool, params) + if not isinstance(params, dict): + return tool, params, ( + f"Error: Tool '{name}' parameters must be a JSON object, got " + f"{type(params).__name__}. Use named parameters like " + 'tool_name(param1="value1", param2="value2") matching the tool schema.' ) cast_params = tool.cast_params(params) @@ -97,21 +119,56 @@ class ToolRegistry: ) return tool, cast_params, None - async def execute(self, name: str, params: dict[str, Any]) -> Any: + @classmethod + def _coerce_argument_value(cls, value: Any) -> Any: + if value is None: + return {} + if not isinstance(value, str): + return value + + stripped = value.strip() + if not stripped: + return {} + + if not stripped.startswith(("{", "[")): + return value + + try: + parsed = json.loads(stripped) + except Exception: + return value + + return parsed + + @classmethod + def _coerce_params(cls, tool: Tool, params: Any) -> Any: + params = cls._coerce_argument_value(params) + return cls._unwrap_arguments_payload(tool, params) + + @classmethod + def _unwrap_arguments_payload(cls, tool: Tool, params: Any) -> Any: + if not isinstance(params, dict) or set(params) != {"arguments"}: + return params + properties = (tool.parameters or {}).get("properties", {}) + if isinstance(properties, dict) and "arguments" in properties: + return params + return cls._coerce_argument_value(params.get("arguments")) + + async def execute(self, name: str, params: Any) -> Any: """Execute a tool by name with given parameters.""" - _HINT = "\n\n[Analyze the error above and try a different approach.]" + hint = "\n\n[Analyze the error above and try a different approach.]" tool, params, error = self.prepare_call(name, params) if error: - return error + _HINT + return error + hint try: assert tool is not None # guarded by prepare_call() result = await tool.execute(**params) if isinstance(result, str) and result.startswith("Error"): - return result + _HINT + return result + hint return result except Exception as e: - return f"Error executing {name}: {str(e)}" + _HINT + return f"Error executing {name}: {str(e)}" + hint @property def tool_names(self) -> list[str]: diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py index 8a59d5c42..ddeb23aed 100644 --- a/nanobot/providers/anthropic_provider.py +++ b/nanobot/providers/anthropic_provider.py @@ -10,9 +10,12 @@ import string from collections.abc import Awaitable, Callable from typing import Any -import json_repair - -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest +from nanobot.providers.base import ( + LLMProvider, + LLMResponse, + ToolCallRequest, + tool_arguments_object_for_replay, +) _ALNUM = string.ascii_letters + string.digits @@ -207,13 +210,11 @@ class AnthropicProvider(LLMProvider): continue func = tc.get("function", {}) args = func.get("arguments", "{}") - if isinstance(args, str): - args = json_repair.loads(args) blocks.append({ "type": "tool_use", "id": tc.get("id") or _gen_tool_id(), "name": func.get("name", ""), - "input": args, + "input": tool_arguments_object_for_replay(args), }) return blocks or [{"type": "text", "text": ""}] @@ -509,7 +510,7 @@ class AnthropicProvider(LLMProvider): tool_calls.append(ToolCallRequest( id=block.id, name=block.name, - arguments=block.input if isinstance(block.input, dict) else {}, + arguments=block.input, )) elif block.type == "thinking": thinking_blocks.append({ diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index c36593cb2..4a692b424 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -11,6 +11,7 @@ from datetime import datetime, timezone from email.utils import parsedate_to_datetime from typing import Any +import json_repair from loguru import logger from nanobot.utils.helpers import image_placeholder_text @@ -21,19 +22,24 @@ class ToolCallRequest: """A tool call request from the LLM.""" id: str name: str - arguments: dict[str, Any] + arguments: Any extra_content: dict[str, Any] | None = None provider_specific_fields: dict[str, Any] | None = None function_provider_specific_fields: dict[str, Any] | None = None def to_openai_tool_call(self) -> dict[str, Any]: """Serialize to an OpenAI-style tool_call payload.""" + arguments = ( + self.arguments + if isinstance(self.arguments, str) + else json.dumps(self.arguments, ensure_ascii=False) + ) tool_call = { "id": self.id, "type": "function", "function": { "name": self.name, - "arguments": json.dumps(self.arguments, ensure_ascii=False), + "arguments": arguments, }, } if self.extra_content: @@ -45,6 +51,62 @@ class ToolCallRequest: return tool_call +def parse_tool_arguments(arguments: Any) -> Any: + """Parse provider tool arguments without guessing executable parameters. + + Valid JSON object strings become dicts. Empty strings become no-arg calls. + Malformed JSON and JSON array/scalar values are preserved so ToolRegistry + can reject them before execution. + """ + if arguments is None: + return {} + if not isinstance(arguments, str): + return arguments + + stripped = arguments.strip() + if not stripped: + return {} + + try: + parsed = json.loads(stripped) + except Exception: + return arguments + return arguments if parsed is None else parsed + + +def tool_arguments_object_for_replay(arguments: Any) -> dict[str, Any]: + """Return object-shaped arguments for provider history replay only. + + This compatibility path may repair malformed JSON because it only shapes + existing conversation history for provider protocols. Do not use it for + newly generated tool calls that are about to execute. + """ + if arguments is None: + return {} + if isinstance(arguments, dict): + return arguments + if not isinstance(arguments, str): + return {} + + stripped = arguments.strip() + if not stripped: + return {} + + try: + parsed = json.loads(stripped) + except Exception: + try: + parsed = json_repair.loads(stripped) + except Exception: + return {} + return parsed if isinstance(parsed, dict) else {} + + +def tool_arguments_json_for_replay(arguments: Any) -> str: + """Return JSON object string arguments for provider history replay only.""" + return json.dumps(tool_arguments_object_for_replay(arguments), ensure_ascii=False) + + @dataclass class LLMResponse: """Response from an LLM provider.""" diff --git a/nanobot/providers/bedrock_provider.py b/nanobot/providers/bedrock_provider.py index ff74badbc..dbac6078a 100644 --- a/nanobot/providers/bedrock_provider.py +++ b/nanobot/providers/bedrock_provider.py @@ -10,9 +10,13 @@ import re from collections.abc import Awaitable, Callable, Iterator from typing import Any -import json_repair - -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest +from nanobot.providers.base import ( + LLMProvider, + LLMResponse, + ToolCallRequest, + parse_tool_arguments, + tool_arguments_object_for_replay, +) _IMAGE_DATA_URL = re.compile(r"^data:image/([a-zA-Z0-9.+-]+);base64,(.*)$", re.DOTALL) _TEXT_BLOCK_TYPES = {"text", "input_text", "output_text"} @@ -176,14 +180,7 @@ class BedrockProvider(LLMProvider): function = tool_call.get("function") if not isinstance(function, dict): return None - args = function.get("arguments", {}) - if isinstance(args, str): - try: - args = json_repair.loads(args) if args.strip() else {} - except Exception: - args = {} - if not isinstance(args, dict): - args = {} + args = tool_arguments_object_for_replay(function.get("arguments", {})) return { "toolUse": { "toolUseId": str(tool_call.get("id") or ""), @@ -491,7 +488,7 @@ class BedrockProvider(LLMProvider): content_parts.append(block["text"]) tool_use = block.get("toolUse") if isinstance(tool_use, dict): - arguments = tool_use.get("input") if isinstance(tool_use.get("input"), dict) else {} + arguments = tool_use.get("input", {}) tool_calls.append(ToolCallRequest( id=str(tool_use.get("toolUseId") or ""), name=str(tool_use.get("name") or ""), @@ -616,14 +613,11 @@ class BedrockProvider(LLMProvider): for buf in tool_buffers.values(): args: Any = {} if buf.get("input"): - try: - args = json_repair.loads(buf["input"]) - except Exception: - args = {} + args = parse_tool_arguments(buf["input"]) tool_calls.append(ToolCallRequest( id=buf.get("id") or "", name=buf.get("name") or "", - arguments=args if isinstance(args, dict) else {}, + arguments=args, )) return LLMResponse( content="".join(content_parts) or None, diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index a0eb35176..ee44333a6 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -17,10 +17,15 @@ from ipaddress import ip_address from typing import TYPE_CHECKING, Any from urllib.parse import urlparse -import json_repair from loguru import logger -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest +from nanobot.providers.base import ( + LLMProvider, + LLMResponse, + ToolCallRequest, + parse_tool_arguments, + tool_arguments_json_for_replay, +) from nanobot.providers.openai_responses import ( consume_sdk_stream, convert_messages, @@ -478,24 +483,6 @@ class OpenAICompatProvider(LLMProvider): """Return True for providers that reject normal OpenAI tool call IDs.""" return bool(self._spec and self._spec.name == "mistral") - @staticmethod - def _normalize_tool_call_arguments(arguments: Any) -> str: - """Force function.arguments into a valid JSON object string.""" - if isinstance(arguments, str): - stripped = arguments.strip() - if not stripped: - return "{}" - try: - parsed = json_repair.loads(stripped) - except Exception: - return "{}" - if isinstance(parsed, dict): - return json.dumps(parsed, ensure_ascii=False) - return "{}" - if isinstance(arguments, dict): - return json.dumps(arguments, ensure_ascii=False) - return "{}" - @staticmethod def _coerce_content_to_string(content: Any) -> str | None: """Coerce block/list content into plain text for strict string-only APIs.""" @@ -572,7 +559,7 @@ class OpenAICompatProvider(LLMProvider): if isinstance(function, dict): function_clean = dict(function) if "arguments" in function_clean: - function_clean["arguments"] = self._normalize_tool_call_arguments( + function_clean["arguments"] = tool_arguments_json_for_replay( function_clean.get("arguments") ) else: @@ -1021,14 +1008,12 @@ class OpenAICompatProvider(LLMProvider): for tc in raw_tool_calls: tc_map = self._maybe_mapping(tc) or {} fn = self._maybe_mapping(tc_map.get("function")) or {} - args = fn.get("arguments", {}) - if isinstance(args, str): - args = json_repair.loads(args) + args = parse_tool_arguments(fn.get("arguments", {})) ec, prov, fn_prov = _extract_tc_extras(tc) parsed_tool_calls.append(ToolCallRequest( id=str(tc_map.get("id") or _short_tool_id()), name=str(fn.get("name") or ""), - arguments=args if isinstance(args, dict) else {}, + arguments=args, extra_content=ec, provider_specific_fields=prov, function_provider_specific_fields=fn_prov, @@ -1064,9 +1049,7 @@ class OpenAICompatProvider(LLMProvider): tool_calls = [] for tc in raw_tool_calls: - args = tc.function.arguments - if isinstance(args, str): - args = json_repair.loads(args) + args = parse_tool_arguments(tc.function.arguments) ec, prov, fn_prov = _extract_tc_extras(tc) tool_calls.append(ToolCallRequest( id=str(getattr(tc, "id", None) or _short_tool_id()), @@ -1207,7 +1190,7 @@ class OpenAICompatProvider(LLMProvider): ToolCallRequest( id=b["id"] or _short_tool_id(), name=b["name"], - arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {}, + arguments=parse_tool_arguments(b["arguments"]), extra_content=b.get("extra_content"), provider_specific_fields=b.get("prov"), function_provider_specific_fields=b.get("fn_prov"), diff --git a/nanobot/providers/openai_responses/converters.py b/nanobot/providers/openai_responses/converters.py index 27c59ab58..c8b756b14 100644 --- a/nanobot/providers/openai_responses/converters.py +++ b/nanobot/providers/openai_responses/converters.py @@ -5,6 +5,8 @@ from __future__ import annotations import json from typing import Any +from nanobot.providers.base import tool_arguments_json_for_replay + def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]: """Convert Chat Completions messages to Responses API input items. @@ -46,7 +48,7 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str "id": response_item_id, "call_id": call_id or f"call_{idx}", "name": fn.get("name"), - "arguments": fn.get("arguments") or "{}", + "arguments": tool_arguments_json_for_replay(fn.get("arguments")), }) continue diff --git a/nanobot/providers/openai_responses/parsing.py b/nanobot/providers/openai_responses/parsing.py index fbfc9813c..a16a6d620 100644 --- a/nanobot/providers/openai_responses/parsing.py +++ b/nanobot/providers/openai_responses/parsing.py @@ -7,10 +7,9 @@ from collections.abc import Awaitable, Callable from typing import Any, AsyncGenerator import httpx -import json_repair from loguru import logger -from nanobot.providers.base import LLMResponse, ToolCallRequest +from nanobot.providers.base import LLMResponse, ToolCallRequest, parse_tool_arguments FINISH_REASON_MAP = { "completed": "stop", @@ -44,6 +43,27 @@ def _usage_from_response_obj(response: Any) -> dict[str, int]: } +def _parse_tool_call_arguments(args_raw: Any, name: str | None) -> Any: + parsed = parse_tool_arguments(args_raw) + if parsed == args_raw and isinstance(args_raw, str) and args_raw.strip(): + logger.warning( + "Failed to parse tool call arguments for '{}': {}", + name, + args_raw[:200], + ) + return parsed + + +def _tool_arguments_source(*values: Any) -> Any: + for value in values: + if value is None: + continue + if isinstance(value, str) and not value.strip(): + continue + return value + return "{}" + + async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]: """Yield parsed JSON events from a Responses API SSE stream.""" buffer: list[str] = [] @@ -116,10 +136,11 @@ async def consume_sse_with_reasoning( call_id = item.get("call_id") if not call_id: continue + arguments = item.get("arguments") tool_call_buffers[call_id] = { "id": item.get("id") or "fc_0", "name": item.get("name"), - "arguments": item.get("arguments") or "", + "arguments": "" if arguments is None else arguments, } if on_tool_call_delta: await on_tool_call_delta({ @@ -156,7 +177,10 @@ async def consume_sse_with_reasoning( call_id = event.get("call_id") if call_id and call_id in tool_call_buffers: delta = event.get("delta") or "" - tool_call_buffers[call_id]["arguments"] += delta + current = tool_call_buffers[call_id].get("arguments") + if not isinstance(current, str): + current = "" + tool_call_buffers[call_id]["arguments"] = current + delta if on_tool_call_delta and delta: await on_tool_call_delta({ "call_id": str(call_id), @@ -166,14 +190,14 @@ async def consume_sse_with_reasoning( elif event_type == "response.function_call_arguments.done": call_id = event.get("call_id") if call_id and call_id in tool_call_buffers: - arguments = event.get("arguments") or "" + arguments = event.get("arguments") tool_call_buffers[call_id]["arguments"] = arguments if on_tool_call_delta: tool_call_args_emitted.add(str(call_id)) await on_tool_call_delta({ "call_id": str(call_id), "name": str(tool_call_buffers[call_id].get("name") or ""), - "arguments": str(arguments), + "arguments": "" if arguments is None else str(arguments), }) elif event_type == "response.output_item.done": item = event.get("item") or {} @@ -182,7 +206,7 @@ async def consume_sse_with_reasoning( if not call_id: continue buf = tool_call_buffers.get(call_id) or {} - args_raw = buf.get("arguments") or item.get("arguments") or "{}" + args_raw = _tool_arguments_source(buf.get("arguments"), item.get("arguments")) if on_tool_call_delta and str(call_id) not in tool_call_args_emitted: tool_call_args_emitted.add(str(call_id)) await on_tool_call_delta({ @@ -190,17 +214,10 @@ async def consume_sse_with_reasoning( "name": str(buf.get("name") or item.get("name") or ""), "arguments": str(args_raw), }) - try: - args = json.loads(args_raw) - except Exception: - logger.warning( - "Failed to parse tool call arguments for '{}': {}", - buf.get("name") or item.get("name"), - args_raw[:200], - ) - args = json_repair.loads(args_raw) - if not isinstance(args, dict): - args = {"raw": args_raw} + args = _parse_tool_call_arguments( + args_raw, + buf.get("name") or item.get("name"), + ) tool_calls.append( ToolCallRequest( id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}", @@ -283,22 +300,12 @@ def parse_response_output(response: Any) -> LLMResponse: elif item_type == "function_call": call_id = item.get("call_id") or "" item_id = item.get("id") or "fc_0" - args_raw = item.get("arguments") or "{}" - try: - args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw - except Exception: - logger.warning( - "Failed to parse tool call arguments for '{}': {}", - item.get("name"), - str(args_raw)[:200], - ) - args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw - if not isinstance(args, dict): - args = {"raw": args_raw} + args_raw = _tool_arguments_source(item.get("arguments")) + args = _parse_tool_call_arguments(args_raw, item.get("name")) tool_calls.append(ToolCallRequest( id=f"{call_id}|{item_id}", name=item.get("name") or "", - arguments=args if isinstance(args, dict) else {}, + arguments=args, )) usage = _usage_from_response_obj(response) @@ -337,10 +344,11 @@ async def consume_sdk_stream( call_id = getattr(item, "call_id", None) if not call_id: continue + arguments = getattr(item, "arguments", None) tool_call_buffers[call_id] = { "id": getattr(item, "id", None) or "fc_0", "name": getattr(item, "name", None), - "arguments": getattr(item, "arguments", None) or "", + "arguments": "" if arguments is None else arguments, } if on_tool_call_delta: await on_tool_call_delta({ @@ -357,7 +365,10 @@ async def consume_sdk_stream( call_id = getattr(event, "call_id", None) if call_id and call_id in tool_call_buffers: delta = getattr(event, "delta", "") or "" - tool_call_buffers[call_id]["arguments"] += delta + current = tool_call_buffers[call_id].get("arguments") + if not isinstance(current, str): + current = "" + tool_call_buffers[call_id]["arguments"] = current + delta if on_tool_call_delta and delta: await on_tool_call_delta({ "call_id": str(call_id), @@ -367,14 +378,14 @@ async def consume_sdk_stream( elif event_type == "response.function_call_arguments.done": call_id = getattr(event, "call_id", None) if call_id and call_id in tool_call_buffers: - arguments = getattr(event, "arguments", "") or "" + arguments = getattr(event, "arguments", None) tool_call_buffers[call_id]["arguments"] = arguments if on_tool_call_delta: tool_call_args_emitted.add(str(call_id)) await on_tool_call_delta({ "call_id": str(call_id), "name": str(tool_call_buffers[call_id].get("name") or ""), - "arguments": str(arguments), + "arguments": "" if arguments is None else str(arguments), }) elif event_type == "response.output_item.done": item = getattr(event, "item", None) @@ -383,7 +394,10 @@ async def consume_sdk_stream( if not call_id: continue buf = tool_call_buffers.get(call_id) or {} - args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}" + args_raw = _tool_arguments_source( + buf.get("arguments"), + getattr(item, "arguments", None), + ) if on_tool_call_delta and str(call_id) not in tool_call_args_emitted: tool_call_args_emitted.add(str(call_id)) await on_tool_call_delta({ @@ -391,17 +405,10 @@ async def consume_sdk_stream( "name": str(buf.get("name") or getattr(item, "name", None) or ""), "arguments": str(args_raw), }) - try: - args = json.loads(args_raw) - except Exception: - logger.warning( - "Failed to parse tool call arguments for '{}': {}", - buf.get("name") or getattr(item, "name", None), - str(args_raw)[:200], - ) - args = json_repair.loads(args_raw) - if not isinstance(args, dict): - args = {"raw": args_raw} + args = _parse_tool_call_arguments( + args_raw, + buf.get("name") or getattr(item, "name", None), + ) tool_calls.append( ToolCallRequest( id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}", diff --git a/nanobot/utils/progress_events.py b/nanobot/utils/progress_events.py index ccf125ec4..645a351d6 100644 --- a/nanobot/utils/progress_events.py +++ b/nanobot/utils/progress_events.py @@ -49,13 +49,18 @@ async def invoke_file_edit_progress( await on_progress("", file_edit_events=file_edit_events) +def _tool_event_arguments(tool_call: Any) -> dict[str, Any]: + arguments = getattr(tool_call, "arguments", {}) or {} + return arguments if isinstance(arguments, dict) else {} + + def build_tool_event_start_payload(tool_call: Any) -> dict[str, Any]: return { "version": 1, "phase": "start", "call_id": str(getattr(tool_call, "id", "") or ""), "name": getattr(tool_call, "name", ""), - "arguments": getattr(tool_call, "arguments", {}) or {}, + "arguments": _tool_event_arguments(tool_call), "result": None, "error": None, "files": [], @@ -86,7 +91,7 @@ def build_tool_event_finish_payloads(context: AgentHookContext) -> list[dict[str "phase": phase, "call_id": str(getattr(tool_call, "id", "") or ""), "name": getattr(tool_call, "name", ""), - "arguments": getattr(tool_call, "arguments", {}) or {}, + "arguments": _tool_event_arguments(tool_call), "result": result if phase == "end" else None, "error": None, "files": files, diff --git a/nanobot/utils/runtime.py b/nanobot/utils/runtime.py index 66783e19f..70d14c442 100644 --- a/nanobot/utils/runtime.py +++ b/nanobot/utils/runtime.py @@ -75,8 +75,10 @@ def build_goal_continue_message(custom: str | None = None) -> dict[str, str]: return {"role": "user", "content": custom or SUSTAINED_GOAL_CONTINUE_PROMPT} -def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str | None: +def external_lookup_signature(tool_name: str, arguments: Any) -> str | None: """Stable signature for repeated external lookups we want to throttle.""" + if not isinstance(arguments, dict): + return None if tool_name == "web_fetch": url = str(arguments.get("url") or "").strip() if url: @@ -90,7 +92,7 @@ def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str def repeated_external_lookup_error( tool_name: str, - arguments: dict[str, Any], + arguments: Any, seen_counts: dict[str, int], ) -> str | None: """Block repeated external lookups after a small retry budget.""" @@ -119,9 +121,11 @@ _OUTSIDE_PATH_PATTERN = re.compile(r"(?:^|[\s|>'\"])((?:/[^\s\"'>;|<]+)|(?:~[^\s def workspace_violation_signature( tool_name: str, - arguments: dict[str, Any], + arguments: Any, ) -> str | None: """Return a stable cross-tool signature for the outside-workspace target.""" + if not isinstance(arguments, dict): + return None for key in ("path", "file_path", "target", "source", "destination"): val = arguments.get(key) if isinstance(val, str) and val.strip(): @@ -151,7 +155,7 @@ def _normalize_violation_target(raw: str) -> str: def repeated_workspace_violation_error( tool_name: str, - arguments: dict[str, Any], + arguments: Any, seen_counts: dict[str, int], ) -> str | None: """Return an escalated error after repeated bypass attempts.""" diff --git a/tests/agent/test_runner_tool_execution.py b/tests/agent/test_runner_tool_execution.py index a0380e871..70e74fafe 100644 --- a/tests/agent/test_runner_tool_execution.py +++ b/tests/agent/test_runner_tool_execution.py @@ -3,17 +3,21 @@ from __future__ import annotations import asyncio -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest +from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.agent.tools.base import Tool from nanobot.agent.tools.registry import ToolRegistry from nanobot.config.schema import AgentDefaults from nanobot.providers.base import LLMResponse, ToolCallRequest +from nanobot.providers.openai_compat_provider import OpenAICompatProvider +from nanobot.providers.openai_responses.parsing import parse_response_output _MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + class _DelayTool(Tool): def __init__( self, @@ -57,10 +61,45 @@ class _DelayTool(Tool): return self._name +async def _run_optional_tool_response(response: LLMResponse): + provider = MagicMock() + calls = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + calls["n"] += 1 + if calls["n"] == 1: + return response + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = ToolRegistry() + shared_events: list[str] = [] + tools.register(_DelayTool( + "optional_tool", + delay=0, + read_only=True, + shared_events=shared_events, + )) + + result = await AgentRunner(provider).run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "try optional"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + return result, shared_events + + +def _tool_message(result, tool_call_id: str) -> dict: + return [ + msg for msg in result.messages + if msg.get("role") == "tool" and msg.get("tool_call_id") == tool_call_id + ][0] + + @pytest.mark.asyncio async def test_runner_batches_read_only_tools_before_exclusive_work(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - tools = ToolRegistry() shared_events: list[str] = [] read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events) @@ -98,8 +137,6 @@ async def test_runner_batches_read_only_tools_before_exclusive_work(): @pytest.mark.asyncio async def test_runner_does_not_batch_exclusive_read_only_tools(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner - tools = ToolRegistry() shared_events: list[str] = [] read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events) @@ -140,9 +177,151 @@ async def test_runner_does_not_batch_exclusive_read_only_tools(): @pytest.mark.asyncio -async def test_runner_blocks_repeated_external_fetches(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner +async def test_runner_rejects_near_miss_tool_name_without_executing(): + provider = MagicMock() + call_count = {"n": 0} + captured_second_call: list[dict] = [] + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="", + tool_calls=[ + ToolCallRequest( + id="call_1", + name="readFile", + arguments={"path": "notes.txt"}, + ) + ], + finish_reason="tool_calls", + usage={}, + ) + captured_second_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = ToolRegistry() + shared_events: list[str] = [] + tools.register(_DelayTool( + "read_file", + delay=0, + read_only=True, + shared_events=shared_events, + )) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "read notes"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + assert result.tools_used == [] + assert shared_events == [] + assistant_message = [ + msg for msg in result.messages + if msg.get("role") == "assistant" and msg.get("tool_calls") + ][0] + assert assistant_message["tool_calls"][0]["function"]["name"] == "readFile" + tool_message = [ + msg for msg in result.messages + if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_1" + ][0] + assert tool_message["name"] == "readFile" + assert "Tool 'readFile' not found" in tool_message["content"] + assert "Did you mean 'read_file'?" in tool_message["content"] + replayed_assistant = [ + msg for msg in captured_second_call + if msg.get("role") == "assistant" and msg.get("tool_calls") + ][0] + assert replayed_assistant["tool_calls"][0]["function"]["name"] == "readFile" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("arguments", ['{path:"notes.txt"}', "null"]) +async def test_runner_rejects_openai_compat_invalid_arguments_without_executing(arguments): + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + parsed = OpenAICompatProvider()._parse({ + "choices": [{ + "message": { + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": { + "name": "optional_tool", + "arguments": arguments, + }, + }], + }, + "finish_reason": "tool_calls", + }], + "usage": {}, + }) + + result, shared_events = await _run_optional_tool_response(parsed) + + assert result.final_content == "done" + assert parsed.tool_calls[0].arguments == arguments + assert result.tools_used == [] + assert shared_events == [] + tool_message = _tool_message(result, "call_1") + assert "parameters must be a JSON object" in tool_message["content"] + + +@pytest.mark.asyncio +async def test_runner_rejects_openai_responses_malformed_arguments_without_executing(): + parsed = parse_response_output({ + "output": [{ + "type": "function_call", + "call_id": "call_1", + "id": "fc_1", + "name": "optional_tool", + "arguments": "{bad", + }], + "status": "completed", + "usage": {}, + }) + + result, shared_events = await _run_optional_tool_response(parsed) + + assert result.final_content == "done" + assert parsed.tool_calls[0].arguments == "{bad" + assert result.tools_used == [] + assert shared_events == [] + tool_message = _tool_message(result, "call_1|fc_1") + assert "parameters must be a JSON object" in tool_message["content"] + + +@pytest.mark.asyncio +async def test_runner_rejects_openai_responses_array_arguments_without_executing(): + parsed = parse_response_output({ + "output": [{ + "type": "function_call", + "call_id": "call_1", + "id": "fc_1", + "name": "optional_tool", + "arguments": [], + }], + "status": "completed", + "usage": {}, + }) + + result, shared_events = await _run_optional_tool_response(parsed) + + assert result.final_content == "done" + assert parsed.tool_calls[0].arguments == [] + assert result.tools_used == [] + assert shared_events == [] + tool_message = _tool_message(result, "call_1|fc_1") + assert "parameters must be a JSON object" in tool_message["content"] + + +@pytest.mark.asyncio +async def test_runner_blocks_repeated_external_fetches(): provider = MagicMock() captured_final_call: list[dict] = [] call_count = {"n": 0} diff --git a/tests/providers/test_anthropic_tool_result.py b/tests/providers/test_anthropic_tool_result.py index f6f6abbfe..3021ff0be 100644 --- a/tests/providers/test_anthropic_tool_result.py +++ b/tests/providers/test_anthropic_tool_result.py @@ -80,3 +80,17 @@ def test_convert_user_content_coerces_mixed_typeless(): ]) assert result[0] == {"type": "text", "text": "42"} assert result[1] == {"type": "text", "text": str({"key": "val"})} + + +def test_convert_assistant_message_repairs_history_tool_arguments(): + blocks = AnthropicProvider._assistant_blocks({ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": "toolu_1", + "function": {"name": "read_file", "arguments": '{path:"foo.txt"}'}, + }], + }) + + assert blocks[0]["type"] == "tool_use" + assert blocks[0]["input"] == {"path": "foo.txt"} diff --git a/tests/providers/test_bedrock_provider.py b/tests/providers/test_bedrock_provider.py index 3a480ef1d..a1c175245 100644 --- a/tests/providers/test_bedrock_provider.py +++ b/tests/providers/test_bedrock_provider.py @@ -161,6 +161,16 @@ def test_build_kwargs_converts_messages_tools_and_tool_results() -> None: assert kwargs["toolConfig"]["toolChoice"] == {"any": {}} +def test_tool_use_block_repairs_history_tool_arguments() -> None: + block = BedrockProvider._tool_use_block({ + "id": "toolu_1", + "function": {"name": "read_file", "arguments": '{path:"foo.txt"}'}, + }) + + assert block is not None + assert block["toolUse"]["input"] == {"path": "foo.txt"} + + def test_build_kwargs_keeps_tool_config_for_historical_tool_blocks_without_tools() -> None: provider = BedrockProvider(region="us-east-1", client=FakeClient()) messages = [ diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index d786aad3e..0a1b85f70 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -54,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace: return SimpleNamespace(choices=[choice], usage=usage) +def _fake_tool_call_response_with_arguments(arguments) -> SimpleNamespace: + """Build a minimal chat response with caller-supplied tool arguments.""" + function = SimpleNamespace(name="optional_tool", arguments=arguments) + tool_call = SimpleNamespace(id="call_123", type="function", function=function) + message = SimpleNamespace(content=None, tool_calls=[tool_call], reasoning_content=None) + choice = SimpleNamespace(message=message, finish_reason="tool_calls") + return SimpleNamespace(choices=[choice], usage=SimpleNamespace()) + + def _fake_responses_response(content: str = "ok") -> MagicMock: """Build a minimal Responses API response object.""" resp = MagicMock() @@ -611,6 +620,24 @@ async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None: assert serialized["function"]["provider_specific_fields"] == {"inner": "value"} +def test_openai_compat_parse_preserves_malformed_tool_arguments() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + result = provider._parse(_fake_tool_call_response_with_arguments('{path:"foo.txt"}')) + + assert result.tool_calls[0].arguments == '{path:"foo.txt"}' + + +def test_openai_compat_parse_preserves_array_tool_arguments() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + result = provider._parse(_fake_tool_call_response_with_arguments('["foo.txt"]')) + + assert result.tool_calls[0].arguments == ["foo.txt"] + + def test_openai_model_passthrough() -> None: """OpenAI models pass through unchanged.""" spec = find_by_name("openai") @@ -1110,7 +1137,7 @@ def test_openai_compat_stringifies_dict_tool_arguments() -> None: assert sanitized[1]["tool_calls"][0]["function"]["arguments"] == '{"cmd": "ls -la"}' -def test_openai_compat_repairs_non_json_tool_arguments_string() -> None: +def test_openai_compat_repairs_object_like_history_tool_arguments_string() -> None: with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): provider = OpenAICompatProvider() diff --git a/tests/providers/test_openai_responses.py b/tests/providers/test_openai_responses.py index e9d8545e1..4c2251c08 100644 --- a/tests/providers/test_openai_responses.py +++ b/tests/providers/test_openai_responses.py @@ -155,6 +155,19 @@ class TestConvertMessages: assert items[0]["call_id"] == "call_abc" assert items[0]["id"] == "fc_1" assert items[0]["name"] == "get_weather" + assert items[0]["arguments"] == '{"city": "SF"}' + + def test_assistant_tool_call_history_repairs_malformed_arguments(self): + _, items = convert_messages([{ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": "call_abc|fc_1", + "function": {"name": "read_file", "arguments": '{path:"foo.txt"}'}, + }], + }]) + + assert json.loads(items[0]["arguments"]) == {"path": "foo.txt"} def test_duplicate_response_item_ids_are_made_unique(self): """Codex rejects replayed Responses input items with duplicate ids.""" @@ -367,7 +380,7 @@ class TestParseResponseOutput: assert result.tool_calls[0].id == "call_1|fc_1" def test_malformed_tool_arguments_logged(self): - """Malformed JSON arguments should log a warning and fallback.""" + """Malformed JSON arguments should log a warning and remain non-object.""" resp = { "output": [{ "type": "function_call", @@ -378,10 +391,29 @@ class TestParseResponseOutput: } with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger: result = parse_response_output(resp) - assert result.tool_calls[0].arguments == {"raw": "{bad json"} + assert result.tool_calls[0].arguments == "{bad json" mock_logger.warning.assert_called_once() assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args) + @pytest.mark.parametrize("arguments", [[], False, 0]) + def test_falsy_non_object_tool_arguments_preserved(self, arguments): + resp = { + "output": [{ + "type": "function_call", + "call_id": "c1", + "id": "fc1", + "name": "f", + "arguments": arguments, + }], + "status": "completed", + "usage": {}, + } + + result = parse_response_output(resp) + + assert result.tool_calls[0].arguments == arguments + assert type(result.tool_calls[0].arguments) is type(arguments) + def test_reasoning_content_extracted(self): resp = { "output": [ @@ -611,6 +643,38 @@ class TestConsumeSse: }, ] + @pytest.mark.asyncio + @pytest.mark.parametrize("arguments", [[], False, 0]) + async def test_falsy_non_object_tool_arguments_preserved(self, arguments): + response = _SseResponse([ + { + "type": "response.output_item.added", + "item": { + "type": "function_call", + "call_id": "c1", + "id": "fc1", + "name": "f", + "arguments": "", + }, + }, + { + "type": "response.output_item.done", + "item": { + "type": "function_call", + "call_id": "c1", + "id": "fc1", + "name": "f", + "arguments": arguments, + }, + }, + {"type": "response.completed", "response": {"status": "completed"}}, + ]) + + _, tool_calls, _, _, _ = await consume_sse_with_reasoning(response) + + assert tool_calls[0].arguments == arguments + assert type(tool_calls[0].arguments) is type(arguments) + # ====================================================================== # parsing - consume_sdk_stream @@ -764,6 +828,28 @@ class TestConsumeSdkStream: }, ] + @pytest.mark.asyncio + @pytest.mark.parametrize("arguments", [[], False, 0]) + async def test_falsy_non_object_tool_arguments_preserved(self, arguments): + item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="") + item_added.name = "f" + ev1 = MagicMock(type="response.output_item.added", item=item_added) + item_done = MagicMock(type="function_call", call_id="c1", id="fc1") + item_done.name = "f" + item_done.arguments = arguments + ev2 = MagicMock(type="response.output_item.done", item=item_done) + resp_obj = MagicMock(status="completed", usage=None, output=[]) + ev3 = MagicMock(type="response.completed", response=resp_obj) + + async def stream(): + for e in [ev1, ev2, ev3]: + yield e + + _, tool_calls, _, _, _ = await consume_sdk_stream(stream()) + + assert tool_calls[0].arguments == arguments + assert type(tool_calls[0].arguments) is type(arguments) + @pytest.mark.asyncio async def test_usage_extracted(self): usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15) @@ -811,7 +897,7 @@ class TestConsumeSdkStream: @pytest.mark.asyncio async def test_malformed_tool_args_logged(self): - """Malformed JSON in streaming tool args should log a warning.""" + """Malformed JSON in streaming tool args should log a warning and remain non-object.""" item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="") item_added.name = "f" ev1 = MagicMock(type="response.output_item.added", item=item_added) @@ -828,6 +914,6 @@ class TestConsumeSdkStream: with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger: _, tool_calls, _, _, _ = await consume_sdk_stream(stream()) - assert tool_calls[0].arguments == {"raw": "{bad"} + assert tool_calls[0].arguments == "{bad" mock_logger.warning.assert_called_once() assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args) diff --git a/tests/providers/test_provider_tool_arguments.py b/tests/providers/test_provider_tool_arguments.py new file mode 100644 index 000000000..d1b4326a4 --- /dev/null +++ b/tests/providers/test_provider_tool_arguments.py @@ -0,0 +1,30 @@ +"""Shared tool-argument parsing policy tests.""" + +from nanobot.providers.base import ( + parse_tool_arguments, + tool_arguments_json_for_replay, + tool_arguments_object_for_replay, +) + + +def test_parse_tool_arguments_preserves_malformed_executable_arguments() -> None: + assert parse_tool_arguments('{path:"foo.txt"}') == '{path:"foo.txt"}' + + +def test_parse_tool_arguments_preserves_non_object_executable_arguments() -> None: + assert parse_tool_arguments('["foo.txt"]') == ["foo.txt"] + assert parse_tool_arguments("false") is False + assert parse_tool_arguments("null") == "null" + + +def test_tool_arguments_object_for_replay_repairs_object_like_history_arguments() -> None: + assert tool_arguments_object_for_replay('{path:"foo.txt"}') == {"path": "foo.txt"} + + +def test_tool_arguments_object_for_replay_keeps_history_object_shaped() -> None: + for arguments in ['["foo.txt"]', "false", "null", "0", ["foo.txt"], False, None, 0]: + assert tool_arguments_object_for_replay(arguments) == {} + + +def test_tool_arguments_json_for_replay_returns_object_string() -> None: + assert tool_arguments_json_for_replay('{path:"foo.txt"}') == '{"path": "foo.txt"}' diff --git a/tests/tools/test_tool_registry.py b/tests/tools/test_tool_registry.py index ca60f30ed..7e9dbb35a 100644 --- a/tests/tools/test_tool_registry.py +++ b/tests/tools/test_tool_registry.py @@ -7,8 +7,9 @@ from nanobot.agent.tools.registry import ToolRegistry class _FakeTool(Tool): - def __init__(self, name: str): + def __init__(self, name: str, schema: dict[str, Any] | None = None): self._name = name + self._schema = schema @property def name(self) -> str: @@ -20,7 +21,7 @@ class _FakeTool(Tool): @property def parameters(self) -> dict[str, Any]: - return {"type": "object", "properties": {}} + return self._schema or {"type": "object", "properties": {}} async def execute(self, **kwargs: Any) -> Any: return kwargs @@ -34,6 +35,13 @@ def _tool_names(definitions: list[dict[str, Any]]) -> list[str]: return names +def _registry_with_names(names: list[str]) -> ToolRegistry: + registry = ToolRegistry() + for name in names: + registry.register(_FakeTool(name)) + return registry + + def test_get_definitions_orders_builtins_then_mcp_tools() -> None: registry = ToolRegistry() registry.register(_FakeTool("mcp_git_status")) @@ -49,17 +57,167 @@ def test_get_definitions_orders_builtins_then_mcp_tools() -> None: ] +def test_prepare_call_rejects_near_miss_tool_name_with_suggestion() -> None: + registry = ToolRegistry() + registry.register(_FakeTool("read_file")) + + tool, params, error = registry.prepare_call("readFile", {"path": "foo.txt"}) + + assert tool is None + assert params == {"path": "foo.txt"} + assert error is not None + assert "Tool 'readFile' not found" in error + assert "Did you mean 'read_file'?" in error + assert "must match exactly" in error + + +def test_suggest_name_handles_canonical_tool_name_variants() -> None: + registry = _registry_with_names(["read_file"]) + expected = { + "readFile": "read_file", + "read-file": "read_file", + "READ_FILE": "read_file", + "read file": "read_file", + "readfile": "read_file", + } + + assert {name: registry._suggest_name(name) for name in expected} == expected + + +def test_suggest_name_suppresses_low_confidence_and_non_unique_matches() -> None: + registry = _registry_with_names(["read_file", "write_file"]) + + for name in ["", "foo", "read", "file", "readfil", "read_file_tool"]: + assert registry._suggest_name(name) is None + + ambiguous = _registry_with_names(["read_file", "readFile"]) + assert ambiguous._suggest_name("readfile") is None + + +def test_suggest_name_updates_after_register_and_unregister() -> None: + registry = _registry_with_names(["read_file"]) + + assert registry._suggest_name("readFile") == "read_file" + + registry.register(_FakeTool("readFile")) + assert registry._suggest_name("read-file") is None + + registry.unregister("read_file") + assert registry._suggest_name("read-file") == "readFile" + + def test_prepare_call_read_file_rejects_non_object_params_with_actionable_hint() -> None: registry = ToolRegistry() registry.register(_FakeTool("read_file")) tool, params, error = registry.prepare_call("read_file", ["foo.txt"]) - assert tool is None + assert tool is not None assert params == ["foo.txt"] assert error is not None assert "must be a JSON object" in error - assert "Use named parameters" in error + assert 'tool_name(param1="value1", param2="value2")' in error + assert "matching the tool schema" in error + + +def test_prepare_call_parses_json_string_arguments() -> None: + registry = ToolRegistry() + registry.register(_FakeTool("read_file")) + + tool, params, error = registry.prepare_call("read_file", '{"path":"foo.txt"}') + + assert tool is not None + assert params == {"path": "foo.txt"} + assert error is None + + +def test_prepare_call_rejects_malformed_json_string_arguments() -> None: + registry = ToolRegistry() + registry.register(_FakeTool("read_file")) + + tool, params, error = registry.prepare_call("read_file", '{path:"foo.txt"}') + + assert tool is not None + assert params == '{path:"foo.txt"}' + assert error is not None + assert "parameters must be a JSON object" in error + + +def test_prepare_call_rejects_scalar_for_single_required_parameter() -> None: + registry = ToolRegistry() + registry.register(_FakeTool( + "web_fetch", + { + "type": "object", + "properties": {"url": {"type": "string"}}, + "required": ["url"], + }, + )) + + tool, params, error = registry.prepare_call("web_fetch", "https://example.com") + + assert tool is not None + assert params == "https://example.com" + assert error is not None + assert "parameters must be a JSON object" in error + + +def test_prepare_call_rejects_unquoted_scalar_strings_before_schema_cast() -> None: + registry = ToolRegistry() + registry.register(_FakeTool( + "message", + { + "type": "object", + "properties": {"content": {"type": "string"}}, + "required": ["content"], + }, + )) + + tool, params, error = registry.prepare_call("message", "true") + + assert tool is not None + assert params == "true" + assert error is not None + assert "parameters must be a JSON object" in error + + +def test_prepare_call_unwraps_arguments_payload() -> None: + registry = ToolRegistry() + registry.register(_FakeTool( + "read_file", + { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + }, + )) + + tool, params, error = registry.prepare_call( + "read_file", + {"arguments": '{"path":"foo.txt"}'}, + ) + + assert tool is not None + assert params == {"path": "foo.txt"} + assert error is None + + +def test_prepare_call_treats_none_arguments_as_empty_object() -> None: + registry = ToolRegistry() + registry.register(_FakeTool("list_exec_sessions")) + + tool, params, error = registry.prepare_call("list_exec_sessions", None) + + assert tool is not None + assert params == {} + assert error is None + + tool, params, error = registry.prepare_call("list_exec_sessions", "null") + + assert tool is not None + assert params == "null" + assert error is not None + assert "parameters must be a JSON object" in error def test_prepare_call_other_tools_keep_generic_object_validation() -> None: @@ -70,7 +228,11 @@ def test_prepare_call_other_tools_keep_generic_object_validation() -> None: assert tool is not None assert params == ["TODO"] - assert error == "Error: Invalid parameters for tool 'grep': parameters must be an object, got list" + assert error == ( + "Error: Tool 'grep' parameters must be a JSON object, got list. " + 'Use named parameters like tool_name(param1="value1", param2="value2") ' + "matching the tool schema." + ) def test_get_definitions_returns_cached_result() -> None: