mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
fix(providers): only apply cache_control for Claude models on OpenRouter
This commit is contained in:
parent
485c75e065
commit
05fe73947f
@ -18,10 +18,17 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.registry import ProviderSpec
|
||||
|
||||
_ALLOWED_MSG_KEYS = frozenset({
|
||||
"role", "content", "tool_calls", "tool_call_id", "name",
|
||||
"reasoning_content", "extra_content",
|
||||
})
|
||||
_ALLOWED_MSG_KEYS = frozenset(
|
||||
{
|
||||
"role",
|
||||
"content",
|
||||
"tool_calls",
|
||||
"tool_call_id",
|
||||
"name",
|
||||
"reasoning_content",
|
||||
"extra_content",
|
||||
}
|
||||
)
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
|
||||
@ -59,7 +66,9 @@ def _coerce_dict(value: Any) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_tc_extras(tc: Any) -> tuple[
|
||||
def _extract_tc_extras(
|
||||
tc: Any,
|
||||
) -> tuple[
|
||||
dict[str, Any] | None,
|
||||
dict[str, Any] | None,
|
||||
dict[str, Any] | None,
|
||||
@ -75,14 +84,18 @@ def _extract_tc_extras(tc: Any) -> tuple[
|
||||
prov = None
|
||||
fn_prov = None
|
||||
if tc_dict is not None:
|
||||
leftover = {k: v for k, v in tc_dict.items()
|
||||
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None}
|
||||
leftover = {
|
||||
k: v
|
||||
for k, v in tc_dict.items()
|
||||
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None
|
||||
}
|
||||
if leftover:
|
||||
prov = leftover
|
||||
fn = _coerce_dict(tc_dict.get("function"))
|
||||
if fn is not None:
|
||||
fn_leftover = {k: v for k, v in fn.items()
|
||||
if k not in _STANDARD_FN_KEYS and v is not None}
|
||||
fn_leftover = {
|
||||
k: v for k, v in fn.items() if k not in _STANDARD_FN_KEYS and v is not None
|
||||
}
|
||||
if fn_leftover:
|
||||
fn_prov = fn_leftover
|
||||
else:
|
||||
@ -163,9 +176,12 @@ class OpenAICompatProvider(LLMProvider):
|
||||
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
return {**msg, "content": [
|
||||
{"type": "text", "text": content, "cache_control": cache_marker},
|
||||
]}
|
||||
return {
|
||||
**msg,
|
||||
"content": [
|
||||
{"type": "text", "text": content, "cache_control": cache_marker},
|
||||
],
|
||||
}
|
||||
if isinstance(content, list) and content:
|
||||
nc = list(content)
|
||||
nc[-1] = {**nc[-1], "cache_control": cache_marker}
|
||||
@ -235,7 +251,9 @@ class OpenAICompatProvider(LLMProvider):
|
||||
spec = self._spec
|
||||
|
||||
if spec and spec.supports_prompt_caching:
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
model_name = model or self.default_model
|
||||
if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")):
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
|
||||
if spec and spec.strip_model_prefix:
|
||||
model_name = model_name.split("/")[-1]
|
||||
@ -348,7 +366,9 @@ class OpenAICompatProvider(LLMProvider):
|
||||
finish_reason=str(response_map.get("finish_reason") or "stop"),
|
||||
usage=self._extract_usage(response_map),
|
||||
)
|
||||
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
||||
return LLMResponse(
|
||||
content="Error: API returned empty choices.", finish_reason="error"
|
||||
)
|
||||
|
||||
choice0 = self._maybe_mapping(choices[0]) or {}
|
||||
msg0 = self._maybe_mapping(choice0.get("message")) or {}
|
||||
@ -378,14 +398,16 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||
parsed_tool_calls.append(ToolCallRequest(
|
||||
id=_short_tool_id(),
|
||||
name=str(fn.get("name") or ""),
|
||||
arguments=args if isinstance(args, dict) else {},
|
||||
extra_content=ec,
|
||||
provider_specific_fields=prov,
|
||||
function_provider_specific_fields=fn_prov,
|
||||
))
|
||||
parsed_tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=_short_tool_id(),
|
||||
name=str(fn.get("name") or ""),
|
||||
arguments=args if isinstance(args, dict) else {},
|
||||
extra_content=ec,
|
||||
provider_specific_fields=prov,
|
||||
function_provider_specific_fields=fn_prov,
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
@ -419,14 +441,16 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=_short_tool_id(),
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
extra_content=ec,
|
||||
provider_specific_fields=prov,
|
||||
function_provider_specific_fields=fn_prov,
|
||||
))
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=_short_tool_id(),
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
extra_content=ec,
|
||||
provider_specific_fields=prov,
|
||||
function_provider_specific_fields=fn_prov,
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
@ -446,10 +470,17 @@ class OpenAICompatProvider(LLMProvider):
|
||||
def _accum_tc(tc: Any, idx_hint: int) -> None:
|
||||
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
|
||||
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
|
||||
buf = tc_bufs.setdefault(tc_index, {
|
||||
"id": "", "name": "", "arguments": "",
|
||||
"extra_content": None, "prov": None, "fn_prov": None,
|
||||
})
|
||||
buf = tc_bufs.setdefault(
|
||||
tc_index,
|
||||
{
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
"extra_content": None,
|
||||
"prov": None,
|
||||
"fn_prov": None,
|
||||
},
|
||||
)
|
||||
tc_id = _get(tc, "id")
|
||||
if tc_id:
|
||||
buf["id"] = str(tc_id)
|
||||
@ -547,8 +578,13 @@ class OpenAICompatProvider(LLMProvider):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
messages,
|
||||
tools,
|
||||
model,
|
||||
max_tokens,
|
||||
temperature,
|
||||
reasoning_effort,
|
||||
tool_choice,
|
||||
)
|
||||
try:
|
||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||
@ -567,8 +603,13 @@ class OpenAICompatProvider(LLMProvider):
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
messages,
|
||||
tools,
|
||||
model,
|
||||
max_tokens,
|
||||
temperature,
|
||||
reasoning_effort,
|
||||
tool_choice,
|
||||
)
|
||||
kwargs["stream"] = True
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user