mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-06 19:23:39 +00:00
refactor: deduplicate tool cache marker helper in base provider
This commit is contained in:
parent
607fd8fd7e
commit
87d493f354
@ -250,36 +250,6 @@ class AnthropicProvider(LLMProvider):
|
||||
# Prompt caching
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _tool_name(tool: dict[str, Any]) -> str:
|
||||
name = tool.get("name")
|
||||
if isinstance(name, str):
|
||||
return name
|
||||
fn = tool.get("function")
|
||||
if isinstance(fn, dict):
|
||||
fname = fn.get("name")
|
||||
if isinstance(fname, str):
|
||||
return fname
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
tail_idx = len(tools) - 1
|
||||
last_builtin_idx: int | None = None
|
||||
for i in range(tail_idx, -1, -1):
|
||||
if not cls._tool_name(tools[i]).startswith("mcp_"):
|
||||
last_builtin_idx = i
|
||||
break
|
||||
|
||||
ordered_unique: list[int] = []
|
||||
for idx in (last_builtin_idx, tail_idx):
|
||||
if idx is not None and idx not in ordered_unique:
|
||||
ordered_unique.append(idx)
|
||||
return ordered_unique
|
||||
|
||||
@classmethod
|
||||
def _apply_cache_control(
|
||||
cls,
|
||||
|
||||
@ -48,7 +48,7 @@ class LLMResponse:
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
|
||||
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
|
||||
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
@ -73,7 +73,7 @@ class GenerationSettings:
|
||||
class LLMProvider(ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers.
|
||||
|
||||
|
||||
Implementations should handle the specifics of each provider's API
|
||||
while maintaining a consistent interface.
|
||||
"""
|
||||
@ -150,6 +150,38 @@ class LLMProvider(ABC):
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _tool_name(tool: dict[str, Any]) -> str:
|
||||
"""Extract tool name from either OpenAI or Anthropic-style tool schemas."""
|
||||
name = tool.get("name")
|
||||
if isinstance(name, str):
|
||||
return name
|
||||
fn = tool.get("function")
|
||||
if isinstance(fn, dict):
|
||||
fname = fn.get("name")
|
||||
if isinstance(fname, str):
|
||||
return fname
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
|
||||
"""Return cache marker indices: builtin/MCP boundary and tail index."""
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
tail_idx = len(tools) - 1
|
||||
last_builtin_idx: int | None = None
|
||||
for i in range(tail_idx, -1, -1):
|
||||
if not cls._tool_name(tools[i]).startswith("mcp_"):
|
||||
last_builtin_idx = i
|
||||
break
|
||||
|
||||
ordered_unique: list[int] = []
|
||||
for idx in (last_builtin_idx, tail_idx):
|
||||
if idx is not None and idx not in ordered_unique:
|
||||
ordered_unique.append(idx)
|
||||
return ordered_unique
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_request_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
@ -177,7 +209,7 @@ class LLMProvider(ABC):
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions.
|
||||
@ -185,7 +217,7 @@ class LLMProvider(ABC):
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
|
||||
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
|
||||
@ -151,34 +151,6 @@ class OpenAICompatProvider(LLMProvider):
|
||||
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
||||
os.environ.setdefault(env_name, resolved)
|
||||
|
||||
@staticmethod
|
||||
def _tool_name(tool: dict[str, Any]) -> str:
|
||||
fn = tool.get("function")
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name")
|
||||
if isinstance(name, str):
|
||||
return name
|
||||
name = tool.get("name")
|
||||
return name if isinstance(name, str) else ""
|
||||
|
||||
@classmethod
|
||||
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
tail_idx = len(tools) - 1
|
||||
last_builtin_idx: int | None = None
|
||||
for i in range(tail_idx, -1, -1):
|
||||
if not cls._tool_name(tools[i]).startswith("mcp_"):
|
||||
last_builtin_idx = i
|
||||
break
|
||||
|
||||
ordered_unique: list[int] = []
|
||||
for idx in (last_builtin_idx, tail_idx):
|
||||
if idx is not None and idx not in ordered_unique:
|
||||
ordered_unique.append(idx)
|
||||
return ordered_unique
|
||||
|
||||
@classmethod
|
||||
def _apply_cache_control(
|
||||
cls,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user