From 87d493f3549fd5a90586f03c07246dfc0be72e5e Mon Sep 17 00:00:00 2001 From: pikaxinge <2392811793@qq.com> Date: Thu, 2 Apr 2026 07:29:07 +0000 Subject: [PATCH] refactor: deduplicate tool cache marker helper in base provider --- nanobot/providers/anthropic_provider.py | 30 ---------------- nanobot/providers/base.py | 40 ++++++++++++++++++--- nanobot/providers/openai_compat_provider.py | 28 --------------- 3 files changed, 36 insertions(+), 62 deletions(-) diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py index 563484585..defbe0bc6 100644 --- a/nanobot/providers/anthropic_provider.py +++ b/nanobot/providers/anthropic_provider.py @@ -250,36 +250,6 @@ class AnthropicProvider(LLMProvider): # Prompt caching # ------------------------------------------------------------------ - @staticmethod - def _tool_name(tool: dict[str, Any]) -> str: - name = tool.get("name") - if isinstance(name, str): - return name - fn = tool.get("function") - if isinstance(fn, dict): - fname = fn.get("name") - if isinstance(fname, str): - return fname - return "" - - @classmethod - def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]: - if not tools: - return [] - - tail_idx = len(tools) - 1 - last_builtin_idx: int | None = None - for i in range(tail_idx, -1, -1): - if not cls._tool_name(tools[i]).startswith("mcp_"): - last_builtin_idx = i - break - - ordered_unique: list[int] = [] - for idx in (last_builtin_idx, tail_idx): - if idx is not None and idx not in ordered_unique: - ordered_unique.append(idx) - return ordered_unique - @classmethod def _apply_cache_control( cls, diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 9ce2b0c63..8eb67d6b0 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -48,7 +48,7 @@ class LLMResponse: usage: dict[str, int] = field(default_factory=dict) reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc. thinking_blocks: list[dict] | None = None # Anthropic extended thinking - + @property def has_tool_calls(self) -> bool: """Check if response contains tool calls.""" @@ -73,7 +73,7 @@ class GenerationSettings: class LLMProvider(ABC): """ Abstract base class for LLM providers. - + Implementations should handle the specifics of each provider's API while maintaining a consistent interface. """ @@ -150,6 +150,38 @@ class LLMProvider(ABC): result.append(msg) return result + @staticmethod + def _tool_name(tool: dict[str, Any]) -> str: + """Extract tool name from either OpenAI or Anthropic-style tool schemas.""" + name = tool.get("name") + if isinstance(name, str): + return name + fn = tool.get("function") + if isinstance(fn, dict): + fname = fn.get("name") + if isinstance(fname, str): + return fname + return "" + + @classmethod + def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]: + """Return cache marker indices: builtin/MCP boundary and tail index.""" + if not tools: + return [] + + tail_idx = len(tools) - 1 + last_builtin_idx: int | None = None + for i in range(tail_idx, -1, -1): + if not cls._tool_name(tools[i]).startswith("mcp_"): + last_builtin_idx = i + break + + ordered_unique: list[int] = [] + for idx in (last_builtin_idx, tail_idx): + if idx is not None and idx not in ordered_unique: + ordered_unique.append(idx) + return ordered_unique + @staticmethod def _sanitize_request_messages( messages: list[dict[str, Any]], @@ -177,7 +209,7 @@ class LLMProvider(ABC): ) -> LLMResponse: """ Send a chat completion request. - + Args: messages: List of message dicts with 'role' and 'content'. tools: Optional list of tool definitions. @@ -185,7 +217,7 @@ class LLMProvider(ABC): max_tokens: Maximum tokens in response. temperature: Sampling temperature. tool_choice: Tool selection strategy ("auto", "required", or specific tool dict). - + Returns: LLMResponse with content and/or tool calls. """ diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 9d70d269d..d9a0be7f9 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -151,34 +151,6 @@ class OpenAICompatProvider(LLMProvider): resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base) os.environ.setdefault(env_name, resolved) - @staticmethod - def _tool_name(tool: dict[str, Any]) -> str: - fn = tool.get("function") - if isinstance(fn, dict): - name = fn.get("name") - if isinstance(name, str): - return name - name = tool.get("name") - return name if isinstance(name, str) else "" - - @classmethod - def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]: - if not tools: - return [] - - tail_idx = len(tools) - 1 - last_builtin_idx: int | None = None - for i in range(tail_idx, -1, -1): - if not cls._tool_name(tools[i]).startswith("mcp_"): - last_builtin_idx = i - break - - ordered_unique: list[int] = [] - for idx in (last_builtin_idx, tail_idx): - if idx is not None and idx not in ordered_unique: - ordered_unique.append(idx) - return ordered_unique - @classmethod def _apply_cache_control( cls,