feat(openai): auto-route direct reasoning requests with responses fallback

This commit is contained in:
Xubin Ren 2026-04-08 15:21:08 +00:00
parent c092896922
commit d084d10dc2
2 changed files with 423 additions and 11 deletions

View File

@ -26,6 +26,12 @@ else:
from openai import AsyncOpenAI
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.openai_responses import (
consume_sdk_stream,
convert_messages,
convert_tools,
parse_response_output,
)
if TYPE_CHECKING:
from nanobot.providers.registry import ProviderSpec
@ -113,6 +119,14 @@ def _uses_openrouter_attribution(spec: "ProviderSpec | None", api_base: str | No
return bool(api_base and "openrouter" in api_base.lower())
def _is_direct_openai_base(api_base: str | None) -> bool:
"""Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways."""
if not api_base:
return True
normalized = api_base.strip().lower().rstrip("/")
return "api.openai.com" in normalized and "openrouter" not in normalized
class OpenAICompatProvider(LLMProvider):
"""Unified provider for all OpenAI-compatible APIs.
@ -137,6 +151,7 @@ class OpenAICompatProvider(LLMProvider):
self._setup_env(api_key, api_base)
effective_base = api_base or (spec.default_api_base if spec else None) or None
self._effective_base = effective_base
default_headers = {"x-session-affinity": uuid.uuid4().hex}
if _uses_openrouter_attribution(spec, effective_base):
default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
@ -321,6 +336,88 @@ class OpenAICompatProvider(LLMProvider):
return kwargs
def _should_use_responses_api(
self,
model: str | None,
reasoning_effort: str | None,
) -> bool:
"""Use Responses API only for direct OpenAI requests that benefit from it."""
if self._spec and self._spec.name != "openai":
return False
if not _is_direct_openai_base(self._effective_base):
return False
model_name = (model or self.default_model).lower()
if reasoning_effort and reasoning_effort.lower() != "none":
return True
return any(token in model_name for token in ("gpt-5", "o1", "o3", "o4"))
@staticmethod
def _should_fallback_from_responses_error(e: Exception) -> bool:
"""Fallback only for likely Responses API compatibility errors."""
response = getattr(e, "response", None)
status_code = getattr(e, "status_code", None)
if status_code is None and response is not None:
status_code = getattr(response, "status_code", None)
if status_code not in {400, 404, 422}:
return False
body = (
getattr(e, "body", None)
or getattr(e, "doc", None)
or getattr(response, "text", None)
)
body_text = str(body).lower() if body is not None else ""
compatibility_markers = (
"responses",
"response api",
"max_output_tokens",
"instructions",
"previous_response",
"unsupported",
"not supported",
"unknown parameter",
"unrecognized request argument",
)
return any(marker in body_text for marker in compatibility_markers)
def _build_responses_body(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
model: str | None,
max_tokens: int,
temperature: float,
reasoning_effort: str | None,
tool_choice: str | dict[str, Any] | None,
) -> dict[str, Any]:
"""Build a Responses API body for direct OpenAI requests."""
model_name = model or self.default_model
sanitized_messages = self._sanitize_messages(self._sanitize_empty_content(messages))
instructions, input_items = convert_messages(sanitized_messages)
body: dict[str, Any] = {
"model": model_name,
"instructions": instructions or None,
"input": input_items,
"max_output_tokens": max(1, max_tokens),
"store": False,
"stream": False,
}
if self._supports_temperature(model_name, reasoning_effort):
body["temperature"] = temperature
if reasoning_effort and reasoning_effort.lower() != "none":
body["reasoning"] = {"effort": reasoning_effort}
body["include"] = ["reasoning.encrypted_content"]
if tools:
body["tools"] = convert_tools(tools)
body["tool_choice"] = tool_choice or "auto"
return body
# ------------------------------------------------------------------
# Response parsing
# ------------------------------------------------------------------
@ -731,11 +828,22 @@ class OpenAICompatProvider(LLMProvider):
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
kwargs = self._build_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
try:
if self._should_use_responses_api(model, reasoning_effort):
try:
body = self._build_responses_body(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
return parse_response_output(await self._client.responses.create(**body))
except Exception as responses_error:
if not self._should_fallback_from_responses_error(responses_error):
raise
kwargs = self._build_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
return self._parse(await self._client.chat.completions.create(**kwargs))
except Exception as e:
return self._handle_error(e)
@ -751,14 +859,49 @@ class OpenAICompatProvider(LLMProvider):
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
kwargs = self._build_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
try:
if self._should_use_responses_api(model, reasoning_effort):
try:
body = self._build_responses_body(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
body["stream"] = True
stream = await self._client.responses.create(**body)
async def _timed_stream():
stream_iter = stream.__aiter__()
while True:
try:
yield await asyncio.wait_for(
stream_iter.__anext__(),
timeout=idle_timeout_s,
)
except StopAsyncIteration:
break
content, tool_calls, finish_reason, usage, reasoning_content = await consume_sdk_stream(
_timed_stream(),
on_content_delta,
)
return LLMResponse(
content=content or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
usage=usage,
reasoning_content=reasoning_content,
)
except Exception as responses_error:
if not self._should_fallback_from_responses_error(responses_error):
raise
kwargs = self._build_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
stream = await self._client.chat.completions.create(**kwargs)
chunks: list[Any] = []
stream_iter = stream.__aiter__()

View File

@ -10,7 +10,7 @@ from __future__ import annotations
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -54,6 +54,57 @@ def _fake_tool_call_response() -> SimpleNamespace:
return SimpleNamespace(choices=[choice], usage=usage)
def _fake_responses_response(content: str = "ok") -> MagicMock:
"""Build a minimal Responses API response object."""
resp = MagicMock()
resp.model_dump.return_value = {
"output": [{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": content}],
}],
"status": "completed",
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
}
return resp
def _fake_responses_stream(text: str = "ok"):
async def _stream():
yield SimpleNamespace(type="response.output_text.delta", delta=text)
yield SimpleNamespace(
type="response.completed",
response=SimpleNamespace(
status="completed",
usage=SimpleNamespace(input_tokens=10, output_tokens=5, total_tokens=15),
output=[],
),
)
return _stream()
def _fake_chat_stream(text: str = "ok"):
async def _stream():
yield SimpleNamespace(
choices=[SimpleNamespace(finish_reason=None, delta=SimpleNamespace(content=text, reasoning_content=None, tool_calls=None))],
usage=None,
)
yield SimpleNamespace(
choices=[SimpleNamespace(finish_reason="stop", delta=SimpleNamespace(content=None, reasoning_content=None, tool_calls=None))],
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
)
return _stream()
class _FakeResponsesError(Exception):
def __init__(self, status_code: int, text: str):
super().__init__(text)
self.status_code = status_code
self.response = SimpleNamespace(status_code=status_code, text=text, headers={})
class _StalledStream:
def __aiter__(self):
return self
@ -226,6 +277,224 @@ def test_openai_model_passthrough() -> None:
assert provider.get_default_model() == "gpt-4o"
@pytest.mark.asyncio
async def test_direct_openai_gpt5_uses_responses_api() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_response())
mock_responses = AsyncMock(return_value=_fake_responses_response("from responses"))
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-5-chat",
spec=spec,
)
result = await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="gpt-5-chat",
)
assert result.content == "from responses"
mock_responses.assert_awaited_once()
mock_chat.assert_not_awaited()
call_kwargs = mock_responses.call_args.kwargs
assert call_kwargs["model"] == "gpt-5-chat"
assert call_kwargs["max_output_tokens"] == 4096
assert "input" in call_kwargs
assert "messages" not in call_kwargs
@pytest.mark.asyncio
async def test_direct_openai_reasoning_prefers_responses_api() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_response())
mock_responses = AsyncMock(return_value=_fake_responses_response("reasoned"))
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-4o",
spec=spec,
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="gpt-4o",
reasoning_effort="medium",
)
mock_responses.assert_awaited_once()
mock_chat.assert_not_awaited()
call_kwargs = mock_responses.call_args.kwargs
assert call_kwargs["reasoning"] == {"effort": "medium"}
assert call_kwargs["include"] == ["reasoning.encrypted_content"]
@pytest.mark.asyncio
async def test_direct_openai_gpt4o_stays_on_chat_completions() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_response())
mock_responses = AsyncMock(return_value=_fake_responses_response())
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-4o",
spec=spec,
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="gpt-4o",
)
mock_chat.assert_awaited_once()
mock_responses.assert_not_awaited()
@pytest.mark.asyncio
async def test_openrouter_gpt5_stays_on_chat_completions() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_response())
mock_responses = AsyncMock(return_value=_fake_responses_response())
spec = find_by_name("openrouter")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-or-test-key",
api_base="https://openrouter.ai/api/v1",
default_model="openai/gpt-5",
spec=spec,
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="openai/gpt-5",
)
mock_chat.assert_awaited_once()
mock_responses.assert_not_awaited()
@pytest.mark.asyncio
async def test_direct_openai_streaming_gpt5_uses_responses_api() -> None:
mock_chat = AsyncMock(return_value=_StalledStream())
mock_responses = AsyncMock(return_value=_fake_responses_stream("hi"))
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-5-chat",
spec=spec,
)
result = await provider.chat_stream(
messages=[{"role": "user", "content": "hello"}],
model="gpt-5-chat",
)
assert result.content == "hi"
assert result.finish_reason == "stop"
mock_responses.assert_awaited_once()
mock_chat.assert_not_awaited()
@pytest.mark.asyncio
async def test_direct_openai_responses_404_falls_back_to_chat_completions() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_response("from chat"))
mock_responses = AsyncMock(side_effect=_FakeResponsesError(404, "Responses endpoint not supported"))
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-5-chat",
spec=spec,
)
result = await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="gpt-5-chat",
)
assert result.content == "from chat"
mock_responses.assert_awaited_once()
mock_chat.assert_awaited_once()
@pytest.mark.asyncio
async def test_direct_openai_stream_responses_unsupported_param_falls_back() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_stream("fallback stream"))
mock_responses = AsyncMock(
side_effect=_FakeResponsesError(400, "Unknown parameter: max_output_tokens for Responses API")
)
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-5-chat",
spec=spec,
)
result = await provider.chat_stream(
messages=[{"role": "user", "content": "hello"}],
model="gpt-5-chat",
)
assert result.content == "fallback stream"
mock_responses.assert_awaited_once()
mock_chat.assert_awaited_once()
@pytest.mark.asyncio
async def test_direct_openai_responses_rate_limit_does_not_fallback() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_response("from chat"))
mock_responses = AsyncMock(side_effect=_FakeResponsesError(429, "rate limit"))
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-5-chat",
spec=spec,
)
result = await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="gpt-5-chat",
)
assert result.finish_reason == "error"
mock_responses.assert_awaited_once()
mock_chat.assert_not_awaited()
def test_openai_compat_supports_temperature_matches_reasoning_model_rules() -> None:
assert OpenAICompatProvider._supports_temperature("gpt-4o") is True
assert OpenAICompatProvider._supports_temperature("gpt-5-chat") is False