mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-20 01:50:08 +00:00
fix(providers): only apply cache_control for Claude models on OpenRouter
This commit is contained in:
parent
05fe73947f
commit
42fa8fa933
@ -18,17 +18,10 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanobot.providers.registry import ProviderSpec
|
from nanobot.providers.registry import ProviderSpec
|
||||||
|
|
||||||
_ALLOWED_MSG_KEYS = frozenset(
|
_ALLOWED_MSG_KEYS = frozenset({
|
||||||
{
|
"role", "content", "tool_calls", "tool_call_id", "name",
|
||||||
"role",
|
"reasoning_content", "extra_content",
|
||||||
"content",
|
})
|
||||||
"tool_calls",
|
|
||||||
"tool_call_id",
|
|
||||||
"name",
|
|
||||||
"reasoning_content",
|
|
||||||
"extra_content",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
_ALNUM = string.ascii_letters + string.digits
|
_ALNUM = string.ascii_letters + string.digits
|
||||||
|
|
||||||
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
|
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
|
||||||
@ -66,9 +59,7 @@ def _coerce_dict(value: Any) -> dict[str, Any] | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _extract_tc_extras(
|
def _extract_tc_extras(tc: Any) -> tuple[
|
||||||
tc: Any,
|
|
||||||
) -> tuple[
|
|
||||||
dict[str, Any] | None,
|
dict[str, Any] | None,
|
||||||
dict[str, Any] | None,
|
dict[str, Any] | None,
|
||||||
dict[str, Any] | None,
|
dict[str, Any] | None,
|
||||||
@ -84,18 +75,14 @@ def _extract_tc_extras(
|
|||||||
prov = None
|
prov = None
|
||||||
fn_prov = None
|
fn_prov = None
|
||||||
if tc_dict is not None:
|
if tc_dict is not None:
|
||||||
leftover = {
|
leftover = {k: v for k, v in tc_dict.items()
|
||||||
k: v
|
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None}
|
||||||
for k, v in tc_dict.items()
|
|
||||||
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None
|
|
||||||
}
|
|
||||||
if leftover:
|
if leftover:
|
||||||
prov = leftover
|
prov = leftover
|
||||||
fn = _coerce_dict(tc_dict.get("function"))
|
fn = _coerce_dict(tc_dict.get("function"))
|
||||||
if fn is not None:
|
if fn is not None:
|
||||||
fn_leftover = {
|
fn_leftover = {k: v for k, v in fn.items()
|
||||||
k: v for k, v in fn.items() if k not in _STANDARD_FN_KEYS and v is not None
|
if k not in _STANDARD_FN_KEYS and v is not None}
|
||||||
}
|
|
||||||
if fn_leftover:
|
if fn_leftover:
|
||||||
fn_prov = fn_leftover
|
fn_prov = fn_leftover
|
||||||
else:
|
else:
|
||||||
@ -176,12 +163,9 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
||||||
content = msg.get("content")
|
content = msg.get("content")
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
return {
|
return {**msg, "content": [
|
||||||
**msg,
|
{"type": "text", "text": content, "cache_control": cache_marker},
|
||||||
"content": [
|
]}
|
||||||
{"type": "text", "text": content, "cache_control": cache_marker},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
if isinstance(content, list) and content:
|
if isinstance(content, list) and content:
|
||||||
nc = list(content)
|
nc = list(content)
|
||||||
nc[-1] = {**nc[-1], "cache_control": cache_marker}
|
nc[-1] = {**nc[-1], "cache_control": cache_marker}
|
||||||
@ -366,9 +350,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
finish_reason=str(response_map.get("finish_reason") or "stop"),
|
finish_reason=str(response_map.get("finish_reason") or "stop"),
|
||||||
usage=self._extract_usage(response_map),
|
usage=self._extract_usage(response_map),
|
||||||
)
|
)
|
||||||
return LLMResponse(
|
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
||||||
content="Error: API returned empty choices.", finish_reason="error"
|
|
||||||
)
|
|
||||||
|
|
||||||
choice0 = self._maybe_mapping(choices[0]) or {}
|
choice0 = self._maybe_mapping(choices[0]) or {}
|
||||||
msg0 = self._maybe_mapping(choice0.get("message")) or {}
|
msg0 = self._maybe_mapping(choice0.get("message")) or {}
|
||||||
@ -398,16 +380,14 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
if isinstance(args, str):
|
if isinstance(args, str):
|
||||||
args = json_repair.loads(args)
|
args = json_repair.loads(args)
|
||||||
ec, prov, fn_prov = _extract_tc_extras(tc)
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||||
parsed_tool_calls.append(
|
parsed_tool_calls.append(ToolCallRequest(
|
||||||
ToolCallRequest(
|
id=_short_tool_id(),
|
||||||
id=_short_tool_id(),
|
name=str(fn.get("name") or ""),
|
||||||
name=str(fn.get("name") or ""),
|
arguments=args if isinstance(args, dict) else {},
|
||||||
arguments=args if isinstance(args, dict) else {},
|
extra_content=ec,
|
||||||
extra_content=ec,
|
provider_specific_fields=prov,
|
||||||
provider_specific_fields=prov,
|
function_provider_specific_fields=fn_prov,
|
||||||
function_provider_specific_fields=fn_prov,
|
))
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=content,
|
content=content,
|
||||||
@ -441,16 +421,14 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
if isinstance(args, str):
|
if isinstance(args, str):
|
||||||
args = json_repair.loads(args)
|
args = json_repair.loads(args)
|
||||||
ec, prov, fn_prov = _extract_tc_extras(tc)
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
||||||
tool_calls.append(
|
tool_calls.append(ToolCallRequest(
|
||||||
ToolCallRequest(
|
id=_short_tool_id(),
|
||||||
id=_short_tool_id(),
|
name=tc.function.name,
|
||||||
name=tc.function.name,
|
arguments=args,
|
||||||
arguments=args,
|
extra_content=ec,
|
||||||
extra_content=ec,
|
provider_specific_fields=prov,
|
||||||
provider_specific_fields=prov,
|
function_provider_specific_fields=fn_prov,
|
||||||
function_provider_specific_fields=fn_prov,
|
))
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=content,
|
content=content,
|
||||||
@ -470,17 +448,10 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
def _accum_tc(tc: Any, idx_hint: int) -> None:
|
def _accum_tc(tc: Any, idx_hint: int) -> None:
|
||||||
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
|
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
|
||||||
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
|
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
|
||||||
buf = tc_bufs.setdefault(
|
buf = tc_bufs.setdefault(tc_index, {
|
||||||
tc_index,
|
"id": "", "name": "", "arguments": "",
|
||||||
{
|
"extra_content": None, "prov": None, "fn_prov": None,
|
||||||
"id": "",
|
})
|
||||||
"name": "",
|
|
||||||
"arguments": "",
|
|
||||||
"extra_content": None,
|
|
||||||
"prov": None,
|
|
||||||
"fn_prov": None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
tc_id = _get(tc, "id")
|
tc_id = _get(tc, "id")
|
||||||
if tc_id:
|
if tc_id:
|
||||||
buf["id"] = str(tc_id)
|
buf["id"] = str(tc_id)
|
||||||
@ -578,13 +549,8 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
kwargs = self._build_kwargs(
|
kwargs = self._build_kwargs(
|
||||||
messages,
|
messages, tools, model, max_tokens, temperature,
|
||||||
tools,
|
reasoning_effort, tool_choice,
|
||||||
model,
|
|
||||||
max_tokens,
|
|
||||||
temperature,
|
|
||||||
reasoning_effort,
|
|
||||||
tool_choice,
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||||
@ -603,13 +569,8 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
kwargs = self._build_kwargs(
|
kwargs = self._build_kwargs(
|
||||||
messages,
|
messages, tools, model, max_tokens, temperature,
|
||||||
tools,
|
reasoning_effort, tool_choice,
|
||||||
model,
|
|
||||||
max_tokens,
|
|
||||||
temperature,
|
|
||||||
reasoning_effort,
|
|
||||||
tool_choice,
|
|
||||||
)
|
)
|
||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
kwargs["stream_options"] = {"include_usage": True}
|
kwargs["stream_options"] = {"include_usage": True}
|
||||||
@ -627,4 +588,4 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
return self._handle_error(e)
|
return self._handle_error(e)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
return self.default_model
|
return self.default_model
|
||||||
Loading…
x
Reference in New Issue
Block a user