feat: configurable context budget for tool-loop iterations (#2317)

* feat: add contextBudgetTokens config field for tool-loop trimming

* feat: implement _trim_history_for_budget for tool-loop cost reduction

* feat: thread contextBudgetTokens into AgentLoop constructor

* feat: wire context budget trimming into agent loop

* refactor: move trim_history_for_budget to helpers and add docs

- Extract trim_history_for_budget() as a pure function in helpers.py
- AgentLoop._trim_history_for_budget becomes a thin wrapper
- Add docs/CONTEXT_BUDGET.md with usage guide and trade-off notes
- Replace wrapper tests with direct helper unit tests

---------

Co-authored-by: chengyongru <chengyongru.ai@gmail.com>
This commit is contained in:
Jesse 2026-03-23 06:13:03 -04:00 committed by GitHub
parent 0182ce2852
commit 528b3cfe5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 319 additions and 5 deletions

62
docs/CONTEXT_BUDGET.md Normal file
View File

@ -0,0 +1,62 @@
# Context Budget (`context_budget_tokens`)
Caps how many tokens of old session history are sent to the LLM during tool-loop iterations 2+. Reduces cost and first-token latency by trimming history between turns.
## How It Works
During multi-turn tool-use sessions, each iteration re-sends the full conversation history. `context_budget_tokens` limits how many old tokens are included:
- **Iteration 1** — always receives full context (no trimming)
- **Iteration 2+** — old history is trimmed to fit within the budget; current turn is never trimmed
- **Memory consolidation** — runs before/after the loop and always sees the full canonical history; trimming only affects the LLM's view
## Configuration
```json
{
"agents": {
"defaults": {
"context_budget_tokens": 1000
}
}
}
```
| Value | Behavior |
|---|---|
---
`0` (default) | No trimming — full history sent every iteration
`4000` | Conservative — barely trims in practice; good for multi-step tasks
`1000` | Aggressive — significant savings; works well for typical linear tasks
`< 500` | Clamped to `500` minimum when positive (12 message pairs at typical token density)
## Trade-offs
**Cost & latency** — Trimming reduces tokens sent each iteration, which saves money and lowers first-token time (TTFT). This is nanobot's primary sweet spot.
**Context loss** — Older context is not visible to the LLM in later iterations. For tasks that genuinely require 20+ iterations of history to stay coherent, consider `0` or `4000`.
**Tool-result truncation** — Large results from a previous turn (e.g., reading a 10,000-line file in Round 1, then editing in Round 2) can be trimmed. The agent can re-read the file via its tools — this is a 1-tool-call recovery cost, not a failure.
**Prefix caching** — Some providers (e.g., DeepSeek) use implicit prefix-based caching. Aggressive trimming breaks prefix matching and can reduce cache hit rates. For these providers, `0` or a high value may be more cost-effective overall.
## When to Use
| Use case | Recommended value |
|---|---|
| Simple read → process → act chains | `1000` |
| Multi-step reasoning with tool chains | `4000` |
| Complex debugging / long task traces | `0` |
| Providers with implicit prefix caching | `0` or `4000` |
| Long file operations across turns | `0` or re-read via tools |
## Example
```
Turn 1: User asks to read a.py (10k lines)
Turn 2: User asks to edit line 100
```
With `context_budget_tokens=500`, the file-content result from Turn 1 may be trimmed before Turn 2. The agent will re-read the file to perform the edit — a 1-call recovery. This is normal behavior for the feature; it is not a bug.

View File

@ -27,7 +27,7 @@ from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.spawn import SpawnTool
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.utils.helpers import build_status_content
from nanobot.utils.helpers import build_status_content, trim_history_for_budget
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session, SessionManager
@ -59,6 +59,7 @@ class AgentLoop:
model: str | None = None,
max_iterations: int = 40,
context_window_tokens: int = 65_536,
context_budget_tokens: int = 0,
web_search_config: WebSearchConfig | None = None,
web_proxy: str | None = None,
exec_config: ExecToolConfig | None = None,
@ -78,6 +79,7 @@ class AgentLoop:
self.model = model or provider.get_default_model()
self.max_iterations = max_iterations
self.context_window_tokens = context_window_tokens
self.context_budget_tokens = max(context_budget_tokens, 500) if context_budget_tokens > 0 else 0
self.web_search_config = web_search_config or WebSearchConfig()
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig()
@ -232,6 +234,22 @@ class AgentLoop:
metadata={"render_as": "text"},
)
def _trim_history_for_budget(
self,
messages: list[dict],
turn_start_index: int,
iteration: int,
) -> list[dict]:
"""Thin wrapper: delegates to trim_history_for_budget helper."""
return trim_history_for_budget(
messages,
turn_start_index,
iteration,
self.context_budget_tokens,
Session._find_legal_start,
)
async def _run_agent_loop(
self,
initial_messages: list[dict],
@ -250,6 +268,7 @@ class AgentLoop:
iteration = 0
final_content = None
tools_used: list[str] = []
turn_start_index = len(initial_messages) - 1
# Wrap on_stream with stateful think-tag filter so downstream
# consumers (CLI, channels) never see <think> blocks.
@ -271,20 +290,23 @@ class AgentLoop:
tool_defs = self.tools.get_definitions()
send_messages = self._trim_history_for_budget(
messages, turn_start_index, iteration,
)
if on_stream:
response = await self.provider.chat_stream_with_retry(
messages=messages,
messages=send_messages,
tools=tool_defs,
model=self.model,
on_content_delta=_filtered_stream,
)
else:
response = await self.provider.chat_with_retry(
messages=messages,
messages=send_messages,
tools=tool_defs,
model=self.model,
)
usage = response.usage or {}
self._last_usage = {
"prompt_tokens": int(usage.get("prompt_tokens", 0) or 0),

View File

@ -527,6 +527,7 @@ def gateway(
model=config.agents.defaults.model,
max_iterations=config.agents.defaults.max_tool_iterations,
context_window_tokens=config.agents.defaults.context_window_tokens,
context_budget_tokens=config.agents.defaults.context_budget_tokens,
web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,
@ -722,6 +723,7 @@ def agent(
model=config.agents.defaults.model,
max_iterations=config.agents.defaults.max_tool_iterations,
context_window_tokens=config.agents.defaults.context_window_tokens,
context_budget_tokens=config.agents.defaults.context_budget_tokens,
web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,

View File

@ -39,6 +39,7 @@ class AgentDefaults(Base):
context_window_tokens: int = 65_536
temperature: float = 0.1
max_tool_iterations: int = 40
context_budget_tokens: int = 0 # Max old-history tokens during tool iterations (0 = no trim)
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode

View File

@ -6,9 +6,10 @@ import re
import time
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Any, Callable
import tiktoken
from loguru import logger
def strip_think(text: str) -> str:
@ -201,6 +202,58 @@ def estimate_message_tokens(message: dict[str, Any]) -> int:
return max(4, len(payload) // 4 + 4)
def trim_history_for_budget(
messages: list[dict[str, Any]],
turn_start_index: int,
iteration: int,
context_budget_tokens: int,
find_legal_start: Callable[[list[dict[str, Any]]], int],
) -> list[dict[str, Any]]:
"""Trim old session history to fit within context_budget_tokens.
Returns the original list unchanged when no trimming is needed.
Only trims on iteration >= 2 when context_budget_tokens > 0.
Current-turn messages (from turn_start_index onward) are never trimmed.
"""
if context_budget_tokens <= 0 or iteration <= 1:
return messages
if turn_start_index <= 1:
return messages # no old history to trim
system = messages[:1]
old_history = messages[1:turn_start_index]
current_turn = messages[turn_start_index:]
# Pre-compute token counts to avoid double-estimation
token_counts = [estimate_message_tokens(m) for m in old_history]
total = sum(token_counts)
if total <= context_budget_tokens:
return messages # fits, no trim needed
# Find cut index (O(n) scan, then single slice)
cut = 0
removed_tokens = 0
while cut < len(old_history) and total > context_budget_tokens:
removed_tokens += token_counts[cut]
total -= token_counts[cut]
cut += 1
old_history = old_history[cut:]
# Fix orphaned tool results after trimming
legal_start = find_legal_start(old_history)
if legal_start > 0:
old_history = old_history[legal_start:]
removed_count = turn_start_index - 1 - len(old_history)
if removed_count > 0:
logger.debug(
"Context budget: trimmed {} history messages ({} tokens) for iteration {}",
removed_count, removed_tokens, iteration,
)
return system + old_history + current_turn
def estimate_prompt_tokens_chain(
provider: Any,
model: str | None,

View File

@ -0,0 +1,174 @@
"""Direct unit tests for trim_history_for_budget() helper."""
import pytest
from nanobot.session.manager import Session
from nanobot.utils.helpers import estimate_message_tokens, trim_history_for_budget
def _msg(role: str, content: str, **kw) -> dict:
return {"role": role, "content": content, **kw}
def _system(content: str = "You are a bot.") -> dict:
return _msg("system", content)
def _user(content: str) -> dict:
return _msg("user", content)
def _assistant(content: str | None = None, tool_calls: list | None = None) -> dict:
m = {"role": "assistant", "content": content}
if tool_calls:
m["tool_calls"] = tool_calls
return m
def _tool_call(tc_id: str, name: str = "exec", args: str = "{}") -> dict:
return {"id": tc_id, "type": "function", "function": {"name": name, "arguments": args}}
def _tool_result(tc_id: str, content: str = "ok") -> dict:
return {"role": "tool", "tool_call_id": tc_id, "name": "exec", "content": content}
# --- Early-exit cases ---
def test_budget_zero_returns_same_list():
msgs = [_system(), _user("old1"), _assistant("old reply"), _user("current")]
result = trim_history_for_budget(msgs, turn_start_index=3, iteration=2, context_budget_tokens=0, find_legal_start=Session._find_legal_start)
assert result is msgs
def test_iteration_one_never_trims():
msgs = [_system(), _user("old1"), _assistant("old reply"), _user("current")]
result = trim_history_for_budget(msgs, turn_start_index=3, iteration=1, context_budget_tokens=1000, find_legal_start=Session._find_legal_start)
assert result is msgs
def test_turn_start_at_one_returns_same():
"""turn_start_index=1 means no old history before the current turn."""
msgs = [_system(), _user("current")]
result = trim_history_for_budget(msgs, turn_start_index=1, iteration=2, context_budget_tokens=0, find_legal_start=Session._find_legal_start)
assert result is msgs
def test_history_under_budget_returns_unchanged():
msgs = [_system(), _user("short msg"), _assistant("short reply"), _user("current")]
result = trim_history_for_budget(msgs, turn_start_index=3, iteration=2, context_budget_tokens=50000, find_legal_start=Session._find_legal_start)
assert result is msgs
# --- Trimming cases ---
def test_trim_removes_oldest_messages():
old_msgs = []
for i in range(40):
old_msgs.append(_user(f"old message number {i} padding extra text here"))
old_msgs.append(_assistant(f"reply to message {i} with more padding"))
current_user = _user("current task")
current_tc = _assistant(None, [_tool_call("tc1")])
current_result = _tool_result("tc1", "done")
msgs = [_system()] + old_msgs + [current_user, current_tc, current_result]
turn_start = 1 + len(old_msgs)
result = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
# System and current turn preserved
assert result[0] == msgs[0]
assert result[-3:] == [current_user, current_tc, current_result]
# Old history trimmed
trimmed_history = result[1:-3]
assert len(trimmed_history) < len(old_msgs)
# Token budget respected
trimmed_tokens = sum(estimate_message_tokens(m) for m in trimmed_history)
assert trimmed_tokens <= 500
def test_trim_preserves_tool_call_boundary():
"""Trimming must not leave orphaned tool results."""
old = [
_user("padding " * 200),
_assistant(None, [_tool_call("old_tc1")]),
_tool_result("old_tc1", "short result"),
_user("recent msg"),
_assistant("recent reply"),
]
current = _user("current")
msgs = [_system()] + old + [current]
turn_start = 1 + len(old)
result = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
# Check no orphaned tool results
trimmed_history = result[1:-1]
declared_ids = set()
for m in trimmed_history:
if m.get("role") == "assistant" and m.get("tool_calls"):
for tc in m["tool_calls"]:
declared_ids.add(tc["id"])
for m in trimmed_history:
if m.get("role") == "tool":
tc_id = m.get("tool_call_id")
assert tc_id in declared_ids, f"Orphan tool result: {tc_id}"
def test_extreme_trim_keeps_system_and_current_turn():
"""When budget is tiny, only system and current turn remain."""
old = [_user("x" * 2000), _assistant("y" * 2000)]
current = _user("current")
msgs = [_system()] + old + [current]
result = trim_history_for_budget(msgs, turn_start_index=3, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
assert result[0] == msgs[0] # system
assert result[-1] == current # current turn
assert len(result) <= len(msgs)
def test_original_messages_not_mutated():
old = [_user("x" * 2000), _assistant("y" * 2000)]
current = _user("current")
msgs = [_system()] + old + [current]
original_len = len(msgs)
_ = trim_history_for_budget(msgs, turn_start_index=3, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
assert len(msgs) == original_len
def test_current_turn_never_trimmed():
"""All messages at or after turn_start_index must be preserved verbatim."""
old = [_user("old"), _assistant("reply")]
current_turn = [
_user("current user message"),
_assistant(None, [_tool_call("tc1")]),
_tool_result("tc1", "result"),
]
msgs = [_system()] + old + current_turn
turn_start = 1 + len(old)
result = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=1, find_legal_start=Session._find_legal_start)
assert result[-len(current_turn):] == current_turn
def test_iteration_two_first_trim():
"""iteration=2 is the first iteration where trimming kicks in."""
old = [_user("x" * 2000), _assistant("y" * 2000)]
msgs = [_system()] + old + [_user("current")]
turn_start = 3
# iteration=1: no trim
r1 = trim_history_for_budget(msgs, turn_start, iteration=1, context_budget_tokens=0, find_legal_start=Session._find_legal_start)
assert r1 is msgs
# iteration=2 with budget=0: no trim (budget is 0)
r2 = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=0, find_legal_start=Session._find_legal_start)
assert r2 is msgs
# iteration=2 with positive budget: trim occurs (2000-char msgs ~= 500+ tokens each)
r3 = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
assert r3 is not msgs