mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-30 14:56:01 +00:00
feat(agent): accumulate usage across iterations and pass through cached_tokens
This commit is contained in:
parent
d02ba20971
commit
9a2f38d7a2
@ -60,7 +60,7 @@ class AgentRunner:
|
|||||||
messages = list(spec.initial_messages)
|
messages = list(spec.initial_messages)
|
||||||
final_content: str | None = None
|
final_content: str | None = None
|
||||||
tools_used: list[str] = []
|
tools_used: list[str] = []
|
||||||
usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
usage = {"prompt_tokens": 0, "completion_tokens": 0, "cached_tokens": 0}
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
stop_reason = "completed"
|
stop_reason = "completed"
|
||||||
tool_events: list[dict[str, str]] = []
|
tool_events: list[dict[str, str]] = []
|
||||||
@ -92,12 +92,20 @@ class AgentRunner:
|
|||||||
response = await self.provider.chat_with_retry(**kwargs)
|
response = await self.provider.chat_with_retry(**kwargs)
|
||||||
|
|
||||||
raw_usage = response.usage or {}
|
raw_usage = response.usage or {}
|
||||||
usage = {
|
iter_usage = {
|
||||||
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
|
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
|
||||||
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
|
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
|
||||||
}
|
}
|
||||||
|
# Pass through cached_tokens if present.
|
||||||
|
cached = raw_usage.get("cached_tokens")
|
||||||
|
if cached:
|
||||||
|
iter_usage["cached_tokens"] = int(cached)
|
||||||
|
usage["prompt_tokens"] += iter_usage["prompt_tokens"]
|
||||||
|
usage["completion_tokens"] += iter_usage["completion_tokens"]
|
||||||
|
if "cached_tokens" in iter_usage:
|
||||||
|
usage["cached_tokens"] = usage.get("cached_tokens", 0) + iter_usage["cached_tokens"]
|
||||||
context.response = response
|
context.response = response
|
||||||
context.usage = usage
|
context.usage = iter_usage
|
||||||
context.tool_calls = list(response.tool_calls)
|
context.tool_calls = list(response.tool_calls)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
|
|||||||
@ -333,3 +333,82 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
|
|||||||
args = mgr._announce_result.await_args.args
|
args = mgr._announce_result.await_args.args
|
||||||
assert args[3] == "Task completed but no final response was generated."
|
assert args[3] == "Task completed but no final response was generated."
|
||||||
assert args[5] == "ok"
|
assert args[5] == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_accumulates_usage_and_preserves_cached_tokens():
|
||||||
|
"""Runner should accumulate prompt/completion tokens across iterations
|
||||||
|
and preserve cached_tokens from provider responses."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="thinking",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
|
||||||
|
usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80},
|
||||||
|
)
|
||||||
|
return LLMResponse(
|
||||||
|
content="done",
|
||||||
|
tool_calls=[],
|
||||||
|
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="file content")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=3,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Usage should be accumulated across iterations
|
||||||
|
assert result.usage["prompt_tokens"] == 300 # 100 + 200
|
||||||
|
assert result.usage["completion_tokens"] == 30 # 10 + 20
|
||||||
|
assert result.usage["cached_tokens"] == 230 # 80 + 150
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_passes_cached_tokens_to_hook_context():
|
||||||
|
"""Hook context.usage should contain cached_tokens."""
|
||||||
|
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_usage: list[dict] = []
|
||||||
|
|
||||||
|
class UsageHook(AgentHook):
|
||||||
|
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||||
|
captured_usage.append(dict(context.usage))
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
return LLMResponse(
|
||||||
|
content="done",
|
||||||
|
tool_calls=[],
|
||||||
|
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
hook=UsageHook(),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert len(captured_usage) == 1
|
||||||
|
assert captured_usage[0]["cached_tokens"] == 150
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user