feat: configurable context budget for tool-loop iterations (#2317)

* feat: add contextBudgetTokens config field for tool-loop trimming

* feat: implement _trim_history_for_budget for tool-loop cost reduction

* feat: thread contextBudgetTokens into AgentLoop constructor

* feat: wire context budget trimming into agent loop

* refactor: move trim_history_for_budget to helpers and add docs

- Extract trim_history_for_budget() as a pure function in helpers.py
- AgentLoop._trim_history_for_budget becomes a thin wrapper
- Add docs/CONTEXT_BUDGET.md with usage guide and trade-off notes
- Replace wrapper tests with direct helper unit tests

---------

Co-authored-by: chengyongru <chengyongru.ai@gmail.com>
This commit is contained in:
Jesse 2026-03-23 06:13:03 -04:00 committed by GitHub
parent 0182ce2852
commit 528b3cfe5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 319 additions and 5 deletions

62
docs/CONTEXT_BUDGET.md Normal file
View File

@ -0,0 +1,62 @@
# Context Budget (`context_budget_tokens`)
Caps how many tokens of old session history are sent to the LLM during tool-loop iterations 2+. Reduces cost and first-token latency by trimming history between turns.
## How It Works
During multi-turn tool-use sessions, each iteration re-sends the full conversation history. `context_budget_tokens` limits how many old tokens are included:
- **Iteration 1** — always receives full context (no trimming)
- **Iteration 2+** — old history is trimmed to fit within the budget; current turn is never trimmed
- **Memory consolidation** — runs before/after the loop and always sees the full canonical history; trimming only affects the LLM's view
## Configuration
```json
{
"agents": {
"defaults": {
"context_budget_tokens": 1000
}
}
}
```
| Value | Behavior |
|---|---|
---
`0` (default) | No trimming — full history sent every iteration
`4000` | Conservative — barely trims in practice; good for multi-step tasks
`1000` | Aggressive — significant savings; works well for typical linear tasks
`< 500` | Clamped to `500` minimum when positive (12 message pairs at typical token density)
## Trade-offs
**Cost & latency** — Trimming reduces tokens sent each iteration, which saves money and lowers first-token time (TTFT). This is nanobot's primary sweet spot.
**Context loss** — Older context is not visible to the LLM in later iterations. For tasks that genuinely require 20+ iterations of history to stay coherent, consider `0` or `4000`.
**Tool-result truncation** — Large results from a previous turn (e.g., reading a 10,000-line file in Round 1, then editing in Round 2) can be trimmed. The agent can re-read the file via its tools — this is a 1-tool-call recovery cost, not a failure.
**Prefix caching** — Some providers (e.g., DeepSeek) use implicit prefix-based caching. Aggressive trimming breaks prefix matching and can reduce cache hit rates. For these providers, `0` or a high value may be more cost-effective overall.
## When to Use
| Use case | Recommended value |
|---|---|
| Simple read → process → act chains | `1000` |
| Multi-step reasoning with tool chains | `4000` |
| Complex debugging / long task traces | `0` |
| Providers with implicit prefix caching | `0` or `4000` |
| Long file operations across turns | `0` or re-read via tools |
## Example
```
Turn 1: User asks to read a.py (10k lines)
Turn 2: User asks to edit line 100
```
With `context_budget_tokens=500`, the file-content result from Turn 1 may be trimmed before Turn 2. The agent will re-read the file to perform the edit — a 1-call recovery. This is normal behavior for the feature; it is not a bug.

View File

@ -27,7 +27,7 @@ from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.spawn import SpawnTool from nanobot.agent.tools.spawn import SpawnTool
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.utils.helpers import build_status_content from nanobot.utils.helpers import build_status_content, trim_history_for_budget
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session, SessionManager from nanobot.session.manager import Session, SessionManager
@ -59,6 +59,7 @@ class AgentLoop:
model: str | None = None, model: str | None = None,
max_iterations: int = 40, max_iterations: int = 40,
context_window_tokens: int = 65_536, context_window_tokens: int = 65_536,
context_budget_tokens: int = 0,
web_search_config: WebSearchConfig | None = None, web_search_config: WebSearchConfig | None = None,
web_proxy: str | None = None, web_proxy: str | None = None,
exec_config: ExecToolConfig | None = None, exec_config: ExecToolConfig | None = None,
@ -78,6 +79,7 @@ class AgentLoop:
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.context_window_tokens = context_window_tokens self.context_window_tokens = context_window_tokens
self.context_budget_tokens = max(context_budget_tokens, 500) if context_budget_tokens > 0 else 0
self.web_search_config = web_search_config or WebSearchConfig() self.web_search_config = web_search_config or WebSearchConfig()
self.web_proxy = web_proxy self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig() self.exec_config = exec_config or ExecToolConfig()
@ -232,6 +234,22 @@ class AgentLoop:
metadata={"render_as": "text"}, metadata={"render_as": "text"},
) )
def _trim_history_for_budget(
self,
messages: list[dict],
turn_start_index: int,
iteration: int,
) -> list[dict]:
"""Thin wrapper: delegates to trim_history_for_budget helper."""
return trim_history_for_budget(
messages,
turn_start_index,
iteration,
self.context_budget_tokens,
Session._find_legal_start,
)
async def _run_agent_loop( async def _run_agent_loop(
self, self,
initial_messages: list[dict], initial_messages: list[dict],
@ -250,6 +268,7 @@ class AgentLoop:
iteration = 0 iteration = 0
final_content = None final_content = None
tools_used: list[str] = [] tools_used: list[str] = []
turn_start_index = len(initial_messages) - 1
# Wrap on_stream with stateful think-tag filter so downstream # Wrap on_stream with stateful think-tag filter so downstream
# consumers (CLI, channels) never see <think> blocks. # consumers (CLI, channels) never see <think> blocks.
@ -271,20 +290,23 @@ class AgentLoop:
tool_defs = self.tools.get_definitions() tool_defs = self.tools.get_definitions()
send_messages = self._trim_history_for_budget(
messages, turn_start_index, iteration,
)
if on_stream: if on_stream:
response = await self.provider.chat_stream_with_retry( response = await self.provider.chat_stream_with_retry(
messages=messages, messages=send_messages,
tools=tool_defs, tools=tool_defs,
model=self.model, model=self.model,
on_content_delta=_filtered_stream, on_content_delta=_filtered_stream,
) )
else: else:
response = await self.provider.chat_with_retry( response = await self.provider.chat_with_retry(
messages=messages, messages=send_messages,
tools=tool_defs, tools=tool_defs,
model=self.model, model=self.model,
) )
usage = response.usage or {} usage = response.usage or {}
self._last_usage = { self._last_usage = {
"prompt_tokens": int(usage.get("prompt_tokens", 0) or 0), "prompt_tokens": int(usage.get("prompt_tokens", 0) or 0),

View File

@ -527,6 +527,7 @@ def gateway(
model=config.agents.defaults.model, model=config.agents.defaults.model,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
context_window_tokens=config.agents.defaults.context_window_tokens, context_window_tokens=config.agents.defaults.context_window_tokens,
context_budget_tokens=config.agents.defaults.context_budget_tokens,
web_search_config=config.tools.web.search, web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None, web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,
@ -722,6 +723,7 @@ def agent(
model=config.agents.defaults.model, model=config.agents.defaults.model,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
context_window_tokens=config.agents.defaults.context_window_tokens, context_window_tokens=config.agents.defaults.context_window_tokens,
context_budget_tokens=config.agents.defaults.context_budget_tokens,
web_search_config=config.tools.web.search, web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None, web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,

View File

@ -39,6 +39,7 @@ class AgentDefaults(Base):
context_window_tokens: int = 65_536 context_window_tokens: int = 65_536
temperature: float = 0.1 temperature: float = 0.1
max_tool_iterations: int = 40 max_tool_iterations: int = 40
context_budget_tokens: int = 0 # Max old-history tokens during tool iterations (0 = no trim)
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode

View File

@ -6,9 +6,10 @@ import re
import time import time
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Callable
import tiktoken import tiktoken
from loguru import logger
def strip_think(text: str) -> str: def strip_think(text: str) -> str:
@ -201,6 +202,58 @@ def estimate_message_tokens(message: dict[str, Any]) -> int:
return max(4, len(payload) // 4 + 4) return max(4, len(payload) // 4 + 4)
def trim_history_for_budget(
messages: list[dict[str, Any]],
turn_start_index: int,
iteration: int,
context_budget_tokens: int,
find_legal_start: Callable[[list[dict[str, Any]]], int],
) -> list[dict[str, Any]]:
"""Trim old session history to fit within context_budget_tokens.
Returns the original list unchanged when no trimming is needed.
Only trims on iteration >= 2 when context_budget_tokens > 0.
Current-turn messages (from turn_start_index onward) are never trimmed.
"""
if context_budget_tokens <= 0 or iteration <= 1:
return messages
if turn_start_index <= 1:
return messages # no old history to trim
system = messages[:1]
old_history = messages[1:turn_start_index]
current_turn = messages[turn_start_index:]
# Pre-compute token counts to avoid double-estimation
token_counts = [estimate_message_tokens(m) for m in old_history]
total = sum(token_counts)
if total <= context_budget_tokens:
return messages # fits, no trim needed
# Find cut index (O(n) scan, then single slice)
cut = 0
removed_tokens = 0
while cut < len(old_history) and total > context_budget_tokens:
removed_tokens += token_counts[cut]
total -= token_counts[cut]
cut += 1
old_history = old_history[cut:]
# Fix orphaned tool results after trimming
legal_start = find_legal_start(old_history)
if legal_start > 0:
old_history = old_history[legal_start:]
removed_count = turn_start_index - 1 - len(old_history)
if removed_count > 0:
logger.debug(
"Context budget: trimmed {} history messages ({} tokens) for iteration {}",
removed_count, removed_tokens, iteration,
)
return system + old_history + current_turn
def estimate_prompt_tokens_chain( def estimate_prompt_tokens_chain(
provider: Any, provider: Any,
model: str | None, model: str | None,

View File

@ -0,0 +1,174 @@
"""Direct unit tests for trim_history_for_budget() helper."""
import pytest
from nanobot.session.manager import Session
from nanobot.utils.helpers import estimate_message_tokens, trim_history_for_budget
def _msg(role: str, content: str, **kw) -> dict:
return {"role": role, "content": content, **kw}
def _system(content: str = "You are a bot.") -> dict:
return _msg("system", content)
def _user(content: str) -> dict:
return _msg("user", content)
def _assistant(content: str | None = None, tool_calls: list | None = None) -> dict:
m = {"role": "assistant", "content": content}
if tool_calls:
m["tool_calls"] = tool_calls
return m
def _tool_call(tc_id: str, name: str = "exec", args: str = "{}") -> dict:
return {"id": tc_id, "type": "function", "function": {"name": name, "arguments": args}}
def _tool_result(tc_id: str, content: str = "ok") -> dict:
return {"role": "tool", "tool_call_id": tc_id, "name": "exec", "content": content}
# --- Early-exit cases ---
def test_budget_zero_returns_same_list():
msgs = [_system(), _user("old1"), _assistant("old reply"), _user("current")]
result = trim_history_for_budget(msgs, turn_start_index=3, iteration=2, context_budget_tokens=0, find_legal_start=Session._find_legal_start)
assert result is msgs
def test_iteration_one_never_trims():
msgs = [_system(), _user("old1"), _assistant("old reply"), _user("current")]
result = trim_history_for_budget(msgs, turn_start_index=3, iteration=1, context_budget_tokens=1000, find_legal_start=Session._find_legal_start)
assert result is msgs
def test_turn_start_at_one_returns_same():
"""turn_start_index=1 means no old history before the current turn."""
msgs = [_system(), _user("current")]
result = trim_history_for_budget(msgs, turn_start_index=1, iteration=2, context_budget_tokens=0, find_legal_start=Session._find_legal_start)
assert result is msgs
def test_history_under_budget_returns_unchanged():
msgs = [_system(), _user("short msg"), _assistant("short reply"), _user("current")]
result = trim_history_for_budget(msgs, turn_start_index=3, iteration=2, context_budget_tokens=50000, find_legal_start=Session._find_legal_start)
assert result is msgs
# --- Trimming cases ---
def test_trim_removes_oldest_messages():
old_msgs = []
for i in range(40):
old_msgs.append(_user(f"old message number {i} padding extra text here"))
old_msgs.append(_assistant(f"reply to message {i} with more padding"))
current_user = _user("current task")
current_tc = _assistant(None, [_tool_call("tc1")])
current_result = _tool_result("tc1", "done")
msgs = [_system()] + old_msgs + [current_user, current_tc, current_result]
turn_start = 1 + len(old_msgs)
result = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
# System and current turn preserved
assert result[0] == msgs[0]
assert result[-3:] == [current_user, current_tc, current_result]
# Old history trimmed
trimmed_history = result[1:-3]
assert len(trimmed_history) < len(old_msgs)
# Token budget respected
trimmed_tokens = sum(estimate_message_tokens(m) for m in trimmed_history)
assert trimmed_tokens <= 500
def test_trim_preserves_tool_call_boundary():
"""Trimming must not leave orphaned tool results."""
old = [
_user("padding " * 200),
_assistant(None, [_tool_call("old_tc1")]),
_tool_result("old_tc1", "short result"),
_user("recent msg"),
_assistant("recent reply"),
]
current = _user("current")
msgs = [_system()] + old + [current]
turn_start = 1 + len(old)
result = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
# Check no orphaned tool results
trimmed_history = result[1:-1]
declared_ids = set()
for m in trimmed_history:
if m.get("role") == "assistant" and m.get("tool_calls"):
for tc in m["tool_calls"]:
declared_ids.add(tc["id"])
for m in trimmed_history:
if m.get("role") == "tool":
tc_id = m.get("tool_call_id")
assert tc_id in declared_ids, f"Orphan tool result: {tc_id}"
def test_extreme_trim_keeps_system_and_current_turn():
"""When budget is tiny, only system and current turn remain."""
old = [_user("x" * 2000), _assistant("y" * 2000)]
current = _user("current")
msgs = [_system()] + old + [current]
result = trim_history_for_budget(msgs, turn_start_index=3, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
assert result[0] == msgs[0] # system
assert result[-1] == current # current turn
assert len(result) <= len(msgs)
def test_original_messages_not_mutated():
old = [_user("x" * 2000), _assistant("y" * 2000)]
current = _user("current")
msgs = [_system()] + old + [current]
original_len = len(msgs)
_ = trim_history_for_budget(msgs, turn_start_index=3, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
assert len(msgs) == original_len
def test_current_turn_never_trimmed():
"""All messages at or after turn_start_index must be preserved verbatim."""
old = [_user("old"), _assistant("reply")]
current_turn = [
_user("current user message"),
_assistant(None, [_tool_call("tc1")]),
_tool_result("tc1", "result"),
]
msgs = [_system()] + old + current_turn
turn_start = 1 + len(old)
result = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=1, find_legal_start=Session._find_legal_start)
assert result[-len(current_turn):] == current_turn
def test_iteration_two_first_trim():
"""iteration=2 is the first iteration where trimming kicks in."""
old = [_user("x" * 2000), _assistant("y" * 2000)]
msgs = [_system()] + old + [_user("current")]
turn_start = 3
# iteration=1: no trim
r1 = trim_history_for_budget(msgs, turn_start, iteration=1, context_budget_tokens=0, find_legal_start=Session._find_legal_start)
assert r1 is msgs
# iteration=2 with budget=0: no trim (budget is 0)
r2 = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=0, find_legal_start=Session._find_legal_start)
assert r2 is msgs
# iteration=2 with positive budget: trim occurs (2000-char msgs ~= 500+ tokens each)
r3 = trim_history_for_budget(msgs, turn_start, iteration=2, context_budget_tokens=500, find_legal_start=Session._find_legal_start)
assert r3 is not msgs