refactor(providers): simplify cached_tokens extraction with _get_nested_int

Extract a _get_nested_int helper that unifies dict-key and attribute
access, then express the 3-tier provider fallback as a simple loop
over path tuples instead of duplicated if/else chains.
This commit is contained in:
chengyongru 2026-03-30 17:46:22 +08:00
parent 07f216b13f
commit 9c869d0bdf

View File

@ -308,6 +308,13 @@ class OpenAICompatProvider(LLMProvider):
@classmethod
def _extract_usage(cls, response: Any) -> dict[str, int]:
"""Extract token usage from an OpenAI-compatible response.
Handles both dict-based (raw JSON) and object-based (SDK Pydantic)
responses. Provider-specific ``cached_tokens`` fields are normalised
under a single key; see the priority chain inside for details.
"""
# --- resolve usage object ---
usage_obj = None
response_map = cls._maybe_mapping(response)
if response_map is not None:
@ -331,34 +338,40 @@ class OpenAICompatProvider(LLMProvider):
else:
return {}
# Extract cached_tokens from various provider formats.
# Priority: prompt_tokens_details > top-level cached_tokens > prompt_cache_hit_tokens
cached = 0
# 1. OpenAI / Zhipu / MiniMax / Qwen / SiliconFlow / 豆包 / Mistral / xAI:
# nested prompt_tokens_details.cached_tokens
details = (usage_map or {}).get("prompt_tokens_details") if usage_map else None
if not cls._maybe_mapping(details):
details = getattr(usage_obj, "prompt_tokens_details", None) if usage_obj else None
details_map = cls._maybe_mapping(details)
if details_map is not None:
cached = int(details_map.get("cached_tokens") or 0)
elif details is not None:
cached = int(getattr(details, "cached_tokens", 0) or 0)
# 2. StepFun / Moonshot: top-level usage.cached_tokens
if not cached and usage_map is not None:
cached = int(usage_map.get("cached_tokens") or 0)
if not cached and usage_obj and not usage_map:
cached = int(getattr(usage_obj, "cached_tokens", 0) or 0)
# 3. DeepSeek / SiliconFlow extra: top-level prompt_cache_hit_tokens
if not cached and usage_map is not None:
cached = int(usage_map.get("prompt_cache_hit_tokens") or 0)
if not cached and usage_obj and not usage_map:
cached = int(getattr(usage_obj, "prompt_cache_hit_tokens", 0) or 0)
if cached:
result["cached_tokens"] = cached
# --- cached_tokens (normalised across providers) ---
# Try nested paths first (dict), fall back to attribute (SDK object).
# Priority order ensures the most specific field wins.
for path in (
("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI
("cached_tokens",), # StepFun/Moonshot (top-level)
("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow
):
cached = cls._get_nested_int(usage_map, path)
if not cached and usage_obj:
cached = cls._get_nested_int(usage_obj, path)
if cached:
result["cached_tokens"] = cached
break
return result
@staticmethod
def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int:
"""Drill into *obj* by *path* segments and return an ``int`` value.
Supports both dict-key access and attribute access so it works
uniformly with raw JSON dicts **and** SDK Pydantic models.
"""
current = obj
for segment in path:
if current is None:
return 0
if isinstance(current, dict):
current = current.get(segment)
else:
current = getattr(current, segment, None)
return int(current or 0) if current is not None else 0
def _parse(self, response: Any) -> LLMResponse:
if isinstance(response, str):
return LLMResponse(content=response, finish_reason="stop")