nanobot/tests/providers/test_litellm_kwargs.py
Xubin Ren b5302b6f3d refactor(provider): preserve extra_content verbatim for Gemini thought_signature round-trip
Replace the flatten/unflatten approach (merging extra_content.google.*
into provider_specific_fields then reconstructing) with direct pass-through:
parse extra_content as-is, store on ToolCallRequest.extra_content, serialize
back untouched.  This is lossless, requires no hardcoded field names, and
covers all three parsing branches (str, dict, SDK object) plus streaming.
2026-03-25 10:00:29 +08:00

178 lines
6.3 KiB
Python

"""Tests for OpenAICompatProvider spec-driven behavior.
Validates that:
- OpenRouter (no strip) keeps model names intact.
- AiHubMix (strip_model_prefix=True) strips provider prefixes.
- Standard providers pass model names through as-is.
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
from nanobot.providers.registry import find_by_name
def _fake_chat_response(content: str = "ok") -> SimpleNamespace:
"""Build a minimal OpenAI chat completion response."""
message = SimpleNamespace(
content=content,
tool_calls=None,
reasoning_content=None,
)
choice = SimpleNamespace(message=message, finish_reason="stop")
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
return SimpleNamespace(choices=[choice], usage=usage)
def _fake_tool_call_response() -> SimpleNamespace:
"""Build a minimal chat response that includes Gemini-style extra_content."""
function = SimpleNamespace(
name="exec",
arguments='{"cmd":"ls"}',
provider_specific_fields={"inner": "value"},
)
tool_call = SimpleNamespace(
id="call_123",
index=0,
type="function",
function=function,
extra_content={"google": {"thought_signature": "signed-token"}},
)
message = SimpleNamespace(
content=None,
tool_calls=[tool_call],
reasoning_content=None,
)
choice = SimpleNamespace(message=message, finish_reason="tool_calls")
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
return SimpleNamespace(choices=[choice], usage=usage)
def test_openrouter_spec_is_gateway() -> None:
spec = find_by_name("openrouter")
assert spec is not None
assert spec.is_gateway is True
assert spec.default_api_base == "https://openrouter.ai/api/v1"
@pytest.mark.asyncio
async def test_openrouter_keeps_model_name_intact() -> None:
"""OpenRouter gateway keeps the full model name (gateway does its own routing)."""
mock_create = AsyncMock(return_value=_fake_chat_response())
spec = find_by_name("openrouter")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_create
provider = OpenAICompatProvider(
api_key="sk-or-test-key",
api_base="https://openrouter.ai/api/v1",
default_model="anthropic/claude-sonnet-4-5",
spec=spec,
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="anthropic/claude-sonnet-4-5",
)
call_kwargs = mock_create.call_args.kwargs
assert call_kwargs["model"] == "anthropic/claude-sonnet-4-5"
@pytest.mark.asyncio
async def test_aihubmix_strips_model_prefix() -> None:
"""AiHubMix strips the provider prefix (strip_model_prefix=True)."""
mock_create = AsyncMock(return_value=_fake_chat_response())
spec = find_by_name("aihubmix")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_create
provider = OpenAICompatProvider(
api_key="sk-aihub-test-key",
api_base="https://aihubmix.com/v1",
default_model="claude-sonnet-4-5",
spec=spec,
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="anthropic/claude-sonnet-4-5",
)
call_kwargs = mock_create.call_args.kwargs
assert call_kwargs["model"] == "claude-sonnet-4-5"
@pytest.mark.asyncio
async def test_standard_provider_passes_model_through() -> None:
"""Standard provider (e.g. deepseek) passes model name through as-is."""
mock_create = AsyncMock(return_value=_fake_chat_response())
spec = find_by_name("deepseek")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_create
provider = OpenAICompatProvider(
api_key="sk-deepseek-test-key",
default_model="deepseek-chat",
spec=spec,
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="deepseek-chat",
)
call_kwargs = mock_create.call_args.kwargs
assert call_kwargs["model"] == "deepseek-chat"
@pytest.mark.asyncio
async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None:
"""Gemini extra_content (thought signatures) must survive parse→serialize round-trip."""
mock_create = AsyncMock(return_value=_fake_tool_call_response())
spec = find_by_name("gemini")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_create
provider = OpenAICompatProvider(
api_key="test-key",
api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
default_model="google/gemini-3.1-pro-preview",
spec=spec,
)
result = await provider.chat(
messages=[{"role": "user", "content": "run exec"}],
model="google/gemini-3.1-pro-preview",
)
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call.extra_content == {"google": {"thought_signature": "signed-token"}}
assert tool_call.function_provider_specific_fields == {"inner": "value"}
serialized = tool_call.to_openai_tool_call()
assert serialized["extra_content"] == {"google": {"thought_signature": "signed-token"}}
assert serialized["function"]["provider_specific_fields"] == {"inner": "value"}
def test_openai_model_passthrough() -> None:
"""OpenAI models pass through unchanged."""
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-4o",
spec=spec,
)
assert provider.get_default_model() == "gpt-4o"