fix(providers): only apply cache_control for Claude models on OpenRouter

This commit is contained in:
Tejas1Koli 2026-04-01 10:36:24 +05:30 committed by Xubin Ren
parent 05fe73947f
commit 42fa8fa933

View File

@ -18,17 +18,10 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
if TYPE_CHECKING:
from nanobot.providers.registry import ProviderSpec
_ALLOWED_MSG_KEYS = frozenset(
{
"role",
"content",
"tool_calls",
"tool_call_id",
"name",
"reasoning_content",
"extra_content",
}
)
_ALLOWED_MSG_KEYS = frozenset({
"role", "content", "tool_calls", "tool_call_id", "name",
"reasoning_content", "extra_content",
})
_ALNUM = string.ascii_letters + string.digits
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
@ -66,9 +59,7 @@ def _coerce_dict(value: Any) -> dict[str, Any] | None:
return None
def _extract_tc_extras(
tc: Any,
) -> tuple[
def _extract_tc_extras(tc: Any) -> tuple[
dict[str, Any] | None,
dict[str, Any] | None,
dict[str, Any] | None,
@ -84,18 +75,14 @@ def _extract_tc_extras(
prov = None
fn_prov = None
if tc_dict is not None:
leftover = {
k: v
for k, v in tc_dict.items()
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None
}
leftover = {k: v for k, v in tc_dict.items()
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None}
if leftover:
prov = leftover
fn = _coerce_dict(tc_dict.get("function"))
if fn is not None:
fn_leftover = {
k: v for k, v in fn.items() if k not in _STANDARD_FN_KEYS and v is not None
}
fn_leftover = {k: v for k, v in fn.items()
if k not in _STANDARD_FN_KEYS and v is not None}
if fn_leftover:
fn_prov = fn_leftover
else:
@ -176,12 +163,9 @@ class OpenAICompatProvider(LLMProvider):
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
content = msg.get("content")
if isinstance(content, str):
return {
**msg,
"content": [
{"type": "text", "text": content, "cache_control": cache_marker},
],
}
return {**msg, "content": [
{"type": "text", "text": content, "cache_control": cache_marker},
]}
if isinstance(content, list) and content:
nc = list(content)
nc[-1] = {**nc[-1], "cache_control": cache_marker}
@ -366,9 +350,7 @@ class OpenAICompatProvider(LLMProvider):
finish_reason=str(response_map.get("finish_reason") or "stop"),
usage=self._extract_usage(response_map),
)
return LLMResponse(
content="Error: API returned empty choices.", finish_reason="error"
)
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
choice0 = self._maybe_mapping(choices[0]) or {}
msg0 = self._maybe_mapping(choice0.get("message")) or {}
@ -398,16 +380,14 @@ class OpenAICompatProvider(LLMProvider):
if isinstance(args, str):
args = json_repair.loads(args)
ec, prov, fn_prov = _extract_tc_extras(tc)
parsed_tool_calls.append(
ToolCallRequest(
id=_short_tool_id(),
name=str(fn.get("name") or ""),
arguments=args if isinstance(args, dict) else {},
extra_content=ec,
provider_specific_fields=prov,
function_provider_specific_fields=fn_prov,
)
)
parsed_tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=str(fn.get("name") or ""),
arguments=args if isinstance(args, dict) else {},
extra_content=ec,
provider_specific_fields=prov,
function_provider_specific_fields=fn_prov,
))
return LLMResponse(
content=content,
@ -441,16 +421,14 @@ class OpenAICompatProvider(LLMProvider):
if isinstance(args, str):
args = json_repair.loads(args)
ec, prov, fn_prov = _extract_tc_extras(tc)
tool_calls.append(
ToolCallRequest(
id=_short_tool_id(),
name=tc.function.name,
arguments=args,
extra_content=ec,
provider_specific_fields=prov,
function_provider_specific_fields=fn_prov,
)
)
tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=tc.function.name,
arguments=args,
extra_content=ec,
provider_specific_fields=prov,
function_provider_specific_fields=fn_prov,
))
return LLMResponse(
content=content,
@ -470,17 +448,10 @@ class OpenAICompatProvider(LLMProvider):
def _accum_tc(tc: Any, idx_hint: int) -> None:
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
buf = tc_bufs.setdefault(
tc_index,
{
"id": "",
"name": "",
"arguments": "",
"extra_content": None,
"prov": None,
"fn_prov": None,
},
)
buf = tc_bufs.setdefault(tc_index, {
"id": "", "name": "", "arguments": "",
"extra_content": None, "prov": None, "fn_prov": None,
})
tc_id = _get(tc, "id")
if tc_id:
buf["id"] = str(tc_id)
@ -578,13 +549,8 @@ class OpenAICompatProvider(LLMProvider):
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
kwargs = self._build_kwargs(
messages,
tools,
model,
max_tokens,
temperature,
reasoning_effort,
tool_choice,
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
try:
return self._parse(await self._client.chat.completions.create(**kwargs))
@ -603,13 +569,8 @@ class OpenAICompatProvider(LLMProvider):
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
kwargs = self._build_kwargs(
messages,
tools,
model,
max_tokens,
temperature,
reasoning_effort,
tool_choice,
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
@ -627,4 +588,4 @@ class OpenAICompatProvider(LLMProvider):
return self._handle_error(e)
def get_default_model(self) -> str:
return self.default_model
return self.default_model