mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-10 21:23:39 +00:00
fix(agent): preserve interrupted tool-call turns
Keep tool-call assistant messages valid across provider sanitization and avoid trailing user-only history after model errors. This prevents follow-up requests from sending broken tool chains back to the gateway.
This commit is contained in:
parent
c579d67887
commit
2bef9cb650
@ -31,6 +31,7 @@ from nanobot.utils.runtime import (
|
||||
)
|
||||
|
||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||
_PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]"
|
||||
_MAX_EMPTY_RETRIES = 2
|
||||
_MAX_LENGTH_RECOVERIES = 3
|
||||
_SNIP_SAFETY_BUFFER = 1024
|
||||
@ -105,7 +106,8 @@ class AgentRunner:
|
||||
# may repair or compact historical messages for the model, but
|
||||
# those synthetic edits must not shift the append boundary used
|
||||
# later when the caller saves only the new turn.
|
||||
messages_for_model = self._backfill_missing_tool_results(messages)
|
||||
messages_for_model = self._drop_orphan_tool_results(messages)
|
||||
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
||||
messages_for_model = self._microcompact(messages_for_model)
|
||||
messages_for_model = self._apply_tool_result_budget(spec, messages_for_model)
|
||||
messages_for_model = self._snip_history(spec, messages_for_model)
|
||||
@ -261,6 +263,7 @@ class AgentRunner:
|
||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||
stop_reason = "error"
|
||||
error = final_content
|
||||
self._append_model_error_placeholder(messages)
|
||||
context.final_content = final_content
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
@ -524,6 +527,12 @@ class AgentRunner:
|
||||
return
|
||||
messages.append(build_assistant_message(content))
|
||||
|
||||
@staticmethod
|
||||
def _append_model_error_placeholder(messages: list[dict[str, Any]]) -> None:
|
||||
if messages and messages[-1].get("role") == "assistant" and not messages[-1].get("tool_calls"):
|
||||
return
|
||||
messages.append(build_assistant_message(_PERSISTED_MODEL_ERROR_PLACEHOLDER))
|
||||
|
||||
def _normalize_tool_result(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
@ -552,6 +561,32 @@ class AgentRunner:
|
||||
return truncate_text(content, spec.max_tool_result_chars)
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
def _drop_orphan_tool_results(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Drop tool results that have no matching assistant tool_call earlier in the history."""
|
||||
declared: set[str] = set()
|
||||
updated: list[dict[str, Any]] | None = None
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
if role == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
if role == "tool":
|
||||
tid = msg.get("tool_call_id")
|
||||
if tid and str(tid) not in declared:
|
||||
if updated is None:
|
||||
updated = [dict(m) for m in messages[:idx]]
|
||||
continue
|
||||
if updated is not None:
|
||||
updated.append(dict(msg))
|
||||
|
||||
if updated is None:
|
||||
return messages
|
||||
return updated
|
||||
|
||||
@staticmethod
|
||||
def _backfill_missing_tool_results(
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@ -375,6 +375,14 @@ class LLMProvider(ABC):
|
||||
and role in ("user", "assistant")
|
||||
):
|
||||
prev = merged[-1]
|
||||
if role == "assistant":
|
||||
prev_has_tools = bool(prev.get("tool_calls"))
|
||||
curr_has_tools = bool(msg.get("tool_calls"))
|
||||
if curr_has_tools:
|
||||
merged[-1] = dict(msg)
|
||||
continue
|
||||
if prev_has_tools:
|
||||
continue
|
||||
prev_content = prev.get("content") or ""
|
||||
curr_content = msg.get("content") or ""
|
||||
if isinstance(prev_content, str) and isinstance(curr_content, str):
|
||||
|
||||
@ -243,6 +243,10 @@ class OpenAICompatProvider(LLMProvider):
|
||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||
normalized.append(tc_clean)
|
||||
clean["tool_calls"] = normalized
|
||||
if clean.get("role") == "assistant":
|
||||
# Some OpenAI-compatible gateways reject assistant messages
|
||||
# that mix non-empty content with tool_calls.
|
||||
clean["content"] = None
|
||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
return self._enforce_role_alternation(sanitized)
|
||||
|
||||
@ -859,7 +859,11 @@ async def test_loop_retries_think_only_final_response(tmp_path):
|
||||
async def test_llm_error_not_appended_to_session_messages():
|
||||
"""When LLM returns finish_reason='error', the error content must NOT be
|
||||
appended to the messages list (prevents polluting session history)."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.runner import (
|
||||
AgentRunSpec,
|
||||
AgentRunner,
|
||||
_PERSISTED_MODEL_ERROR_PLACEHOLDER,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
@ -882,6 +886,7 @@ async def test_llm_error_not_appended_to_session_messages():
|
||||
assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"]
|
||||
assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \
|
||||
"Error content should not appear in session messages"
|
||||
assert assistant_msgs[-1]["content"] == _PERSISTED_MODEL_ERROR_PLACEHOLDER
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -918,6 +923,56 @@ async def test_streamed_flag_not_set_on_llm_error(tmp_path):
|
||||
"_streamed must not be set when stop_reason is error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.runner import _PERSISTED_MODEL_ERROR_PLACEHOLDER
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(side_effect=[
|
||||
LLMResponse(content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}),
|
||||
LLMResponse(content="Recovered answer", tool_calls=[], usage={}),
|
||||
])
|
||||
|
||||
loop = AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
|
||||
first = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="first question")
|
||||
)
|
||||
assert first is not None
|
||||
assert first.content == "429 rate limit exceeded"
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
assert [
|
||||
{key: value for key, value in message.items() if key in {"role", "content"}}
|
||||
for message in session.messages
|
||||
] == [
|
||||
{"role": "user", "content": "first question"},
|
||||
{"role": "assistant", "content": _PERSISTED_MODEL_ERROR_PLACEHOLDER},
|
||||
]
|
||||
|
||||
second = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="second question")
|
||||
)
|
||||
assert second is not None
|
||||
assert second.content == "Recovered answer"
|
||||
|
||||
request_messages = provider.chat_with_retry.await_args_list[1].kwargs["messages"]
|
||||
non_system = [message for message in request_messages if message.get("role") != "system"]
|
||||
assert non_system[0] == {"role": "user", "content": "first question"}
|
||||
assert non_system[1] == {
|
||||
"role": "assistant",
|
||||
"content": _PERSISTED_MODEL_ERROR_PLACEHOLDER,
|
||||
}
|
||||
assert non_system[2]["role"] == "user"
|
||||
assert "second question" in non_system[2]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_tool_error_sets_final_content():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
@ -1218,6 +1273,41 @@ async def test_backfill_missing_tool_results_inserts_error():
|
||||
assert backfilled[0]["name"] == "read_file"
|
||||
|
||||
|
||||
def test_drop_orphan_tool_results_removes_unmatched_tool_messages():
|
||||
from nanobot.agent.runner import AgentRunner
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
|
||||
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
|
||||
{"role": "assistant", "content": "after tool"},
|
||||
]
|
||||
|
||||
cleaned = AgentRunner._drop_orphan_tool_results(messages)
|
||||
|
||||
assert cleaned == [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
|
||||
{"role": "assistant", "content": "after tool"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_noop_when_complete():
|
||||
"""Complete message chains should not be modified."""
|
||||
@ -1239,6 +1329,45 @@ async def test_backfill_noop_when_complete():
|
||||
assert result is messages # same object — no copy
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_drops_orphan_tool_results_before_model_request():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
captured_messages[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
|
||||
{"role": "assistant", "content": "after orphan"},
|
||||
{"role": "user", "content": "new prompt"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert all(
|
||||
message.get("tool_call_id") != "call_orphan"
|
||||
for message in captured_messages
|
||||
if message.get("role") == "tool"
|
||||
)
|
||||
assert result.messages[2]["tool_call_id"] == "call_orphan"
|
||||
assert result.final_content == "done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path):
|
||||
"""Historical backfill should not duplicate old tail messages on persist."""
|
||||
|
||||
@ -84,6 +84,34 @@ class TestEnforceRoleAlternation:
|
||||
tool_msgs = [m for m in result if m["role"] == "tool"]
|
||||
assert len(tool_msgs) == 2
|
||||
|
||||
def test_consecutive_assistant_keeps_later_tool_call_message(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Previous reply"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||
{"role": "user", "content": "Next"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert result[1]["tool_calls"] == [{"id": "1"}]
|
||||
assert result[1]["content"] is None
|
||||
assert result[2]["role"] == "tool"
|
||||
|
||||
def test_consecutive_assistant_does_not_overwrite_existing_tool_call_message(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||
{"role": "assistant", "content": "Later plain assistant"},
|
||||
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||
{"role": "user", "content": "Next"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert result[1]["tool_calls"] == [{"id": "1"}]
|
||||
assert result[1]["content"] is None
|
||||
assert result[2]["role"] == "tool"
|
||||
|
||||
def test_non_string_content_uses_latest(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": [{"type": "text", "text": "A"}]},
|
||||
|
||||
@ -550,11 +550,40 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
|
||||
{"role": "user", "content": "thanks"},
|
||||
])
|
||||
|
||||
assert sanitized[1]["content"] is None
|
||||
assert sanitized[1]["reasoning_content"] == "hidden"
|
||||
assert sanitized[1]["extra_content"] == {"debug": True}
|
||||
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||
|
||||
|
||||
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{"role": "user", "content": "不错"},
|
||||
{"role": "assistant", "content": "对,破 4 万指日可待"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<think>我再查一下</think>",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_function_akxp3wqzn7ph_1",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_function_akxp3wqzn7ph_1", "name": "exec", "content": "ok"},
|
||||
{"role": "user", "content": "多少star了呢"},
|
||||
])
|
||||
|
||||
assert sanitized[1]["role"] == "assistant"
|
||||
assert sanitized[1]["content"] is None
|
||||
assert sanitized[1]["tool_calls"][0]["id"] == "3ec83c30d"
|
||||
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
|
||||
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user