refactor(provider): preserve extra_content verbatim for Gemini thought_signature round-trip

Replace the flatten/unflatten approach (merging extra_content.google.*
into provider_specific_fields then reconstructing) with direct pass-through:
parse extra_content as-is, store on ToolCallRequest.extra_content, serialize
back untouched.  This is lossless, requires no hardcoded field names, and
covers all three parsing branches (str, dict, SDK object) plus streaming.
This commit is contained in:
Xubin Ren 2026-03-25 01:56:44 +00:00 committed by Xubin Ren
parent af84b1b8c0
commit b5302b6f3d
4 changed files with 299 additions and 106 deletions

View File

@ -16,6 +16,7 @@ class ToolCallRequest:
id: str
name: str
arguments: dict[str, Any]
extra_content: dict[str, Any] | None = None
provider_specific_fields: dict[str, Any] | None = None
function_provider_specific_fields: dict[str, Any] | None = None
@ -29,22 +30,10 @@ class ToolCallRequest:
"arguments": json.dumps(self.arguments, ensure_ascii=False),
},
}
if self.extra_content:
tool_call["extra_content"] = self.extra_content
if self.provider_specific_fields:
# Gemini OpenAI compatibility expects thought signatures in extra_content.google.
if "thought_signature" in self.provider_specific_fields:
tool_call["extra_content"] = {
"google": {
"thought_signature": self.provider_specific_fields["thought_signature"],
}
}
other_fields = {
k: v for k, v in self.provider_specific_fields.items()
if k != "thought_signature"
}
if other_fields:
tool_call["provider_specific_fields"] = other_fields
else:
tool_call["provider_specific_fields"] = self.provider_specific_fields
tool_call["provider_specific_fields"] = self.provider_specific_fields
if self.function_provider_specific_fields:
tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields
return tool_call

View File

@ -19,42 +19,13 @@ if TYPE_CHECKING:
from nanobot.providers.registry import ProviderSpec
_ALLOWED_MSG_KEYS = frozenset({
"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content",
"role", "content", "tool_calls", "tool_call_id", "name",
"reasoning_content", "extra_content",
})
_ALNUM = string.ascii_letters + string.digits
def _get_attr_or_item(obj: Any, key: str, default: Any = None) -> Any:
"""Read an attribute or dict key from provider SDK objects."""
if obj is None:
return default
if isinstance(obj, dict):
return obj.get(key, default)
return getattr(obj, key, default)
def _coerce_dict(value: Any) -> dict[str, Any] | None:
"""Return a shallow dict if the value looks mapping-like."""
if isinstance(value, dict):
return dict(value)
return None
def _extract_tool_call_fields(tc: Any) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
"""Extract provider-specific metadata from a tool call object."""
provider_specific_fields = _coerce_dict(_get_attr_or_item(tc, "provider_specific_fields"))
extra_content = _coerce_dict(_get_attr_or_item(tc, "extra_content"))
google_content = _coerce_dict(_get_attr_or_item(extra_content, "google")) if extra_content else None
if google_content:
provider_specific_fields = {
**(provider_specific_fields or {}),
**google_content,
}
function = _get_attr_or_item(tc, "function")
function_provider_specific_fields = _coerce_dict(
_get_attr_or_item(function, "provider_specific_fields")
)
return provider_specific_fields, function_provider_specific_fields
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
_STANDARD_FN_KEYS = frozenset({"name", "arguments"})
def _short_tool_id() -> str:
@ -62,6 +33,62 @@ def _short_tool_id() -> str:
return "".join(secrets.choice(_ALNUM) for _ in range(9))
def _get(obj: Any, key: str) -> Any:
"""Get a value from dict or object attribute, returning None if absent."""
if isinstance(obj, dict):
return obj.get(key)
return getattr(obj, key, None)
def _coerce_dict(value: Any) -> dict[str, Any] | None:
"""Try to coerce *value* to a dict; return None if not possible or empty."""
if value is None:
return None
if isinstance(value, dict):
return value if value else None
model_dump = getattr(value, "model_dump", None)
if callable(model_dump):
dumped = model_dump()
if isinstance(dumped, dict) and dumped:
return dumped
return None
def _extract_tc_extras(tc: Any) -> tuple[
dict[str, Any] | None,
dict[str, Any] | None,
dict[str, Any] | None,
]:
"""Extract (extra_content, provider_specific_fields, fn_provider_specific_fields).
Works for both SDK objects and dicts. Captures Gemini ``extra_content``
verbatim and any non-standard keys on the tool-call / function.
"""
extra_content = _coerce_dict(_get(tc, "extra_content"))
tc_dict = _coerce_dict(tc)
prov = None
fn_prov = None
if tc_dict is not None:
leftover = {k: v for k, v in tc_dict.items()
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None}
if leftover:
prov = leftover
fn = _coerce_dict(tc_dict.get("function"))
if fn is not None:
fn_leftover = {k: v for k, v in fn.items()
if k not in _STANDARD_FN_KEYS and v is not None}
if fn_leftover:
fn_prov = fn_leftover
else:
prov = _coerce_dict(_get(tc, "provider_specific_fields"))
fn_obj = _get(tc, "function")
if fn_obj is not None:
fn_prov = _coerce_dict(_get(fn_obj, "provider_specific_fields"))
return extra_content, prov, fn_prov
class OpenAICompatProvider(LLMProvider):
"""Unified provider for all OpenAI-compatible APIs.
@ -332,10 +359,14 @@ class OpenAICompatProvider(LLMProvider):
args = fn.get("arguments", {})
if isinstance(args, str):
args = json_repair.loads(args)
ec, prov, fn_prov = _extract_tc_extras(tc)
parsed_tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=str(fn.get("name") or ""),
arguments=args if isinstance(args, dict) else {},
extra_content=ec,
provider_specific_fields=prov,
function_provider_specific_fields=fn_prov,
))
return LLMResponse(
@ -366,17 +397,17 @@ class OpenAICompatProvider(LLMProvider):
tool_calls = []
for tc in raw_tool_calls:
function = _get_attr_or_item(tc, "function")
args = _get_attr_or_item(function, "arguments")
args = tc.function.arguments
if isinstance(args, str):
args = json_repair.loads(args)
provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc)
ec, prov, fn_prov = _extract_tc_extras(tc)
tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=_get_attr_or_item(function, "name", ""),
name=tc.function.name,
arguments=args,
provider_specific_fields=provider_specific_fields,
function_provider_specific_fields=function_provider_specific_fields,
extra_content=ec,
provider_specific_fields=prov,
function_provider_specific_fields=fn_prov,
))
return LLMResponse(
@ -390,10 +421,36 @@ class OpenAICompatProvider(LLMProvider):
@classmethod
def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse:
content_parts: list[str] = []
tc_bufs: dict[int, dict[str, str]] = {}
tc_bufs: dict[int, dict[str, Any]] = {}
finish_reason = "stop"
usage: dict[str, int] = {}
def _accum_tc(tc: Any, idx_hint: int) -> None:
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
buf = tc_bufs.setdefault(tc_index, {
"id": "", "name": "", "arguments": "",
"extra_content": None, "prov": None, "fn_prov": None,
})
tc_id = _get(tc, "id")
if tc_id:
buf["id"] = str(tc_id)
fn = _get(tc, "function")
if fn is not None:
fn_name = _get(fn, "name")
if fn_name:
buf["name"] = str(fn_name)
fn_args = _get(fn, "arguments")
if fn_args:
buf["arguments"] += str(fn_args)
ec, prov, fn_prov = _extract_tc_extras(tc)
if ec:
buf["extra_content"] = ec
if prov:
buf["prov"] = prov
if fn_prov:
buf["fn_prov"] = fn_prov
for chunk in chunks:
if isinstance(chunk, str):
content_parts.append(chunk)
@ -418,16 +475,7 @@ class OpenAICompatProvider(LLMProvider):
if text:
content_parts.append(text)
for idx, tc in enumerate(delta.get("tool_calls") or []):
tc_map = cls._maybe_mapping(tc) or {}
tc_index = tc_map.get("index", idx)
buf = tc_bufs.setdefault(tc_index, {"id": "", "name": "", "arguments": ""})
if tc_map.get("id"):
buf["id"] = str(tc_map["id"])
fn = cls._maybe_mapping(tc_map.get("function")) or {}
if fn.get("name"):
buf["name"] = str(fn["name"])
if fn.get("arguments"):
buf["arguments"] += str(fn["arguments"])
_accum_tc(tc, idx)
usage = cls._extract_usage(chunk_map) or usage
continue
@ -441,34 +489,7 @@ class OpenAICompatProvider(LLMProvider):
if delta and delta.content:
content_parts.append(delta.content)
for tc in (delta.tool_calls or []) if delta else []:
idx = _get_attr_or_item(tc, "index")
if idx is None:
continue
buf = tc_bufs.setdefault(
idx,
{
"id": "",
"name": "",
"arguments": "",
"provider_specific_fields": None,
"function_provider_specific_fields": None,
},
)
tc_id = _get_attr_or_item(tc, "id")
if tc_id:
buf["id"] = tc_id
function = _get_attr_or_item(tc, "function")
function_name = _get_attr_or_item(function, "name")
if function_name:
buf["name"] = function_name
arguments = _get_attr_or_item(function, "arguments")
if arguments:
buf["arguments"] += arguments
provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc)
if provider_specific_fields:
buf["provider_specific_fields"] = provider_specific_fields
if function_provider_specific_fields:
buf["function_provider_specific_fields"] = function_provider_specific_fields
_accum_tc(tc, getattr(tc, "index", 0))
return LLMResponse(
content="".join(content_parts) or None,
@ -477,8 +498,9 @@ class OpenAICompatProvider(LLMProvider):
id=b["id"] or _short_tool_id(),
name=b["name"],
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
provider_specific_fields=b["provider_specific_fields"],
function_provider_specific_fields=b["function_provider_specific_fields"],
extra_content=b.get("extra_content"),
provider_specific_fields=b.get("prov"),
function_provider_specific_fields=b.get("fn_prov"),
)
for b in tc_bufs.values()
],

View File

@ -1,19 +1,200 @@
"""Tests for Gemini thought_signature round-trip through extra_content.
The Gemini OpenAI-compatibility API returns tool calls with an extra_content
field: ``{"google": {"thought_signature": "..."}}``. This MUST survive the
parse serialize round-trip so the model can continue reasoning.
"""
from types import SimpleNamespace
from unittest.mock import patch
from nanobot.providers.base import ToolCallRequest
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
def test_tool_call_request_serializes_provider_fields() -> None:
tool_call = ToolCallRequest(
GEMINI_EXTRA = {"google": {"thought_signature": "sig-abc-123"}}
# ── ToolCallRequest serialization ──────────────────────────────────────
def test_tool_call_request_serializes_extra_content() -> None:
tc = ToolCallRequest(
id="abc123xyz",
name="read_file",
arguments={"path": "todo.md"},
provider_specific_fields={"thought_signature": "signed-token"},
extra_content=GEMINI_EXTRA,
)
payload = tc.to_openai_tool_call()
assert payload["extra_content"] == GEMINI_EXTRA
assert payload["function"]["arguments"] == '{"path": "todo.md"}'
def test_tool_call_request_serializes_provider_fields() -> None:
tc = ToolCallRequest(
id="abc123xyz",
name="read_file",
arguments={"path": "todo.md"},
provider_specific_fields={"custom_key": "custom_val"},
function_provider_specific_fields={"inner": "value"},
)
message = tool_call.to_openai_tool_call()
payload = tc.to_openai_tool_call()
assert message["extra_content"] == {"google": {"thought_signature": "signed-token"}}
assert message["function"]["provider_specific_fields"] == {"inner": "value"}
assert message["function"]["arguments"] == '{"path": "todo.md"}'
assert payload["provider_specific_fields"] == {"custom_key": "custom_val"}
assert payload["function"]["provider_specific_fields"] == {"inner": "value"}
def test_tool_call_request_omits_absent_extras() -> None:
tc = ToolCallRequest(id="x", name="fn", arguments={})
payload = tc.to_openai_tool_call()
assert "extra_content" not in payload
assert "provider_specific_fields" not in payload
assert "provider_specific_fields" not in payload["function"]
# ── _parse: SDK-object branch ──────────────────────────────────────────
def _make_sdk_response_with_extra_content():
"""Simulate a Gemini response via the OpenAI SDK (SimpleNamespace)."""
fn = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}')
tc = SimpleNamespace(
id="call_1",
index=0,
type="function",
function=fn,
extra_content=GEMINI_EXTRA,
)
msg = SimpleNamespace(
content=None,
tool_calls=[tc],
reasoning_content=None,
)
choice = SimpleNamespace(message=msg, finish_reason="tool_calls")
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
return SimpleNamespace(choices=[choice], usage=usage)
def test_parse_sdk_object_preserves_extra_content() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
result = provider._parse(_make_sdk_response_with_extra_content())
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.name == "get_weather"
assert tc.extra_content == GEMINI_EXTRA
payload = tc.to_openai_tool_call()
assert payload["extra_content"] == GEMINI_EXTRA
# ── _parse: dict/mapping branch ───────────────────────────────────────
def test_parse_dict_preserves_extra_content() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
response_dict = {
"choices": [{
"message": {
"content": None,
"tool_calls": [{
"id": "call_1",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'},
"extra_content": GEMINI_EXTRA,
}],
},
"finish_reason": "tool_calls",
}],
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}
result = provider._parse(response_dict)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.name == "get_weather"
assert tc.extra_content == GEMINI_EXTRA
payload = tc.to_openai_tool_call()
assert payload["extra_content"] == GEMINI_EXTRA
# ── _parse_chunks: streaming round-trip ───────────────────────────────
def test_parse_chunks_sdk_preserves_extra_content() -> None:
fn_delta = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}')
tc_delta = SimpleNamespace(
id="call_1",
index=0,
function=fn_delta,
extra_content=GEMINI_EXTRA,
)
delta = SimpleNamespace(content=None, tool_calls=[tc_delta])
choice = SimpleNamespace(finish_reason="tool_calls", delta=delta)
chunk = SimpleNamespace(choices=[choice], usage=None)
result = OpenAICompatProvider._parse_chunks([chunk])
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.extra_content == GEMINI_EXTRA
payload = tc.to_openai_tool_call()
assert payload["extra_content"] == GEMINI_EXTRA
def test_parse_chunks_dict_preserves_extra_content() -> None:
chunk = {
"choices": [{
"finish_reason": "tool_calls",
"delta": {
"content": None,
"tool_calls": [{
"index": 0,
"id": "call_1",
"function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'},
"extra_content": GEMINI_EXTRA,
}],
},
}],
}
result = OpenAICompatProvider._parse_chunks([chunk])
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.extra_content == GEMINI_EXTRA
payload = tc.to_openai_tool_call()
assert payload["extra_content"] == GEMINI_EXTRA
# ── Model switching: stale extras shouldn't break other providers ─────
def test_stale_extra_content_in_tool_calls_survives_sanitize() -> None:
"""When switching from Gemini to OpenAI, extra_content inside tool_calls
should survive message sanitization (it lives inside the tool_call dict,
not at message level, so it bypasses _ALLOWED_MSG_KEYS filtering)."""
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
messages = [{
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_1",
"type": "function",
"function": {"name": "fn", "arguments": "{}"},
"extra_content": GEMINI_EXTRA,
}],
}]
sanitized = provider._sanitize_messages(messages)
assert sanitized[0]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA

View File

@ -30,7 +30,7 @@ def _fake_chat_response(content: str = "ok") -> SimpleNamespace:
def _fake_tool_call_response() -> SimpleNamespace:
"""Build a minimal chat response that includes Gemini-style provider fields."""
"""Build a minimal chat response that includes Gemini-style extra_content."""
function = SimpleNamespace(
name="exec",
arguments='{"cmd":"ls"}',
@ -39,6 +39,7 @@ def _fake_tool_call_response() -> SimpleNamespace:
tool_call = SimpleNamespace(
id="call_123",
index=0,
type="function",
function=function,
extra_content={"google": {"thought_signature": "signed-token"}},
)
@ -134,8 +135,8 @@ async def test_standard_provider_passes_model_through() -> None:
@pytest.mark.asyncio
async def test_openai_compat_preserves_provider_specific_fields_on_tool_calls() -> None:
"""Gemini thought signatures must survive parsing so they can be sent back."""
async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None:
"""Gemini extra_content (thought signatures) must survive parse→serialize round-trip."""
mock_create = AsyncMock(return_value=_fake_tool_call_response())
spec = find_by_name("gemini")
@ -156,7 +157,7 @@ async def test_openai_compat_preserves_provider_specific_fields_on_tool_calls()
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call.provider_specific_fields == {"thought_signature": "signed-token"}
assert tool_call.extra_content == {"google": {"thought_signature": "signed-token"}}
assert tool_call.function_provider_specific_fields == {"inner": "value"}
serialized = tool_call.to_openai_tool_call()