mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
feat: configurable context budget for tool-loop iterations (#2317)
* feat: add contextBudgetTokens config field for tool-loop trimming * feat: implement _trim_history_for_budget for tool-loop cost reduction * feat: thread contextBudgetTokens into AgentLoop constructor * feat: wire context budget trimming into agent loop * refactor: move trim_history_for_budget to helpers and add docs - Extract trim_history_for_budget() as a pure function in helpers.py - AgentLoop._trim_history_for_budget becomes a thin wrapper - Add docs/CONTEXT_BUDGET.md with usage guide and trade-off notes - Replace wrapper tests with direct helper unit tests --------- Co-authored-by: chengyongru <chengyongru.ai@gmail.com>
This commit is contained in:
parent
0182ce2852
commit
528b3cfe5a
62
docs/CONTEXT_BUDGET.md
Normal file
62
docs/CONTEXT_BUDGET.md
Normal file
@ -0,0 +1,62 @@
|
||||
# Context Budget (`context_budget_tokens`)
|
||||
|
||||
Caps how many tokens of old session history are sent to the LLM during tool-loop iterations 2+. Reduces cost and first-token latency by trimming history between turns.
|
||||
|
||||
## How It Works
|
||||
|
||||
During multi-turn tool-use sessions, each iteration re-sends the full conversation history. `context_budget_tokens` limits how many old tokens are included:
|
||||
|
||||
- **Iteration 1** — always receives full context (no trimming)
|
||||
- **Iteration 2+** — old history is trimmed to fit within the budget; current turn is never trimmed
|
||||
- **Memory consolidation** — runs before/after the loop and always sees the full canonical history; trimming only affects the LLM's view
|
||||
|
||||
## Configuration
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"context_budget_tokens": 1000
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Value | Behavior |
|
||||
|---|---|
|
||||
|
||||
---
|
||||
|
||||
`0` (default) | No trimming — full history sent every iteration
|
||||
`4000` | Conservative — barely trims in practice; good for multi-step tasks
|
||||
`1000` | Aggressive — significant savings; works well for typical linear tasks
|
||||
`< 500` | Clamped to `500` minimum when positive (1–2 message pairs at typical token density)
|
||||
|
||||
## Trade-offs
|
||||
|
||||
**Cost & latency** — Trimming reduces tokens sent each iteration, which saves money and lowers first-token time (TTFT). This is nanobot's primary sweet spot.
|
||||
|
||||
**Context loss** — Older context is not visible to the LLM in later iterations. For tasks that genuinely require 20+ iterations of history to stay coherent, consider `0` or `4000`.
|
||||
|
||||
**Tool-result truncation** — Large results from a previous turn (e.g., reading a 10,000-line file in Round 1, then editing in Round 2) can be trimmed. The agent can re-read the file via its tools — this is a 1-tool-call recovery cost, not a failure.
|
||||
|
||||
**Prefix caching** — Some providers (e.g., DeepSeek) use implicit prefix-based caching. Aggressive trimming breaks prefix matching and can reduce cache hit rates. For these providers, `0` or a high value may be more cost-effective overall.
|
||||
|
||||
## When to Use
|
||||
|
||||
| Use case | Recommended value |
|
||||
|---|---|
|
||||
| Simple read → process → act chains | `1000` |
|
||||
| Multi-step reasoning with tool chains | `4000` |
|
||||
| Complex debugging / long task traces | `0` |
|
||||
| Providers with implicit prefix caching | `0` or `4000` |
|
||||
| Long file operations across turns | `0` or re-read via tools |
|
||||
|
||||
## Example
|
||||
|
||||
```
|
||||
Turn 1: User asks to read a.py (10k lines)
|
||||
Turn 2: User asks to edit line 100
|
||||
```
|
||||
|
||||
With `context_budget_tokens=500`, the file-content result from Turn 1 may be trimmed before Turn 2. The agent will re-read the file to perform the edit — a 1-call recovery. This is normal behavior for the feature; it is not a bug.
|
||||
@ -27,7 +27,7 @@ from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.utils.helpers import build_status_content
|
||||
from nanobot.utils.helpers import build_status_content, trim_history_for_budget
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
@ -59,6 +59,7 @@ class AgentLoop:
|
||||
model: str | None = None,
|
||||
max_iterations: int = 40,
|
||||
context_window_tokens: int = 65_536,
|
||||
context_budget_tokens: int = 0,
|
||||
web_search_config: WebSearchConfig | None = None,
|
||||
web_proxy: str | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
@ -78,6 +79,7 @@ class AgentLoop:
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = max_iterations
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self.context_budget_tokens = max(context_budget_tokens, 500) if context_budget_tokens > 0 else 0
|
||||
self.web_search_config = web_search_config or WebSearchConfig()
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
@ -232,6 +234,22 @@ class AgentLoop:
|
||||
metadata={"render_as": "text"},
|
||||
)
|
||||
|
||||
def _trim_history_for_budget(
|
||||
self,
|
||||
messages: list[dict],
|
||||
turn_start_index: int,
|
||||
iteration: int,
|
||||
) -> list[dict]:
|
||||
"""Thin wrapper: delegates to trim_history_for_budget helper."""
|
||||
return trim_history_for_budget(
|
||||
messages,
|
||||
turn_start_index,
|
||||
iteration,
|
||||
self.context_budget_tokens,
|
||||
Session._find_legal_start,
|
||||
)
|
||||
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
@ -250,6 +268,7 @@ class AgentLoop:
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
turn_start_index = len(initial_messages) - 1
|
||||
|
||||
# Wrap on_stream with stateful think-tag filter so downstream
|
||||
# consumers (CLI, channels) never see <think> blocks.
|
||||
@ -271,20 +290,23 @@ class AgentLoop:
|
||||
|
||||
tool_defs = self.tools.get_definitions()
|
||||
|
||||
send_messages = self._trim_history_for_budget(
|
||||
messages, turn_start_index, iteration,
|
||||
)
|
||||
|
||||
if on_stream:
|
||||
response = await self.provider.chat_stream_with_retry(
|
||||
messages=messages,
|
||||
messages=send_messages,
|
||||
tools=tool_defs,
|
||||
model=self.model,
|
||||
on_content_delta=_filtered_stream,
|
||||
)
|
||||
else:
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
messages=send_messages,
|
||||
tools=tool_defs,
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
usage = response.usage or {}
|
||||
self._last_usage = {
|
||||
"prompt_tokens": int(usage.get("prompt_tokens", 0) or 0),
|
||||
|
||||
@ -527,6 +527,7 @@ def gateway(
|
||||
model=config.agents.defaults.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
context_budget_tokens=config.agents.defaults.context_budget_tokens,
|
||||
web_search_config=config.tools.web.search,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
@ -722,6 +723,7 @@ def agent(
|
||||
model=config.agents.defaults.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
context_budget_tokens=config.agents.defaults.context_budget_tokens,
|
||||
web_search_config=config.tools.web.search,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
|
||||
@ -39,6 +39,7 @@ class AgentDefaults(Base):
|
||||
context_window_tokens: int = 65_536
|
||||
temperature: float = 0.1
|
||||
max_tool_iterations: int = 40
|
||||
context_budget_tokens: int = 0 # Max old-history tokens during tool iterations (0 = no trim)
|
||||
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
|
||||
|
||||
|
||||
|
||||
@ -6,9 +6,10 @@ import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
import tiktoken
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def strip_think(text: str) -> str:
|
||||
@ -201,6 +202,58 @@ def estimate_message_tokens(message: dict[str, Any]) -> int:
|
||||
return max(4, len(payload) // 4 + 4)
|
||||
|
||||
|
||||
def trim_history_for_budget(
|
||||
messages: list[dict[str, Any]],
|
||||
turn_start_index: int,
|
||||
iteration: int,
|
||||
context_budget_tokens: int,
|
||||
find_legal_start: Callable[[list[dict[str, Any]]], int],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Trim old session history to fit within context_budget_tokens.
|
||||
|
||||
Returns the original list unchanged when no trimming is needed.
|
||||
Only trims on iteration >= 2 when context_budget_tokens > 0.
|
||||
Current-turn messages (from turn_start_index onward) are never trimmed.
|
||||
"""
|
||||
if context_budget_tokens <= 0 or iteration <= 1:
|
||||
return messages
|
||||
if turn_start_index <= 1:
|
||||
return messages # no old history to trim
|
||||
|
||||
system = messages[:1]
|
||||
old_history = messages[1:turn_start_index]
|
||||
current_turn = messages[turn_start_index:]
|
||||
|
||||
# Pre-compute token counts to avoid double-estimation
|
||||
token_counts = [estimate_message_tokens(m) for m in old_history]
|
||||
total = sum(token_counts)
|
||||
if total <= context_budget_tokens:
|
||||
return messages # fits, no trim needed
|
||||
|
||||
# Find cut index (O(n) scan, then single slice)
|
||||
cut = 0
|
||||
removed_tokens = 0
|
||||
while cut < len(old_history) and total > context_budget_tokens:
|
||||
removed_tokens += token_counts[cut]
|
||||
total -= token_counts[cut]
|
||||
cut += 1
|
||||
old_history = old_history[cut:]
|
||||
|
||||
# Fix orphaned tool results after trimming
|
||||
legal_start = find_legal_start(old_history)
|
||||
if legal_start > 0:
|
||||
old_history = old_history[legal_start:]
|
||||
|
||||
removed_count = turn_start_index - 1 - len(old_history)
|
||||
if removed_count > 0:
|
||||
logger.debug(
|
||||
"Context budget: trimmed {} history messages ({} tokens) for iteration {}",
|
||||
removed_count, removed_tokens, iteration,
|
||||
)
|
||||
|
||||
return system + old_history + current_turn
|
||||
|
||||
|
||||
def estimate_prompt_tokens_chain(
|
||||
provider: Any,
|
||||
model: str | None,
|
||||
|
||||
174
tests/test_trim_history_for_budget.py
Normal file
174
tests/test_trim_history_for_budget.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""Direct unit tests for trim_history_for_budget() helper."""
|
||||
|
||||
import pytest
|
||||
from nanobot.session.manager import Session
|
||||
from nanobot.utils.helpers import estimate_message_tokens, trim_history_for_budget
|
||||
|
||||
|
||||
def _msg(role: str, content: str, **kw) -> dict:
|
||||
return {"role": role, "content": content, **kw}
|
||||
|
||||
|
||||
def _system(content: str = "You are a bot.") -> dict:
|
||||
return _msg("system", content)
|
||||
|
||||
|
||||
def _user(content: str) -> dict:
|
||||
return _msg("user", content)
|
||||
|
||||
|
||||
def _assistant(content: str | None = None, tool_calls: list | None = None) -> dict:
|
||||
m = {"role": "assistant", "content": content}
|
||||
if tool_calls:
|
||||
m["tool_calls"] = tool_calls
|
||||
return m
|
||||
|
||||
|
||||
def _tool_call(tc_id: str, name: str = "exec", args: str = "{}") -> dict:
|
||||
return {"id": tc_id, "type": "function", "function": {"name": name, "arguments": args}}
|
||||
|
||||
|
||||
def _tool_result(tc_id: str, content: str = "ok") -> dict:
|
||||
return {"role": "tool", "tool_call_id": tc_id, "name": "exec", "content": content}
|
||||
|
||||
|
||||
# --- Early-exit cases ---
|
||||
|
||||
def test_budget_zero_returns_same_list():
|
||||
msgs = [_system(), _user("old1"), _assistant("old reply"), _user("current")]
|
||||
result = trim_history_for_budget(msgs, turn_start_index=3, iteration=2, context_budget_tokens=0, find_legal_start=Session._find_legal_start)
|
||||
assert result is msgs
|
||||
|
||||
|
||||
def test_iteration_one_never_trims():
|
||||
msgs = [_system(), _user("old1"), _assistant("old reply"), _user("current")]
|
||||
result = trim_history_for_budget(msgs, turn_start_index=3, iteration=1, context_budget_tokens=1000, find_legal_start=Session._find_legal_start)
|
||||
assert result is msgs
|
||||
|
||||
|
||||
def test_turn_start_at_one_returns_same():
|
||||
"""turn_start_index=1 means no old history before the current turn."""
|
||||
msgs = [_system(), _user("current")]
|
||||
result = trim_history_for_budget(msgs, turn_start_index=1, iteration=2, context_budget_tokens=0, find_legal_start=Session._find_legal_start)
|
||||
assert result is msgs
|
||||
|
||||
|
||||
def test_history_under_budget_returns_unchanged():
|
||||
msgs = [_system(), _user("short msg"), _assistant("short reply"), _user("current")]
|
||||
result = trim_history_for_budget(msgs, turn_start_index=3, iteration=2, context_budget_tokens=50000, find_legal_start=Session._find_legal_start)
|
||||
assert result is msgs
|
||||
|
||||
|
||||
# --- Trimming cases ---
|
||||
|
||||
def test_trim_removes_oldest_messages():
|
||||
old_msgs = []
|
||||
for i in range(40):
|
||||
old_msgs.append(_user(f"old message number {i} padding extra text here"))
|
||||
old_msgs.append(_assistant(f"reply to message {i} with more padding"))
|
||||
|
||||
current_user = _user("current task")
|
||||
current_tc = _assistant(None, [_tool_call("tc1")])
|
||||
current_result = _tool_result("tc1", "done")
|
||||
|
||||
msgs = [_system()] + old_msgs + [current_user, current_tc, current_result]
|
||||
turn_start = 1 + len(old_msgs)
|
||||
|
||||
result = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
|
||||
|
||||
# System and current turn preserved
|
||||
assert result[0] == msgs[0]
|
||||
assert result[-3:] == [current_user, current_tc, current_result]
|
||||
# Old history trimmed
|
||||
trimmed_history = result[1:-3]
|
||||
assert len(trimmed_history) < len(old_msgs)
|
||||
# Token budget respected
|
||||
trimmed_tokens = sum(estimate_message_tokens(m) for m in trimmed_history)
|
||||
assert trimmed_tokens <= 500
|
||||
|
||||
|
||||
def test_trim_preserves_tool_call_boundary():
|
||||
"""Trimming must not leave orphaned tool results."""
|
||||
old = [
|
||||
_user("padding " * 200),
|
||||
_assistant(None, [_tool_call("old_tc1")]),
|
||||
_tool_result("old_tc1", "short result"),
|
||||
_user("recent msg"),
|
||||
_assistant("recent reply"),
|
||||
]
|
||||
current = _user("current")
|
||||
msgs = [_system()] + old + [current]
|
||||
turn_start = 1 + len(old)
|
||||
|
||||
result = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
|
||||
|
||||
# Check no orphaned tool results
|
||||
trimmed_history = result[1:-1]
|
||||
declared_ids = set()
|
||||
for m in trimmed_history:
|
||||
if m.get("role") == "assistant" and m.get("tool_calls"):
|
||||
for tc in m["tool_calls"]:
|
||||
declared_ids.add(tc["id"])
|
||||
for m in trimmed_history:
|
||||
if m.get("role") == "tool":
|
||||
tc_id = m.get("tool_call_id")
|
||||
assert tc_id in declared_ids, f"Orphan tool result: {tc_id}"
|
||||
|
||||
|
||||
def test_extreme_trim_keeps_system_and_current_turn():
|
||||
"""When budget is tiny, only system and current turn remain."""
|
||||
old = [_user("x" * 2000), _assistant("y" * 2000)]
|
||||
current = _user("current")
|
||||
msgs = [_system()] + old + [current]
|
||||
|
||||
result = trim_history_for_budget(msgs, turn_start_index=3, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
|
||||
|
||||
assert result[0] == msgs[0] # system
|
||||
assert result[-1] == current # current turn
|
||||
assert len(result) <= len(msgs)
|
||||
|
||||
|
||||
def test_original_messages_not_mutated():
|
||||
old = [_user("x" * 2000), _assistant("y" * 2000)]
|
||||
current = _user("current")
|
||||
msgs = [_system()] + old + [current]
|
||||
original_len = len(msgs)
|
||||
|
||||
_ = trim_history_for_budget(msgs, turn_start_index=3, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
|
||||
|
||||
assert len(msgs) == original_len
|
||||
|
||||
|
||||
def test_current_turn_never_trimmed():
|
||||
"""All messages at or after turn_start_index must be preserved verbatim."""
|
||||
old = [_user("old"), _assistant("reply")]
|
||||
current_turn = [
|
||||
_user("current user message"),
|
||||
_assistant(None, [_tool_call("tc1")]),
|
||||
_tool_result("tc1", "result"),
|
||||
]
|
||||
msgs = [_system()] + old + current_turn
|
||||
turn_start = 1 + len(old)
|
||||
|
||||
result = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=1, find_legal_start=Session._find_legal_start)
|
||||
|
||||
assert result[-len(current_turn):] == current_turn
|
||||
|
||||
|
||||
def test_iteration_two_first_trim():
|
||||
"""iteration=2 is the first iteration where trimming kicks in."""
|
||||
old = [_user("x" * 2000), _assistant("y" * 2000)]
|
||||
msgs = [_system()] + old + [_user("current")]
|
||||
turn_start = 3
|
||||
|
||||
# iteration=1: no trim
|
||||
r1 = trim_history_for_budget(msgs, turn_start, iteration=1, context_budget_tokens=0, find_legal_start=Session._find_legal_start)
|
||||
assert r1 is msgs
|
||||
|
||||
# iteration=2 with budget=0: no trim (budget is 0)
|
||||
r2 = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=0, find_legal_start=Session._find_legal_start)
|
||||
assert r2 is msgs
|
||||
|
||||
# iteration=2 with positive budget: trim occurs (2000-char msgs ~= 500+ tokens each)
|
||||
r3 = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
|
||||
assert r3 is not msgs
|
||||
Loading…
x
Reference in New Issue
Block a user