From 05fe73947f219be405be57d9a27eb97e00fa4953 Mon Sep 17 00:00:00 2001 From: Tejas1Koli Date: Wed, 1 Apr 2026 00:51:49 +0530 Subject: [PATCH] fix(providers): only apply cache_control for Claude models on OpenRouter --- nanobot/providers/openai_compat_provider.py | 117 +++++++++++++------- 1 file changed, 79 insertions(+), 38 deletions(-) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 397b8e797..a033b44ef 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -18,10 +18,17 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest if TYPE_CHECKING: from nanobot.providers.registry import ProviderSpec -_ALLOWED_MSG_KEYS = frozenset({ - "role", "content", "tool_calls", "tool_call_id", "name", - "reasoning_content", "extra_content", -}) +_ALLOWED_MSG_KEYS = frozenset( + { + "role", + "content", + "tool_calls", + "tool_call_id", + "name", + "reasoning_content", + "extra_content", + } +) _ALNUM = string.ascii_letters + string.digits _STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"}) @@ -59,7 +66,9 @@ def _coerce_dict(value: Any) -> dict[str, Any] | None: return None -def _extract_tc_extras(tc: Any) -> tuple[ +def _extract_tc_extras( + tc: Any, +) -> tuple[ dict[str, Any] | None, dict[str, Any] | None, dict[str, Any] | None, @@ -75,14 +84,18 @@ def _extract_tc_extras(tc: Any) -> tuple[ prov = None fn_prov = None if tc_dict is not None: - leftover = {k: v for k, v in tc_dict.items() - if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None} + leftover = { + k: v + for k, v in tc_dict.items() + if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None + } if leftover: prov = leftover fn = _coerce_dict(tc_dict.get("function")) if fn is not None: - fn_leftover = {k: v for k, v in fn.items() - if k not in _STANDARD_FN_KEYS and v is not None} + fn_leftover = { + k: v for k, v in fn.items() if k not in _STANDARD_FN_KEYS and v is not None + } if fn_leftover: fn_prov = fn_leftover else: @@ -163,9 +176,12 @@ class OpenAICompatProvider(LLMProvider): def _mark(msg: dict[str, Any]) -> dict[str, Any]: content = msg.get("content") if isinstance(content, str): - return {**msg, "content": [ - {"type": "text", "text": content, "cache_control": cache_marker}, - ]} + return { + **msg, + "content": [ + {"type": "text", "text": content, "cache_control": cache_marker}, + ], + } if isinstance(content, list) and content: nc = list(content) nc[-1] = {**nc[-1], "cache_control": cache_marker} @@ -235,7 +251,9 @@ class OpenAICompatProvider(LLMProvider): spec = self._spec if spec and spec.supports_prompt_caching: - messages, tools = self._apply_cache_control(messages, tools) + model_name = model or self.default_model + if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")): + messages, tools = self._apply_cache_control(messages, tools) if spec and spec.strip_model_prefix: model_name = model_name.split("/")[-1] @@ -348,7 +366,9 @@ class OpenAICompatProvider(LLMProvider): finish_reason=str(response_map.get("finish_reason") or "stop"), usage=self._extract_usage(response_map), ) - return LLMResponse(content="Error: API returned empty choices.", finish_reason="error") + return LLMResponse( + content="Error: API returned empty choices.", finish_reason="error" + ) choice0 = self._maybe_mapping(choices[0]) or {} msg0 = self._maybe_mapping(choice0.get("message")) or {} @@ -378,14 +398,16 @@ class OpenAICompatProvider(LLMProvider): if isinstance(args, str): args = json_repair.loads(args) ec, prov, fn_prov = _extract_tc_extras(tc) - parsed_tool_calls.append(ToolCallRequest( - id=_short_tool_id(), - name=str(fn.get("name") or ""), - arguments=args if isinstance(args, dict) else {}, - extra_content=ec, - provider_specific_fields=prov, - function_provider_specific_fields=fn_prov, - )) + parsed_tool_calls.append( + ToolCallRequest( + id=_short_tool_id(), + name=str(fn.get("name") or ""), + arguments=args if isinstance(args, dict) else {}, + extra_content=ec, + provider_specific_fields=prov, + function_provider_specific_fields=fn_prov, + ) + ) return LLMResponse( content=content, @@ -419,14 +441,16 @@ class OpenAICompatProvider(LLMProvider): if isinstance(args, str): args = json_repair.loads(args) ec, prov, fn_prov = _extract_tc_extras(tc) - tool_calls.append(ToolCallRequest( - id=_short_tool_id(), - name=tc.function.name, - arguments=args, - extra_content=ec, - provider_specific_fields=prov, - function_provider_specific_fields=fn_prov, - )) + tool_calls.append( + ToolCallRequest( + id=_short_tool_id(), + name=tc.function.name, + arguments=args, + extra_content=ec, + provider_specific_fields=prov, + function_provider_specific_fields=fn_prov, + ) + ) return LLMResponse( content=content, @@ -446,10 +470,17 @@ class OpenAICompatProvider(LLMProvider): def _accum_tc(tc: Any, idx_hint: int) -> None: """Accumulate one streaming tool-call delta into *tc_bufs*.""" tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint - buf = tc_bufs.setdefault(tc_index, { - "id": "", "name": "", "arguments": "", - "extra_content": None, "prov": None, "fn_prov": None, - }) + buf = tc_bufs.setdefault( + tc_index, + { + "id": "", + "name": "", + "arguments": "", + "extra_content": None, + "prov": None, + "fn_prov": None, + }, + ) tc_id = _get(tc, "id") if tc_id: buf["id"] = str(tc_id) @@ -547,8 +578,13 @@ class OpenAICompatProvider(LLMProvider): tool_choice: str | dict[str, Any] | None = None, ) -> LLMResponse: kwargs = self._build_kwargs( - messages, tools, model, max_tokens, temperature, - reasoning_effort, tool_choice, + messages, + tools, + model, + max_tokens, + temperature, + reasoning_effort, + tool_choice, ) try: return self._parse(await self._client.chat.completions.create(**kwargs)) @@ -567,8 +603,13 @@ class OpenAICompatProvider(LLMProvider): on_content_delta: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: kwargs = self._build_kwargs( - messages, tools, model, max_tokens, temperature, - reasoning_effort, tool_choice, + messages, + tools, + model, + max_tokens, + temperature, + reasoning_effort, + tool_choice, ) kwargs["stream"] = True kwargs["stream_options"] = {"include_usage": True}