mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-22 09:32:33 +00:00
fix(provider): dedupe repeated tool ids in history
This commit is contained in:
parent
d29fcaf5d1
commit
23d5148a57
@ -11,6 +11,7 @@ import secrets
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Awaitable, Callable
|
||||
from ipaddress import ip_address
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -463,6 +464,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
"""Strip non-standard keys, normalize tool_call IDs."""
|
||||
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
||||
id_map: dict[str, str] = {}
|
||||
pending_tool_ids: dict[str, deque[str]] = {}
|
||||
force_string_content = bool(self._spec and self._spec.name == "deepseek")
|
||||
|
||||
def map_id(value: Any) -> Any:
|
||||
@ -470,15 +472,49 @@ class OpenAICompatProvider(LLMProvider):
|
||||
return value
|
||||
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
||||
|
||||
def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str:
|
||||
if isinstance(value, str) and value:
|
||||
base = map_id(value)
|
||||
else:
|
||||
base = _short_tool_id()
|
||||
if not isinstance(base, str) or not base:
|
||||
base = _short_tool_id()
|
||||
if base not in used_ids:
|
||||
return base
|
||||
seed = value if isinstance(value, str) and value else base
|
||||
salt = 1
|
||||
while True:
|
||||
candidate = self._normalize_tool_call_id(f"{seed}:{idx}:{salt}")
|
||||
if isinstance(candidate, str) and candidate not in used_ids:
|
||||
return candidate
|
||||
salt += 1
|
||||
|
||||
def map_tool_result_id(value: Any) -> Any:
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
queue = pending_tool_ids.get(value)
|
||||
if queue:
|
||||
mapped = queue.popleft()
|
||||
if not queue:
|
||||
pending_tool_ids.pop(value, None)
|
||||
return mapped
|
||||
return map_id(value)
|
||||
|
||||
for clean in sanitized:
|
||||
if isinstance(clean.get("tool_calls"), list):
|
||||
normalized = []
|
||||
for tc in clean["tool_calls"]:
|
||||
used_ids: set[str] = set()
|
||||
for idx, tc in enumerate(clean["tool_calls"]):
|
||||
if not isinstance(tc, dict):
|
||||
normalized.append(tc)
|
||||
continue
|
||||
tc_clean = dict(tc)
|
||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||
raw_id = tc_clean.get("id")
|
||||
mapped_id = unique_tool_id(raw_id, used_ids, idx)
|
||||
tc_clean["id"] = mapped_id
|
||||
used_ids.add(mapped_id)
|
||||
if isinstance(raw_id, str) and raw_id:
|
||||
pending_tool_ids.setdefault(raw_id, deque()).append(mapped_id)
|
||||
function = tc_clean.get("function")
|
||||
if isinstance(function, dict):
|
||||
function_clean = dict(function)
|
||||
@ -496,7 +532,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
# that mix non-empty content with tool_calls.
|
||||
clean["content"] = None
|
||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
clean["tool_call_id"] = map_tool_result_id(clean["tool_call_id"])
|
||||
if (
|
||||
force_string_content
|
||||
and not (clean.get("role") == "assistant" and clean.get("tool_calls"))
|
||||
|
||||
@ -1007,6 +1007,41 @@ def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -
|
||||
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
||||
|
||||
|
||||
def test_openai_compat_deduplicates_duplicate_tool_call_ids_in_history() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{"role": "user", "content": "check both files"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "ab1b45c2a",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": '{"path":"a.txt"}'},
|
||||
},
|
||||
{
|
||||
"id": "ab1b45c2a",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": '{"path":"b.txt"}'},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "a"},
|
||||
{"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "b"},
|
||||
{"role": "user", "content": "continue"},
|
||||
])
|
||||
|
||||
tool_call_ids = [tc["id"] for tc in sanitized[1]["tool_calls"]]
|
||||
tool_result_ids = [sanitized[2]["tool_call_id"], sanitized[3]["tool_call_id"]]
|
||||
|
||||
assert tool_call_ids[0] == "ab1b45c2a"
|
||||
assert len(tool_call_ids) == len(set(tool_call_ids)) == 2
|
||||
assert tool_result_ids == tool_call_ids
|
||||
|
||||
|
||||
def test_openai_compat_stringifies_dict_tool_arguments() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user