diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 1fd610b91..9ce2b0c63 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -16,6 +16,7 @@ class ToolCallRequest: id: str name: str arguments: dict[str, Any] + extra_content: dict[str, Any] | None = None provider_specific_fields: dict[str, Any] | None = None function_provider_specific_fields: dict[str, Any] | None = None @@ -29,22 +30,10 @@ class ToolCallRequest: "arguments": json.dumps(self.arguments, ensure_ascii=False), }, } + if self.extra_content: + tool_call["extra_content"] = self.extra_content if self.provider_specific_fields: - # Gemini OpenAI compatibility expects thought signatures in extra_content.google. - if "thought_signature" in self.provider_specific_fields: - tool_call["extra_content"] = { - "google": { - "thought_signature": self.provider_specific_fields["thought_signature"], - } - } - other_fields = { - k: v for k, v in self.provider_specific_fields.items() - if k != "thought_signature" - } - if other_fields: - tool_call["provider_specific_fields"] = other_fields - else: - tool_call["provider_specific_fields"] = self.provider_specific_fields + tool_call["provider_specific_fields"] = self.provider_specific_fields if self.function_provider_specific_fields: tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields return tool_call diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 1157e176d..ffb221e50 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -19,42 +19,13 @@ if TYPE_CHECKING: from nanobot.providers.registry import ProviderSpec _ALLOWED_MSG_KEYS = frozenset({ - "role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content", + "role", "content", "tool_calls", "tool_call_id", "name", + "reasoning_content", "extra_content", }) _ALNUM = string.ascii_letters + string.digits - -def _get_attr_or_item(obj: Any, key: str, default: Any = None) -> Any: - """Read an attribute or dict key from provider SDK objects.""" - if obj is None: - return default - if isinstance(obj, dict): - return obj.get(key, default) - return getattr(obj, key, default) - - -def _coerce_dict(value: Any) -> dict[str, Any] | None: - """Return a shallow dict if the value looks mapping-like.""" - if isinstance(value, dict): - return dict(value) - return None - - -def _extract_tool_call_fields(tc: Any) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: - """Extract provider-specific metadata from a tool call object.""" - provider_specific_fields = _coerce_dict(_get_attr_or_item(tc, "provider_specific_fields")) - extra_content = _coerce_dict(_get_attr_or_item(tc, "extra_content")) - google_content = _coerce_dict(_get_attr_or_item(extra_content, "google")) if extra_content else None - if google_content: - provider_specific_fields = { - **(provider_specific_fields or {}), - **google_content, - } - function = _get_attr_or_item(tc, "function") - function_provider_specific_fields = _coerce_dict( - _get_attr_or_item(function, "provider_specific_fields") - ) - return provider_specific_fields, function_provider_specific_fields +_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"}) +_STANDARD_FN_KEYS = frozenset({"name", "arguments"}) def _short_tool_id() -> str: @@ -62,6 +33,62 @@ def _short_tool_id() -> str: return "".join(secrets.choice(_ALNUM) for _ in range(9)) +def _get(obj: Any, key: str) -> Any: + """Get a value from dict or object attribute, returning None if absent.""" + if isinstance(obj, dict): + return obj.get(key) + return getattr(obj, key, None) + + +def _coerce_dict(value: Any) -> dict[str, Any] | None: + """Try to coerce *value* to a dict; return None if not possible or empty.""" + if value is None: + return None + if isinstance(value, dict): + return value if value else None + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict) and dumped: + return dumped + return None + + +def _extract_tc_extras(tc: Any) -> tuple[ + dict[str, Any] | None, + dict[str, Any] | None, + dict[str, Any] | None, +]: + """Extract (extra_content, provider_specific_fields, fn_provider_specific_fields). + + Works for both SDK objects and dicts. Captures Gemini ``extra_content`` + verbatim and any non-standard keys on the tool-call / function. + """ + extra_content = _coerce_dict(_get(tc, "extra_content")) + + tc_dict = _coerce_dict(tc) + prov = None + fn_prov = None + if tc_dict is not None: + leftover = {k: v for k, v in tc_dict.items() + if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None} + if leftover: + prov = leftover + fn = _coerce_dict(tc_dict.get("function")) + if fn is not None: + fn_leftover = {k: v for k, v in fn.items() + if k not in _STANDARD_FN_KEYS and v is not None} + if fn_leftover: + fn_prov = fn_leftover + else: + prov = _coerce_dict(_get(tc, "provider_specific_fields")) + fn_obj = _get(tc, "function") + if fn_obj is not None: + fn_prov = _coerce_dict(_get(fn_obj, "provider_specific_fields")) + + return extra_content, prov, fn_prov + + class OpenAICompatProvider(LLMProvider): """Unified provider for all OpenAI-compatible APIs. @@ -332,10 +359,14 @@ class OpenAICompatProvider(LLMProvider): args = fn.get("arguments", {}) if isinstance(args, str): args = json_repair.loads(args) + ec, prov, fn_prov = _extract_tc_extras(tc) parsed_tool_calls.append(ToolCallRequest( id=_short_tool_id(), name=str(fn.get("name") or ""), arguments=args if isinstance(args, dict) else {}, + extra_content=ec, + provider_specific_fields=prov, + function_provider_specific_fields=fn_prov, )) return LLMResponse( @@ -366,17 +397,17 @@ class OpenAICompatProvider(LLMProvider): tool_calls = [] for tc in raw_tool_calls: - function = _get_attr_or_item(tc, "function") - args = _get_attr_or_item(function, "arguments") + args = tc.function.arguments if isinstance(args, str): args = json_repair.loads(args) - provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc) + ec, prov, fn_prov = _extract_tc_extras(tc) tool_calls.append(ToolCallRequest( id=_short_tool_id(), - name=_get_attr_or_item(function, "name", ""), + name=tc.function.name, arguments=args, - provider_specific_fields=provider_specific_fields, - function_provider_specific_fields=function_provider_specific_fields, + extra_content=ec, + provider_specific_fields=prov, + function_provider_specific_fields=fn_prov, )) return LLMResponse( @@ -390,10 +421,36 @@ class OpenAICompatProvider(LLMProvider): @classmethod def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse: content_parts: list[str] = [] - tc_bufs: dict[int, dict[str, str]] = {} + tc_bufs: dict[int, dict[str, Any]] = {} finish_reason = "stop" usage: dict[str, int] = {} + def _accum_tc(tc: Any, idx_hint: int) -> None: + """Accumulate one streaming tool-call delta into *tc_bufs*.""" + tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint + buf = tc_bufs.setdefault(tc_index, { + "id": "", "name": "", "arguments": "", + "extra_content": None, "prov": None, "fn_prov": None, + }) + tc_id = _get(tc, "id") + if tc_id: + buf["id"] = str(tc_id) + fn = _get(tc, "function") + if fn is not None: + fn_name = _get(fn, "name") + if fn_name: + buf["name"] = str(fn_name) + fn_args = _get(fn, "arguments") + if fn_args: + buf["arguments"] += str(fn_args) + ec, prov, fn_prov = _extract_tc_extras(tc) + if ec: + buf["extra_content"] = ec + if prov: + buf["prov"] = prov + if fn_prov: + buf["fn_prov"] = fn_prov + for chunk in chunks: if isinstance(chunk, str): content_parts.append(chunk) @@ -418,16 +475,7 @@ class OpenAICompatProvider(LLMProvider): if text: content_parts.append(text) for idx, tc in enumerate(delta.get("tool_calls") or []): - tc_map = cls._maybe_mapping(tc) or {} - tc_index = tc_map.get("index", idx) - buf = tc_bufs.setdefault(tc_index, {"id": "", "name": "", "arguments": ""}) - if tc_map.get("id"): - buf["id"] = str(tc_map["id"]) - fn = cls._maybe_mapping(tc_map.get("function")) or {} - if fn.get("name"): - buf["name"] = str(fn["name"]) - if fn.get("arguments"): - buf["arguments"] += str(fn["arguments"]) + _accum_tc(tc, idx) usage = cls._extract_usage(chunk_map) or usage continue @@ -441,34 +489,7 @@ class OpenAICompatProvider(LLMProvider): if delta and delta.content: content_parts.append(delta.content) for tc in (delta.tool_calls or []) if delta else []: - idx = _get_attr_or_item(tc, "index") - if idx is None: - continue - buf = tc_bufs.setdefault( - idx, - { - "id": "", - "name": "", - "arguments": "", - "provider_specific_fields": None, - "function_provider_specific_fields": None, - }, - ) - tc_id = _get_attr_or_item(tc, "id") - if tc_id: - buf["id"] = tc_id - function = _get_attr_or_item(tc, "function") - function_name = _get_attr_or_item(function, "name") - if function_name: - buf["name"] = function_name - arguments = _get_attr_or_item(function, "arguments") - if arguments: - buf["arguments"] += arguments - provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc) - if provider_specific_fields: - buf["provider_specific_fields"] = provider_specific_fields - if function_provider_specific_fields: - buf["function_provider_specific_fields"] = function_provider_specific_fields + _accum_tc(tc, getattr(tc, "index", 0)) return LLMResponse( content="".join(content_parts) or None, @@ -477,8 +498,9 @@ class OpenAICompatProvider(LLMProvider): id=b["id"] or _short_tool_id(), name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {}, - provider_specific_fields=b["provider_specific_fields"], - function_provider_specific_fields=b["function_provider_specific_fields"], + extra_content=b.get("extra_content"), + provider_specific_fields=b.get("prov"), + function_provider_specific_fields=b.get("fn_prov"), ) for b in tc_bufs.values() ], diff --git a/tests/agent/test_gemini_thought_signature.py b/tests/agent/test_gemini_thought_signature.py index f4b279b65..320c1ecd2 100644 --- a/tests/agent/test_gemini_thought_signature.py +++ b/tests/agent/test_gemini_thought_signature.py @@ -1,19 +1,200 @@ +"""Tests for Gemini thought_signature round-trip through extra_content. + +The Gemini OpenAI-compatibility API returns tool calls with an extra_content +field: ``{"google": {"thought_signature": "..."}}``. This MUST survive the +parse → serialize round-trip so the model can continue reasoning. +""" + from types import SimpleNamespace +from unittest.mock import patch from nanobot.providers.base import ToolCallRequest +from nanobot.providers.openai_compat_provider import OpenAICompatProvider -def test_tool_call_request_serializes_provider_fields() -> None: - tool_call = ToolCallRequest( +GEMINI_EXTRA = {"google": {"thought_signature": "sig-abc-123"}} + + +# ── ToolCallRequest serialization ────────────────────────────────────── + +def test_tool_call_request_serializes_extra_content() -> None: + tc = ToolCallRequest( id="abc123xyz", name="read_file", arguments={"path": "todo.md"}, - provider_specific_fields={"thought_signature": "signed-token"}, + extra_content=GEMINI_EXTRA, + ) + + payload = tc.to_openai_tool_call() + + assert payload["extra_content"] == GEMINI_EXTRA + assert payload["function"]["arguments"] == '{"path": "todo.md"}' + + +def test_tool_call_request_serializes_provider_fields() -> None: + tc = ToolCallRequest( + id="abc123xyz", + name="read_file", + arguments={"path": "todo.md"}, + provider_specific_fields={"custom_key": "custom_val"}, function_provider_specific_fields={"inner": "value"}, ) - message = tool_call.to_openai_tool_call() + payload = tc.to_openai_tool_call() - assert message["extra_content"] == {"google": {"thought_signature": "signed-token"}} - assert message["function"]["provider_specific_fields"] == {"inner": "value"} - assert message["function"]["arguments"] == '{"path": "todo.md"}' + assert payload["provider_specific_fields"] == {"custom_key": "custom_val"} + assert payload["function"]["provider_specific_fields"] == {"inner": "value"} + + +def test_tool_call_request_omits_absent_extras() -> None: + tc = ToolCallRequest(id="x", name="fn", arguments={}) + payload = tc.to_openai_tool_call() + + assert "extra_content" not in payload + assert "provider_specific_fields" not in payload + assert "provider_specific_fields" not in payload["function"] + + +# ── _parse: SDK-object branch ────────────────────────────────────────── + +def _make_sdk_response_with_extra_content(): + """Simulate a Gemini response via the OpenAI SDK (SimpleNamespace).""" + fn = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}') + tc = SimpleNamespace( + id="call_1", + index=0, + type="function", + function=fn, + extra_content=GEMINI_EXTRA, + ) + msg = SimpleNamespace( + content=None, + tool_calls=[tc], + reasoning_content=None, + ) + choice = SimpleNamespace(message=msg, finish_reason="tool_calls") + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return SimpleNamespace(choices=[choice], usage=usage) + + +def test_parse_sdk_object_preserves_extra_content() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + result = provider._parse(_make_sdk_response_with_extra_content()) + + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.name == "get_weather" + assert tc.extra_content == GEMINI_EXTRA + + payload = tc.to_openai_tool_call() + assert payload["extra_content"] == GEMINI_EXTRA + + +# ── _parse: dict/mapping branch ─────────────────────────────────────── + +def test_parse_dict_preserves_extra_content() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + response_dict = { + "choices": [{ + "message": { + "content": None, + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'}, + "extra_content": GEMINI_EXTRA, + }], + }, + "finish_reason": "tool_calls", + }], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + + result = provider._parse(response_dict) + + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.name == "get_weather" + assert tc.extra_content == GEMINI_EXTRA + + payload = tc.to_openai_tool_call() + assert payload["extra_content"] == GEMINI_EXTRA + + +# ── _parse_chunks: streaming round-trip ─────────────────────────────── + +def test_parse_chunks_sdk_preserves_extra_content() -> None: + fn_delta = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}') + tc_delta = SimpleNamespace( + id="call_1", + index=0, + function=fn_delta, + extra_content=GEMINI_EXTRA, + ) + delta = SimpleNamespace(content=None, tool_calls=[tc_delta]) + choice = SimpleNamespace(finish_reason="tool_calls", delta=delta) + chunk = SimpleNamespace(choices=[choice], usage=None) + + result = OpenAICompatProvider._parse_chunks([chunk]) + + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.extra_content == GEMINI_EXTRA + + payload = tc.to_openai_tool_call() + assert payload["extra_content"] == GEMINI_EXTRA + + +def test_parse_chunks_dict_preserves_extra_content() -> None: + chunk = { + "choices": [{ + "finish_reason": "tool_calls", + "delta": { + "content": None, + "tool_calls": [{ + "index": 0, + "id": "call_1", + "function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'}, + "extra_content": GEMINI_EXTRA, + }], + }, + }], + } + + result = OpenAICompatProvider._parse_chunks([chunk]) + + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.extra_content == GEMINI_EXTRA + + payload = tc.to_openai_tool_call() + assert payload["extra_content"] == GEMINI_EXTRA + + +# ── Model switching: stale extras shouldn't break other providers ───── + +def test_stale_extra_content_in_tool_calls_survives_sanitize() -> None: + """When switching from Gemini to OpenAI, extra_content inside tool_calls + should survive message sanitization (it lives inside the tool_call dict, + not at message level, so it bypasses _ALLOWED_MSG_KEYS filtering).""" + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + messages = [{ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": {"name": "fn", "arguments": "{}"}, + "extra_content": GEMINI_EXTRA, + }], + }] + + sanitized = provider._sanitize_messages(messages) + + assert sanitized[0]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index e912a7bfd..b166cb026 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -30,7 +30,7 @@ def _fake_chat_response(content: str = "ok") -> SimpleNamespace: def _fake_tool_call_response() -> SimpleNamespace: - """Build a minimal chat response that includes Gemini-style provider fields.""" + """Build a minimal chat response that includes Gemini-style extra_content.""" function = SimpleNamespace( name="exec", arguments='{"cmd":"ls"}', @@ -39,6 +39,7 @@ def _fake_tool_call_response() -> SimpleNamespace: tool_call = SimpleNamespace( id="call_123", index=0, + type="function", function=function, extra_content={"google": {"thought_signature": "signed-token"}}, ) @@ -134,8 +135,8 @@ async def test_standard_provider_passes_model_through() -> None: @pytest.mark.asyncio -async def test_openai_compat_preserves_provider_specific_fields_on_tool_calls() -> None: - """Gemini thought signatures must survive parsing so they can be sent back.""" +async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None: + """Gemini extra_content (thought signatures) must survive parse→serialize round-trip.""" mock_create = AsyncMock(return_value=_fake_tool_call_response()) spec = find_by_name("gemini") @@ -156,7 +157,7 @@ async def test_openai_compat_preserves_provider_specific_fields_on_tool_calls() assert len(result.tool_calls) == 1 tool_call = result.tool_calls[0] - assert tool_call.provider_specific_fields == {"thought_signature": "signed-token"} + assert tool_call.extra_content == {"google": {"thought_signature": "signed-token"}} assert tool_call.function_provider_specific_fields == {"inner": "value"} serialized = tool_call.to_openai_tool_call()