mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-26 04:45:59 +00:00
fix(agent): preserve interrupted tool-call turns
Keep tool-call assistant messages valid across provider sanitization and avoid trailing user-only history after model errors. This prevents follow-up requests from sending broken tool chains back to the gateway.
This commit is contained in:
parent
c579d67887
commit
2bef9cb650
@ -31,6 +31,7 @@ from nanobot.utils.runtime import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||||
|
_PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]"
|
||||||
_MAX_EMPTY_RETRIES = 2
|
_MAX_EMPTY_RETRIES = 2
|
||||||
_MAX_LENGTH_RECOVERIES = 3
|
_MAX_LENGTH_RECOVERIES = 3
|
||||||
_SNIP_SAFETY_BUFFER = 1024
|
_SNIP_SAFETY_BUFFER = 1024
|
||||||
@ -105,7 +106,8 @@ class AgentRunner:
|
|||||||
# may repair or compact historical messages for the model, but
|
# may repair or compact historical messages for the model, but
|
||||||
# those synthetic edits must not shift the append boundary used
|
# those synthetic edits must not shift the append boundary used
|
||||||
# later when the caller saves only the new turn.
|
# later when the caller saves only the new turn.
|
||||||
messages_for_model = self._backfill_missing_tool_results(messages)
|
messages_for_model = self._drop_orphan_tool_results(messages)
|
||||||
|
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
||||||
messages_for_model = self._microcompact(messages_for_model)
|
messages_for_model = self._microcompact(messages_for_model)
|
||||||
messages_for_model = self._apply_tool_result_budget(spec, messages_for_model)
|
messages_for_model = self._apply_tool_result_budget(spec, messages_for_model)
|
||||||
messages_for_model = self._snip_history(spec, messages_for_model)
|
messages_for_model = self._snip_history(spec, messages_for_model)
|
||||||
@ -261,6 +263,7 @@ class AgentRunner:
|
|||||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||||
stop_reason = "error"
|
stop_reason = "error"
|
||||||
error = final_content
|
error = final_content
|
||||||
|
self._append_model_error_placeholder(messages)
|
||||||
context.final_content = final_content
|
context.final_content = final_content
|
||||||
context.error = error
|
context.error = error
|
||||||
context.stop_reason = stop_reason
|
context.stop_reason = stop_reason
|
||||||
@ -524,6 +527,12 @@ class AgentRunner:
|
|||||||
return
|
return
|
||||||
messages.append(build_assistant_message(content))
|
messages.append(build_assistant_message(content))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _append_model_error_placeholder(messages: list[dict[str, Any]]) -> None:
|
||||||
|
if messages and messages[-1].get("role") == "assistant" and not messages[-1].get("tool_calls"):
|
||||||
|
return
|
||||||
|
messages.append(build_assistant_message(_PERSISTED_MODEL_ERROR_PLACEHOLDER))
|
||||||
|
|
||||||
def _normalize_tool_result(
|
def _normalize_tool_result(
|
||||||
self,
|
self,
|
||||||
spec: AgentRunSpec,
|
spec: AgentRunSpec,
|
||||||
@ -552,6 +561,32 @@ class AgentRunner:
|
|||||||
return truncate_text(content, spec.max_tool_result_chars)
|
return truncate_text(content, spec.max_tool_result_chars)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _drop_orphan_tool_results(
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Drop tool results that have no matching assistant tool_call earlier in the history."""
|
||||||
|
declared: set[str] = set()
|
||||||
|
updated: list[dict[str, Any]] | None = None
|
||||||
|
for idx, msg in enumerate(messages):
|
||||||
|
role = msg.get("role")
|
||||||
|
if role == "assistant":
|
||||||
|
for tc in msg.get("tool_calls") or []:
|
||||||
|
if isinstance(tc, dict) and tc.get("id"):
|
||||||
|
declared.add(str(tc["id"]))
|
||||||
|
if role == "tool":
|
||||||
|
tid = msg.get("tool_call_id")
|
||||||
|
if tid and str(tid) not in declared:
|
||||||
|
if updated is None:
|
||||||
|
updated = [dict(m) for m in messages[:idx]]
|
||||||
|
continue
|
||||||
|
if updated is not None:
|
||||||
|
updated.append(dict(msg))
|
||||||
|
|
||||||
|
if updated is None:
|
||||||
|
return messages
|
||||||
|
return updated
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _backfill_missing_tool_results(
|
def _backfill_missing_tool_results(
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
|
|||||||
@ -375,6 +375,14 @@ class LLMProvider(ABC):
|
|||||||
and role in ("user", "assistant")
|
and role in ("user", "assistant")
|
||||||
):
|
):
|
||||||
prev = merged[-1]
|
prev = merged[-1]
|
||||||
|
if role == "assistant":
|
||||||
|
prev_has_tools = bool(prev.get("tool_calls"))
|
||||||
|
curr_has_tools = bool(msg.get("tool_calls"))
|
||||||
|
if curr_has_tools:
|
||||||
|
merged[-1] = dict(msg)
|
||||||
|
continue
|
||||||
|
if prev_has_tools:
|
||||||
|
continue
|
||||||
prev_content = prev.get("content") or ""
|
prev_content = prev.get("content") or ""
|
||||||
curr_content = msg.get("content") or ""
|
curr_content = msg.get("content") or ""
|
||||||
if isinstance(prev_content, str) and isinstance(curr_content, str):
|
if isinstance(prev_content, str) and isinstance(curr_content, str):
|
||||||
|
|||||||
@ -243,6 +243,10 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||||
normalized.append(tc_clean)
|
normalized.append(tc_clean)
|
||||||
clean["tool_calls"] = normalized
|
clean["tool_calls"] = normalized
|
||||||
|
if clean.get("role") == "assistant":
|
||||||
|
# Some OpenAI-compatible gateways reject assistant messages
|
||||||
|
# that mix non-empty content with tool_calls.
|
||||||
|
clean["content"] = None
|
||||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||||
return self._enforce_role_alternation(sanitized)
|
return self._enforce_role_alternation(sanitized)
|
||||||
|
|||||||
@ -859,7 +859,11 @@ async def test_loop_retries_think_only_final_response(tmp_path):
|
|||||||
async def test_llm_error_not_appended_to_session_messages():
|
async def test_llm_error_not_appended_to_session_messages():
|
||||||
"""When LLM returns finish_reason='error', the error content must NOT be
|
"""When LLM returns finish_reason='error', the error content must NOT be
|
||||||
appended to the messages list (prevents polluting session history)."""
|
appended to the messages list (prevents polluting session history)."""
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
from nanobot.agent.runner import (
|
||||||
|
AgentRunSpec,
|
||||||
|
AgentRunner,
|
||||||
|
_PERSISTED_MODEL_ERROR_PLACEHOLDER,
|
||||||
|
)
|
||||||
|
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
@ -882,6 +886,7 @@ async def test_llm_error_not_appended_to_session_messages():
|
|||||||
assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"]
|
assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"]
|
||||||
assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \
|
assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \
|
||||||
"Error content should not appear in session messages"
|
"Error content should not appear in session messages"
|
||||||
|
assert assistant_msgs[-1]["content"] == _PERSISTED_MODEL_ERROR_PLACEHOLDER
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -918,6 +923,56 @@ async def test_streamed_flag_not_set_on_llm_error(tmp_path):
|
|||||||
"_streamed must not be set when stop_reason is error"
|
"_streamed must not be set when stop_reason is error"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path):
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.agent.runner import _PERSISTED_MODEL_ERROR_PLACEHOLDER
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.chat_with_retry = AsyncMock(side_effect=[
|
||||||
|
LLMResponse(content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}),
|
||||||
|
LLMResponse(content="Recovered answer", tool_calls=[], usage={}),
|
||||||
|
])
|
||||||
|
|
||||||
|
loop = AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
first = await loop._process_message(
|
||||||
|
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="first question")
|
||||||
|
)
|
||||||
|
assert first is not None
|
||||||
|
assert first.content == "429 rate limit exceeded"
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
assert [
|
||||||
|
{key: value for key, value in message.items() if key in {"role", "content"}}
|
||||||
|
for message in session.messages
|
||||||
|
] == [
|
||||||
|
{"role": "user", "content": "first question"},
|
||||||
|
{"role": "assistant", "content": _PERSISTED_MODEL_ERROR_PLACEHOLDER},
|
||||||
|
]
|
||||||
|
|
||||||
|
second = await loop._process_message(
|
||||||
|
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="second question")
|
||||||
|
)
|
||||||
|
assert second is not None
|
||||||
|
assert second.content == "Recovered answer"
|
||||||
|
|
||||||
|
request_messages = provider.chat_with_retry.await_args_list[1].kwargs["messages"]
|
||||||
|
non_system = [message for message in request_messages if message.get("role") != "system"]
|
||||||
|
assert non_system[0] == {"role": "user", "content": "first question"}
|
||||||
|
assert non_system[1] == {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": _PERSISTED_MODEL_ERROR_PLACEHOLDER,
|
||||||
|
}
|
||||||
|
assert non_system[2]["role"] == "user"
|
||||||
|
assert "second question" in non_system[2]["content"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_tool_error_sets_final_content():
|
async def test_runner_tool_error_sets_final_content():
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
@ -1218,6 +1273,41 @@ async def test_backfill_missing_tool_results_inserts_error():
|
|||||||
assert backfilled[0]["name"] == "read_file"
|
assert backfilled[0]["name"] == "read_file"
|
||||||
|
|
||||||
|
|
||||||
|
def test_drop_orphan_tool_results_removes_unmatched_tool_messages():
|
||||||
|
from nanobot.agent.runner import AgentRunner
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old user"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
|
||||||
|
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
|
||||||
|
{"role": "assistant", "content": "after tool"},
|
||||||
|
]
|
||||||
|
|
||||||
|
cleaned = AgentRunner._drop_orphan_tool_results(messages)
|
||||||
|
|
||||||
|
assert cleaned == [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old user"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
|
||||||
|
{"role": "assistant", "content": "after tool"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_backfill_noop_when_complete():
|
async def test_backfill_noop_when_complete():
|
||||||
"""Complete message chains should not be modified."""
|
"""Complete message chains should not be modified."""
|
||||||
@ -1239,6 +1329,45 @@ async def test_backfill_noop_when_complete():
|
|||||||
assert result is messages # same object — no copy
|
assert result is messages # same object — no copy
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_drops_orphan_tool_results_before_model_request():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_messages: list[dict] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
captured_messages[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old user"},
|
||||||
|
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
|
||||||
|
{"role": "assistant", "content": "after orphan"},
|
||||||
|
{"role": "user", "content": "new prompt"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
message.get("tool_call_id") != "call_orphan"
|
||||||
|
for message in captured_messages
|
||||||
|
if message.get("role") == "tool"
|
||||||
|
)
|
||||||
|
assert result.messages[2]["tool_call_id"] == "call_orphan"
|
||||||
|
assert result.final_content == "done"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path):
|
async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path):
|
||||||
"""Historical backfill should not duplicate old tail messages on persist."""
|
"""Historical backfill should not duplicate old tail messages on persist."""
|
||||||
|
|||||||
@ -84,6 +84,34 @@ class TestEnforceRoleAlternation:
|
|||||||
tool_msgs = [m for m in result if m["role"] == "tool"]
|
tool_msgs = [m for m in result if m["role"] == "tool"]
|
||||||
assert len(tool_msgs) == 2
|
assert len(tool_msgs) == 2
|
||||||
|
|
||||||
|
def test_consecutive_assistant_keeps_later_tool_call_message(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "Previous reply"},
|
||||||
|
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||||
|
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||||
|
{"role": "user", "content": "Next"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert result[1]["role"] == "assistant"
|
||||||
|
assert result[1]["tool_calls"] == [{"id": "1"}]
|
||||||
|
assert result[1]["content"] is None
|
||||||
|
assert result[2]["role"] == "tool"
|
||||||
|
|
||||||
|
def test_consecutive_assistant_does_not_overwrite_existing_tool_call_message(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||||
|
{"role": "assistant", "content": "Later plain assistant"},
|
||||||
|
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||||
|
{"role": "user", "content": "Next"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert result[1]["role"] == "assistant"
|
||||||
|
assert result[1]["tool_calls"] == [{"id": "1"}]
|
||||||
|
assert result[1]["content"] is None
|
||||||
|
assert result[2]["role"] == "tool"
|
||||||
|
|
||||||
def test_non_string_content_uses_latest(self):
|
def test_non_string_content_uses_latest(self):
|
||||||
msgs = [
|
msgs = [
|
||||||
{"role": "user", "content": [{"type": "text", "text": "A"}]},
|
{"role": "user", "content": [{"type": "text", "text": "A"}]},
|
||||||
|
|||||||
@ -550,11 +550,40 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
|
|||||||
{"role": "user", "content": "thanks"},
|
{"role": "user", "content": "thanks"},
|
||||||
])
|
])
|
||||||
|
|
||||||
|
assert sanitized[1]["content"] is None
|
||||||
assert sanitized[1]["reasoning_content"] == "hidden"
|
assert sanitized[1]["reasoning_content"] == "hidden"
|
||||||
assert sanitized[1]["extra_content"] == {"debug": True}
|
assert sanitized[1]["extra_content"] == {"debug": True}
|
||||||
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
sanitized = provider._sanitize_messages([
|
||||||
|
{"role": "user", "content": "不错"},
|
||||||
|
{"role": "assistant", "content": "对,破 4 万指日可待"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "<think>我再查一下</think>",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_function_akxp3wqzn7ph_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_function_akxp3wqzn7ph_1", "name": "exec", "content": "ok"},
|
||||||
|
{"role": "user", "content": "多少star了呢"},
|
||||||
|
])
|
||||||
|
|
||||||
|
assert sanitized[1]["role"] == "assistant"
|
||||||
|
assert sanitized[1]["content"] is None
|
||||||
|
assert sanitized[1]["tool_calls"][0]["id"] == "3ec83c30d"
|
||||||
|
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
|
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
|
||||||
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")
|
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user