fix(agent): preserve interrupted tool-call turns

Keep tool-call assistant messages valid across provider sanitization and avoid trailing user-only history after model errors. This prevents follow-up requests from sending broken tool chains back to the gateway.
This commit is contained in:
Xubin Ren 2026-04-10 05:37:25 +00:00
parent c579d67887
commit 2bef9cb650
6 changed files with 235 additions and 2 deletions

View File

@ -31,6 +31,7 @@ from nanobot.utils.runtime import (
)
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
_PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]"
_MAX_EMPTY_RETRIES = 2
_MAX_LENGTH_RECOVERIES = 3
_SNIP_SAFETY_BUFFER = 1024
@ -105,7 +106,8 @@ class AgentRunner:
# may repair or compact historical messages for the model, but
# those synthetic edits must not shift the append boundary used
# later when the caller saves only the new turn.
messages_for_model = self._backfill_missing_tool_results(messages)
messages_for_model = self._drop_orphan_tool_results(messages)
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
messages_for_model = self._microcompact(messages_for_model)
messages_for_model = self._apply_tool_result_budget(spec, messages_for_model)
messages_for_model = self._snip_history(spec, messages_for_model)
@ -261,6 +263,7 @@ class AgentRunner:
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
stop_reason = "error"
error = final_content
self._append_model_error_placeholder(messages)
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
@ -524,6 +527,12 @@ class AgentRunner:
return
messages.append(build_assistant_message(content))
@staticmethod
def _append_model_error_placeholder(messages: list[dict[str, Any]]) -> None:
if messages and messages[-1].get("role") == "assistant" and not messages[-1].get("tool_calls"):
return
messages.append(build_assistant_message(_PERSISTED_MODEL_ERROR_PLACEHOLDER))
def _normalize_tool_result(
self,
spec: AgentRunSpec,
@ -552,6 +561,32 @@ class AgentRunner:
return truncate_text(content, spec.max_tool_result_chars)
return content
@staticmethod
def _drop_orphan_tool_results(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Drop tool results that have no matching assistant tool_call earlier in the history."""
declared: set[str] = set()
updated: list[dict[str, Any]] | None = None
for idx, msg in enumerate(messages):
role = msg.get("role")
if role == "assistant":
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
if role == "tool":
tid = msg.get("tool_call_id")
if tid and str(tid) not in declared:
if updated is None:
updated = [dict(m) for m in messages[:idx]]
continue
if updated is not None:
updated.append(dict(msg))
if updated is None:
return messages
return updated
@staticmethod
def _backfill_missing_tool_results(
messages: list[dict[str, Any]],

View File

@ -375,6 +375,14 @@ class LLMProvider(ABC):
and role in ("user", "assistant")
):
prev = merged[-1]
if role == "assistant":
prev_has_tools = bool(prev.get("tool_calls"))
curr_has_tools = bool(msg.get("tool_calls"))
if curr_has_tools:
merged[-1] = dict(msg)
continue
if prev_has_tools:
continue
prev_content = prev.get("content") or ""
curr_content = msg.get("content") or ""
if isinstance(prev_content, str) and isinstance(curr_content, str):

View File

@ -243,6 +243,10 @@ class OpenAICompatProvider(LLMProvider):
tc_clean["id"] = map_id(tc_clean.get("id"))
normalized.append(tc_clean)
clean["tool_calls"] = normalized
if clean.get("role") == "assistant":
# Some OpenAI-compatible gateways reject assistant messages
# that mix non-empty content with tool_calls.
clean["content"] = None
if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"])
return self._enforce_role_alternation(sanitized)

View File

@ -859,7 +859,11 @@ async def test_loop_retries_think_only_final_response(tmp_path):
async def test_llm_error_not_appended_to_session_messages():
"""When LLM returns finish_reason='error', the error content must NOT be
appended to the messages list (prevents polluting session history)."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.runner import (
AgentRunSpec,
AgentRunner,
_PERSISTED_MODEL_ERROR_PLACEHOLDER,
)
provider = MagicMock()
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
@ -882,6 +886,7 @@ async def test_llm_error_not_appended_to_session_messages():
assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"]
assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \
"Error content should not appear in session messages"
assert assistant_msgs[-1]["content"] == _PERSISTED_MODEL_ERROR_PLACEHOLDER
@pytest.mark.asyncio
@ -918,6 +923,56 @@ async def test_streamed_flag_not_set_on_llm_error(tmp_path):
"_streamed must not be set when stop_reason is error"
@pytest.mark.asyncio
async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path):
from nanobot.agent.loop import AgentLoop
from nanobot.agent.runner import _PERSISTED_MODEL_ERROR_PLACEHOLDER
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(side_effect=[
LLMResponse(content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}),
LLMResponse(content="Recovered answer", tool_calls=[], usage={}),
])
loop = AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
loop.tools.get_definitions = MagicMock(return_value=[])
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
first = await loop._process_message(
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="first question")
)
assert first is not None
assert first.content == "429 rate limit exceeded"
session = loop.sessions.get_or_create("cli:test")
assert [
{key: value for key, value in message.items() if key in {"role", "content"}}
for message in session.messages
] == [
{"role": "user", "content": "first question"},
{"role": "assistant", "content": _PERSISTED_MODEL_ERROR_PLACEHOLDER},
]
second = await loop._process_message(
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="second question")
)
assert second is not None
assert second.content == "Recovered answer"
request_messages = provider.chat_with_retry.await_args_list[1].kwargs["messages"]
non_system = [message for message in request_messages if message.get("role") != "system"]
assert non_system[0] == {"role": "user", "content": "first question"}
assert non_system[1] == {
"role": "assistant",
"content": _PERSISTED_MODEL_ERROR_PLACEHOLDER,
}
assert non_system[2]["role"] == "user"
assert "second question" in non_system[2]["content"]
@pytest.mark.asyncio
async def test_runner_tool_error_sets_final_content():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
@ -1218,6 +1273,41 @@ async def test_backfill_missing_tool_results_inserts_error():
assert backfilled[0]["name"] == "read_file"
def test_drop_orphan_tool_results_removes_unmatched_tool_messages():
from nanobot.agent.runner import AgentRunner
messages = [
{"role": "system", "content": "system"},
{"role": "user", "content": "old user"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
],
},
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
{"role": "assistant", "content": "after tool"},
]
cleaned = AgentRunner._drop_orphan_tool_results(messages)
assert cleaned == [
{"role": "system", "content": "system"},
{"role": "user", "content": "old user"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
],
},
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
{"role": "assistant", "content": "after tool"},
]
@pytest.mark.asyncio
async def test_backfill_noop_when_complete():
"""Complete message chains should not be modified."""
@ -1239,6 +1329,45 @@ async def test_backfill_noop_when_complete():
assert result is messages # same object — no copy
@pytest.mark.asyncio
async def test_runner_drops_orphan_tool_results_before_model_request():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_messages: list[dict] = []
async def chat_with_retry(*, messages, **kwargs):
captured_messages[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[
{"role": "system", "content": "system"},
{"role": "user", "content": "old user"},
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
{"role": "assistant", "content": "after orphan"},
{"role": "user", "content": "new prompt"},
],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert all(
message.get("tool_call_id") != "call_orphan"
for message in captured_messages
if message.get("role") == "tool"
)
assert result.messages[2]["tool_call_id"] == "call_orphan"
assert result.final_content == "done"
@pytest.mark.asyncio
async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path):
"""Historical backfill should not duplicate old tail messages on persist."""

View File

@ -84,6 +84,34 @@ class TestEnforceRoleAlternation:
tool_msgs = [m for m in result if m["role"] == "tool"]
assert len(tool_msgs) == 2
def test_consecutive_assistant_keeps_later_tool_call_message(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Previous reply"},
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
{"role": "tool", "content": "result1", "tool_call_id": "1"},
{"role": "user", "content": "Next"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert result[1]["role"] == "assistant"
assert result[1]["tool_calls"] == [{"id": "1"}]
assert result[1]["content"] is None
assert result[2]["role"] == "tool"
def test_consecutive_assistant_does_not_overwrite_existing_tool_call_message(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
{"role": "assistant", "content": "Later plain assistant"},
{"role": "tool", "content": "result1", "tool_call_id": "1"},
{"role": "user", "content": "Next"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert result[1]["role"] == "assistant"
assert result[1]["tool_calls"] == [{"id": "1"}]
assert result[1]["content"] is None
assert result[2]["role"] == "tool"
def test_non_string_content_uses_latest(self):
msgs = [
{"role": "user", "content": [{"type": "text", "text": "A"}]},

View File

@ -550,11 +550,40 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
{"role": "user", "content": "thanks"},
])
assert sanitized[1]["content"] is None
assert sanitized[1]["reasoning_content"] == "hidden"
assert sanitized[1]["extra_content"] == {"debug": True}
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
sanitized = provider._sanitize_messages([
{"role": "user", "content": "不错"},
{"role": "assistant", "content": "对,破 4 万指日可待"},
{
"role": "assistant",
"content": "<think>我再查一下</think>",
"tool_calls": [
{
"id": "call_function_akxp3wqzn7ph_1",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
}
],
},
{"role": "tool", "tool_call_id": "call_function_akxp3wqzn7ph_1", "name": "exec", "content": "ok"},
{"role": "user", "content": "多少star了呢"},
])
assert sanitized[1]["role"] == "assistant"
assert sanitized[1]["content"] is None
assert sanitized[1]["tool_calls"][0]["id"] == "3ec83c30d"
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
@pytest.mark.asyncio
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")