refactor: deduplicate tool cache marker helper in base provider

This commit is contained in:
pikaxinge 2026-04-02 07:29:07 +00:00
parent 607fd8fd7e
commit 87d493f354
3 changed files with 36 additions and 62 deletions

View File

@ -250,36 +250,6 @@ class AnthropicProvider(LLMProvider):
# Prompt caching
# ------------------------------------------------------------------
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
name = tool.get("name")
if isinstance(name, str):
return name
fn = tool.get("function")
if isinstance(fn, dict):
fname = fn.get("name")
if isinstance(fname, str):
return fname
return ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@classmethod
def _apply_cache_control(
cls,

View File

@ -48,7 +48,7 @@ class LLMResponse:
usage: dict[str, int] = field(default_factory=dict)
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
@property
def has_tool_calls(self) -> bool:
"""Check if response contains tool calls."""
@ -73,7 +73,7 @@ class GenerationSettings:
class LLMProvider(ABC):
"""
Abstract base class for LLM providers.
Implementations should handle the specifics of each provider's API
while maintaining a consistent interface.
"""
@ -150,6 +150,38 @@ class LLMProvider(ABC):
result.append(msg)
return result
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
"""Extract tool name from either OpenAI or Anthropic-style tool schemas."""
name = tool.get("name")
if isinstance(name, str):
return name
fn = tool.get("function")
if isinstance(fn, dict):
fname = fn.get("name")
if isinstance(fname, str):
return fname
return ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
"""Return cache marker indices: builtin/MCP boundary and tail index."""
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],
@ -177,7 +209,7 @@ class LLMProvider(ABC):
) -> LLMResponse:
"""
Send a chat completion request.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions.
@ -185,7 +217,7 @@ class LLMProvider(ABC):
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
Returns:
LLMResponse with content and/or tool calls.
"""

View File

@ -151,34 +151,6 @@ class OpenAICompatProvider(LLMProvider):
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
os.environ.setdefault(env_name, resolved)
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
fn = tool.get("function")
if isinstance(fn, dict):
name = fn.get("name")
if isinstance(name, str):
return name
name = tool.get("name")
return name if isinstance(name, str) else ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@classmethod
def _apply_cache_control(
cls,