feat(OpenAICompatProvider): enhance tool call handling with provider-specific fields

This commit is contained in:
Yohei Nishikubo 2026-03-25 09:31:42 +09:00 committed by Xubin Ren
parent 263069583d
commit 7b720ce9f7
2 changed files with 116 additions and 9 deletions

View File

@ -24,6 +24,32 @@ _ALLOWED_MSG_KEYS = frozenset({
_ALNUM = string.ascii_letters + string.digits
def _get_attr_or_item(obj: Any, key: str, default: Any = None) -> Any:
"""Read an attribute or dict key from provider SDK objects."""
if obj is None:
return default
if isinstance(obj, dict):
return obj.get(key, default)
return getattr(obj, key, default)
def _coerce_dict(value: Any) -> dict[str, Any] | None:
"""Return a shallow dict if the value looks mapping-like."""
if isinstance(value, dict):
return dict(value)
return None
def _extract_tool_call_fields(tc: Any) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
"""Extract provider-specific metadata from a tool call object."""
provider_specific_fields = _coerce_dict(_get_attr_or_item(tc, "provider_specific_fields"))
function = _get_attr_or_item(tc, "function")
function_provider_specific_fields = _coerce_dict(
_get_attr_or_item(function, "provider_specific_fields")
)
return provider_specific_fields, function_provider_specific_fields
def _short_tool_id() -> str:
"""9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
return "".join(secrets.choice(_ALNUM) for _ in range(9))
@ -333,13 +359,17 @@ class OpenAICompatProvider(LLMProvider):
tool_calls = []
for tc in raw_tool_calls:
args = tc.function.arguments
function = _get_attr_or_item(tc, "function")
args = _get_attr_or_item(function, "arguments")
if isinstance(args, str):
args = json_repair.loads(args)
provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc)
tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=tc.function.name,
name=_get_attr_or_item(function, "name", ""),
arguments=args,
provider_specific_fields=provider_specific_fields,
function_provider_specific_fields=function_provider_specific_fields,
))
return LLMResponse(
@ -404,13 +434,34 @@ class OpenAICompatProvider(LLMProvider):
if delta and delta.content:
content_parts.append(delta.content)
for tc in (delta.tool_calls or []) if delta else []:
buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""})
if tc.id:
buf["id"] = tc.id
if tc.function and tc.function.name:
buf["name"] = tc.function.name
if tc.function and tc.function.arguments:
buf["arguments"] += tc.function.arguments
idx = _get_attr_or_item(tc, "index")
if idx is None:
continue
buf = tc_bufs.setdefault(
idx,
{
"id": "",
"name": "",
"arguments": "",
"provider_specific_fields": None,
"function_provider_specific_fields": None,
},
)
tc_id = _get_attr_or_item(tc, "id")
if tc_id:
buf["id"] = tc_id
function = _get_attr_or_item(tc, "function")
function_name = _get_attr_or_item(function, "name")
if function_name:
buf["name"] = function_name
arguments = _get_attr_or_item(function, "arguments")
if arguments:
buf["arguments"] += arguments
provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc)
if provider_specific_fields:
buf["provider_specific_fields"] = provider_specific_fields
if function_provider_specific_fields:
buf["function_provider_specific_fields"] = function_provider_specific_fields
return LLMResponse(
content="".join(content_parts) or None,
@ -419,6 +470,8 @@ class OpenAICompatProvider(LLMProvider):
id=b["id"] or _short_tool_id(),
name=b["name"],
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
provider_specific_fields=b["provider_specific_fields"],
function_provider_specific_fields=b["function_provider_specific_fields"],
)
for b in tc_bufs.values()
],

View File

@ -29,6 +29,29 @@ def _fake_chat_response(content: str = "ok") -> SimpleNamespace:
return SimpleNamespace(choices=[choice], usage=usage)
def _fake_tool_call_response() -> SimpleNamespace:
"""Build a minimal chat response that includes Gemini-style provider fields."""
function = SimpleNamespace(
name="exec",
arguments='{"cmd":"ls"}',
provider_specific_fields={"inner": "value"},
)
tool_call = SimpleNamespace(
id="call_123",
index=0,
function=function,
provider_specific_fields={"thought_signature": "signed-token"},
)
message = SimpleNamespace(
content=None,
tool_calls=[tool_call],
reasoning_content=None,
)
choice = SimpleNamespace(message=message, finish_reason="tool_calls")
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
return SimpleNamespace(choices=[choice], usage=usage)
def test_openrouter_spec_is_gateway() -> None:
spec = find_by_name("openrouter")
assert spec is not None
@ -110,6 +133,37 @@ async def test_standard_provider_passes_model_through() -> None:
assert call_kwargs["model"] == "deepseek-chat"
@pytest.mark.asyncio
async def test_openai_compat_preserves_provider_specific_fields_on_tool_calls() -> None:
"""Gemini thought signatures must survive parsing so they can be sent back."""
mock_create = AsyncMock(return_value=_fake_tool_call_response())
spec = find_by_name("gemini")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_create
provider = OpenAICompatProvider(
api_key="test-key",
api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
default_model="google/gemini-3.1-pro-preview",
spec=spec,
)
result = await provider.chat(
messages=[{"role": "user", "content": "run exec"}],
model="google/gemini-3.1-pro-preview",
)
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call.provider_specific_fields == {"thought_signature": "signed-token"}
assert tool_call.function_provider_specific_fields == {"inner": "value"}
serialized = tool_call.to_openai_tool_call()
assert serialized["provider_specific_fields"] == {"thought_signature": "signed-token"}
assert serialized["function"]["provider_specific_fields"] == {"inner": "value"}
def test_openai_model_passthrough() -> None:
"""OpenAI models pass through unchanged."""
spec = find_by_name("openai")