fix(provider): preserve OpenAI-compatible tool call ids

This commit is contained in:
Yuxin Lou 2026-05-24 19:33:43 +08:00 committed by Xubin Ren
parent c4e2fcaf0c
commit 3f0098839e
2 changed files with 39 additions and 3 deletions

View File

@ -428,6 +428,10 @@ class OpenAICompatProvider(LLMProvider):
return tool_call_id return tool_call_id
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9] return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
def _should_normalize_tool_call_ids(self) -> bool:
"""Return True for providers that reject normal OpenAI tool call IDs."""
return bool(self._spec and self._spec.name == "mistral")
@staticmethod @staticmethod
def _normalize_tool_call_arguments(arguments: Any) -> str: def _normalize_tool_call_arguments(arguments: Any) -> str:
"""Force function.arguments into a valid JSON object string.""" """Force function.arguments into a valid JSON object string."""
@ -466,10 +470,13 @@ class OpenAICompatProvider(LLMProvider):
id_map: dict[str, str] = {} id_map: dict[str, str] = {}
pending_tool_ids: dict[str, deque[str]] = {} pending_tool_ids: dict[str, deque[str]] = {}
force_string_content = bool(self._spec and self._spec.name == "deepseek") force_string_content = bool(self._spec and self._spec.name == "deepseek")
normalize_tool_ids = self._should_normalize_tool_call_ids()
def map_id(value: Any) -> Any: def map_id(value: Any) -> Any:
if not isinstance(value, str): if not isinstance(value, str):
return value return value
if not normalize_tool_ids:
return value
return id_map.setdefault(value, self._normalize_tool_call_id(value)) return id_map.setdefault(value, self._normalize_tool_call_id(value))
def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str: def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str:
@ -956,7 +963,7 @@ class OpenAICompatProvider(LLMProvider):
args = json_repair.loads(args) args = json_repair.loads(args)
ec, prov, fn_prov = _extract_tc_extras(tc) ec, prov, fn_prov = _extract_tc_extras(tc)
parsed_tool_calls.append(ToolCallRequest( parsed_tool_calls.append(ToolCallRequest(
id=_short_tool_id(), id=str(tc_map.get("id") or _short_tool_id()),
name=str(fn.get("name") or ""), name=str(fn.get("name") or ""),
arguments=args if isinstance(args, dict) else {}, arguments=args if isinstance(args, dict) else {},
extra_content=ec, extra_content=ec,
@ -999,7 +1006,7 @@ class OpenAICompatProvider(LLMProvider):
args = json_repair.loads(args) args = json_repair.loads(args)
ec, prov, fn_prov = _extract_tc_extras(tc) ec, prov, fn_prov = _extract_tc_extras(tc)
tool_calls.append(ToolCallRequest( tool_calls.append(ToolCallRequest(
id=_short_tool_id(), id=str(getattr(tc, "id", None) or _short_tool_id()),
name=tc.function.name, name=tc.function.name,
arguments=args, arguments=args,
extra_content=ec, extra_content=ec,

View File

@ -602,6 +602,7 @@ async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None:
assert len(result.tool_calls) == 1 assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0] tool_call = result.tool_calls[0]
assert tool_call.id == "call_123"
assert tool_call.extra_content == {"google": {"thought_signature": "signed-token"}} assert tool_call.extra_content == {"google": {"thought_signature": "signed-token"}}
assert tool_call.function_provider_specific_fields == {"inner": "value"} assert tool_call.function_provider_specific_fields == {"inner": "value"}
@ -994,7 +995,7 @@ def test_deepseek_thinking_keeps_tool_history_with_reasoning_content() -> None:
assert kwargs["messages"][2]["role"] == "tool" assert kwargs["messages"][2]["role"] == "tool"
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None: def test_openai_compat_preserves_tool_call_ids_after_consecutive_assistant_messages() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider() provider = OpenAICompatProvider()
@ -1016,6 +1017,34 @@ def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -
{"role": "user", "content": "多少star了呢"}, {"role": "user", "content": "多少star了呢"},
]) ])
assert sanitized[1]["role"] == "assistant"
assert sanitized[1]["content"] is None
assert sanitized[1]["tool_calls"][0]["id"] == "call_function_akxp3wqzn7ph_1"
assert sanitized[2]["tool_call_id"] == "call_function_akxp3wqzn7ph_1"
def test_mistral_normalizes_tool_call_ids_after_consecutive_assistant_messages() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider(spec=find_by_name("mistral"))
sanitized = provider._sanitize_messages([
{"role": "user", "content": "不错"},
{"role": "assistant", "content": "对,破 4 万指日可待"},
{
"role": "assistant",
"content": "<think>我再查一下</think>",
"tool_calls": [
{
"id": "call_function_akxp3wqzn7ph_1",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
}
],
},
{"role": "tool", "tool_call_id": "call_function_akxp3wqzn7ph_1", "name": "exec", "content": "ok"},
{"role": "user", "content": "多少star了呢"},
])
assert sanitized[1]["role"] == "assistant" assert sanitized[1]["role"] == "assistant"
assert sanitized[1]["content"] is None assert sanitized[1]["content"] is None
assert sanitized[1]["tool_calls"][0]["id"] == "3ec83c30d" assert sanitized[1]["tool_calls"][0]["id"] == "3ec83c30d"