From 9a2f38d7a2f6b6541b9410fa88bf4c1e7dac3a4f Mon Sep 17 00:00:00 2001 From: chengyongru Date: Mon, 30 Mar 2026 16:50:32 +0800 Subject: [PATCH] feat(agent): accumulate usage across iterations and pass through cached_tokens --- nanobot/agent/runner.py | 14 +++++-- tests/agent/test_runner.py | 79 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index d6242a6b4..49f1d4487 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -60,7 +60,7 @@ class AgentRunner: messages = list(spec.initial_messages) final_content: str | None = None tools_used: list[str] = [] - usage = {"prompt_tokens": 0, "completion_tokens": 0} + usage = {"prompt_tokens": 0, "completion_tokens": 0, "cached_tokens": 0} error: str | None = None stop_reason = "completed" tool_events: list[dict[str, str]] = [] @@ -92,12 +92,20 @@ class AgentRunner: response = await self.provider.chat_with_retry(**kwargs) raw_usage = response.usage or {} - usage = { + iter_usage = { "prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0), "completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0), } + # Pass through cached_tokens if present. + cached = raw_usage.get("cached_tokens") + if cached: + iter_usage["cached_tokens"] = int(cached) + usage["prompt_tokens"] += iter_usage["prompt_tokens"] + usage["completion_tokens"] += iter_usage["completion_tokens"] + if "cached_tokens" in iter_usage: + usage["cached_tokens"] = usage.get("cached_tokens", 0) + iter_usage["cached_tokens"] context.response = response - context.usage = usage + context.usage = iter_usage context.tool_calls = list(response.tool_calls) if response.has_tool_calls: diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 86b0ba710..98f1d73ae 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -333,3 +333,82 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon args = mgr._announce_result.await_args.args assert args[3] == "Task completed but no final response was generated." assert args[5] == "ok" + + +@pytest.mark.asyncio +async def test_runner_accumulates_usage_and_preserves_cached_tokens(): + """Runner should accumulate prompt/completion tokens across iterations + and preserve cached_tokens from provider responses.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="thinking", + tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})], + usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80}, + ) + return LLMResponse( + content="done", + tool_calls=[], + usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=3, + )) + + # Usage should be accumulated across iterations + assert result.usage["prompt_tokens"] == 300 # 100 + 200 + assert result.usage["completion_tokens"] == 30 # 10 + 20 + assert result.usage["cached_tokens"] == 230 # 80 + 150 + + +@pytest.mark.asyncio +async def test_runner_passes_cached_tokens_to_hook_context(): + """Hook context.usage should contain cached_tokens.""" + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_usage: list[dict] = [] + + class UsageHook(AgentHook): + async def after_iteration(self, context: AgentHookContext) -> None: + captured_usage.append(dict(context.usage)) + + async def chat_with_retry(**kwargs): + return LLMResponse( + content="done", + tool_calls=[], + usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + hook=UsageHook(), + )) + + assert len(captured_usage) == 1 + assert captured_usage[0]["cached_tokens"] == 150