refactor: deduplicate tool cache marker helper in base provider

This commit is contained in:
pikaxinge 2026-04-02 07:29:07 +00:00
parent 607fd8fd7e
commit 87d493f354
3 changed files with 36 additions and 62 deletions

View File

@ -250,36 +250,6 @@ class AnthropicProvider(LLMProvider):
# Prompt caching
# ------------------------------------------------------------------
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
name = tool.get("name")
if isinstance(name, str):
return name
fn = tool.get("function")
if isinstance(fn, dict):
fname = fn.get("name")
if isinstance(fname, str):
return fname
return ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@classmethod
def _apply_cache_control(
cls,

View File

@ -150,6 +150,38 @@ class LLMProvider(ABC):
result.append(msg)
return result
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
"""Extract tool name from either OpenAI or Anthropic-style tool schemas."""
name = tool.get("name")
if isinstance(name, str):
return name
fn = tool.get("function")
if isinstance(fn, dict):
fname = fn.get("name")
if isinstance(fname, str):
return fname
return ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
"""Return cache marker indices: builtin/MCP boundary and tail index."""
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],

View File

@ -151,34 +151,6 @@ class OpenAICompatProvider(LLMProvider):
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
os.environ.setdefault(env_name, resolved)
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
fn = tool.get("function")
if isinstance(fn, dict):
name = fn.get("name")
if isinstance(name, str):
return name
name = tool.get("name")
return name if isinstance(name, str) else ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@classmethod
def _apply_cache_control(
cls,