mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
Use SDK for stream
This commit is contained in:
parent
0417c3f03b
commit
8c0607e079
@ -11,12 +11,11 @@ import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
from nanobot.providers.openai_responses_common import (
|
||||
consume_sse,
|
||||
consume_sdk_stream,
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
parse_response_output,
|
||||
@ -94,6 +93,7 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
"model": deployment,
|
||||
"instructions": instructions or None,
|
||||
"input": input_items,
|
||||
"max_output_tokens": max(1, max_tokens),
|
||||
"store": False,
|
||||
"stream": False,
|
||||
}
|
||||
@ -159,31 +159,15 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
body["stream"] = True
|
||||
|
||||
try:
|
||||
# Use raw httpx stream via the SDK's base URL so we can reuse
|
||||
# the shared Responses-API SSE parser (same as Codex provider).
|
||||
base_url = str(self._client.base_url).rstrip("/")
|
||||
url = f"{base_url}/responses"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._client.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**(self._client._custom_headers or {}),
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as http:
|
||||
async with http.stream("POST", url, headers=headers, json=body) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
return LLMResponse(
|
||||
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
|
||||
finish_reason="error",
|
||||
)
|
||||
content, tool_calls, finish_reason = await consume_sse(
|
||||
response, on_content_delta,
|
||||
)
|
||||
return LLMResponse(
|
||||
content=content or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
stream = await self._client.responses.create(**body)
|
||||
content, tool_calls, finish_reason = await consume_sdk_stream(
|
||||
stream, on_content_delta,
|
||||
)
|
||||
return LLMResponse(
|
||||
content=content or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from nanobot.providers.openai_responses_common.converters import (
|
||||
)
|
||||
from nanobot.providers.openai_responses_common.parsing import (
|
||||
FINISH_REASON_MAP,
|
||||
consume_sdk_stream,
|
||||
consume_sse,
|
||||
iter_sse,
|
||||
map_finish_reason,
|
||||
@ -21,6 +22,7 @@ __all__ = [
|
||||
"split_tool_call_id",
|
||||
"iter_sse",
|
||||
"consume_sse",
|
||||
"consume_sdk_stream",
|
||||
"map_finish_reason",
|
||||
"parse_response_output",
|
||||
"FINISH_REASON_MAP",
|
||||
|
||||
@ -171,3 +171,72 @@ def parse_response_output(response: Any) -> LLMResponse:
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
async def consume_sdk_stream(
|
||||
stream: Any,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
"""Consume an SDK async stream from ``client.responses.create(stream=True)``.
|
||||
|
||||
The SDK yields typed event objects with a ``.type`` attribute and
|
||||
event-specific fields. Returns ``(content, tool_calls, finish_reason)``.
|
||||
"""
|
||||
content = ""
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for event in stream:
|
||||
event_type = getattr(event, "type", None)
|
||||
if event_type == "response.output_item.added":
|
||||
item = getattr(event, "item", None)
|
||||
if item and getattr(item, "type", None) == "function_call":
|
||||
call_id = getattr(item, "call_id", None)
|
||||
if not call_id:
|
||||
continue
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": getattr(item, "id", None) or "fc_0",
|
||||
"name": getattr(item, "name", None),
|
||||
"arguments": getattr(item, "arguments", None) or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
delta_text = getattr(event, "delta", "") or ""
|
||||
content += delta_text
|
||||
if on_content_delta and delta_text:
|
||||
await on_content_delta(delta_text)
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = getattr(event, "call_id", None)
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or ""
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = getattr(event, "call_id", None)
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or ""
|
||||
elif event_type == "response.output_item.done":
|
||||
item = getattr(event, "item", None)
|
||||
if item and getattr(item, "type", None) == "function_call":
|
||||
call_id = getattr(item, "call_id", None)
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
|
||||
name=buf.get("name") or getattr(item, "name", None),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
resp = getattr(event, "response", None)
|
||||
status = getattr(resp, "status", None) if resp else None
|
||||
finish_reason = map_finish_reason(status)
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
raise RuntimeError("Response failed")
|
||||
|
||||
return content, tool_calls, finish_reason
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -93,6 +93,7 @@ def test_build_body_basic():
|
||||
assert body["model"] == "gpt-4o"
|
||||
assert body["instructions"] == "You are helpful."
|
||||
assert body["temperature"] == 0.7
|
||||
assert body["max_output_tokens"] == 4096
|
||||
assert body["store"] is False
|
||||
assert "reasoning" not in body
|
||||
# input should contain the converted user message only (system extracted)
|
||||
@ -102,6 +103,13 @@ def test_build_body_basic():
|
||||
)
|
||||
|
||||
|
||||
def test_build_body_max_tokens_minimum():
|
||||
"""max_output_tokens should never be less than 1."""
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||
body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None)
|
||||
assert body["max_output_tokens"] == 1
|
||||
|
||||
|
||||
def test_build_body_with_tools():
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
@ -290,46 +298,29 @@ async def test_chat_stream_success():
|
||||
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Build SSE lines for the mock httpx stream
|
||||
sse_events = [
|
||||
'event: response.output_text.delta',
|
||||
'data: {"type":"response.output_text.delta","delta":"Hello"}',
|
||||
'',
|
||||
'event: response.output_text.delta',
|
||||
'data: {"type":"response.output_text.delta","delta":" world"}',
|
||||
'',
|
||||
'event: response.completed',
|
||||
'data: {"type":"response.completed","response":{"status":"completed"}}',
|
||||
'',
|
||||
]
|
||||
# Build mock SDK stream events
|
||||
events = []
|
||||
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
|
||||
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
|
||||
resp_obj = MagicMock(status="completed")
|
||||
ev3 = MagicMock(type="response.completed", response=resp_obj)
|
||||
events = [ev1, ev2, ev3]
|
||||
|
||||
async def mock_stream():
|
||||
for e in events:
|
||||
yield e
|
||||
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_stream())
|
||||
|
||||
deltas: list[str] = []
|
||||
|
||||
async def on_delta(text: str) -> None:
|
||||
deltas.append(text)
|
||||
|
||||
# Mock httpx stream
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
async def aiter_lines():
|
||||
for line in sse_events:
|
||||
yield line
|
||||
|
||||
mock_response.aiter_lines = aiter_lines
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_stream_ctx = AsyncMock()
|
||||
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
result = await provider.chat_stream(
|
||||
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
|
||||
)
|
||||
result = await provider.chat_stream(
|
||||
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
|
||||
)
|
||||
|
||||
assert result.content == "Hello world"
|
||||
assert result.finish_reason == "stop"
|
||||
@ -343,41 +334,34 @@ async def test_chat_stream_with_tool_calls():
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
sse_events = [
|
||||
'data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":""}}',
|
||||
'',
|
||||
'data: {"type":"response.function_call_arguments.delta","call_id":"call_1","delta":"{\\"loc"}',
|
||||
'',
|
||||
'data: {"type":"response.function_call_arguments.done","call_id":"call_1","arguments":"{\\"location\\":\\"SF\\"}"}',
|
||||
'',
|
||||
'data: {"type":"response.output_item.done","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":"{\\"location\\":\\"SF\\"}"}}',
|
||||
'',
|
||||
'data: {"type":"response.completed","response":{"status":"completed"}}',
|
||||
'',
|
||||
]
|
||||
item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="")
|
||||
item_added.name = "get_weather"
|
||||
ev_added = MagicMock(type="response.output_item.added", item=item_added)
|
||||
ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc')
|
||||
ev_args_done = MagicMock(
|
||||
type="response.function_call_arguments.done",
|
||||
call_id="call_1", arguments='{"location":"SF"}',
|
||||
)
|
||||
item_done = MagicMock(
|
||||
type="function_call", call_id="call_1", id="fc_1",
|
||||
arguments='{"location":"SF"}',
|
||||
)
|
||||
item_done.name = "get_weather"
|
||||
ev_item_done = MagicMock(type="response.output_item.done", item=item_done)
|
||||
resp_obj = MagicMock(status="completed")
|
||||
ev_completed = MagicMock(type="response.completed", response=resp_obj)
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
async def mock_stream():
|
||||
for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]:
|
||||
yield e
|
||||
|
||||
async def aiter_lines():
|
||||
for line in sse_events:
|
||||
yield line
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_stream())
|
||||
|
||||
mock_response.aiter_lines = aiter_lines
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_stream_ctx = AsyncMock()
|
||||
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
result = await provider.chat_stream(
|
||||
[{"role": "user", "content": "weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
||||
)
|
||||
result = await provider.chat_stream(
|
||||
[{"role": "user", "content": "weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
@ -385,28 +369,17 @@ async def test_chat_stream_with_tool_calls():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_http_error():
|
||||
"""Streaming should return error on non-200 status."""
|
||||
async def test_chat_stream_error():
|
||||
"""Streaming should return error when SDK raises."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.aread = AsyncMock(return_value=b"Unauthorized")
|
||||
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_stream_ctx = AsyncMock()
|
||||
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
|
||||
|
||||
assert "401" in result.content
|
||||
assert "Connection failed" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user