diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 412044d64..e9dd08645 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -308,6 +308,13 @@ class OpenAICompatProvider(LLMProvider): @classmethod def _extract_usage(cls, response: Any) -> dict[str, int]: + """Extract token usage from an OpenAI-compatible response. + + Handles both dict-based (raw JSON) and object-based (SDK Pydantic) + responses. Provider-specific ``cached_tokens`` fields are normalised + under a single key; see the priority chain inside for details. + """ + # --- resolve usage object --- usage_obj = None response_map = cls._maybe_mapping(response) if response_map is not None: @@ -331,34 +338,40 @@ class OpenAICompatProvider(LLMProvider): else: return {} - # Extract cached_tokens from various provider formats. - # Priority: prompt_tokens_details > top-level cached_tokens > prompt_cache_hit_tokens - cached = 0 - # 1. OpenAI / Zhipu / MiniMax / Qwen / SiliconFlow / 豆包 / Mistral / xAI: - # nested prompt_tokens_details.cached_tokens - details = (usage_map or {}).get("prompt_tokens_details") if usage_map else None - if not cls._maybe_mapping(details): - details = getattr(usage_obj, "prompt_tokens_details", None) if usage_obj else None - details_map = cls._maybe_mapping(details) - if details_map is not None: - cached = int(details_map.get("cached_tokens") or 0) - elif details is not None: - cached = int(getattr(details, "cached_tokens", 0) or 0) - # 2. StepFun / Moonshot: top-level usage.cached_tokens - if not cached and usage_map is not None: - cached = int(usage_map.get("cached_tokens") or 0) - if not cached and usage_obj and not usage_map: - cached = int(getattr(usage_obj, "cached_tokens", 0) or 0) - # 3. DeepSeek / SiliconFlow extra: top-level prompt_cache_hit_tokens - if not cached and usage_map is not None: - cached = int(usage_map.get("prompt_cache_hit_tokens") or 0) - if not cached and usage_obj and not usage_map: - cached = int(getattr(usage_obj, "prompt_cache_hit_tokens", 0) or 0) - if cached: - result["cached_tokens"] = cached + # --- cached_tokens (normalised across providers) --- + # Try nested paths first (dict), fall back to attribute (SDK object). + # Priority order ensures the most specific field wins. + for path in ( + ("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI + ("cached_tokens",), # StepFun/Moonshot (top-level) + ("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow + ): + cached = cls._get_nested_int(usage_map, path) + if not cached and usage_obj: + cached = cls._get_nested_int(usage_obj, path) + if cached: + result["cached_tokens"] = cached + break return result + @staticmethod + def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int: + """Drill into *obj* by *path* segments and return an ``int`` value. + + Supports both dict-key access and attribute access so it works + uniformly with raw JSON dicts **and** SDK Pydantic models. + """ + current = obj + for segment in path: + if current is None: + return 0 + if isinstance(current, dict): + current = current.get(segment) + else: + current = getattr(current, segment, None) + return int(current or 0) if current is not None else 0 + def _parse(self, response: Any) -> LLMResponse: if isinstance(response, str): return LLMResponse(content=response, finish_reason="stop")