Merge PR #2722: perf(cache): stabilize tool prefix caching under MCP tool churn

perf(cache): stabilize tool prefix caching under MCP tool churn
This commit is contained in:
Xubin Ren 2026-04-04 21:57:15 +08:00 committed by GitHub
commit cf56d15bdf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 208 additions and 10 deletions

View File

@ -31,9 +31,36 @@ class ToolRegistry:
"""Check if a tool is registered."""
return name in self._tools
@staticmethod
def _schema_name(schema: dict[str, Any]) -> str:
"""Extract a normalized tool name from either OpenAI or flat schemas."""
fn = schema.get("function")
if isinstance(fn, dict):
name = fn.get("name")
if isinstance(name, str):
return name
name = schema.get("name")
return name if isinstance(name, str) else ""
def get_definitions(self) -> list[dict[str, Any]]:
"""Get all tool definitions in OpenAI format."""
return [tool.to_schema() for tool in self._tools.values()]
"""Get tool definitions with stable ordering for cache-friendly prompts.
Built-in tools are sorted first as a stable prefix, then MCP tools are
sorted and appended.
"""
definitions = [tool.to_schema() for tool in self._tools.values()]
builtins: list[dict[str, Any]] = []
mcp_tools: list[dict[str, Any]] = []
for schema in definitions:
name = self._schema_name(schema)
if name.startswith("mcp_"):
mcp_tools.append(schema)
else:
builtins.append(schema)
builtins.sort(key=self._schema_name)
mcp_tools.sort(key=self._schema_name)
return builtins + mcp_tools
def prepare_call(
self,

View File

@ -11,7 +11,6 @@ from collections.abc import Awaitable, Callable
from typing import Any
import json_repair
from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
@ -255,8 +254,9 @@ class AnthropicProvider(LLMProvider):
# Prompt caching
# ------------------------------------------------------------------
@staticmethod
@classmethod
def _apply_cache_control(
cls,
system: str | list[dict[str, Any]],
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
@ -283,7 +283,8 @@ class AnthropicProvider(LLMProvider):
new_tools = tools
if tools:
new_tools = list(tools)
new_tools[-1] = {**new_tools[-1], "cache_control": marker}
for idx in cls._tool_cache_marker_indices(new_tools):
new_tools[idx] = {**new_tools[idx], "cache_control": marker}
return system, new_msgs, new_tools

View File

@ -54,7 +54,7 @@ class LLMResponse:
retry_after: float | None = None # Provider supplied retry wait in seconds.
reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc.
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
@property
def has_tool_calls(self) -> bool:
"""Check if response contains tool calls."""
@ -148,6 +148,38 @@ class LLMProvider(ABC):
result.append(msg)
return result
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
"""Extract tool name from either OpenAI or Anthropic-style tool schemas."""
name = tool.get("name")
if isinstance(name, str):
return name
fn = tool.get("function")
if isinstance(fn, dict):
fname = fn.get("name")
if isinstance(fname, str):
return fname
return ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
"""Return cache marker indices: builtin/MCP boundary and tail index."""
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],
@ -175,7 +207,7 @@ class LLMProvider(ABC):
) -> LLMResponse:
"""
Send a chat completion request.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions.
@ -183,7 +215,7 @@ class LLMProvider(ABC):
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
Returns:
LLMResponse with content and/or tool calls.
"""

View File

@ -153,8 +153,9 @@ class OpenAICompatProvider(LLMProvider):
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
os.environ.setdefault(env_name, resolved)
@staticmethod
@classmethod
def _apply_cache_control(
cls,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
@ -182,7 +183,8 @@ class OpenAICompatProvider(LLMProvider):
new_tools = tools
if tools:
new_tools = list(tools)
new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker}
for idx in cls._tool_cache_marker_indices(new_tools):
new_tools[idx] = {**new_tools[idx], "cache_control": cache_marker}
return new_messages, new_tools
@staticmethod

View File

@ -0,0 +1,87 @@
from __future__ import annotations
from typing import Any
from nanobot.providers.anthropic_provider import AnthropicProvider
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
def _openai_tools(*names: str) -> list[dict[str, Any]]:
return [
{
"type": "function",
"function": {
"name": name,
"description": f"{name} tool",
"parameters": {"type": "object", "properties": {}},
},
}
for name in names
]
def _anthropic_tools(*names: str) -> list[dict[str, Any]]:
return [
{
"name": name,
"description": f"{name} tool",
"input_schema": {"type": "object", "properties": {}},
}
for name in names
]
def _marked_openai_tool_names(tools: list[dict[str, Any]] | None) -> list[str]:
if not tools:
return []
marked: list[str] = []
for tool in tools:
if "cache_control" in tool:
marked.append((tool.get("function") or {}).get("name", ""))
return marked
def _marked_anthropic_tool_names(tools: list[dict[str, Any]] | None) -> list[str]:
if not tools:
return []
return [tool.get("name", "") for tool in tools if "cache_control" in tool]
def test_openai_compat_marks_builtin_boundary_and_tail_tool() -> None:
messages = [
{"role": "system", "content": "system"},
{"role": "assistant", "content": "assistant"},
{"role": "user", "content": "user"},
]
_, marked_tools = OpenAICompatProvider._apply_cache_control(
messages,
_openai_tools("read_file", "write_file", "mcp_fs_ls", "mcp_git_status"),
)
assert _marked_openai_tool_names(marked_tools) == ["write_file", "mcp_git_status"]
def test_anthropic_marks_builtin_boundary_and_tail_tool() -> None:
messages = [
{"role": "user", "content": "u1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "u2"},
]
_, _, marked_tools = AnthropicProvider._apply_cache_control(
"system",
messages,
_anthropic_tools("read_file", "write_file", "mcp_fs_ls", "mcp_git_status"),
)
assert _marked_anthropic_tool_names(marked_tools) == ["write_file", "mcp_git_status"]
def test_openai_compat_marks_only_tail_without_mcp() -> None:
messages = [
{"role": "system", "content": "system"},
{"role": "assistant", "content": "assistant"},
{"role": "user", "content": "user"},
]
_, marked_tools = OpenAICompatProvider._apply_cache_control(
messages,
_openai_tools("read_file", "write_file"),
)
assert _marked_openai_tool_names(marked_tools) == ["write_file"]

View File

@ -0,0 +1,49 @@
from __future__ import annotations
from typing import Any
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry
class _FakeTool(Tool):
def __init__(self, name: str):
self._name = name
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return f"{self._name} tool"
@property
def parameters(self) -> dict[str, Any]:
return {"type": "object", "properties": {}}
async def execute(self, **kwargs: Any) -> Any:
return kwargs
def _tool_names(definitions: list[dict[str, Any]]) -> list[str]:
names: list[str] = []
for definition in definitions:
fn = definition.get("function", {})
names.append(fn.get("name", ""))
return names
def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
registry = ToolRegistry()
registry.register(_FakeTool("mcp_git_status"))
registry.register(_FakeTool("write_file"))
registry.register(_FakeTool("mcp_fs_list"))
registry.register(_FakeTool("read_file"))
assert _tool_names(registry.get_definitions()) == [
"read_file",
"write_file",
"mcp_fs_list",
"mcp_git_status",
]