fix(cache): stabilize tool ordering and cache markers for MCP

This commit is contained in:
pikaxinge 2026-04-01 17:07:22 +00:00
parent 63d646f731
commit 607fd8fd7e
5 changed files with 234 additions and 10 deletions

View File

@ -31,13 +31,40 @@ class ToolRegistry:
"""Check if a tool is registered."""
return name in self._tools
@staticmethod
def _schema_name(schema: dict[str, Any]) -> str:
"""Extract a normalized tool name from either OpenAI or flat schemas."""
fn = schema.get("function")
if isinstance(fn, dict):
name = fn.get("name")
if isinstance(name, str):
return name
name = schema.get("name")
return name if isinstance(name, str) else ""
def get_definitions(self) -> list[dict[str, Any]]:
"""Get all tool definitions in OpenAI format."""
return [tool.to_schema() for tool in self._tools.values()]
"""Get tool definitions with stable ordering for cache-friendly prompts.
Built-in tools are sorted first as a stable prefix, then MCP tools are
sorted and appended.
"""
definitions = [tool.to_schema() for tool in self._tools.values()]
builtins: list[dict[str, Any]] = []
mcp_tools: list[dict[str, Any]] = []
for schema in definitions:
name = self._schema_name(schema)
if name.startswith("mcp_"):
mcp_tools.append(schema)
else:
builtins.append(schema)
builtins.sort(key=self._schema_name)
mcp_tools.sort(key=self._schema_name)
return builtins + mcp_tools
async def execute(self, name: str, params: dict[str, Any]) -> Any:
"""Execute a tool by name with given parameters."""
_HINT = "\n\n[Analyze the error above and try a different approach.]"
hint = "\n\n[Analyze the error above and try a different approach.]"
tool = self._tools.get(name)
if not tool:
@ -46,17 +73,17 @@ class ToolRegistry:
try:
# Attempt to cast parameters to match schema types
params = tool.cast_params(params)
# Validate parameters
errors = tool.validate_params(params)
if errors:
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + hint
result = await tool.execute(**params)
if isinstance(result, str) and result.startswith("Error"):
return result + _HINT
return result + hint
return result
except Exception as e:
return f"Error executing {name}: {str(e)}" + _HINT
return f"Error executing {name}: {str(e)}" + hint
@property
def tool_names(self) -> list[str]:

View File

@ -9,7 +9,6 @@ from collections.abc import Awaitable, Callable
from typing import Any
import json_repair
from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
@ -252,7 +251,38 @@ class AnthropicProvider(LLMProvider):
# ------------------------------------------------------------------
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
name = tool.get("name")
if isinstance(name, str):
return name
fn = tool.get("function")
if isinstance(fn, dict):
fname = fn.get("name")
if isinstance(fname, str):
return fname
return ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@classmethod
def _apply_cache_control(
cls,
system: str | list[dict[str, Any]],
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
@ -279,7 +309,8 @@ class AnthropicProvider(LLMProvider):
new_tools = tools
if tools:
new_tools = list(tools)
new_tools[-1] = {**new_tools[-1], "cache_control": marker}
for idx in cls._tool_cache_marker_indices(new_tools):
new_tools[idx] = {**new_tools[idx], "cache_control": marker}
return system, new_msgs, new_tools

View File

@ -152,7 +152,36 @@ class OpenAICompatProvider(LLMProvider):
os.environ.setdefault(env_name, resolved)
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
fn = tool.get("function")
if isinstance(fn, dict):
name = fn.get("name")
if isinstance(name, str):
return name
name = tool.get("name")
return name if isinstance(name, str) else ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@classmethod
def _apply_cache_control(
cls,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
@ -180,7 +209,8 @@ class OpenAICompatProvider(LLMProvider):
new_tools = tools
if tools:
new_tools = list(tools)
new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker}
for idx in cls._tool_cache_marker_indices(new_tools):
new_tools[idx] = {**new_tools[idx], "cache_control": cache_marker}
return new_messages, new_tools
@staticmethod

View File

@ -0,0 +1,87 @@
from __future__ import annotations
from typing import Any
from nanobot.providers.anthropic_provider import AnthropicProvider
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
def _openai_tools(*names: str) -> list[dict[str, Any]]:
return [
{
"type": "function",
"function": {
"name": name,
"description": f"{name} tool",
"parameters": {"type": "object", "properties": {}},
},
}
for name in names
]
def _anthropic_tools(*names: str) -> list[dict[str, Any]]:
return [
{
"name": name,
"description": f"{name} tool",
"input_schema": {"type": "object", "properties": {}},
}
for name in names
]
def _marked_openai_tool_names(tools: list[dict[str, Any]] | None) -> list[str]:
if not tools:
return []
marked: list[str] = []
for tool in tools:
if "cache_control" in tool:
marked.append((tool.get("function") or {}).get("name", ""))
return marked
def _marked_anthropic_tool_names(tools: list[dict[str, Any]] | None) -> list[str]:
if not tools:
return []
return [tool.get("name", "") for tool in tools if "cache_control" in tool]
def test_openai_compat_marks_builtin_boundary_and_tail_tool() -> None:
messages = [
{"role": "system", "content": "system"},
{"role": "assistant", "content": "assistant"},
{"role": "user", "content": "user"},
]
_, marked_tools = OpenAICompatProvider._apply_cache_control(
messages,
_openai_tools("read_file", "write_file", "mcp_fs_ls", "mcp_git_status"),
)
assert _marked_openai_tool_names(marked_tools) == ["write_file", "mcp_git_status"]
def test_anthropic_marks_builtin_boundary_and_tail_tool() -> None:
messages = [
{"role": "user", "content": "u1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "u2"},
]
_, _, marked_tools = AnthropicProvider._apply_cache_control(
"system",
messages,
_anthropic_tools("read_file", "write_file", "mcp_fs_ls", "mcp_git_status"),
)
assert _marked_anthropic_tool_names(marked_tools) == ["write_file", "mcp_git_status"]
def test_openai_compat_marks_only_tail_without_mcp() -> None:
messages = [
{"role": "system", "content": "system"},
{"role": "assistant", "content": "assistant"},
{"role": "user", "content": "user"},
]
_, marked_tools = OpenAICompatProvider._apply_cache_control(
messages,
_openai_tools("read_file", "write_file"),
)
assert _marked_openai_tool_names(marked_tools) == ["write_file"]

View File

@ -0,0 +1,49 @@
from __future__ import annotations
from typing import Any
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry
class _FakeTool(Tool):
def __init__(self, name: str):
self._name = name
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return f"{self._name} tool"
@property
def parameters(self) -> dict[str, Any]:
return {"type": "object", "properties": {}}
async def execute(self, **kwargs: Any) -> Any:
return kwargs
def _tool_names(definitions: list[dict[str, Any]]) -> list[str]:
names: list[str] = []
for definition in definitions:
fn = definition.get("function", {})
names.append(fn.get("name", ""))
return names
def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
registry = ToolRegistry()
registry.register(_FakeTool("mcp_git_status"))
registry.register(_FakeTool("write_file"))
registry.register(_FakeTool("mcp_fs_list"))
registry.register(_FakeTool("read_file"))
assert _tool_names(registry.get_definitions()) == [
"read_file",
"write_file",
"mcp_fs_list",
"mcp_git_status",
]