diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 555c6b9f5..4f4c68ffe 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -257,6 +257,15 @@ class AgentLoop: def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: return loop_self._strip_think(content) + async def after_iteration(self, context: AgentHookContext) -> None: + u = context.usage or {} + logger.debug( + "LLM usage: prompt={} completion={} cached={}", + u.get("prompt_tokens", 0), + u.get("completion_tokens", 0), + u.get("cached_tokens", 0), + ) + result = await self.runner.run(AgentRunSpec( initial_messages=initial_messages, tools=self.tools, diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 819a26a99..4fec539dd 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -6,8 +6,6 @@ import asyncio from dataclasses import dataclass, field from typing import Any -from loguru import logger - from nanobot.agent.hook import AgentHook, AgentHookContext from nanobot.agent.tools.registry import ToolRegistry from nanobot.providers.base import LLMProvider, ToolCallRequest @@ -62,7 +60,7 @@ class AgentRunner: messages = list(spec.initial_messages) final_content: str | None = None tools_used: list[str] = [] - usage = {"prompt_tokens": 0, "completion_tokens": 0, "cached_tokens": 0} + usage: dict[str, int] = {} error: str | None = None stop_reason = "completed" tool_events: list[dict[str, str]] = [] @@ -94,27 +92,15 @@ class AgentRunner: response = await self.provider.chat_with_retry(**kwargs) raw_usage = response.usage or {} - iter_usage = { - "prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0), - "completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0), - } - # Pass through cached_tokens if present. + context.response = response + context.usage = raw_usage + context.tool_calls = list(response.tool_calls) + # Accumulate standard fields into result usage. + usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0) + usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0) cached = raw_usage.get("cached_tokens") if cached: - iter_usage["cached_tokens"] = int(cached) - usage["prompt_tokens"] += iter_usage["prompt_tokens"] - usage["completion_tokens"] += iter_usage["completion_tokens"] - if "cached_tokens" in iter_usage: - usage["cached_tokens"] = usage.get("cached_tokens", 0) + iter_usage["cached_tokens"] - context.response = response - context.usage = iter_usage - logger.debug( - "LLM usage: prompt={} completion={} cached={}", - iter_usage["prompt_tokens"], - iter_usage["completion_tokens"], - iter_usage.get("cached_tokens", 0), - ) - context.tool_calls = list(response.tool_calls) + usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached) if response.has_tool_calls: if hook.wants_streaming():