mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
feat(provider): show cache hit rate in /status (#2645)
This commit is contained in:
parent
4741026538
commit
0334fa9944
@ -97,6 +97,15 @@ class _LoopHook(AgentHook):
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
u = context.usage or {}
|
||||
logger.debug(
|
||||
"LLM usage: prompt={} completion={} cached={}",
|
||||
u.get("prompt_tokens", 0),
|
||||
u.get("completion_tokens", 0),
|
||||
u.get("cached_tokens", 0),
|
||||
)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return self._loop._strip_think(content)
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ class AgentRunner:
|
||||
messages = list(spec.initial_messages)
|
||||
final_content: str | None = None
|
||||
tools_used: list[str] = []
|
||||
usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
usage: dict[str, int] = {}
|
||||
error: str | None = None
|
||||
stop_reason = "completed"
|
||||
tool_events: list[dict[str, str]] = []
|
||||
@ -92,13 +92,15 @@ class AgentRunner:
|
||||
response = await self.provider.chat_with_retry(**kwargs)
|
||||
|
||||
raw_usage = response.usage or {}
|
||||
usage = {
|
||||
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
|
||||
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
|
||||
}
|
||||
context.response = response
|
||||
context.usage = usage
|
||||
context.usage = raw_usage
|
||||
context.tool_calls = list(response.tool_calls)
|
||||
# Accumulate standard fields into result usage.
|
||||
usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0)
|
||||
usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0)
|
||||
cached = raw_usage.get("cached_tokens")
|
||||
if cached:
|
||||
usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached)
|
||||
|
||||
if response.has_tool_calls:
|
||||
if hook.wants_streaming():
|
||||
|
||||
@ -379,6 +379,10 @@ class AnthropicProvider(LLMProvider):
|
||||
val = getattr(response.usage, attr, 0)
|
||||
if val:
|
||||
usage[attr] = val
|
||||
# Normalize to cached_tokens for downstream consistency.
|
||||
cache_read = usage.get("cache_read_input_tokens", 0)
|
||||
if cache_read:
|
||||
usage["cached_tokens"] = cache_read
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
|
||||
@ -308,6 +308,13 @@ class OpenAICompatProvider(LLMProvider):
|
||||
|
||||
@classmethod
|
||||
def _extract_usage(cls, response: Any) -> dict[str, int]:
|
||||
"""Extract token usage from an OpenAI-compatible response.
|
||||
|
||||
Handles both dict-based (raw JSON) and object-based (SDK Pydantic)
|
||||
responses. Provider-specific ``cached_tokens`` fields are normalised
|
||||
under a single key; see the priority chain inside for details.
|
||||
"""
|
||||
# --- resolve usage object ---
|
||||
usage_obj = None
|
||||
response_map = cls._maybe_mapping(response)
|
||||
if response_map is not None:
|
||||
@ -317,19 +324,53 @@ class OpenAICompatProvider(LLMProvider):
|
||||
|
||||
usage_map = cls._maybe_mapping(usage_obj)
|
||||
if usage_map is not None:
|
||||
return {
|
||||
result = {
|
||||
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
|
||||
"completion_tokens": int(usage_map.get("completion_tokens") or 0),
|
||||
"total_tokens": int(usage_map.get("total_tokens") or 0),
|
||||
}
|
||||
|
||||
if usage_obj:
|
||||
return {
|
||||
elif usage_obj:
|
||||
result = {
|
||||
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
|
||||
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
|
||||
}
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
# --- cached_tokens (normalised across providers) ---
|
||||
# Try nested paths first (dict), fall back to attribute (SDK object).
|
||||
# Priority order ensures the most specific field wins.
|
||||
for path in (
|
||||
("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI
|
||||
("cached_tokens",), # StepFun/Moonshot (top-level)
|
||||
("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow
|
||||
):
|
||||
cached = cls._get_nested_int(usage_map, path)
|
||||
if not cached and usage_obj:
|
||||
cached = cls._get_nested_int(usage_obj, path)
|
||||
if cached:
|
||||
result["cached_tokens"] = cached
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int:
|
||||
"""Drill into *obj* by *path* segments and return an ``int`` value.
|
||||
|
||||
Supports both dict-key access and attribute access so it works
|
||||
uniformly with raw JSON dicts **and** SDK Pydantic models.
|
||||
"""
|
||||
current = obj
|
||||
for segment in path:
|
||||
if current is None:
|
||||
return 0
|
||||
if isinstance(current, dict):
|
||||
current = current.get(segment)
|
||||
else:
|
||||
current = getattr(current, segment, None)
|
||||
return int(current or 0) if current is not None else 0
|
||||
|
||||
def _parse(self, response: Any) -> LLMResponse:
|
||||
if isinstance(response, str):
|
||||
|
||||
@ -255,14 +255,18 @@ def build_status_content(
|
||||
)
|
||||
last_in = last_usage.get("prompt_tokens", 0)
|
||||
last_out = last_usage.get("completion_tokens", 0)
|
||||
cached = last_usage.get("cached_tokens", 0)
|
||||
ctx_total = max(context_window_tokens, 0)
|
||||
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
|
||||
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
|
||||
ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a"
|
||||
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
|
||||
if cached and last_in:
|
||||
token_line += f" ({cached * 100 // last_in}% cached)"
|
||||
return "\n".join([
|
||||
f"\U0001f408 nanobot v{version}",
|
||||
f"\U0001f9e0 Model: {model}",
|
||||
f"\U0001f4ca Tokens: {last_in} in / {last_out} out",
|
||||
token_line,
|
||||
f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
|
||||
f"\U0001f4ac Session: {session_msg_count} messages",
|
||||
f"\u23f1 Uptime: {uptime}",
|
||||
|
||||
@ -333,3 +333,82 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
|
||||
args = mgr._announce_result.await_args.args
|
||||
assert args[3] == "Task completed but no final response was generated."
|
||||
assert args[5] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_accumulates_usage_and_preserves_cached_tokens():
|
||||
"""Runner should accumulate prompt/completion tokens across iterations
|
||||
and preserve cached_tokens from provider responses."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
|
||||
usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80},
|
||||
)
|
||||
return LLMResponse(
|
||||
content="done",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
))
|
||||
|
||||
# Usage should be accumulated across iterations
|
||||
assert result.usage["prompt_tokens"] == 300 # 100 + 200
|
||||
assert result.usage["completion_tokens"] == 30 # 10 + 20
|
||||
assert result.usage["cached_tokens"] == 230 # 80 + 150
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_passes_cached_tokens_to_hook_context():
|
||||
"""Hook context.usage should contain cached_tokens."""
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_usage: list[dict] = []
|
||||
|
||||
class UsageHook(AgentHook):
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
captured_usage.append(dict(context.usage))
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="done",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
hook=UsageHook(),
|
||||
))
|
||||
|
||||
assert len(captured_usage) == 1
|
||||
assert captured_usage[0]["cached_tokens"] == 150
|
||||
|
||||
@ -152,10 +152,12 @@ class TestRestartCommand:
|
||||
])
|
||||
|
||||
await loop._run_agent_loop([])
|
||||
assert loop._last_usage == {"prompt_tokens": 9, "completion_tokens": 4}
|
||||
assert loop._last_usage["prompt_tokens"] == 9
|
||||
assert loop._last_usage["completion_tokens"] == 4
|
||||
|
||||
await loop._run_agent_loop([])
|
||||
assert loop._last_usage == {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
assert loop._last_usage["prompt_tokens"] == 0
|
||||
assert loop._last_usage["completion_tokens"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self):
|
||||
|
||||
231
tests/providers/test_cached_tokens.py
Normal file
231
tests/providers/test_cached_tokens.py
Normal file
@ -0,0 +1,231 @@
|
||||
"""Tests for cached token extraction from OpenAI-compatible providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
|
||||
class FakeUsage:
|
||||
"""Mimics an OpenAI SDK usage object (has attributes, not dict keys)."""
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class FakePromptDetails:
|
||||
"""Mimics prompt_tokens_details sub-object."""
|
||||
def __init__(self, cached_tokens=0):
|
||||
self.cached_tokens = cached_tokens
|
||||
|
||||
|
||||
class _FakeSpec:
|
||||
supports_prompt_caching = False
|
||||
model_id_prefix = None
|
||||
strip_model_prefix = False
|
||||
max_completion_tokens = False
|
||||
reasoning_effort = None
|
||||
|
||||
|
||||
def _provider():
|
||||
from unittest.mock import MagicMock
|
||||
p = OpenAICompatProvider.__new__(OpenAICompatProvider)
|
||||
p.client = MagicMock()
|
||||
p.spec = _FakeSpec()
|
||||
return p
|
||||
|
||||
|
||||
# Minimal valid choice so _parse reaches _extract_usage.
|
||||
_DICT_CHOICE = {"message": {"content": "Hello"}}
|
||||
|
||||
class _FakeMessage:
|
||||
content = "Hello"
|
||||
tool_calls = None
|
||||
|
||||
|
||||
class _FakeChoice:
|
||||
message = _FakeMessage()
|
||||
finish_reason = "stop"
|
||||
|
||||
|
||||
# --- dict-based response (raw JSON / mapping) ---
|
||||
|
||||
def test_extract_usage_openai_cached_tokens_dict():
|
||||
"""prompt_tokens_details.cached_tokens from a dict response."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 2000,
|
||||
"completion_tokens": 300,
|
||||
"total_tokens": 2300,
|
||||
"prompt_tokens_details": {"cached_tokens": 1200},
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
assert result.usage["prompt_tokens"] == 2000
|
||||
|
||||
|
||||
def test_extract_usage_deepseek_cached_tokens_dict():
|
||||
"""prompt_cache_hit_tokens from a DeepSeek dict response."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 1500,
|
||||
"completion_tokens": 200,
|
||||
"total_tokens": 1700,
|
||||
"prompt_cache_hit_tokens": 1200,
|
||||
"prompt_cache_miss_tokens": 300,
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
|
||||
|
||||
def test_extract_usage_no_cached_tokens_dict():
|
||||
"""Response without any cache fields -> no cached_tokens key."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 200,
|
||||
"total_tokens": 1200,
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert "cached_tokens" not in result.usage
|
||||
|
||||
|
||||
def test_extract_usage_openai_cached_zero_dict():
|
||||
"""cached_tokens=0 should NOT be included (same as existing fields)."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 2000,
|
||||
"completion_tokens": 300,
|
||||
"total_tokens": 2300,
|
||||
"prompt_tokens_details": {"cached_tokens": 0},
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert "cached_tokens" not in result.usage
|
||||
|
||||
|
||||
# --- object-based response (OpenAI SDK Pydantic model) ---
|
||||
|
||||
def test_extract_usage_openai_cached_tokens_obj():
|
||||
"""prompt_tokens_details.cached_tokens from an SDK object response."""
|
||||
p = _provider()
|
||||
usage_obj = FakeUsage(
|
||||
prompt_tokens=2000,
|
||||
completion_tokens=300,
|
||||
total_tokens=2300,
|
||||
prompt_tokens_details=FakePromptDetails(cached_tokens=1200),
|
||||
)
|
||||
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
|
||||
|
||||
def test_extract_usage_deepseek_cached_tokens_obj():
|
||||
"""prompt_cache_hit_tokens from a DeepSeek SDK object response."""
|
||||
p = _provider()
|
||||
usage_obj = FakeUsage(
|
||||
prompt_tokens=1500,
|
||||
completion_tokens=200,
|
||||
total_tokens=1700,
|
||||
prompt_cache_hit_tokens=1200,
|
||||
)
|
||||
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
|
||||
|
||||
def test_extract_usage_stepfun_top_level_cached_tokens_dict():
|
||||
"""StepFun/Moonshot: usage.cached_tokens at top level (not nested)."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 591,
|
||||
"completion_tokens": 120,
|
||||
"total_tokens": 711,
|
||||
"cached_tokens": 512,
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 512
|
||||
|
||||
|
||||
def test_extract_usage_stepfun_top_level_cached_tokens_obj():
|
||||
"""StepFun/Moonshot: usage.cached_tokens as SDK object attribute."""
|
||||
p = _provider()
|
||||
usage_obj = FakeUsage(
|
||||
prompt_tokens=591,
|
||||
completion_tokens=120,
|
||||
total_tokens=711,
|
||||
cached_tokens=512,
|
||||
)
|
||||
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 512
|
||||
|
||||
|
||||
def test_extract_usage_priority_nested_over_top_level_dict():
|
||||
"""When both nested and top-level cached_tokens exist, nested wins."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 2000,
|
||||
"completion_tokens": 300,
|
||||
"total_tokens": 2300,
|
||||
"prompt_tokens_details": {"cached_tokens": 100},
|
||||
"cached_tokens": 500,
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 100
|
||||
|
||||
|
||||
def test_anthropic_maps_cache_fields_to_cached_tokens():
|
||||
"""Anthropic's cache_read_input_tokens should map to cached_tokens."""
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
usage_obj = FakeUsage(
|
||||
input_tokens=800,
|
||||
output_tokens=200,
|
||||
cache_creation_input_tokens=0,
|
||||
cache_read_input_tokens=1200,
|
||||
)
|
||||
content_block = FakeUsage(type="text", text="hello")
|
||||
response = FakeUsage(
|
||||
id="msg_1",
|
||||
type="message",
|
||||
stop_reason="end_turn",
|
||||
content=[content_block],
|
||||
usage=usage_obj,
|
||||
)
|
||||
result = AnthropicProvider._parse_response(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
assert result.usage["prompt_tokens"] == 800
|
||||
|
||||
|
||||
def test_anthropic_no_cache_fields():
|
||||
"""Anthropic response without cache fields should not have cached_tokens."""
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
usage_obj = FakeUsage(input_tokens=800, output_tokens=200)
|
||||
content_block = FakeUsage(type="text", text="hello")
|
||||
response = FakeUsage(
|
||||
id="msg_1",
|
||||
type="message",
|
||||
stop_reason="end_turn",
|
||||
content=[content_block],
|
||||
usage=usage_obj,
|
||||
)
|
||||
result = AnthropicProvider._parse_response(response)
|
||||
assert "cached_tokens" not in result.usage
|
||||
59
tests/test_build_status.py
Normal file
59
tests/test_build_status.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Tests for build_status_content cache hit rate display."""
|
||||
|
||||
from nanobot.utils.helpers import build_status_content
|
||||
|
||||
|
||||
def test_status_shows_cache_hit_rate():
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="glm-4-plus",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 1200},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=10,
|
||||
context_tokens_estimate=5000,
|
||||
)
|
||||
assert "60% cached" in content
|
||||
assert "2000 in / 300 out" in content
|
||||
|
||||
|
||||
def test_status_no_cache_info():
|
||||
"""Without cached_tokens, display should not show cache percentage."""
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="glm-4-plus",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 2000, "completion_tokens": 300},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=10,
|
||||
context_tokens_estimate=5000,
|
||||
)
|
||||
assert "cached" not in content.lower()
|
||||
assert "2000 in / 300 out" in content
|
||||
|
||||
|
||||
def test_status_zero_cached_tokens():
|
||||
"""cached_tokens=0 should not show cache percentage."""
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="glm-4-plus",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 0},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=10,
|
||||
context_tokens_estimate=5000,
|
||||
)
|
||||
assert "cached" not in content.lower()
|
||||
|
||||
|
||||
def test_status_100_percent_cached():
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="glm-4-plus",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 1000, "completion_tokens": 100, "cached_tokens": 1000},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=5,
|
||||
context_tokens_estimate=3000,
|
||||
)
|
||||
assert "100% cached" in content
|
||||
Loading…
x
Reference in New Issue
Block a user