mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-05 10:52:36 +00:00
Merge PR #2722: perf(cache): stabilize tool prefix caching under MCP tool churn
perf(cache): stabilize tool prefix caching under MCP tool churn
This commit is contained in:
commit
cf56d15bdf
@ -31,9 +31,36 @@ class ToolRegistry:
|
||||
"""Check if a tool is registered."""
|
||||
return name in self._tools
|
||||
|
||||
@staticmethod
|
||||
def _schema_name(schema: dict[str, Any]) -> str:
|
||||
"""Extract a normalized tool name from either OpenAI or flat schemas."""
|
||||
fn = schema.get("function")
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name")
|
||||
if isinstance(name, str):
|
||||
return name
|
||||
name = schema.get("name")
|
||||
return name if isinstance(name, str) else ""
|
||||
|
||||
def get_definitions(self) -> list[dict[str, Any]]:
|
||||
"""Get all tool definitions in OpenAI format."""
|
||||
return [tool.to_schema() for tool in self._tools.values()]
|
||||
"""Get tool definitions with stable ordering for cache-friendly prompts.
|
||||
|
||||
Built-in tools are sorted first as a stable prefix, then MCP tools are
|
||||
sorted and appended.
|
||||
"""
|
||||
definitions = [tool.to_schema() for tool in self._tools.values()]
|
||||
builtins: list[dict[str, Any]] = []
|
||||
mcp_tools: list[dict[str, Any]] = []
|
||||
for schema in definitions:
|
||||
name = self._schema_name(schema)
|
||||
if name.startswith("mcp_"):
|
||||
mcp_tools.append(schema)
|
||||
else:
|
||||
builtins.append(schema)
|
||||
|
||||
builtins.sort(key=self._schema_name)
|
||||
mcp_tools.sort(key=self._schema_name)
|
||||
return builtins + mcp_tools
|
||||
|
||||
def prepare_call(
|
||||
self,
|
||||
|
||||
@ -11,7 +11,6 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
@ -255,8 +254,9 @@ class AnthropicProvider(LLMProvider):
|
||||
# Prompt caching
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def _apply_cache_control(
|
||||
cls,
|
||||
system: str | list[dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
@ -283,7 +283,8 @@ class AnthropicProvider(LLMProvider):
|
||||
new_tools = tools
|
||||
if tools:
|
||||
new_tools = list(tools)
|
||||
new_tools[-1] = {**new_tools[-1], "cache_control": marker}
|
||||
for idx in cls._tool_cache_marker_indices(new_tools):
|
||||
new_tools[idx] = {**new_tools[idx], "cache_control": marker}
|
||||
|
||||
return system, new_msgs, new_tools
|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ class LLMResponse:
|
||||
retry_after: float | None = None # Provider supplied retry wait in seconds.
|
||||
reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc.
|
||||
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
|
||||
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
@ -148,6 +148,38 @@ class LLMProvider(ABC):
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _tool_name(tool: dict[str, Any]) -> str:
|
||||
"""Extract tool name from either OpenAI or Anthropic-style tool schemas."""
|
||||
name = tool.get("name")
|
||||
if isinstance(name, str):
|
||||
return name
|
||||
fn = tool.get("function")
|
||||
if isinstance(fn, dict):
|
||||
fname = fn.get("name")
|
||||
if isinstance(fname, str):
|
||||
return fname
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
|
||||
"""Return cache marker indices: builtin/MCP boundary and tail index."""
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
tail_idx = len(tools) - 1
|
||||
last_builtin_idx: int | None = None
|
||||
for i in range(tail_idx, -1, -1):
|
||||
if not cls._tool_name(tools[i]).startswith("mcp_"):
|
||||
last_builtin_idx = i
|
||||
break
|
||||
|
||||
ordered_unique: list[int] = []
|
||||
for idx in (last_builtin_idx, tail_idx):
|
||||
if idx is not None and idx not in ordered_unique:
|
||||
ordered_unique.append(idx)
|
||||
return ordered_unique
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_request_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
@ -175,7 +207,7 @@ class LLMProvider(ABC):
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions.
|
||||
@ -183,7 +215,7 @@ class LLMProvider(ABC):
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
|
||||
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
|
||||
@ -153,8 +153,9 @@ class OpenAICompatProvider(LLMProvider):
|
||||
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
||||
os.environ.setdefault(env_name, resolved)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def _apply_cache_control(
|
||||
cls,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
||||
@ -182,7 +183,8 @@ class OpenAICompatProvider(LLMProvider):
|
||||
new_tools = tools
|
||||
if tools:
|
||||
new_tools = list(tools)
|
||||
new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker}
|
||||
for idx in cls._tool_cache_marker_indices(new_tools):
|
||||
new_tools[idx] = {**new_tools[idx], "cache_control": cache_marker}
|
||||
return new_messages, new_tools
|
||||
|
||||
@staticmethod
|
||||
|
||||
87
tests/providers/test_prompt_cache_markers.py
Normal file
87
tests/providers/test_prompt_cache_markers.py
Normal file
@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
|
||||
def _openai_tools(*names: str) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": f"{name} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for name in names
|
||||
]
|
||||
|
||||
|
||||
def _anthropic_tools(*names: str) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"name": name,
|
||||
"description": f"{name} tool",
|
||||
"input_schema": {"type": "object", "properties": {}},
|
||||
}
|
||||
for name in names
|
||||
]
|
||||
|
||||
|
||||
def _marked_openai_tool_names(tools: list[dict[str, Any]] | None) -> list[str]:
|
||||
if not tools:
|
||||
return []
|
||||
marked: list[str] = []
|
||||
for tool in tools:
|
||||
if "cache_control" in tool:
|
||||
marked.append((tool.get("function") or {}).get("name", ""))
|
||||
return marked
|
||||
|
||||
|
||||
def _marked_anthropic_tool_names(tools: list[dict[str, Any]] | None) -> list[str]:
|
||||
if not tools:
|
||||
return []
|
||||
return [tool.get("name", "") for tool in tools if "cache_control" in tool]
|
||||
|
||||
|
||||
def test_openai_compat_marks_builtin_boundary_and_tail_tool() -> None:
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "assistant", "content": "assistant"},
|
||||
{"role": "user", "content": "user"},
|
||||
]
|
||||
_, marked_tools = OpenAICompatProvider._apply_cache_control(
|
||||
messages,
|
||||
_openai_tools("read_file", "write_file", "mcp_fs_ls", "mcp_git_status"),
|
||||
)
|
||||
assert _marked_openai_tool_names(marked_tools) == ["write_file", "mcp_git_status"]
|
||||
|
||||
|
||||
def test_anthropic_marks_builtin_boundary_and_tail_tool() -> None:
|
||||
messages = [
|
||||
{"role": "user", "content": "u1"},
|
||||
{"role": "assistant", "content": "a1"},
|
||||
{"role": "user", "content": "u2"},
|
||||
]
|
||||
_, _, marked_tools = AnthropicProvider._apply_cache_control(
|
||||
"system",
|
||||
messages,
|
||||
_anthropic_tools("read_file", "write_file", "mcp_fs_ls", "mcp_git_status"),
|
||||
)
|
||||
assert _marked_anthropic_tool_names(marked_tools) == ["write_file", "mcp_git_status"]
|
||||
|
||||
|
||||
def test_openai_compat_marks_only_tail_without_mcp() -> None:
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "assistant", "content": "assistant"},
|
||||
{"role": "user", "content": "user"},
|
||||
]
|
||||
_, marked_tools = OpenAICompatProvider._apply_cache_control(
|
||||
messages,
|
||||
_openai_tools("read_file", "write_file"),
|
||||
)
|
||||
assert _marked_openai_tool_names(marked_tools) == ["write_file"]
|
||||
49
tests/tools/test_tool_registry.py
Normal file
49
tests/tools/test_tool_registry.py
Normal file
@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
class _FakeTool(Tool):
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"{self._name} tool"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> Any:
|
||||
return kwargs
|
||||
|
||||
|
||||
def _tool_names(definitions: list[dict[str, Any]]) -> list[str]:
|
||||
names: list[str] = []
|
||||
for definition in definitions:
|
||||
fn = definition.get("function", {})
|
||||
names.append(fn.get("name", ""))
|
||||
return names
|
||||
|
||||
|
||||
def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool("mcp_git_status"))
|
||||
registry.register(_FakeTool("write_file"))
|
||||
registry.register(_FakeTool("mcp_fs_list"))
|
||||
registry.register(_FakeTool("read_file"))
|
||||
|
||||
assert _tool_names(registry.get_definitions()) == [
|
||||
"read_file",
|
||||
"write_file",
|
||||
"mcp_fs_list",
|
||||
"mcp_git_status",
|
||||
]
|
||||
Loading…
x
Reference in New Issue
Block a user