mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
fix(runtime): address review feedback on retry and cleanup
This commit is contained in:
parent
eefd7e60f2
commit
714a4c7bb6
@ -72,6 +72,7 @@ class LLMProvider(ABC):
|
||||
|
||||
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||
_PERSISTENT_MAX_DELAY = 60
|
||||
_PERSISTENT_IDENTICAL_ERROR_LIMIT = 10
|
||||
_RETRY_HEARTBEAT_CHUNK = 30
|
||||
_TRANSIENT_ERROR_MARKERS = (
|
||||
"429",
|
||||
@ -377,12 +378,20 @@ class LLMProvider(ABC):
|
||||
delays = list(self._CHAT_RETRY_DELAYS)
|
||||
persistent = retry_mode == "persistent"
|
||||
last_response: LLMResponse | None = None
|
||||
last_error_key: str | None = None
|
||||
identical_error_count = 0
|
||||
while True:
|
||||
attempt += 1
|
||||
response = await call(**kw)
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
last_response = response
|
||||
error_key = ((response.content or "").strip().lower() or None)
|
||||
if error_key and error_key == last_error_key:
|
||||
identical_error_count += 1
|
||||
else:
|
||||
last_error_key = error_key
|
||||
identical_error_count = 1 if error_key else 0
|
||||
|
||||
if not self._is_transient_error(response.content):
|
||||
stripped = self._strip_image_content(original_messages)
|
||||
@ -395,6 +404,14 @@ class LLMProvider(ABC):
|
||||
return await call(**retry_kw)
|
||||
return response
|
||||
|
||||
if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT:
|
||||
logger.warning(
|
||||
"Stopping persistent retry after {} identical transient errors: {}",
|
||||
identical_error_count,
|
||||
(response.content or "")[:120].lower(),
|
||||
)
|
||||
return response
|
||||
|
||||
if not persistent and attempt > len(delays):
|
||||
break
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def strip_think(text: str) -> str:
|
||||
@ -214,8 +215,8 @@ def maybe_persist_tool_result(
|
||||
bucket = ensure_dir(root / safe_filename(session_key or "default"))
|
||||
try:
|
||||
_cleanup_tool_result_buckets(root, bucket)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to clean stale tool result buckets in {}: {}", root, exc)
|
||||
path = bucket / f"{safe_filename(tool_call_id)}.{suffix}"
|
||||
if not path.exists():
|
||||
if suffix == "json" and isinstance(content, list):
|
||||
|
||||
@ -359,6 +359,32 @@ def test_persist_tool_result_leaves_no_temp_files(tmp_path):
|
||||
assert list((root / "current_session").glob("*.tmp")) == []
|
||||
|
||||
|
||||
def test_persist_tool_result_logs_cleanup_failures(monkeypatch, tmp_path):
|
||||
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||
|
||||
warnings: list[str] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.helpers._cleanup_tool_result_buckets",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError("busy")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.helpers.logger.warning",
|
||||
lambda message, *args: warnings.append(message.format(*args)),
|
||||
)
|
||||
|
||||
persisted = maybe_persist_tool_result(
|
||||
tmp_path,
|
||||
"current:session",
|
||||
"call_big",
|
||||
"x" * 5000,
|
||||
max_chars=64,
|
||||
)
|
||||
|
||||
assert "[tool output persisted]" in persisted
|
||||
assert warnings and "Failed to clean stale tool result buckets" in warnings[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_uses_raw_messages_when_context_governance_fails():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
@ -392,6 +418,55 @@ async def test_runner_uses_raw_messages_when_context_governance_fails():
|
||||
assert captured_messages == initial_messages
|
||||
|
||||
|
||||
def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch):
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
runner = AgentRunner(provider)
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "tool call",
|
||||
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "ls", "arguments": "{}"}}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "tool output"},
|
||||
{"role": "assistant", "content": "after tool"},
|
||||
]
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
context_window_tokens=2000,
|
||||
context_block_limit=100,
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_args, **_kwargs: (500, None))
|
||||
token_sizes = {
|
||||
"old user": 120,
|
||||
"tool call": 120,
|
||||
"tool output": 40,
|
||||
"after tool": 40,
|
||||
"system": 0,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runner.estimate_message_tokens",
|
||||
lambda msg: token_sizes.get(str(msg.get("content")), 40),
|
||||
)
|
||||
|
||||
trimmed = runner._snip_history(spec, messages)
|
||||
|
||||
assert trimmed == [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "assistant", "content": "after tool"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_keeps_going_when_tool_result_persistence_fails():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
@ -614,6 +689,7 @@ async def test_runner_accumulates_usage_and_preserves_cached_tokens():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
# Usage should be accumulated across iterations
|
||||
@ -652,6 +728,7 @@ async def test_runner_passes_cached_tokens_to_hook_context():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=UsageHook(),
|
||||
))
|
||||
|
||||
|
||||
@ -240,3 +240,27 @@ async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypa
|
||||
assert progress and "7s" in progress[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
*[LLMResponse(content="429 rate limit", finish_reason="error") for _ in range(10)],
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[float] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
retry_mode="persistent",
|
||||
)
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert response.content == "429 rate limit"
|
||||
assert provider.calls == 10
|
||||
assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4]
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user