diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index a69a716b1..866e05ef8 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -24,6 +24,32 @@ _ALLOWED_MSG_KEYS = frozenset({ _ALNUM = string.ascii_letters + string.digits +def _get_attr_or_item(obj: Any, key: str, default: Any = None) -> Any: + """Read an attribute or dict key from provider SDK objects.""" + if obj is None: + return default + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + +def _coerce_dict(value: Any) -> dict[str, Any] | None: + """Return a shallow dict if the value looks mapping-like.""" + if isinstance(value, dict): + return dict(value) + return None + + +def _extract_tool_call_fields(tc: Any) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: + """Extract provider-specific metadata from a tool call object.""" + provider_specific_fields = _coerce_dict(_get_attr_or_item(tc, "provider_specific_fields")) + function = _get_attr_or_item(tc, "function") + function_provider_specific_fields = _coerce_dict( + _get_attr_or_item(function, "provider_specific_fields") + ) + return provider_specific_fields, function_provider_specific_fields + + def _short_tool_id() -> str: """9-char alphanumeric ID compatible with all providers (incl. Mistral).""" return "".join(secrets.choice(_ALNUM) for _ in range(9)) @@ -333,13 +359,17 @@ class OpenAICompatProvider(LLMProvider): tool_calls = [] for tc in raw_tool_calls: - args = tc.function.arguments + function = _get_attr_or_item(tc, "function") + args = _get_attr_or_item(function, "arguments") if isinstance(args, str): args = json_repair.loads(args) + provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc) tool_calls.append(ToolCallRequest( id=_short_tool_id(), - name=tc.function.name, + name=_get_attr_or_item(function, "name", ""), arguments=args, + provider_specific_fields=provider_specific_fields, + function_provider_specific_fields=function_provider_specific_fields, )) return LLMResponse( @@ -404,13 +434,34 @@ class OpenAICompatProvider(LLMProvider): if delta and delta.content: content_parts.append(delta.content) for tc in (delta.tool_calls or []) if delta else []: - buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""}) - if tc.id: - buf["id"] = tc.id - if tc.function and tc.function.name: - buf["name"] = tc.function.name - if tc.function and tc.function.arguments: - buf["arguments"] += tc.function.arguments + idx = _get_attr_or_item(tc, "index") + if idx is None: + continue + buf = tc_bufs.setdefault( + idx, + { + "id": "", + "name": "", + "arguments": "", + "provider_specific_fields": None, + "function_provider_specific_fields": None, + }, + ) + tc_id = _get_attr_or_item(tc, "id") + if tc_id: + buf["id"] = tc_id + function = _get_attr_or_item(tc, "function") + function_name = _get_attr_or_item(function, "name") + if function_name: + buf["name"] = function_name + arguments = _get_attr_or_item(function, "arguments") + if arguments: + buf["arguments"] += arguments + provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc) + if provider_specific_fields: + buf["provider_specific_fields"] = provider_specific_fields + if function_provider_specific_fields: + buf["function_provider_specific_fields"] = function_provider_specific_fields return LLMResponse( content="".join(content_parts) or None, @@ -419,6 +470,8 @@ class OpenAICompatProvider(LLMProvider): id=b["id"] or _short_tool_id(), name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {}, + provider_specific_fields=b["provider_specific_fields"], + function_provider_specific_fields=b["function_provider_specific_fields"], ) for b in tc_bufs.values() ], diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index c55857b3b..4d1572075 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -29,6 +29,29 @@ def _fake_chat_response(content: str = "ok") -> SimpleNamespace: return SimpleNamespace(choices=[choice], usage=usage) +def _fake_tool_call_response() -> SimpleNamespace: + """Build a minimal chat response that includes Gemini-style provider fields.""" + function = SimpleNamespace( + name="exec", + arguments='{"cmd":"ls"}', + provider_specific_fields={"inner": "value"}, + ) + tool_call = SimpleNamespace( + id="call_123", + index=0, + function=function, + provider_specific_fields={"thought_signature": "signed-token"}, + ) + message = SimpleNamespace( + content=None, + tool_calls=[tool_call], + reasoning_content=None, + ) + choice = SimpleNamespace(message=message, finish_reason="tool_calls") + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return SimpleNamespace(choices=[choice], usage=usage) + + def test_openrouter_spec_is_gateway() -> None: spec = find_by_name("openrouter") assert spec is not None @@ -110,6 +133,37 @@ async def test_standard_provider_passes_model_through() -> None: assert call_kwargs["model"] == "deepseek-chat" +@pytest.mark.asyncio +async def test_openai_compat_preserves_provider_specific_fields_on_tool_calls() -> None: + """Gemini thought signatures must survive parsing so they can be sent back.""" + mock_create = AsyncMock(return_value=_fake_tool_call_response()) + spec = find_by_name("gemini") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create + + provider = OpenAICompatProvider( + api_key="test-key", + api_base="https://generativelanguage.googleapis.com/v1beta/openai/", + default_model="google/gemini-3.1-pro-preview", + spec=spec, + ) + result = await provider.chat( + messages=[{"role": "user", "content": "run exec"}], + model="google/gemini-3.1-pro-preview", + ) + + assert len(result.tool_calls) == 1 + tool_call = result.tool_calls[0] + assert tool_call.provider_specific_fields == {"thought_signature": "signed-token"} + assert tool_call.function_provider_specific_fields == {"inner": "value"} + + serialized = tool_call.to_openai_tool_call() + assert serialized["provider_specific_fields"] == {"thought_signature": "signed-token"} + assert serialized["function"]["provider_specific_fields"] == {"inner": "value"} + + def test_openai_model_passthrough() -> None: """OpenAI models pass through unchanged.""" spec = find_by_name("openai")