Use SDK for stream

This commit is contained in:
Kunal Karmakar 2026-03-31 02:17:30 +00:00 committed by Xubin Ren
parent 0417c3f03b
commit 8c0607e079
4 changed files with 139 additions and 111 deletions

View File

@ -11,12 +11,11 @@ import uuid
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
import httpx
from openai import AsyncOpenAI from openai import AsyncOpenAI
from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.openai_responses_common import ( from nanobot.providers.openai_responses_common import (
consume_sse, consume_sdk_stream,
convert_messages, convert_messages,
convert_tools, convert_tools,
parse_response_output, parse_response_output,
@ -94,6 +93,7 @@ class AzureOpenAIProvider(LLMProvider):
"model": deployment, "model": deployment,
"instructions": instructions or None, "instructions": instructions or None,
"input": input_items, "input": input_items,
"max_output_tokens": max(1, max_tokens),
"store": False, "store": False,
"stream": False, "stream": False,
} }
@ -159,31 +159,15 @@ class AzureOpenAIProvider(LLMProvider):
body["stream"] = True body["stream"] = True
try: try:
# Use raw httpx stream via the SDK's base URL so we can reuse stream = await self._client.responses.create(**body)
# the shared Responses-API SSE parser (same as Codex provider). content, tool_calls, finish_reason = await consume_sdk_stream(
base_url = str(self._client.base_url).rstrip("/") stream, on_content_delta,
url = f"{base_url}/responses" )
headers = { return LLMResponse(
"Authorization": f"Bearer {self._client.api_key}", content=content or None,
"Content-Type": "application/json", tool_calls=tool_calls,
**(self._client._custom_headers or {}), finish_reason=finish_reason,
} )
async with httpx.AsyncClient(timeout=60.0, verify=True) as http:
async with http.stream("POST", url, headers=headers, json=body) as response:
if response.status_code != 200:
text = await response.aread()
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
finish_reason="error",
)
content, tool_calls, finish_reason = await consume_sse(
response, on_content_delta,
)
return LLMResponse(
content=content or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
)
except Exception as e: except Exception as e:
return self._handle_error(e) return self._handle_error(e)

View File

@ -8,6 +8,7 @@ from nanobot.providers.openai_responses_common.converters import (
) )
from nanobot.providers.openai_responses_common.parsing import ( from nanobot.providers.openai_responses_common.parsing import (
FINISH_REASON_MAP, FINISH_REASON_MAP,
consume_sdk_stream,
consume_sse, consume_sse,
iter_sse, iter_sse,
map_finish_reason, map_finish_reason,
@ -21,6 +22,7 @@ __all__ = [
"split_tool_call_id", "split_tool_call_id",
"iter_sse", "iter_sse",
"consume_sse", "consume_sse",
"consume_sdk_stream",
"map_finish_reason", "map_finish_reason",
"parse_response_output", "parse_response_output",
"FINISH_REASON_MAP", "FINISH_REASON_MAP",

View File

@ -171,3 +171,72 @@ def parse_response_output(response: Any) -> LLMResponse:
finish_reason=finish_reason, finish_reason=finish_reason,
usage=usage, usage=usage,
) )
async def consume_sdk_stream(
stream: Any,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]:
"""Consume an SDK async stream from ``client.responses.create(stream=True)``.
The SDK yields typed event objects with a ``.type`` attribute and
event-specific fields. Returns ``(content, tool_calls, finish_reason)``.
"""
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
async for event in stream:
event_type = getattr(event, "type", None)
if event_type == "response.output_item.added":
item = getattr(event, "item", None)
if item and getattr(item, "type", None) == "function_call":
call_id = getattr(item, "call_id", None)
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": getattr(item, "id", None) or "fc_0",
"name": getattr(item, "name", None),
"arguments": getattr(item, "arguments", None) or "",
}
elif event_type == "response.output_text.delta":
delta_text = getattr(event, "delta", "") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = getattr(event, "call_id", None)
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or ""
elif event_type == "response.function_call_arguments.done":
call_id = getattr(event, "call_id", None)
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or ""
elif event_type == "response.output_item.done":
item = getattr(event, "item", None)
if item and getattr(item, "type", None) == "function_call":
call_id = getattr(item, "call_id", None)
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
try:
args = json.loads(args_raw)
except Exception:
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
name=buf.get("name") or getattr(item, "name", None),
arguments=args,
)
)
elif event_type == "response.completed":
resp = getattr(event, "response", None)
status = getattr(resp, "status", None) if resp else None
finish_reason = map_finish_reason(status)
elif event_type in {"error", "response.failed"}:
raise RuntimeError("Response failed")
return content, tool_calls, finish_reason

View File

@ -1,6 +1,6 @@
"""Test Azure OpenAI provider (Responses API via OpenAI SDK).""" """Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@ -93,6 +93,7 @@ def test_build_body_basic():
assert body["model"] == "gpt-4o" assert body["model"] == "gpt-4o"
assert body["instructions"] == "You are helpful." assert body["instructions"] == "You are helpful."
assert body["temperature"] == 0.7 assert body["temperature"] == 0.7
assert body["max_output_tokens"] == 4096
assert body["store"] is False assert body["store"] is False
assert "reasoning" not in body assert "reasoning" not in body
# input should contain the converted user message only (system extracted) # input should contain the converted user message only (system extracted)
@ -102,6 +103,13 @@ def test_build_body_basic():
) )
def test_build_body_max_tokens_minimum():
"""max_output_tokens should never be less than 1."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None)
assert body["max_output_tokens"] == 1
def test_build_body_with_tools(): def test_build_body_with_tools():
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o") provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
@ -290,46 +298,29 @@ async def test_chat_stream_success():
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o", api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
) )
# Build SSE lines for the mock httpx stream # Build mock SDK stream events
sse_events = [ events = []
'event: response.output_text.delta', ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
'data: {"type":"response.output_text.delta","delta":"Hello"}', ev2 = MagicMock(type="response.output_text.delta", delta=" world")
'', resp_obj = MagicMock(status="completed")
'event: response.output_text.delta', ev3 = MagicMock(type="response.completed", response=resp_obj)
'data: {"type":"response.output_text.delta","delta":" world"}', events = [ev1, ev2, ev3]
'',
'event: response.completed', async def mock_stream():
'data: {"type":"response.completed","response":{"status":"completed"}}', for e in events:
'', yield e
]
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_stream())
deltas: list[str] = [] deltas: list[str] = []
async def on_delta(text: str) -> None: async def on_delta(text: str) -> None:
deltas.append(text) deltas.append(text)
# Mock httpx stream result = await provider.chat_stream(
mock_response = AsyncMock() [{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
mock_response.status_code = 200 )
async def aiter_lines():
for line in sse_events:
yield line
mock_response.aiter_lines = aiter_lines
with patch("httpx.AsyncClient") as mock_client:
mock_ctx = AsyncMock()
mock_stream_ctx = AsyncMock()
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
result = await provider.chat_stream(
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
)
assert result.content == "Hello world" assert result.content == "Hello world"
assert result.finish_reason == "stop" assert result.finish_reason == "stop"
@ -343,41 +334,34 @@ async def test_chat_stream_with_tool_calls():
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
) )
sse_events = [ item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="")
'data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":""}}', item_added.name = "get_weather"
'', ev_added = MagicMock(type="response.output_item.added", item=item_added)
'data: {"type":"response.function_call_arguments.delta","call_id":"call_1","delta":"{\\"loc"}', ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc')
'', ev_args_done = MagicMock(
'data: {"type":"response.function_call_arguments.done","call_id":"call_1","arguments":"{\\"location\\":\\"SF\\"}"}', type="response.function_call_arguments.done",
'', call_id="call_1", arguments='{"location":"SF"}',
'data: {"type":"response.output_item.done","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":"{\\"location\\":\\"SF\\"}"}}', )
'', item_done = MagicMock(
'data: {"type":"response.completed","response":{"status":"completed"}}', type="function_call", call_id="call_1", id="fc_1",
'', arguments='{"location":"SF"}',
] )
item_done.name = "get_weather"
ev_item_done = MagicMock(type="response.output_item.done", item=item_done)
resp_obj = MagicMock(status="completed")
ev_completed = MagicMock(type="response.completed", response=resp_obj)
mock_response = AsyncMock() async def mock_stream():
mock_response.status_code = 200 for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]:
yield e
async def aiter_lines(): provider._client.responses = MagicMock()
for line in sse_events: provider._client.responses.create = AsyncMock(return_value=mock_stream())
yield line
mock_response.aiter_lines = aiter_lines result = await provider.chat_stream(
[{"role": "user", "content": "weather?"}],
with patch("httpx.AsyncClient") as mock_client: tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
mock_ctx = AsyncMock() )
mock_stream_ctx = AsyncMock()
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
result = await provider.chat_stream(
[{"role": "user", "content": "weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
)
assert len(result.tool_calls) == 1 assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather" assert result.tool_calls[0].name == "get_weather"
@ -385,28 +369,17 @@ async def test_chat_stream_with_tool_calls():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_stream_http_error(): async def test_chat_stream_error():
"""Streaming should return error on non-200 status.""" """Streaming should return error when SDK raises."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
) )
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
mock_response = AsyncMock() result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
mock_response.status_code = 401
mock_response.aread = AsyncMock(return_value=b"Unauthorized")
with patch("httpx.AsyncClient") as mock_client: assert "Connection failed" in result.content
mock_ctx = AsyncMock()
mock_stream_ctx = AsyncMock()
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
assert "401" in result.content
assert result.finish_reason == "error" assert result.finish_reason == "error"