refactor: deduplicate tool cache marker helper in base provider

This commit is contained in:
pikaxinge 2026-04-02 07:29:07 +00:00
parent 607fd8fd7e
commit 87d493f354
3 changed files with 36 additions and 62 deletions

View File

@ -250,36 +250,6 @@ class AnthropicProvider(LLMProvider):
# Prompt caching # Prompt caching
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
name = tool.get("name")
if isinstance(name, str):
return name
fn = tool.get("function")
if isinstance(fn, dict):
fname = fn.get("name")
if isinstance(fname, str):
return fname
return ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@classmethod @classmethod
def _apply_cache_control( def _apply_cache_control(
cls, cls,

View File

@ -48,7 +48,7 @@ class LLMResponse:
usage: dict[str, int] = field(default_factory=dict) usage: dict[str, int] = field(default_factory=dict)
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc. reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
thinking_blocks: list[dict] | None = None # Anthropic extended thinking thinking_blocks: list[dict] | None = None # Anthropic extended thinking
@property @property
def has_tool_calls(self) -> bool: def has_tool_calls(self) -> bool:
"""Check if response contains tool calls.""" """Check if response contains tool calls."""
@ -73,7 +73,7 @@ class GenerationSettings:
class LLMProvider(ABC): class LLMProvider(ABC):
""" """
Abstract base class for LLM providers. Abstract base class for LLM providers.
Implementations should handle the specifics of each provider's API Implementations should handle the specifics of each provider's API
while maintaining a consistent interface. while maintaining a consistent interface.
""" """
@ -150,6 +150,38 @@ class LLMProvider(ABC):
result.append(msg) result.append(msg)
return result return result
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
"""Extract tool name from either OpenAI or Anthropic-style tool schemas."""
name = tool.get("name")
if isinstance(name, str):
return name
fn = tool.get("function")
if isinstance(fn, dict):
fname = fn.get("name")
if isinstance(fname, str):
return fname
return ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
"""Return cache marker indices: builtin/MCP boundary and tail index."""
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@staticmethod @staticmethod
def _sanitize_request_messages( def _sanitize_request_messages(
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
@ -177,7 +209,7 @@ class LLMProvider(ABC):
) -> LLMResponse: ) -> LLMResponse:
""" """
Send a chat completion request. Send a chat completion request.
Args: Args:
messages: List of message dicts with 'role' and 'content'. messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions. tools: Optional list of tool definitions.
@ -185,7 +217,7 @@ class LLMProvider(ABC):
max_tokens: Maximum tokens in response. max_tokens: Maximum tokens in response.
temperature: Sampling temperature. temperature: Sampling temperature.
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict). tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
Returns: Returns:
LLMResponse with content and/or tool calls. LLMResponse with content and/or tool calls.
""" """

View File

@ -151,34 +151,6 @@ class OpenAICompatProvider(LLMProvider):
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base) resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
os.environ.setdefault(env_name, resolved) os.environ.setdefault(env_name, resolved)
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
fn = tool.get("function")
if isinstance(fn, dict):
name = fn.get("name")
if isinstance(name, str):
return name
name = tool.get("name")
return name if isinstance(name, str) else ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@classmethod @classmethod
def _apply_cache_control( def _apply_cache_control(
cls, cls,