mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-24 11:44:12 +00:00
fix(providers): only apply cache_control for Claude models on OpenRouter
This commit is contained in:
parent
485c75e065
commit
05fe73947f
@ -18,10 +18,17 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanobot.providers.registry import ProviderSpec
|
from nanobot.providers.registry import ProviderSpec
|
||||||
|
|
||||||
_ALLOWED_MSG_KEYS = frozenset({
|
_ALLOWED_MSG_KEYS = frozenset(
|
||||||
"role", "content", "tool_calls", "tool_call_id", "name",
|
{
|
||||||
"reasoning_content", "extra_content",
|
"role",
|
||||||
})
|
"content",
|
||||||
|
"tool_calls",
|
||||||
|
"tool_call_id",
|
||||||
|
"name",
|
||||||
|
"reasoning_content",
|
||||||
|
"extra_content",
|
||||||
|
}
|
||||||
|
)
|
||||||
_ALNUM = string.ascii_letters + string.digits
|
_ALNUM = string.ascii_letters + string.digits
|
||||||
|
|
||||||
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
|
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
|
||||||
@ -59,7 +66,9 @@ def _coerce_dict(value: Any) -> dict[str, Any] | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _extract_tc_extras(tc: Any) -> tuple[
|
def _extract_tc_extras(
|
||||||
|
tc: Any,
|
||||||
|
) -> tuple[
|
||||||
dict[str, Any] | None,
|
dict[str, Any] | None,
|
||||||
dict[str, Any] | None,
|
dict[str, Any] | None,
|
||||||
dict[str, Any] | None,
|
dict[str, Any] | None,
|
||||||
@ -75,14 +84,18 @@ def _extract_tc_extras(tc: Any) -> tuple[
|
|||||||
prov = None
|
prov = None
|
||||||
fn_prov = None
|
fn_prov = None
|
||||||
if tc_dict is not None:
|
if tc_dict is not None:
|
||||||
leftover = {k: v for k, v in tc_dict.items()
|
leftover = {
|
||||||
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None}
|
k: v
|
||||||
|
for k, v in tc_dict.items()
|
||||||
|
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None
|
||||||
|
}
|
||||||
if leftover:
|
if leftover:
|
||||||
prov = leftover
|
prov = leftover
|
||||||
fn = _coerce_dict(tc_dict.get("function"))
|
fn = _coerce_dict(tc_dict.get("function"))
|
||||||
if fn is not None:
|
if fn is not None:
|
||||||
fn_leftover = {k: v for k, v in fn.items()
|
fn_leftover = {
|
||||||
if k not in _STANDARD_FN_KEYS and v is not None}
|
k: v for k, v in fn.items() if k not in _STANDARD_FN_KEYS and v is not None
|
||||||
|
}
|
||||||
if fn_leftover:
|
if fn_leftover:
|
||||||
fn_prov = fn_leftover
|
fn_prov = fn_leftover
|
||||||
else:
|
else:
|
||||||
@ -163,9 +176,12 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
||||||
content = msg.get("content")
|
content = msg.get("content")
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
return {**msg, "content": [
|
return {
|
||||||
|
**msg,
|
||||||
|
"content": [
|
||||||
{"type": "text", "text": content, "cache_control": cache_marker},
|
{"type": "text", "text": content, "cache_control": cache_marker},
|
||||||
]}
|
],
|
||||||
|
}
|
||||||
if isinstance(content, list) and content:
|
if isinstance(content, list) and content:
|
||||||
nc = list(content)
|
nc = list(content)
|
||||||
nc[-1] = {**nc[-1], "cache_control": cache_marker}
|
nc[-1] = {**nc[-1], "cache_control": cache_marker}
|
||||||
@ -235,6 +251,8 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
spec = self._spec
|
spec = self._spec
|
||||||
|
|
||||||
if spec and spec.supports_prompt_caching:
|
if spec and spec.supports_prompt_caching:
|
||||||
|
model_name = model or self.default_model
|
||||||
|
if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")):
|
||||||
messages, tools = self._apply_cache_control(messages, tools)
|
messages, tools = self._apply_cache_control(messages, tools)
|
||||||
|
|
||||||
if spec and spec.strip_model_prefix:
|
if spec and spec.strip_model_prefix:
|
||||||
@ -348,7 +366,9 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
finish_reason=str(response_map.get("finish_reason") or "stop"),
|
finish_reason=str(response_map.get("finish_reason") or "stop"),
|
||||||
usage=self._extract_usage(response_map),
|
usage=self._extract_usage(response_map),
|
||||||
)
|
)
|
||||||
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
return LLMResponse(
|
||||||
|
content="Error: API returned empty choices.", finish_reason="error"
|
||||||
|
)
|
||||||
|
|
||||||
choice0 = self._maybe_mapping(choices[0]) or {}
|
choice0 = self._maybe_mapping(choices[0]) or {}
|
||||||
msg0 = self._maybe_mapping(choice0.get("message")) or {}
|
msg0 = self._maybe_mapping(choice0.get("message")) or {}
|
||||||
@ -378,14 +398,16 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
if isinstance(args, str):
|
if isinstance(args, str):
|
||||||
args = json_repair.loads(args)
|
args = json_repair.loads(args)
|
||||||
ec, prov, fn_prov = _extract_tc_extras(tc)
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||||
parsed_tool_calls.append(ToolCallRequest(
|
parsed_tool_calls.append(
|
||||||
|
ToolCallRequest(
|
||||||
id=_short_tool_id(),
|
id=_short_tool_id(),
|
||||||
name=str(fn.get("name") or ""),
|
name=str(fn.get("name") or ""),
|
||||||
arguments=args if isinstance(args, dict) else {},
|
arguments=args if isinstance(args, dict) else {},
|
||||||
extra_content=ec,
|
extra_content=ec,
|
||||||
provider_specific_fields=prov,
|
provider_specific_fields=prov,
|
||||||
function_provider_specific_fields=fn_prov,
|
function_provider_specific_fields=fn_prov,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=content,
|
content=content,
|
||||||
@ -419,14 +441,16 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
if isinstance(args, str):
|
if isinstance(args, str):
|
||||||
args = json_repair.loads(args)
|
args = json_repair.loads(args)
|
||||||
ec, prov, fn_prov = _extract_tc_extras(tc)
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||||
tool_calls.append(ToolCallRequest(
|
tool_calls.append(
|
||||||
|
ToolCallRequest(
|
||||||
id=_short_tool_id(),
|
id=_short_tool_id(),
|
||||||
name=tc.function.name,
|
name=tc.function.name,
|
||||||
arguments=args,
|
arguments=args,
|
||||||
extra_content=ec,
|
extra_content=ec,
|
||||||
provider_specific_fields=prov,
|
provider_specific_fields=prov,
|
||||||
function_provider_specific_fields=fn_prov,
|
function_provider_specific_fields=fn_prov,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=content,
|
content=content,
|
||||||
@ -446,10 +470,17 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
def _accum_tc(tc: Any, idx_hint: int) -> None:
|
def _accum_tc(tc: Any, idx_hint: int) -> None:
|
||||||
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
|
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
|
||||||
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
|
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
|
||||||
buf = tc_bufs.setdefault(tc_index, {
|
buf = tc_bufs.setdefault(
|
||||||
"id": "", "name": "", "arguments": "",
|
tc_index,
|
||||||
"extra_content": None, "prov": None, "fn_prov": None,
|
{
|
||||||
})
|
"id": "",
|
||||||
|
"name": "",
|
||||||
|
"arguments": "",
|
||||||
|
"extra_content": None,
|
||||||
|
"prov": None,
|
||||||
|
"fn_prov": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
tc_id = _get(tc, "id")
|
tc_id = _get(tc, "id")
|
||||||
if tc_id:
|
if tc_id:
|
||||||
buf["id"] = str(tc_id)
|
buf["id"] = str(tc_id)
|
||||||
@ -547,8 +578,13 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
kwargs = self._build_kwargs(
|
kwargs = self._build_kwargs(
|
||||||
messages, tools, model, max_tokens, temperature,
|
messages,
|
||||||
reasoning_effort, tool_choice,
|
tools,
|
||||||
|
model,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
reasoning_effort,
|
||||||
|
tool_choice,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||||
@ -567,8 +603,13 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
kwargs = self._build_kwargs(
|
kwargs = self._build_kwargs(
|
||||||
messages, tools, model, max_tokens, temperature,
|
messages,
|
||||||
reasoning_effort, tool_choice,
|
tools,
|
||||||
|
model,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
reasoning_effort,
|
||||||
|
tool_choice,
|
||||||
)
|
)
|
||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
kwargs["stream_options"] = {"include_usage": True}
|
kwargs["stream_options"] = {"include_usage": True}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user