mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
feat(OpenAICompatProvider): enhance tool call handling with provider-specific fields
This commit is contained in:
parent
263069583d
commit
7b720ce9f7
@ -24,6 +24,32 @@ _ALLOWED_MSG_KEYS = frozenset({
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def _get_attr_or_item(obj: Any, key: str, default: Any = None) -> Any:
|
||||
"""Read an attribute or dict key from provider SDK objects."""
|
||||
if obj is None:
|
||||
return default
|
||||
if isinstance(obj, dict):
|
||||
return obj.get(key, default)
|
||||
return getattr(obj, key, default)
|
||||
|
||||
|
||||
def _coerce_dict(value: Any) -> dict[str, Any] | None:
|
||||
"""Return a shallow dict if the value looks mapping-like."""
|
||||
if isinstance(value, dict):
|
||||
return dict(value)
|
||||
return None
|
||||
|
||||
|
||||
def _extract_tool_call_fields(tc: Any) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
|
||||
"""Extract provider-specific metadata from a tool call object."""
|
||||
provider_specific_fields = _coerce_dict(_get_attr_or_item(tc, "provider_specific_fields"))
|
||||
function = _get_attr_or_item(tc, "function")
|
||||
function_provider_specific_fields = _coerce_dict(
|
||||
_get_attr_or_item(function, "provider_specific_fields")
|
||||
)
|
||||
return provider_specific_fields, function_provider_specific_fields
|
||||
|
||||
|
||||
def _short_tool_id() -> str:
|
||||
"""9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
|
||||
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||||
@ -333,13 +359,17 @@ class OpenAICompatProvider(LLMProvider):
|
||||
|
||||
tool_calls = []
|
||||
for tc in raw_tool_calls:
|
||||
args = tc.function.arguments
|
||||
function = _get_attr_or_item(tc, "function")
|
||||
args = _get_attr_or_item(function, "arguments")
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc)
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=_short_tool_id(),
|
||||
name=tc.function.name,
|
||||
name=_get_attr_or_item(function, "name", ""),
|
||||
arguments=args,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
function_provider_specific_fields=function_provider_specific_fields,
|
||||
))
|
||||
|
||||
return LLMResponse(
|
||||
@ -404,13 +434,34 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if delta and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
for tc in (delta.tool_calls or []) if delta else []:
|
||||
buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""})
|
||||
if tc.id:
|
||||
buf["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
buf["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
buf["arguments"] += tc.function.arguments
|
||||
idx = _get_attr_or_item(tc, "index")
|
||||
if idx is None:
|
||||
continue
|
||||
buf = tc_bufs.setdefault(
|
||||
idx,
|
||||
{
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
"provider_specific_fields": None,
|
||||
"function_provider_specific_fields": None,
|
||||
},
|
||||
)
|
||||
tc_id = _get_attr_or_item(tc, "id")
|
||||
if tc_id:
|
||||
buf["id"] = tc_id
|
||||
function = _get_attr_or_item(tc, "function")
|
||||
function_name = _get_attr_or_item(function, "name")
|
||||
if function_name:
|
||||
buf["name"] = function_name
|
||||
arguments = _get_attr_or_item(function, "arguments")
|
||||
if arguments:
|
||||
buf["arguments"] += arguments
|
||||
provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc)
|
||||
if provider_specific_fields:
|
||||
buf["provider_specific_fields"] = provider_specific_fields
|
||||
if function_provider_specific_fields:
|
||||
buf["function_provider_specific_fields"] = function_provider_specific_fields
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
@ -419,6 +470,8 @@ class OpenAICompatProvider(LLMProvider):
|
||||
id=b["id"] or _short_tool_id(),
|
||||
name=b["name"],
|
||||
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
|
||||
provider_specific_fields=b["provider_specific_fields"],
|
||||
function_provider_specific_fields=b["function_provider_specific_fields"],
|
||||
)
|
||||
for b in tc_bufs.values()
|
||||
],
|
||||
|
||||
@ -29,6 +29,29 @@ def _fake_chat_response(content: str = "ok") -> SimpleNamespace:
|
||||
return SimpleNamespace(choices=[choice], usage=usage)
|
||||
|
||||
|
||||
def _fake_tool_call_response() -> SimpleNamespace:
|
||||
"""Build a minimal chat response that includes Gemini-style provider fields."""
|
||||
function = SimpleNamespace(
|
||||
name="exec",
|
||||
arguments='{"cmd":"ls"}',
|
||||
provider_specific_fields={"inner": "value"},
|
||||
)
|
||||
tool_call = SimpleNamespace(
|
||||
id="call_123",
|
||||
index=0,
|
||||
function=function,
|
||||
provider_specific_fields={"thought_signature": "signed-token"},
|
||||
)
|
||||
message = SimpleNamespace(
|
||||
content=None,
|
||||
tool_calls=[tool_call],
|
||||
reasoning_content=None,
|
||||
)
|
||||
choice = SimpleNamespace(message=message, finish_reason="tool_calls")
|
||||
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
||||
return SimpleNamespace(choices=[choice], usage=usage)
|
||||
|
||||
|
||||
def test_openrouter_spec_is_gateway() -> None:
|
||||
spec = find_by_name("openrouter")
|
||||
assert spec is not None
|
||||
@ -110,6 +133,37 @@ async def test_standard_provider_passes_model_through() -> None:
|
||||
assert call_kwargs["model"] == "deepseek-chat"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compat_preserves_provider_specific_fields_on_tool_calls() -> None:
|
||||
"""Gemini thought signatures must survive parsing so they can be sent back."""
|
||||
mock_create = AsyncMock(return_value=_fake_tool_call_response())
|
||||
spec = find_by_name("gemini")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_create
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
default_model="google/gemini-3.1-pro-preview",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat(
|
||||
messages=[{"role": "user", "content": "run exec"}],
|
||||
model="google/gemini-3.1-pro-preview",
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tool_call = result.tool_calls[0]
|
||||
assert tool_call.provider_specific_fields == {"thought_signature": "signed-token"}
|
||||
assert tool_call.function_provider_specific_fields == {"inner": "value"}
|
||||
|
||||
serialized = tool_call.to_openai_tool_call()
|
||||
assert serialized["provider_specific_fields"] == {"thought_signature": "signed-token"}
|
||||
assert serialized["function"]["provider_specific_fields"] == {"inner": "value"}
|
||||
|
||||
|
||||
def test_openai_model_passthrough() -> None:
|
||||
"""OpenAI models pass through unchanged."""
|
||||
spec = find_by_name("openai")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user