diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 98701c004..f52b17d14 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -97,6 +97,15 @@ class _LoopHook(AgentHook): logger.info("Tool call: {}({})", tc.name, args_str[:200]) self._loop._set_tool_context(self._channel, self._chat_id, self._message_id) + async def after_iteration(self, context: AgentHookContext) -> None: + u = context.usage or {} + logger.debug( + "LLM usage: prompt={} completion={} cached={}", + u.get("prompt_tokens", 0), + u.get("completion_tokens", 0), + u.get("cached_tokens", 0), + ) + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: return self._loop._strip_think(content) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index d6242a6b4..4fec539dd 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -60,7 +60,7 @@ class AgentRunner: messages = list(spec.initial_messages) final_content: str | None = None tools_used: list[str] = [] - usage = {"prompt_tokens": 0, "completion_tokens": 0} + usage: dict[str, int] = {} error: str | None = None stop_reason = "completed" tool_events: list[dict[str, str]] = [] @@ -92,13 +92,15 @@ class AgentRunner: response = await self.provider.chat_with_retry(**kwargs) raw_usage = response.usage or {} - usage = { - "prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0), - "completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0), - } context.response = response - context.usage = usage + context.usage = raw_usage context.tool_calls = list(response.tool_calls) + # Accumulate standard fields into result usage. + usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0) + usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0) + cached = raw_usage.get("cached_tokens") + if cached: + usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached) if response.has_tool_calls: if hook.wants_streaming(): diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py index 3c789e730..fabcd5656 100644 --- a/nanobot/providers/anthropic_provider.py +++ b/nanobot/providers/anthropic_provider.py @@ -379,6 +379,10 @@ class AnthropicProvider(LLMProvider): val = getattr(response.usage, attr, 0) if val: usage[attr] = val + # Normalize to cached_tokens for downstream consistency. + cache_read = usage.get("cache_read_input_tokens", 0) + if cache_read: + usage["cached_tokens"] = cache_read return LLMResponse( content="".join(content_parts) or None, diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 397b8e797..e9dd08645 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -308,6 +308,13 @@ class OpenAICompatProvider(LLMProvider): @classmethod def _extract_usage(cls, response: Any) -> dict[str, int]: + """Extract token usage from an OpenAI-compatible response. + + Handles both dict-based (raw JSON) and object-based (SDK Pydantic) + responses. Provider-specific ``cached_tokens`` fields are normalised + under a single key; see the priority chain inside for details. + """ + # --- resolve usage object --- usage_obj = None response_map = cls._maybe_mapping(response) if response_map is not None: @@ -317,19 +324,53 @@ class OpenAICompatProvider(LLMProvider): usage_map = cls._maybe_mapping(usage_obj) if usage_map is not None: - return { + result = { "prompt_tokens": int(usage_map.get("prompt_tokens") or 0), "completion_tokens": int(usage_map.get("completion_tokens") or 0), "total_tokens": int(usage_map.get("total_tokens") or 0), } - - if usage_obj: - return { + elif usage_obj: + result = { "prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0, "completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0, "total_tokens": getattr(usage_obj, "total_tokens", 0) or 0, } - return {} + else: + return {} + + # --- cached_tokens (normalised across providers) --- + # Try nested paths first (dict), fall back to attribute (SDK object). + # Priority order ensures the most specific field wins. + for path in ( + ("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI + ("cached_tokens",), # StepFun/Moonshot (top-level) + ("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow + ): + cached = cls._get_nested_int(usage_map, path) + if not cached and usage_obj: + cached = cls._get_nested_int(usage_obj, path) + if cached: + result["cached_tokens"] = cached + break + + return result + + @staticmethod + def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int: + """Drill into *obj* by *path* segments and return an ``int`` value. + + Supports both dict-key access and attribute access so it works + uniformly with raw JSON dicts **and** SDK Pydantic models. + """ + current = obj + for segment in path: + if current is None: + return 0 + if isinstance(current, dict): + current = current.get(segment) + else: + current = getattr(current, segment, None) + return int(current or 0) if current is not None else 0 def _parse(self, response: Any) -> LLMResponse: if isinstance(response, str): diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index a10a4f18b..64b8448ec 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -255,14 +255,18 @@ def build_status_content( ) last_in = last_usage.get("prompt_tokens", 0) last_out = last_usage.get("completion_tokens", 0) + cached = last_usage.get("cached_tokens", 0) ctx_total = max(context_window_tokens, 0) ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0 ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate) ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a" + token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out" + if cached and last_in: + token_line += f" ({cached * 100 // last_in}% cached)" return "\n".join([ f"\U0001f408 nanobot v{version}", f"\U0001f9e0 Model: {model}", - f"\U0001f4ca Tokens: {last_in} in / {last_out} out", + token_line, f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)", f"\U0001f4ac Session: {session_msg_count} messages", f"\u23f1 Uptime: {uptime}", diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 86b0ba710..98f1d73ae 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -333,3 +333,82 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon args = mgr._announce_result.await_args.args assert args[3] == "Task completed but no final response was generated." assert args[5] == "ok" + + +@pytest.mark.asyncio +async def test_runner_accumulates_usage_and_preserves_cached_tokens(): + """Runner should accumulate prompt/completion tokens across iterations + and preserve cached_tokens from provider responses.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="thinking", + tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})], + usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80}, + ) + return LLMResponse( + content="done", + tool_calls=[], + usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=3, + )) + + # Usage should be accumulated across iterations + assert result.usage["prompt_tokens"] == 300 # 100 + 200 + assert result.usage["completion_tokens"] == 30 # 10 + 20 + assert result.usage["cached_tokens"] == 230 # 80 + 150 + + +@pytest.mark.asyncio +async def test_runner_passes_cached_tokens_to_hook_context(): + """Hook context.usage should contain cached_tokens.""" + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_usage: list[dict] = [] + + class UsageHook(AgentHook): + async def after_iteration(self, context: AgentHookContext) -> None: + captured_usage.append(dict(context.usage)) + + async def chat_with_retry(**kwargs): + return LLMResponse( + content="done", + tool_calls=[], + usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + hook=UsageHook(), + )) + + assert len(captured_usage) == 1 + assert captured_usage[0]["cached_tokens"] == 150 diff --git a/tests/cli/test_restart_command.py b/tests/cli/test_restart_command.py index 3281afe2d..6efcdad0d 100644 --- a/tests/cli/test_restart_command.py +++ b/tests/cli/test_restart_command.py @@ -152,10 +152,12 @@ class TestRestartCommand: ]) await loop._run_agent_loop([]) - assert loop._last_usage == {"prompt_tokens": 9, "completion_tokens": 4} + assert loop._last_usage["prompt_tokens"] == 9 + assert loop._last_usage["completion_tokens"] == 4 await loop._run_agent_loop([]) - assert loop._last_usage == {"prompt_tokens": 0, "completion_tokens": 0} + assert loop._last_usage["prompt_tokens"] == 0 + assert loop._last_usage["completion_tokens"] == 0 @pytest.mark.asyncio async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self): diff --git a/tests/providers/test_cached_tokens.py b/tests/providers/test_cached_tokens.py new file mode 100644 index 000000000..fce22cf65 --- /dev/null +++ b/tests/providers/test_cached_tokens.py @@ -0,0 +1,231 @@ +"""Tests for cached token extraction from OpenAI-compatible providers.""" + +from __future__ import annotations + +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + + +class FakeUsage: + """Mimics an OpenAI SDK usage object (has attributes, not dict keys).""" + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class FakePromptDetails: + """Mimics prompt_tokens_details sub-object.""" + def __init__(self, cached_tokens=0): + self.cached_tokens = cached_tokens + + +class _FakeSpec: + supports_prompt_caching = False + model_id_prefix = None + strip_model_prefix = False + max_completion_tokens = False + reasoning_effort = None + + +def _provider(): + from unittest.mock import MagicMock + p = OpenAICompatProvider.__new__(OpenAICompatProvider) + p.client = MagicMock() + p.spec = _FakeSpec() + return p + + +# Minimal valid choice so _parse reaches _extract_usage. +_DICT_CHOICE = {"message": {"content": "Hello"}} + +class _FakeMessage: + content = "Hello" + tool_calls = None + + +class _FakeChoice: + message = _FakeMessage() + finish_reason = "stop" + + +# --- dict-based response (raw JSON / mapping) --- + +def test_extract_usage_openai_cached_tokens_dict(): + """prompt_tokens_details.cached_tokens from a dict response.""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 2000, + "completion_tokens": 300, + "total_tokens": 2300, + "prompt_tokens_details": {"cached_tokens": 1200}, + } + } + result = p._parse(response) + assert result.usage["cached_tokens"] == 1200 + assert result.usage["prompt_tokens"] == 2000 + + +def test_extract_usage_deepseek_cached_tokens_dict(): + """prompt_cache_hit_tokens from a DeepSeek dict response.""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 1500, + "completion_tokens": 200, + "total_tokens": 1700, + "prompt_cache_hit_tokens": 1200, + "prompt_cache_miss_tokens": 300, + } + } + result = p._parse(response) + assert result.usage["cached_tokens"] == 1200 + + +def test_extract_usage_no_cached_tokens_dict(): + """Response without any cache fields -> no cached_tokens key.""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 200, + "total_tokens": 1200, + } + } + result = p._parse(response) + assert "cached_tokens" not in result.usage + + +def test_extract_usage_openai_cached_zero_dict(): + """cached_tokens=0 should NOT be included (same as existing fields).""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 2000, + "completion_tokens": 300, + "total_tokens": 2300, + "prompt_tokens_details": {"cached_tokens": 0}, + } + } + result = p._parse(response) + assert "cached_tokens" not in result.usage + + +# --- object-based response (OpenAI SDK Pydantic model) --- + +def test_extract_usage_openai_cached_tokens_obj(): + """prompt_tokens_details.cached_tokens from an SDK object response.""" + p = _provider() + usage_obj = FakeUsage( + prompt_tokens=2000, + completion_tokens=300, + total_tokens=2300, + prompt_tokens_details=FakePromptDetails(cached_tokens=1200), + ) + response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj) + result = p._parse(response) + assert result.usage["cached_tokens"] == 1200 + + +def test_extract_usage_deepseek_cached_tokens_obj(): + """prompt_cache_hit_tokens from a DeepSeek SDK object response.""" + p = _provider() + usage_obj = FakeUsage( + prompt_tokens=1500, + completion_tokens=200, + total_tokens=1700, + prompt_cache_hit_tokens=1200, + ) + response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj) + result = p._parse(response) + assert result.usage["cached_tokens"] == 1200 + + +def test_extract_usage_stepfun_top_level_cached_tokens_dict(): + """StepFun/Moonshot: usage.cached_tokens at top level (not nested).""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 591, + "completion_tokens": 120, + "total_tokens": 711, + "cached_tokens": 512, + } + } + result = p._parse(response) + assert result.usage["cached_tokens"] == 512 + + +def test_extract_usage_stepfun_top_level_cached_tokens_obj(): + """StepFun/Moonshot: usage.cached_tokens as SDK object attribute.""" + p = _provider() + usage_obj = FakeUsage( + prompt_tokens=591, + completion_tokens=120, + total_tokens=711, + cached_tokens=512, + ) + response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj) + result = p._parse(response) + assert result.usage["cached_tokens"] == 512 + + +def test_extract_usage_priority_nested_over_top_level_dict(): + """When both nested and top-level cached_tokens exist, nested wins.""" + p = _provider() + response = { + "choices": [_DICT_CHOICE], + "usage": { + "prompt_tokens": 2000, + "completion_tokens": 300, + "total_tokens": 2300, + "prompt_tokens_details": {"cached_tokens": 100}, + "cached_tokens": 500, + } + } + result = p._parse(response) + assert result.usage["cached_tokens"] == 100 + + +def test_anthropic_maps_cache_fields_to_cached_tokens(): + """Anthropic's cache_read_input_tokens should map to cached_tokens.""" + from nanobot.providers.anthropic_provider import AnthropicProvider + + usage_obj = FakeUsage( + input_tokens=800, + output_tokens=200, + cache_creation_input_tokens=0, + cache_read_input_tokens=1200, + ) + content_block = FakeUsage(type="text", text="hello") + response = FakeUsage( + id="msg_1", + type="message", + stop_reason="end_turn", + content=[content_block], + usage=usage_obj, + ) + result = AnthropicProvider._parse_response(response) + assert result.usage["cached_tokens"] == 1200 + assert result.usage["prompt_tokens"] == 800 + + +def test_anthropic_no_cache_fields(): + """Anthropic response without cache fields should not have cached_tokens.""" + from nanobot.providers.anthropic_provider import AnthropicProvider + + usage_obj = FakeUsage(input_tokens=800, output_tokens=200) + content_block = FakeUsage(type="text", text="hello") + response = FakeUsage( + id="msg_1", + type="message", + stop_reason="end_turn", + content=[content_block], + usage=usage_obj, + ) + result = AnthropicProvider._parse_response(response) + assert "cached_tokens" not in result.usage diff --git a/tests/test_build_status.py b/tests/test_build_status.py new file mode 100644 index 000000000..d98301cf7 --- /dev/null +++ b/tests/test_build_status.py @@ -0,0 +1,59 @@ +"""Tests for build_status_content cache hit rate display.""" + +from nanobot.utils.helpers import build_status_content + + +def test_status_shows_cache_hit_rate(): + content = build_status_content( + version="0.1.0", + model="glm-4-plus", + start_time=1000000.0, + last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 1200}, + context_window_tokens=128000, + session_msg_count=10, + context_tokens_estimate=5000, + ) + assert "60% cached" in content + assert "2000 in / 300 out" in content + + +def test_status_no_cache_info(): + """Without cached_tokens, display should not show cache percentage.""" + content = build_status_content( + version="0.1.0", + model="glm-4-plus", + start_time=1000000.0, + last_usage={"prompt_tokens": 2000, "completion_tokens": 300}, + context_window_tokens=128000, + session_msg_count=10, + context_tokens_estimate=5000, + ) + assert "cached" not in content.lower() + assert "2000 in / 300 out" in content + + +def test_status_zero_cached_tokens(): + """cached_tokens=0 should not show cache percentage.""" + content = build_status_content( + version="0.1.0", + model="glm-4-plus", + start_time=1000000.0, + last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 0}, + context_window_tokens=128000, + session_msg_count=10, + context_tokens_estimate=5000, + ) + assert "cached" not in content.lower() + + +def test_status_100_percent_cached(): + content = build_status_content( + version="0.1.0", + model="glm-4-plus", + start_time=1000000.0, + last_usage={"prompt_tokens": 1000, "completion_tokens": 100, "cached_tokens": 1000}, + context_window_tokens=128000, + session_msg_count=5, + context_tokens_estimate=3000, + ) + assert "100% cached" in content