mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-10 21:23:39 +00:00
feat(openai): auto-route direct reasoning requests with responses fallback
This commit is contained in:
parent
c092896922
commit
d084d10dc2
@ -26,6 +26,12 @@ else:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.openai_responses import (
|
||||
consume_sdk_stream,
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
parse_response_output,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.registry import ProviderSpec
|
||||
@ -113,6 +119,14 @@ def _uses_openrouter_attribution(spec: "ProviderSpec | None", api_base: str | No
|
||||
return bool(api_base and "openrouter" in api_base.lower())
|
||||
|
||||
|
||||
def _is_direct_openai_base(api_base: str | None) -> bool:
|
||||
"""Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways."""
|
||||
if not api_base:
|
||||
return True
|
||||
normalized = api_base.strip().lower().rstrip("/")
|
||||
return "api.openai.com" in normalized and "openrouter" not in normalized
|
||||
|
||||
|
||||
class OpenAICompatProvider(LLMProvider):
|
||||
"""Unified provider for all OpenAI-compatible APIs.
|
||||
|
||||
@ -137,6 +151,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
self._setup_env(api_key, api_base)
|
||||
|
||||
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
||||
self._effective_base = effective_base
|
||||
default_headers = {"x-session-affinity": uuid.uuid4().hex}
|
||||
if _uses_openrouter_attribution(spec, effective_base):
|
||||
default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
|
||||
@ -321,6 +336,88 @@ class OpenAICompatProvider(LLMProvider):
|
||||
|
||||
return kwargs
|
||||
|
||||
def _should_use_responses_api(
|
||||
self,
|
||||
model: str | None,
|
||||
reasoning_effort: str | None,
|
||||
) -> bool:
|
||||
"""Use Responses API only for direct OpenAI requests that benefit from it."""
|
||||
if self._spec and self._spec.name != "openai":
|
||||
return False
|
||||
if not _is_direct_openai_base(self._effective_base):
|
||||
return False
|
||||
|
||||
model_name = (model or self.default_model).lower()
|
||||
if reasoning_effort and reasoning_effort.lower() != "none":
|
||||
return True
|
||||
return any(token in model_name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||
|
||||
@staticmethod
|
||||
def _should_fallback_from_responses_error(e: Exception) -> bool:
|
||||
"""Fallback only for likely Responses API compatibility errors."""
|
||||
response = getattr(e, "response", None)
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if status_code is None and response is not None:
|
||||
status_code = getattr(response, "status_code", None)
|
||||
if status_code not in {400, 404, 422}:
|
||||
return False
|
||||
|
||||
body = (
|
||||
getattr(e, "body", None)
|
||||
or getattr(e, "doc", None)
|
||||
or getattr(response, "text", None)
|
||||
)
|
||||
body_text = str(body).lower() if body is not None else ""
|
||||
compatibility_markers = (
|
||||
"responses",
|
||||
"response api",
|
||||
"max_output_tokens",
|
||||
"instructions",
|
||||
"previous_response",
|
||||
"unsupported",
|
||||
"not supported",
|
||||
"unknown parameter",
|
||||
"unrecognized request argument",
|
||||
)
|
||||
return any(marker in body_text for marker in compatibility_markers)
|
||||
|
||||
def _build_responses_body(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a Responses API body for direct OpenAI requests."""
|
||||
model_name = model or self.default_model
|
||||
sanitized_messages = self._sanitize_messages(self._sanitize_empty_content(messages))
|
||||
instructions, input_items = convert_messages(sanitized_messages)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"model": model_name,
|
||||
"instructions": instructions or None,
|
||||
"input": input_items,
|
||||
"max_output_tokens": max(1, max_tokens),
|
||||
"store": False,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
if self._supports_temperature(model_name, reasoning_effort):
|
||||
body["temperature"] = temperature
|
||||
|
||||
if reasoning_effort and reasoning_effort.lower() != "none":
|
||||
body["reasoning"] = {"effort": reasoning_effort}
|
||||
body["include"] = ["reasoning.encrypted_content"]
|
||||
|
||||
if tools:
|
||||
body["tools"] = convert_tools(tools)
|
||||
body["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
return body
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Response parsing
|
||||
# ------------------------------------------------------------------
|
||||
@ -731,11 +828,22 @@ class OpenAICompatProvider(LLMProvider):
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
try:
|
||||
if self._should_use_responses_api(model, reasoning_effort):
|
||||
try:
|
||||
body = self._build_responses_body(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
return parse_response_output(await self._client.responses.create(**body))
|
||||
except Exception as responses_error:
|
||||
if not self._should_fallback_from_responses_error(responses_error):
|
||||
raise
|
||||
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
@ -751,14 +859,49 @@ class OpenAICompatProvider(LLMProvider):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
kwargs["stream"] = True
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||
try:
|
||||
if self._should_use_responses_api(model, reasoning_effort):
|
||||
try:
|
||||
body = self._build_responses_body(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
body["stream"] = True
|
||||
stream = await self._client.responses.create(**body)
|
||||
|
||||
async def _timed_stream():
|
||||
stream_iter = stream.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
yield await asyncio.wait_for(
|
||||
stream_iter.__anext__(),
|
||||
timeout=idle_timeout_s,
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
content, tool_calls, finish_reason, usage, reasoning_content = await consume_sdk_stream(
|
||||
_timed_stream(),
|
||||
on_content_delta,
|
||||
)
|
||||
return LLMResponse(
|
||||
content=content or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
except Exception as responses_error:
|
||||
if not self._should_fallback_from_responses_error(responses_error):
|
||||
raise
|
||||
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
kwargs["stream"] = True
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
chunks: list[Any] = []
|
||||
stream_iter = stream.__aiter__()
|
||||
|
||||
@ -10,7 +10,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -54,6 +54,57 @@ def _fake_tool_call_response() -> SimpleNamespace:
|
||||
return SimpleNamespace(choices=[choice], usage=usage)
|
||||
|
||||
|
||||
def _fake_responses_response(content: str = "ok") -> MagicMock:
|
||||
"""Build a minimal Responses API response object."""
|
||||
resp = MagicMock()
|
||||
resp.model_dump.return_value = {
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
return resp
|
||||
|
||||
|
||||
def _fake_responses_stream(text: str = "ok"):
|
||||
async def _stream():
|
||||
yield SimpleNamespace(type="response.output_text.delta", delta=text)
|
||||
yield SimpleNamespace(
|
||||
type="response.completed",
|
||||
response=SimpleNamespace(
|
||||
status="completed",
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=5, total_tokens=15),
|
||||
output=[],
|
||||
),
|
||||
)
|
||||
|
||||
return _stream()
|
||||
|
||||
|
||||
def _fake_chat_stream(text: str = "ok"):
|
||||
async def _stream():
|
||||
yield SimpleNamespace(
|
||||
choices=[SimpleNamespace(finish_reason=None, delta=SimpleNamespace(content=text, reasoning_content=None, tool_calls=None))],
|
||||
usage=None,
|
||||
)
|
||||
yield SimpleNamespace(
|
||||
choices=[SimpleNamespace(finish_reason="stop", delta=SimpleNamespace(content=None, reasoning_content=None, tool_calls=None))],
|
||||
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
|
||||
return _stream()
|
||||
|
||||
|
||||
class _FakeResponsesError(Exception):
|
||||
def __init__(self, status_code: int, text: str):
|
||||
super().__init__(text)
|
||||
self.status_code = status_code
|
||||
self.response = SimpleNamespace(status_code=status_code, text=text, headers={})
|
||||
|
||||
|
||||
class _StalledStream:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
@ -226,6 +277,224 @@ def test_openai_model_passthrough() -> None:
|
||||
assert provider.get_default_model() == "gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_gpt5_uses_responses_api() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_response())
|
||||
mock_responses = AsyncMock(return_value=_fake_responses_response("from responses"))
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-5-chat",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-5-chat",
|
||||
)
|
||||
|
||||
assert result.content == "from responses"
|
||||
mock_responses.assert_awaited_once()
|
||||
mock_chat.assert_not_awaited()
|
||||
call_kwargs = mock_responses.call_args.kwargs
|
||||
assert call_kwargs["model"] == "gpt-5-chat"
|
||||
assert call_kwargs["max_output_tokens"] == 4096
|
||||
assert "input" in call_kwargs
|
||||
assert "messages" not in call_kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_reasoning_prefers_responses_api() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_response())
|
||||
mock_responses = AsyncMock(return_value=_fake_responses_response("reasoned"))
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-4o",
|
||||
spec=spec,
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-4o",
|
||||
reasoning_effort="medium",
|
||||
)
|
||||
|
||||
mock_responses.assert_awaited_once()
|
||||
mock_chat.assert_not_awaited()
|
||||
call_kwargs = mock_responses.call_args.kwargs
|
||||
assert call_kwargs["reasoning"] == {"effort": "medium"}
|
||||
assert call_kwargs["include"] == ["reasoning.encrypted_content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_gpt4o_stays_on_chat_completions() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_response())
|
||||
mock_responses = AsyncMock(return_value=_fake_responses_response())
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-4o",
|
||||
spec=spec,
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
mock_chat.assert_awaited_once()
|
||||
mock_responses.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openrouter_gpt5_stays_on_chat_completions() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_response())
|
||||
mock_responses = AsyncMock(return_value=_fake_responses_response())
|
||||
spec = find_by_name("openrouter")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-or-test-key",
|
||||
api_base="https://openrouter.ai/api/v1",
|
||||
default_model="openai/gpt-5",
|
||||
spec=spec,
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="openai/gpt-5",
|
||||
)
|
||||
|
||||
mock_chat.assert_awaited_once()
|
||||
mock_responses.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_streaming_gpt5_uses_responses_api() -> None:
|
||||
mock_chat = AsyncMock(return_value=_StalledStream())
|
||||
mock_responses = AsyncMock(return_value=_fake_responses_stream("hi"))
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-5-chat",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat_stream(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-5-chat",
|
||||
)
|
||||
|
||||
assert result.content == "hi"
|
||||
assert result.finish_reason == "stop"
|
||||
mock_responses.assert_awaited_once()
|
||||
mock_chat.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_responses_404_falls_back_to_chat_completions() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_response("from chat"))
|
||||
mock_responses = AsyncMock(side_effect=_FakeResponsesError(404, "Responses endpoint not supported"))
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-5-chat",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-5-chat",
|
||||
)
|
||||
|
||||
assert result.content == "from chat"
|
||||
mock_responses.assert_awaited_once()
|
||||
mock_chat.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_stream_responses_unsupported_param_falls_back() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_stream("fallback stream"))
|
||||
mock_responses = AsyncMock(
|
||||
side_effect=_FakeResponsesError(400, "Unknown parameter: max_output_tokens for Responses API")
|
||||
)
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-5-chat",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat_stream(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-5-chat",
|
||||
)
|
||||
|
||||
assert result.content == "fallback stream"
|
||||
mock_responses.assert_awaited_once()
|
||||
mock_chat.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_responses_rate_limit_does_not_fallback() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_response("from chat"))
|
||||
mock_responses = AsyncMock(side_effect=_FakeResponsesError(429, "rate limit"))
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-5-chat",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-5-chat",
|
||||
)
|
||||
|
||||
assert result.finish_reason == "error"
|
||||
mock_responses.assert_awaited_once()
|
||||
mock_chat.assert_not_awaited()
|
||||
|
||||
|
||||
def test_openai_compat_supports_temperature_matches_reasoning_model_rules() -> None:
|
||||
assert OpenAICompatProvider._supports_temperature("gpt-4o") is True
|
||||
assert OpenAICompatProvider._supports_temperature("gpt-5-chat") is False
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user