fix(provider): dedupe repeated tool ids in history

This commit is contained in:
Xubin Ren 2026-05-21 15:33:49 +08:00
parent d29fcaf5d1
commit 23d5148a57
2 changed files with 74 additions and 3 deletions

View File

@ -11,6 +11,7 @@ import secrets
import string import string
import time import time
import uuid import uuid
from collections import deque
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from ipaddress import ip_address from ipaddress import ip_address
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -463,6 +464,7 @@ class OpenAICompatProvider(LLMProvider):
"""Strip non-standard keys, normalize tool_call IDs.""" """Strip non-standard keys, normalize tool_call IDs."""
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS) sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
id_map: dict[str, str] = {} id_map: dict[str, str] = {}
pending_tool_ids: dict[str, deque[str]] = {}
force_string_content = bool(self._spec and self._spec.name == "deepseek") force_string_content = bool(self._spec and self._spec.name == "deepseek")
def map_id(value: Any) -> Any: def map_id(value: Any) -> Any:
@ -470,15 +472,49 @@ class OpenAICompatProvider(LLMProvider):
return value return value
return id_map.setdefault(value, self._normalize_tool_call_id(value)) return id_map.setdefault(value, self._normalize_tool_call_id(value))
def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str:
if isinstance(value, str) and value:
base = map_id(value)
else:
base = _short_tool_id()
if not isinstance(base, str) or not base:
base = _short_tool_id()
if base not in used_ids:
return base
seed = value if isinstance(value, str) and value else base
salt = 1
while True:
candidate = self._normalize_tool_call_id(f"{seed}:{idx}:{salt}")
if isinstance(candidate, str) and candidate not in used_ids:
return candidate
salt += 1
def map_tool_result_id(value: Any) -> Any:
if not isinstance(value, str):
return value
queue = pending_tool_ids.get(value)
if queue:
mapped = queue.popleft()
if not queue:
pending_tool_ids.pop(value, None)
return mapped
return map_id(value)
for clean in sanitized: for clean in sanitized:
if isinstance(clean.get("tool_calls"), list): if isinstance(clean.get("tool_calls"), list):
normalized = [] normalized = []
for tc in clean["tool_calls"]: used_ids: set[str] = set()
for idx, tc in enumerate(clean["tool_calls"]):
if not isinstance(tc, dict): if not isinstance(tc, dict):
normalized.append(tc) normalized.append(tc)
continue continue
tc_clean = dict(tc) tc_clean = dict(tc)
tc_clean["id"] = map_id(tc_clean.get("id")) raw_id = tc_clean.get("id")
mapped_id = unique_tool_id(raw_id, used_ids, idx)
tc_clean["id"] = mapped_id
used_ids.add(mapped_id)
if isinstance(raw_id, str) and raw_id:
pending_tool_ids.setdefault(raw_id, deque()).append(mapped_id)
function = tc_clean.get("function") function = tc_clean.get("function")
if isinstance(function, dict): if isinstance(function, dict):
function_clean = dict(function) function_clean = dict(function)
@ -496,7 +532,7 @@ class OpenAICompatProvider(LLMProvider):
# that mix non-empty content with tool_calls. # that mix non-empty content with tool_calls.
clean["content"] = None clean["content"] = None
if "tool_call_id" in clean and clean["tool_call_id"]: if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"]) clean["tool_call_id"] = map_tool_result_id(clean["tool_call_id"])
if ( if (
force_string_content force_string_content
and not (clean.get("role") == "assistant" and clean.get("tool_calls")) and not (clean.get("role") == "assistant" and clean.get("tool_calls"))

View File

@ -1007,6 +1007,41 @@ def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -
assert sanitized[2]["tool_call_id"] == "3ec83c30d" assert sanitized[2]["tool_call_id"] == "3ec83c30d"
def test_openai_compat_deduplicates_duplicate_tool_call_ids_in_history() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
sanitized = provider._sanitize_messages([
{"role": "user", "content": "check both files"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "ab1b45c2a",
"type": "function",
"function": {"name": "read_file", "arguments": '{"path":"a.txt"}'},
},
{
"id": "ab1b45c2a",
"type": "function",
"function": {"name": "read_file", "arguments": '{"path":"b.txt"}'},
},
],
},
{"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "a"},
{"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "b"},
{"role": "user", "content": "continue"},
])
tool_call_ids = [tc["id"] for tc in sanitized[1]["tool_calls"]]
tool_result_ids = [sanitized[2]["tool_call_id"], sanitized[3]["tool_call_id"]]
assert tool_call_ids[0] == "ab1b45c2a"
assert len(tool_call_ids) == len(set(tool_call_ids)) == 2
assert tool_result_ids == tool_call_ids
def test_openai_compat_stringifies_dict_tool_arguments() -> None: def test_openai_compat_stringifies_dict_tool_arguments() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider() provider = OpenAICompatProvider()