diff --git a/nanobot/agent/tools/registry.py b/nanobot/agent/tools/registry.py index 725706dce..99d3ec63a 100644 --- a/nanobot/agent/tools/registry.py +++ b/nanobot/agent/tools/registry.py @@ -31,9 +31,36 @@ class ToolRegistry: """Check if a tool is registered.""" return name in self._tools + @staticmethod + def _schema_name(schema: dict[str, Any]) -> str: + """Extract a normalized tool name from either OpenAI or flat schemas.""" + fn = schema.get("function") + if isinstance(fn, dict): + name = fn.get("name") + if isinstance(name, str): + return name + name = schema.get("name") + return name if isinstance(name, str) else "" + def get_definitions(self) -> list[dict[str, Any]]: - """Get all tool definitions in OpenAI format.""" - return [tool.to_schema() for tool in self._tools.values()] + """Get tool definitions with stable ordering for cache-friendly prompts. + + Built-in tools are sorted first as a stable prefix, then MCP tools are + sorted and appended. + """ + definitions = [tool.to_schema() for tool in self._tools.values()] + builtins: list[dict[str, Any]] = [] + mcp_tools: list[dict[str, Any]] = [] + for schema in definitions: + name = self._schema_name(schema) + if name.startswith("mcp_"): + mcp_tools.append(schema) + else: + builtins.append(schema) + + builtins.sort(key=self._schema_name) + mcp_tools.sort(key=self._schema_name) + return builtins + mcp_tools def prepare_call( self, diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py index 00a7f8271..1cade5fb5 100644 --- a/nanobot/providers/anthropic_provider.py +++ b/nanobot/providers/anthropic_provider.py @@ -11,7 +11,6 @@ from collections.abc import Awaitable, Callable from typing import Any import json_repair -from loguru import logger from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest @@ -255,8 +254,9 @@ class AnthropicProvider(LLMProvider): # Prompt caching # ------------------------------------------------------------------ - @staticmethod + @classmethod def _apply_cache_control( + cls, system: str | list[dict[str, Any]], messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None, @@ -283,7 +283,8 @@ class AnthropicProvider(LLMProvider): new_tools = tools if tools: new_tools = list(tools) - new_tools[-1] = {**new_tools[-1], "cache_control": marker} + for idx in cls._tool_cache_marker_indices(new_tools): + new_tools[idx] = {**new_tools[idx], "cache_control": marker} return system, new_msgs, new_tools diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 5644d194f..118eb80ca 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -54,7 +54,7 @@ class LLMResponse: retry_after: float | None = None # Provider supplied retry wait in seconds. reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc. thinking_blocks: list[dict] | None = None # Anthropic extended thinking - + @property def has_tool_calls(self) -> bool: """Check if response contains tool calls.""" @@ -148,6 +148,38 @@ class LLMProvider(ABC): result.append(msg) return result + @staticmethod + def _tool_name(tool: dict[str, Any]) -> str: + """Extract tool name from either OpenAI or Anthropic-style tool schemas.""" + name = tool.get("name") + if isinstance(name, str): + return name + fn = tool.get("function") + if isinstance(fn, dict): + fname = fn.get("name") + if isinstance(fname, str): + return fname + return "" + + @classmethod + def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]: + """Return cache marker indices: builtin/MCP boundary and tail index.""" + if not tools: + return [] + + tail_idx = len(tools) - 1 + last_builtin_idx: int | None = None + for i in range(tail_idx, -1, -1): + if not cls._tool_name(tools[i]).startswith("mcp_"): + last_builtin_idx = i + break + + ordered_unique: list[int] = [] + for idx in (last_builtin_idx, tail_idx): + if idx is not None and idx not in ordered_unique: + ordered_unique.append(idx) + return ordered_unique + @staticmethod def _sanitize_request_messages( messages: list[dict[str, Any]], @@ -175,7 +207,7 @@ class LLMProvider(ABC): ) -> LLMResponse: """ Send a chat completion request. - + Args: messages: List of message dicts with 'role' and 'content'. tools: Optional list of tool definitions. @@ -183,7 +215,7 @@ class LLMProvider(ABC): max_tokens: Maximum tokens in response. temperature: Sampling temperature. tool_choice: Tool selection strategy ("auto", "required", or specific tool dict). - + Returns: LLMResponse with content and/or tool calls. """ diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 1dca0248b..c9f797705 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -153,8 +153,9 @@ class OpenAICompatProvider(LLMProvider): resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base) os.environ.setdefault(env_name, resolved) - @staticmethod + @classmethod def _apply_cache_control( + cls, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None, ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]: @@ -182,7 +183,8 @@ class OpenAICompatProvider(LLMProvider): new_tools = tools if tools: new_tools = list(tools) - new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker} + for idx in cls._tool_cache_marker_indices(new_tools): + new_tools[idx] = {**new_tools[idx], "cache_control": cache_marker} return new_messages, new_tools @staticmethod diff --git a/tests/providers/test_prompt_cache_markers.py b/tests/providers/test_prompt_cache_markers.py new file mode 100644 index 000000000..61d5677de --- /dev/null +++ b/tests/providers/test_prompt_cache_markers.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import Any + +from nanobot.providers.anthropic_provider import AnthropicProvider +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + + +def _openai_tools(*names: str) -> list[dict[str, Any]]: + return [ + { + "type": "function", + "function": { + "name": name, + "description": f"{name} tool", + "parameters": {"type": "object", "properties": {}}, + }, + } + for name in names + ] + + +def _anthropic_tools(*names: str) -> list[dict[str, Any]]: + return [ + { + "name": name, + "description": f"{name} tool", + "input_schema": {"type": "object", "properties": {}}, + } + for name in names + ] + + +def _marked_openai_tool_names(tools: list[dict[str, Any]] | None) -> list[str]: + if not tools: + return [] + marked: list[str] = [] + for tool in tools: + if "cache_control" in tool: + marked.append((tool.get("function") or {}).get("name", "")) + return marked + + +def _marked_anthropic_tool_names(tools: list[dict[str, Any]] | None) -> list[str]: + if not tools: + return [] + return [tool.get("name", "") for tool in tools if "cache_control" in tool] + + +def test_openai_compat_marks_builtin_boundary_and_tail_tool() -> None: + messages = [ + {"role": "system", "content": "system"}, + {"role": "assistant", "content": "assistant"}, + {"role": "user", "content": "user"}, + ] + _, marked_tools = OpenAICompatProvider._apply_cache_control( + messages, + _openai_tools("read_file", "write_file", "mcp_fs_ls", "mcp_git_status"), + ) + assert _marked_openai_tool_names(marked_tools) == ["write_file", "mcp_git_status"] + + +def test_anthropic_marks_builtin_boundary_and_tail_tool() -> None: + messages = [ + {"role": "user", "content": "u1"}, + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "u2"}, + ] + _, _, marked_tools = AnthropicProvider._apply_cache_control( + "system", + messages, + _anthropic_tools("read_file", "write_file", "mcp_fs_ls", "mcp_git_status"), + ) + assert _marked_anthropic_tool_names(marked_tools) == ["write_file", "mcp_git_status"] + + +def test_openai_compat_marks_only_tail_without_mcp() -> None: + messages = [ + {"role": "system", "content": "system"}, + {"role": "assistant", "content": "assistant"}, + {"role": "user", "content": "user"}, + ] + _, marked_tools = OpenAICompatProvider._apply_cache_control( + messages, + _openai_tools("read_file", "write_file"), + ) + assert _marked_openai_tool_names(marked_tools) == ["write_file"] diff --git a/tests/tools/test_tool_registry.py b/tests/tools/test_tool_registry.py new file mode 100644 index 000000000..5b259119e --- /dev/null +++ b/tests/tools/test_tool_registry.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import Any + +from nanobot.agent.tools.base import Tool +from nanobot.agent.tools.registry import ToolRegistry + + +class _FakeTool(Tool): + def __init__(self, name: str): + self._name = name + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return f"{self._name} tool" + + @property + def parameters(self) -> dict[str, Any]: + return {"type": "object", "properties": {}} + + async def execute(self, **kwargs: Any) -> Any: + return kwargs + + +def _tool_names(definitions: list[dict[str, Any]]) -> list[str]: + names: list[str] = [] + for definition in definitions: + fn = definition.get("function", {}) + names.append(fn.get("name", "")) + return names + + +def test_get_definitions_orders_builtins_then_mcp_tools() -> None: + registry = ToolRegistry() + registry.register(_FakeTool("mcp_git_status")) + registry.register(_FakeTool("write_file")) + registry.register(_FakeTool("mcp_fs_list")) + registry.register(_FakeTool("read_file")) + + assert _tool_names(registry.get_definitions()) == [ + "read_file", + "write_file", + "mcp_fs_list", + "mcp_git_status", + ]