Use SDK for stream

This commit is contained in:
Kunal Karmakar 2026-03-31 02:17:30 +00:00 committed by Xubin Ren
parent 0417c3f03b
commit 8c0607e079
4 changed files with 139 additions and 111 deletions

View File

@ -11,12 +11,11 @@ import uuid
from collections.abc import Awaitable, Callable
from typing import Any
import httpx
from openai import AsyncOpenAI
from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.openai_responses_common import (
consume_sse,
consume_sdk_stream,
convert_messages,
convert_tools,
parse_response_output,
@ -94,6 +93,7 @@ class AzureOpenAIProvider(LLMProvider):
"model": deployment,
"instructions": instructions or None,
"input": input_items,
"max_output_tokens": max(1, max_tokens),
"store": False,
"stream": False,
}
@ -159,31 +159,15 @@ class AzureOpenAIProvider(LLMProvider):
body["stream"] = True
try:
# Use raw httpx stream via the SDK's base URL so we can reuse
# the shared Responses-API SSE parser (same as Codex provider).
base_url = str(self._client.base_url).rstrip("/")
url = f"{base_url}/responses"
headers = {
"Authorization": f"Bearer {self._client.api_key}",
"Content-Type": "application/json",
**(self._client._custom_headers or {}),
}
async with httpx.AsyncClient(timeout=60.0, verify=True) as http:
async with http.stream("POST", url, headers=headers, json=body) as response:
if response.status_code != 200:
text = await response.aread()
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
finish_reason="error",
)
content, tool_calls, finish_reason = await consume_sse(
response, on_content_delta,
)
return LLMResponse(
content=content or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
)
stream = await self._client.responses.create(**body)
content, tool_calls, finish_reason = await consume_sdk_stream(
stream, on_content_delta,
)
return LLMResponse(
content=content or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
)
except Exception as e:
return self._handle_error(e)

View File

@ -8,6 +8,7 @@ from nanobot.providers.openai_responses_common.converters import (
)
from nanobot.providers.openai_responses_common.parsing import (
FINISH_REASON_MAP,
consume_sdk_stream,
consume_sse,
iter_sse,
map_finish_reason,
@ -21,6 +22,7 @@ __all__ = [
"split_tool_call_id",
"iter_sse",
"consume_sse",
"consume_sdk_stream",
"map_finish_reason",
"parse_response_output",
"FINISH_REASON_MAP",

View File

@ -171,3 +171,72 @@ def parse_response_output(response: Any) -> LLMResponse:
finish_reason=finish_reason,
usage=usage,
)
async def consume_sdk_stream(
stream: Any,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]:
"""Consume an SDK async stream from ``client.responses.create(stream=True)``.
The SDK yields typed event objects with a ``.type`` attribute and
event-specific fields. Returns ``(content, tool_calls, finish_reason)``.
"""
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
async for event in stream:
event_type = getattr(event, "type", None)
if event_type == "response.output_item.added":
item = getattr(event, "item", None)
if item and getattr(item, "type", None) == "function_call":
call_id = getattr(item, "call_id", None)
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": getattr(item, "id", None) or "fc_0",
"name": getattr(item, "name", None),
"arguments": getattr(item, "arguments", None) or "",
}
elif event_type == "response.output_text.delta":
delta_text = getattr(event, "delta", "") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = getattr(event, "call_id", None)
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or ""
elif event_type == "response.function_call_arguments.done":
call_id = getattr(event, "call_id", None)
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or ""
elif event_type == "response.output_item.done":
item = getattr(event, "item", None)
if item and getattr(item, "type", None) == "function_call":
call_id = getattr(item, "call_id", None)
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
try:
args = json.loads(args_raw)
except Exception:
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
name=buf.get("name") or getattr(item, "name", None),
arguments=args,
)
)
elif event_type == "response.completed":
resp = getattr(event, "response", None)
status = getattr(resp, "status", None) if resp else None
finish_reason = map_finish_reason(status)
elif event_type in {"error", "response.failed"}:
raise RuntimeError("Response failed")
return content, tool_calls, finish_reason

View File

@ -1,6 +1,6 @@
"""Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -93,6 +93,7 @@ def test_build_body_basic():
assert body["model"] == "gpt-4o"
assert body["instructions"] == "You are helpful."
assert body["temperature"] == 0.7
assert body["max_output_tokens"] == 4096
assert body["store"] is False
assert "reasoning" not in body
# input should contain the converted user message only (system extracted)
@ -102,6 +103,13 @@ def test_build_body_basic():
)
def test_build_body_max_tokens_minimum():
"""max_output_tokens should never be less than 1."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None)
assert body["max_output_tokens"] == 1
def test_build_body_with_tools():
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
@ -290,46 +298,29 @@ async def test_chat_stream_success():
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
# Build SSE lines for the mock httpx stream
sse_events = [
'event: response.output_text.delta',
'data: {"type":"response.output_text.delta","delta":"Hello"}',
'',
'event: response.output_text.delta',
'data: {"type":"response.output_text.delta","delta":" world"}',
'',
'event: response.completed',
'data: {"type":"response.completed","response":{"status":"completed"}}',
'',
]
# Build mock SDK stream events
events = []
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
resp_obj = MagicMock(status="completed")
ev3 = MagicMock(type="response.completed", response=resp_obj)
events = [ev1, ev2, ev3]
async def mock_stream():
for e in events:
yield e
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_stream())
deltas: list[str] = []
async def on_delta(text: str) -> None:
deltas.append(text)
# Mock httpx stream
mock_response = AsyncMock()
mock_response.status_code = 200
async def aiter_lines():
for line in sse_events:
yield line
mock_response.aiter_lines = aiter_lines
with patch("httpx.AsyncClient") as mock_client:
mock_ctx = AsyncMock()
mock_stream_ctx = AsyncMock()
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
result = await provider.chat_stream(
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
)
result = await provider.chat_stream(
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
)
assert result.content == "Hello world"
assert result.finish_reason == "stop"
@ -343,41 +334,34 @@ async def test_chat_stream_with_tool_calls():
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
sse_events = [
'data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":""}}',
'',
'data: {"type":"response.function_call_arguments.delta","call_id":"call_1","delta":"{\\"loc"}',
'',
'data: {"type":"response.function_call_arguments.done","call_id":"call_1","arguments":"{\\"location\\":\\"SF\\"}"}',
'',
'data: {"type":"response.output_item.done","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":"{\\"location\\":\\"SF\\"}"}}',
'',
'data: {"type":"response.completed","response":{"status":"completed"}}',
'',
]
item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="")
item_added.name = "get_weather"
ev_added = MagicMock(type="response.output_item.added", item=item_added)
ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc')
ev_args_done = MagicMock(
type="response.function_call_arguments.done",
call_id="call_1", arguments='{"location":"SF"}',
)
item_done = MagicMock(
type="function_call", call_id="call_1", id="fc_1",
arguments='{"location":"SF"}',
)
item_done.name = "get_weather"
ev_item_done = MagicMock(type="response.output_item.done", item=item_done)
resp_obj = MagicMock(status="completed")
ev_completed = MagicMock(type="response.completed", response=resp_obj)
mock_response = AsyncMock()
mock_response.status_code = 200
async def mock_stream():
for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]:
yield e
async def aiter_lines():
for line in sse_events:
yield line
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_stream())
mock_response.aiter_lines = aiter_lines
with patch("httpx.AsyncClient") as mock_client:
mock_ctx = AsyncMock()
mock_stream_ctx = AsyncMock()
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
result = await provider.chat_stream(
[{"role": "user", "content": "weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
)
result = await provider.chat_stream(
[{"role": "user", "content": "weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
@ -385,28 +369,17 @@ async def test_chat_stream_with_tool_calls():
@pytest.mark.asyncio
async def test_chat_stream_http_error():
"""Streaming should return error on non-200 status."""
async def test_chat_stream_error():
"""Streaming should return error when SDK raises."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
mock_response = AsyncMock()
mock_response.status_code = 401
mock_response.aread = AsyncMock(return_value=b"Unauthorized")
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
with patch("httpx.AsyncClient") as mock_client:
mock_ctx = AsyncMock()
mock_stream_ctx = AsyncMock()
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
assert "401" in result.content
assert "Connection failed" in result.content
assert result.finish_reason == "error"