fix(agent): preserve interrupted tool-call turns

Keep tool-call assistant messages valid across provider sanitization and avoid trailing user-only history after model errors. This prevents follow-up requests from sending broken tool chains back to the gateway.
This commit is contained in:
Xubin Ren 2026-04-10 05:37:25 +00:00
parent c579d67887
commit 2bef9cb650
6 changed files with 235 additions and 2 deletions

View File

@ -31,6 +31,7 @@ from nanobot.utils.runtime import (
) )
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model." _DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
_PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]"
_MAX_EMPTY_RETRIES = 2 _MAX_EMPTY_RETRIES = 2
_MAX_LENGTH_RECOVERIES = 3 _MAX_LENGTH_RECOVERIES = 3
_SNIP_SAFETY_BUFFER = 1024 _SNIP_SAFETY_BUFFER = 1024
@ -105,7 +106,8 @@ class AgentRunner:
# may repair or compact historical messages for the model, but # may repair or compact historical messages for the model, but
# those synthetic edits must not shift the append boundary used # those synthetic edits must not shift the append boundary used
# later when the caller saves only the new turn. # later when the caller saves only the new turn.
messages_for_model = self._backfill_missing_tool_results(messages) messages_for_model = self._drop_orphan_tool_results(messages)
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
messages_for_model = self._microcompact(messages_for_model) messages_for_model = self._microcompact(messages_for_model)
messages_for_model = self._apply_tool_result_budget(spec, messages_for_model) messages_for_model = self._apply_tool_result_budget(spec, messages_for_model)
messages_for_model = self._snip_history(spec, messages_for_model) messages_for_model = self._snip_history(spec, messages_for_model)
@ -261,6 +263,7 @@ class AgentRunner:
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
stop_reason = "error" stop_reason = "error"
error = final_content error = final_content
self._append_model_error_placeholder(messages)
context.final_content = final_content context.final_content = final_content
context.error = error context.error = error
context.stop_reason = stop_reason context.stop_reason = stop_reason
@ -524,6 +527,12 @@ class AgentRunner:
return return
messages.append(build_assistant_message(content)) messages.append(build_assistant_message(content))
@staticmethod
def _append_model_error_placeholder(messages: list[dict[str, Any]]) -> None:
if messages and messages[-1].get("role") == "assistant" and not messages[-1].get("tool_calls"):
return
messages.append(build_assistant_message(_PERSISTED_MODEL_ERROR_PLACEHOLDER))
def _normalize_tool_result( def _normalize_tool_result(
self, self,
spec: AgentRunSpec, spec: AgentRunSpec,
@ -552,6 +561,32 @@ class AgentRunner:
return truncate_text(content, spec.max_tool_result_chars) return truncate_text(content, spec.max_tool_result_chars)
return content return content
@staticmethod
def _drop_orphan_tool_results(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Drop tool results that have no matching assistant tool_call earlier in the history."""
declared: set[str] = set()
updated: list[dict[str, Any]] | None = None
for idx, msg in enumerate(messages):
role = msg.get("role")
if role == "assistant":
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
if role == "tool":
tid = msg.get("tool_call_id")
if tid and str(tid) not in declared:
if updated is None:
updated = [dict(m) for m in messages[:idx]]
continue
if updated is not None:
updated.append(dict(msg))
if updated is None:
return messages
return updated
@staticmethod @staticmethod
def _backfill_missing_tool_results( def _backfill_missing_tool_results(
messages: list[dict[str, Any]], messages: list[dict[str, Any]],

View File

@ -375,6 +375,14 @@ class LLMProvider(ABC):
and role in ("user", "assistant") and role in ("user", "assistant")
): ):
prev = merged[-1] prev = merged[-1]
if role == "assistant":
prev_has_tools = bool(prev.get("tool_calls"))
curr_has_tools = bool(msg.get("tool_calls"))
if curr_has_tools:
merged[-1] = dict(msg)
continue
if prev_has_tools:
continue
prev_content = prev.get("content") or "" prev_content = prev.get("content") or ""
curr_content = msg.get("content") or "" curr_content = msg.get("content") or ""
if isinstance(prev_content, str) and isinstance(curr_content, str): if isinstance(prev_content, str) and isinstance(curr_content, str):

View File

@ -243,6 +243,10 @@ class OpenAICompatProvider(LLMProvider):
tc_clean["id"] = map_id(tc_clean.get("id")) tc_clean["id"] = map_id(tc_clean.get("id"))
normalized.append(tc_clean) normalized.append(tc_clean)
clean["tool_calls"] = normalized clean["tool_calls"] = normalized
if clean.get("role") == "assistant":
# Some OpenAI-compatible gateways reject assistant messages
# that mix non-empty content with tool_calls.
clean["content"] = None
if "tool_call_id" in clean and clean["tool_call_id"]: if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"]) clean["tool_call_id"] = map_id(clean["tool_call_id"])
return self._enforce_role_alternation(sanitized) return self._enforce_role_alternation(sanitized)

View File

@ -859,7 +859,11 @@ async def test_loop_retries_think_only_final_response(tmp_path):
async def test_llm_error_not_appended_to_session_messages(): async def test_llm_error_not_appended_to_session_messages():
"""When LLM returns finish_reason='error', the error content must NOT be """When LLM returns finish_reason='error', the error content must NOT be
appended to the messages list (prevents polluting session history).""" appended to the messages list (prevents polluting session history)."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.runner import (
AgentRunSpec,
AgentRunner,
_PERSISTED_MODEL_ERROR_PLACEHOLDER,
)
provider = MagicMock() provider = MagicMock()
provider.chat_with_retry = AsyncMock(return_value=LLMResponse( provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
@ -882,6 +886,7 @@ async def test_llm_error_not_appended_to_session_messages():
assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"] assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"]
assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \ assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \
"Error content should not appear in session messages" "Error content should not appear in session messages"
assert assistant_msgs[-1]["content"] == _PERSISTED_MODEL_ERROR_PLACEHOLDER
@pytest.mark.asyncio @pytest.mark.asyncio
@ -918,6 +923,56 @@ async def test_streamed_flag_not_set_on_llm_error(tmp_path):
"_streamed must not be set when stop_reason is error" "_streamed must not be set when stop_reason is error"
@pytest.mark.asyncio
async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path):
from nanobot.agent.loop import AgentLoop
from nanobot.agent.runner import _PERSISTED_MODEL_ERROR_PLACEHOLDER
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(side_effect=[
LLMResponse(content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}),
LLMResponse(content="Recovered answer", tool_calls=[], usage={}),
])
loop = AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
loop.tools.get_definitions = MagicMock(return_value=[])
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
first = await loop._process_message(
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="first question")
)
assert first is not None
assert first.content == "429 rate limit exceeded"
session = loop.sessions.get_or_create("cli:test")
assert [
{key: value for key, value in message.items() if key in {"role", "content"}}
for message in session.messages
] == [
{"role": "user", "content": "first question"},
{"role": "assistant", "content": _PERSISTED_MODEL_ERROR_PLACEHOLDER},
]
second = await loop._process_message(
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="second question")
)
assert second is not None
assert second.content == "Recovered answer"
request_messages = provider.chat_with_retry.await_args_list[1].kwargs["messages"]
non_system = [message for message in request_messages if message.get("role") != "system"]
assert non_system[0] == {"role": "user", "content": "first question"}
assert non_system[1] == {
"role": "assistant",
"content": _PERSISTED_MODEL_ERROR_PLACEHOLDER,
}
assert non_system[2]["role"] == "user"
assert "second question" in non_system[2]["content"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_runner_tool_error_sets_final_content(): async def test_runner_tool_error_sets_final_content():
from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.runner import AgentRunSpec, AgentRunner
@ -1218,6 +1273,41 @@ async def test_backfill_missing_tool_results_inserts_error():
assert backfilled[0]["name"] == "read_file" assert backfilled[0]["name"] == "read_file"
def test_drop_orphan_tool_results_removes_unmatched_tool_messages():
from nanobot.agent.runner import AgentRunner
messages = [
{"role": "system", "content": "system"},
{"role": "user", "content": "old user"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
],
},
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
{"role": "assistant", "content": "after tool"},
]
cleaned = AgentRunner._drop_orphan_tool_results(messages)
assert cleaned == [
{"role": "system", "content": "system"},
{"role": "user", "content": "old user"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
],
},
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
{"role": "assistant", "content": "after tool"},
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_backfill_noop_when_complete(): async def test_backfill_noop_when_complete():
"""Complete message chains should not be modified.""" """Complete message chains should not be modified."""
@ -1239,6 +1329,45 @@ async def test_backfill_noop_when_complete():
assert result is messages # same object — no copy assert result is messages # same object — no copy
@pytest.mark.asyncio
async def test_runner_drops_orphan_tool_results_before_model_request():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_messages: list[dict] = []
async def chat_with_retry(*, messages, **kwargs):
captured_messages[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[
{"role": "system", "content": "system"},
{"role": "user", "content": "old user"},
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
{"role": "assistant", "content": "after orphan"},
{"role": "user", "content": "new prompt"},
],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert all(
message.get("tool_call_id") != "call_orphan"
for message in captured_messages
if message.get("role") == "tool"
)
assert result.messages[2]["tool_call_id"] == "call_orphan"
assert result.final_content == "done"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path): async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path):
"""Historical backfill should not duplicate old tail messages on persist.""" """Historical backfill should not duplicate old tail messages on persist."""

View File

@ -84,6 +84,34 @@ class TestEnforceRoleAlternation:
tool_msgs = [m for m in result if m["role"] == "tool"] tool_msgs = [m for m in result if m["role"] == "tool"]
assert len(tool_msgs) == 2 assert len(tool_msgs) == 2
def test_consecutive_assistant_keeps_later_tool_call_message(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Previous reply"},
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
{"role": "tool", "content": "result1", "tool_call_id": "1"},
{"role": "user", "content": "Next"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert result[1]["role"] == "assistant"
assert result[1]["tool_calls"] == [{"id": "1"}]
assert result[1]["content"] is None
assert result[2]["role"] == "tool"
def test_consecutive_assistant_does_not_overwrite_existing_tool_call_message(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
{"role": "assistant", "content": "Later plain assistant"},
{"role": "tool", "content": "result1", "tool_call_id": "1"},
{"role": "user", "content": "Next"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert result[1]["role"] == "assistant"
assert result[1]["tool_calls"] == [{"id": "1"}]
assert result[1]["content"] is None
assert result[2]["role"] == "tool"
def test_non_string_content_uses_latest(self): def test_non_string_content_uses_latest(self):
msgs = [ msgs = [
{"role": "user", "content": [{"type": "text", "text": "A"}]}, {"role": "user", "content": [{"type": "text", "text": "A"}]},

View File

@ -550,11 +550,40 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
{"role": "user", "content": "thanks"}, {"role": "user", "content": "thanks"},
]) ])
assert sanitized[1]["content"] is None
assert sanitized[1]["reasoning_content"] == "hidden" assert sanitized[1]["reasoning_content"] == "hidden"
assert sanitized[1]["extra_content"] == {"debug": True} assert sanitized[1]["extra_content"] == {"debug": True}
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}} assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
sanitized = provider._sanitize_messages([
{"role": "user", "content": "不错"},
{"role": "assistant", "content": "对,破 4 万指日可待"},
{
"role": "assistant",
"content": "<think>我再查一下</think>",
"tool_calls": [
{
"id": "call_function_akxp3wqzn7ph_1",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
}
],
},
{"role": "tool", "tool_call_id": "call_function_akxp3wqzn7ph_1", "name": "exec", "content": "ok"},
{"role": "user", "content": "多少star了呢"},
])
assert sanitized[1]["role"] == "assistant"
assert sanitized[1]["content"] is None
assert sanitized[1]["tool_calls"][0]["id"] == "3ec83c30d"
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None: async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0") monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")