mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-22 09:32:33 +00:00
fix(provider): dedupe repeated tool ids in history
This commit is contained in:
parent
d29fcaf5d1
commit
23d5148a57
@ -11,6 +11,7 @@ import secrets
|
|||||||
import string
|
import string
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections import deque
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@ -463,6 +464,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
"""Strip non-standard keys, normalize tool_call IDs."""
|
"""Strip non-standard keys, normalize tool_call IDs."""
|
||||||
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
||||||
id_map: dict[str, str] = {}
|
id_map: dict[str, str] = {}
|
||||||
|
pending_tool_ids: dict[str, deque[str]] = {}
|
||||||
force_string_content = bool(self._spec and self._spec.name == "deepseek")
|
force_string_content = bool(self._spec and self._spec.name == "deepseek")
|
||||||
|
|
||||||
def map_id(value: Any) -> Any:
|
def map_id(value: Any) -> Any:
|
||||||
@ -470,15 +472,49 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
return value
|
return value
|
||||||
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
||||||
|
|
||||||
|
def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str:
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
base = map_id(value)
|
||||||
|
else:
|
||||||
|
base = _short_tool_id()
|
||||||
|
if not isinstance(base, str) or not base:
|
||||||
|
base = _short_tool_id()
|
||||||
|
if base not in used_ids:
|
||||||
|
return base
|
||||||
|
seed = value if isinstance(value, str) and value else base
|
||||||
|
salt = 1
|
||||||
|
while True:
|
||||||
|
candidate = self._normalize_tool_call_id(f"{seed}:{idx}:{salt}")
|
||||||
|
if isinstance(candidate, str) and candidate not in used_ids:
|
||||||
|
return candidate
|
||||||
|
salt += 1
|
||||||
|
|
||||||
|
def map_tool_result_id(value: Any) -> Any:
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return value
|
||||||
|
queue = pending_tool_ids.get(value)
|
||||||
|
if queue:
|
||||||
|
mapped = queue.popleft()
|
||||||
|
if not queue:
|
||||||
|
pending_tool_ids.pop(value, None)
|
||||||
|
return mapped
|
||||||
|
return map_id(value)
|
||||||
|
|
||||||
for clean in sanitized:
|
for clean in sanitized:
|
||||||
if isinstance(clean.get("tool_calls"), list):
|
if isinstance(clean.get("tool_calls"), list):
|
||||||
normalized = []
|
normalized = []
|
||||||
for tc in clean["tool_calls"]:
|
used_ids: set[str] = set()
|
||||||
|
for idx, tc in enumerate(clean["tool_calls"]):
|
||||||
if not isinstance(tc, dict):
|
if not isinstance(tc, dict):
|
||||||
normalized.append(tc)
|
normalized.append(tc)
|
||||||
continue
|
continue
|
||||||
tc_clean = dict(tc)
|
tc_clean = dict(tc)
|
||||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
raw_id = tc_clean.get("id")
|
||||||
|
mapped_id = unique_tool_id(raw_id, used_ids, idx)
|
||||||
|
tc_clean["id"] = mapped_id
|
||||||
|
used_ids.add(mapped_id)
|
||||||
|
if isinstance(raw_id, str) and raw_id:
|
||||||
|
pending_tool_ids.setdefault(raw_id, deque()).append(mapped_id)
|
||||||
function = tc_clean.get("function")
|
function = tc_clean.get("function")
|
||||||
if isinstance(function, dict):
|
if isinstance(function, dict):
|
||||||
function_clean = dict(function)
|
function_clean = dict(function)
|
||||||
@ -496,7 +532,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
# that mix non-empty content with tool_calls.
|
# that mix non-empty content with tool_calls.
|
||||||
clean["content"] = None
|
clean["content"] = None
|
||||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
clean["tool_call_id"] = map_tool_result_id(clean["tool_call_id"])
|
||||||
if (
|
if (
|
||||||
force_string_content
|
force_string_content
|
||||||
and not (clean.get("role") == "assistant" and clean.get("tool_calls"))
|
and not (clean.get("role") == "assistant" and clean.get("tool_calls"))
|
||||||
|
|||||||
@ -1007,6 +1007,41 @@ def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -
|
|||||||
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_deduplicates_duplicate_tool_call_ids_in_history() -> None:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
sanitized = provider._sanitize_messages([
|
||||||
|
{"role": "user", "content": "check both files"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "ab1b45c2a",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": '{"path":"a.txt"}'},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "ab1b45c2a",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": '{"path":"b.txt"}'},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "a"},
|
||||||
|
{"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "b"},
|
||||||
|
{"role": "user", "content": "continue"},
|
||||||
|
])
|
||||||
|
|
||||||
|
tool_call_ids = [tc["id"] for tc in sanitized[1]["tool_calls"]]
|
||||||
|
tool_result_ids = [sanitized[2]["tool_call_id"], sanitized[3]["tool_call_id"]]
|
||||||
|
|
||||||
|
assert tool_call_ids[0] == "ab1b45c2a"
|
||||||
|
assert len(tool_call_ids) == len(set(tool_call_ids)) == 2
|
||||||
|
assert tool_result_ids == tool_call_ids
|
||||||
|
|
||||||
|
|
||||||
def test_openai_compat_stringifies_dict_tool_arguments() -> None:
|
def test_openai_compat_stringifies_dict_tool_arguments() -> None:
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
provider = OpenAICompatProvider()
|
provider = OpenAICompatProvider()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user