mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
Use SDK for stream
This commit is contained in:
parent
0417c3f03b
commit
8c0607e079
@ -11,12 +11,11 @@ import uuid
|
|||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||||
from nanobot.providers.openai_responses_common import (
|
from nanobot.providers.openai_responses_common import (
|
||||||
consume_sse,
|
consume_sdk_stream,
|
||||||
convert_messages,
|
convert_messages,
|
||||||
convert_tools,
|
convert_tools,
|
||||||
parse_response_output,
|
parse_response_output,
|
||||||
@ -94,6 +93,7 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
"model": deployment,
|
"model": deployment,
|
||||||
"instructions": instructions or None,
|
"instructions": instructions or None,
|
||||||
"input": input_items,
|
"input": input_items,
|
||||||
|
"max_output_tokens": max(1, max_tokens),
|
||||||
"store": False,
|
"store": False,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
}
|
}
|
||||||
@ -159,31 +159,15 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
body["stream"] = True
|
body["stream"] = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use raw httpx stream via the SDK's base URL so we can reuse
|
stream = await self._client.responses.create(**body)
|
||||||
# the shared Responses-API SSE parser (same as Codex provider).
|
content, tool_calls, finish_reason = await consume_sdk_stream(
|
||||||
base_url = str(self._client.base_url).rstrip("/")
|
stream, on_content_delta,
|
||||||
url = f"{base_url}/responses"
|
)
|
||||||
headers = {
|
return LLMResponse(
|
||||||
"Authorization": f"Bearer {self._client.api_key}",
|
content=content or None,
|
||||||
"Content-Type": "application/json",
|
tool_calls=tool_calls,
|
||||||
**(self._client._custom_headers or {}),
|
finish_reason=finish_reason,
|
||||||
}
|
)
|
||||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as http:
|
|
||||||
async with http.stream("POST", url, headers=headers, json=body) as response:
|
|
||||||
if response.status_code != 200:
|
|
||||||
text = await response.aread()
|
|
||||||
return LLMResponse(
|
|
||||||
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
|
|
||||||
finish_reason="error",
|
|
||||||
)
|
|
||||||
content, tool_calls, finish_reason = await consume_sse(
|
|
||||||
response, on_content_delta,
|
|
||||||
)
|
|
||||||
return LLMResponse(
|
|
||||||
content=content or None,
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self._handle_error(e)
|
return self._handle_error(e)
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from nanobot.providers.openai_responses_common.converters import (
|
|||||||
)
|
)
|
||||||
from nanobot.providers.openai_responses_common.parsing import (
|
from nanobot.providers.openai_responses_common.parsing import (
|
||||||
FINISH_REASON_MAP,
|
FINISH_REASON_MAP,
|
||||||
|
consume_sdk_stream,
|
||||||
consume_sse,
|
consume_sse,
|
||||||
iter_sse,
|
iter_sse,
|
||||||
map_finish_reason,
|
map_finish_reason,
|
||||||
@ -21,6 +22,7 @@ __all__ = [
|
|||||||
"split_tool_call_id",
|
"split_tool_call_id",
|
||||||
"iter_sse",
|
"iter_sse",
|
||||||
"consume_sse",
|
"consume_sse",
|
||||||
|
"consume_sdk_stream",
|
||||||
"map_finish_reason",
|
"map_finish_reason",
|
||||||
"parse_response_output",
|
"parse_response_output",
|
||||||
"FINISH_REASON_MAP",
|
"FINISH_REASON_MAP",
|
||||||
|
|||||||
@ -171,3 +171,72 @@ def parse_response_output(response: Any) -> LLMResponse:
|
|||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def consume_sdk_stream(
|
||||||
|
stream: Any,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> tuple[str, list[ToolCallRequest], str]:
|
||||||
|
"""Consume an SDK async stream from ``client.responses.create(stream=True)``.
|
||||||
|
|
||||||
|
The SDK yields typed event objects with a ``.type`` attribute and
|
||||||
|
event-specific fields. Returns ``(content, tool_calls, finish_reason)``.
|
||||||
|
"""
|
||||||
|
content = ""
|
||||||
|
tool_calls: list[ToolCallRequest] = []
|
||||||
|
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
async for event in stream:
|
||||||
|
event_type = getattr(event, "type", None)
|
||||||
|
if event_type == "response.output_item.added":
|
||||||
|
item = getattr(event, "item", None)
|
||||||
|
if item and getattr(item, "type", None) == "function_call":
|
||||||
|
call_id = getattr(item, "call_id", None)
|
||||||
|
if not call_id:
|
||||||
|
continue
|
||||||
|
tool_call_buffers[call_id] = {
|
||||||
|
"id": getattr(item, "id", None) or "fc_0",
|
||||||
|
"name": getattr(item, "name", None),
|
||||||
|
"arguments": getattr(item, "arguments", None) or "",
|
||||||
|
}
|
||||||
|
elif event_type == "response.output_text.delta":
|
||||||
|
delta_text = getattr(event, "delta", "") or ""
|
||||||
|
content += delta_text
|
||||||
|
if on_content_delta and delta_text:
|
||||||
|
await on_content_delta(delta_text)
|
||||||
|
elif event_type == "response.function_call_arguments.delta":
|
||||||
|
call_id = getattr(event, "call_id", None)
|
||||||
|
if call_id and call_id in tool_call_buffers:
|
||||||
|
tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or ""
|
||||||
|
elif event_type == "response.function_call_arguments.done":
|
||||||
|
call_id = getattr(event, "call_id", None)
|
||||||
|
if call_id and call_id in tool_call_buffers:
|
||||||
|
tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or ""
|
||||||
|
elif event_type == "response.output_item.done":
|
||||||
|
item = getattr(event, "item", None)
|
||||||
|
if item and getattr(item, "type", None) == "function_call":
|
||||||
|
call_id = getattr(item, "call_id", None)
|
||||||
|
if not call_id:
|
||||||
|
continue
|
||||||
|
buf = tool_call_buffers.get(call_id) or {}
|
||||||
|
args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}"
|
||||||
|
try:
|
||||||
|
args = json.loads(args_raw)
|
||||||
|
except Exception:
|
||||||
|
args = {"raw": args_raw}
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCallRequest(
|
||||||
|
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
|
||||||
|
name=buf.get("name") or getattr(item, "name", None),
|
||||||
|
arguments=args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif event_type == "response.completed":
|
||||||
|
resp = getattr(event, "response", None)
|
||||||
|
status = getattr(resp, "status", None) if resp else None
|
||||||
|
finish_reason = map_finish_reason(status)
|
||||||
|
elif event_type in {"error", "response.failed"}:
|
||||||
|
raise RuntimeError("Response failed")
|
||||||
|
|
||||||
|
return content, tool_calls, finish_reason
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
|
"""Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -93,6 +93,7 @@ def test_build_body_basic():
|
|||||||
assert body["model"] == "gpt-4o"
|
assert body["model"] == "gpt-4o"
|
||||||
assert body["instructions"] == "You are helpful."
|
assert body["instructions"] == "You are helpful."
|
||||||
assert body["temperature"] == 0.7
|
assert body["temperature"] == 0.7
|
||||||
|
assert body["max_output_tokens"] == 4096
|
||||||
assert body["store"] is False
|
assert body["store"] is False
|
||||||
assert "reasoning" not in body
|
assert "reasoning" not in body
|
||||||
# input should contain the converted user message only (system extracted)
|
# input should contain the converted user message only (system extracted)
|
||||||
@ -102,6 +103,13 @@ def test_build_body_basic():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_body_max_tokens_minimum():
|
||||||
|
"""max_output_tokens should never be less than 1."""
|
||||||
|
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||||
|
body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None)
|
||||||
|
assert body["max_output_tokens"] == 1
|
||||||
|
|
||||||
|
|
||||||
def test_build_body_with_tools():
|
def test_build_body_with_tools():
|
||||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||||
@ -290,46 +298,29 @@ async def test_chat_stream_success():
|
|||||||
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build SSE lines for the mock httpx stream
|
# Build mock SDK stream events
|
||||||
sse_events = [
|
events = []
|
||||||
'event: response.output_text.delta',
|
ev1 = MagicMock(type="response.output_text.delta", delta="Hello")
|
||||||
'data: {"type":"response.output_text.delta","delta":"Hello"}',
|
ev2 = MagicMock(type="response.output_text.delta", delta=" world")
|
||||||
'',
|
resp_obj = MagicMock(status="completed")
|
||||||
'event: response.output_text.delta',
|
ev3 = MagicMock(type="response.completed", response=resp_obj)
|
||||||
'data: {"type":"response.output_text.delta","delta":" world"}',
|
events = [ev1, ev2, ev3]
|
||||||
'',
|
|
||||||
'event: response.completed',
|
async def mock_stream():
|
||||||
'data: {"type":"response.completed","response":{"status":"completed"}}',
|
for e in events:
|
||||||
'',
|
yield e
|
||||||
]
|
|
||||||
|
provider._client.responses = MagicMock()
|
||||||
|
provider._client.responses.create = AsyncMock(return_value=mock_stream())
|
||||||
|
|
||||||
deltas: list[str] = []
|
deltas: list[str] = []
|
||||||
|
|
||||||
async def on_delta(text: str) -> None:
|
async def on_delta(text: str) -> None:
|
||||||
deltas.append(text)
|
deltas.append(text)
|
||||||
|
|
||||||
# Mock httpx stream
|
result = await provider.chat_stream(
|
||||||
mock_response = AsyncMock()
|
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
|
||||||
mock_response.status_code = 200
|
)
|
||||||
|
|
||||||
async def aiter_lines():
|
|
||||||
for line in sse_events:
|
|
||||||
yield line
|
|
||||||
|
|
||||||
mock_response.aiter_lines = aiter_lines
|
|
||||||
|
|
||||||
with patch("httpx.AsyncClient") as mock_client:
|
|
||||||
mock_ctx = AsyncMock()
|
|
||||||
mock_stream_ctx = AsyncMock()
|
|
||||||
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
|
|
||||||
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
|
|
||||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
|
|
||||||
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
|
|
||||||
result = await provider.chat_stream(
|
|
||||||
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.content == "Hello world"
|
assert result.content == "Hello world"
|
||||||
assert result.finish_reason == "stop"
|
assert result.finish_reason == "stop"
|
||||||
@ -343,41 +334,34 @@ async def test_chat_stream_with_tool_calls():
|
|||||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||||
)
|
)
|
||||||
|
|
||||||
sse_events = [
|
item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="")
|
||||||
'data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":""}}',
|
item_added.name = "get_weather"
|
||||||
'',
|
ev_added = MagicMock(type="response.output_item.added", item=item_added)
|
||||||
'data: {"type":"response.function_call_arguments.delta","call_id":"call_1","delta":"{\\"loc"}',
|
ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc')
|
||||||
'',
|
ev_args_done = MagicMock(
|
||||||
'data: {"type":"response.function_call_arguments.done","call_id":"call_1","arguments":"{\\"location\\":\\"SF\\"}"}',
|
type="response.function_call_arguments.done",
|
||||||
'',
|
call_id="call_1", arguments='{"location":"SF"}',
|
||||||
'data: {"type":"response.output_item.done","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":"{\\"location\\":\\"SF\\"}"}}',
|
)
|
||||||
'',
|
item_done = MagicMock(
|
||||||
'data: {"type":"response.completed","response":{"status":"completed"}}',
|
type="function_call", call_id="call_1", id="fc_1",
|
||||||
'',
|
arguments='{"location":"SF"}',
|
||||||
]
|
)
|
||||||
|
item_done.name = "get_weather"
|
||||||
|
ev_item_done = MagicMock(type="response.output_item.done", item=item_done)
|
||||||
|
resp_obj = MagicMock(status="completed")
|
||||||
|
ev_completed = MagicMock(type="response.completed", response=resp_obj)
|
||||||
|
|
||||||
mock_response = AsyncMock()
|
async def mock_stream():
|
||||||
mock_response.status_code = 200
|
for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]:
|
||||||
|
yield e
|
||||||
|
|
||||||
async def aiter_lines():
|
provider._client.responses = MagicMock()
|
||||||
for line in sse_events:
|
provider._client.responses.create = AsyncMock(return_value=mock_stream())
|
||||||
yield line
|
|
||||||
|
|
||||||
mock_response.aiter_lines = aiter_lines
|
result = await provider.chat_stream(
|
||||||
|
[{"role": "user", "content": "weather?"}],
|
||||||
with patch("httpx.AsyncClient") as mock_client:
|
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
||||||
mock_ctx = AsyncMock()
|
)
|
||||||
mock_stream_ctx = AsyncMock()
|
|
||||||
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
|
|
||||||
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
|
|
||||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
|
|
||||||
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
|
|
||||||
result = await provider.chat_stream(
|
|
||||||
[{"role": "user", "content": "weather?"}],
|
|
||||||
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result.tool_calls) == 1
|
assert len(result.tool_calls) == 1
|
||||||
assert result.tool_calls[0].name == "get_weather"
|
assert result.tool_calls[0].name == "get_weather"
|
||||||
@ -385,28 +369,17 @@ async def test_chat_stream_with_tool_calls():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_stream_http_error():
|
async def test_chat_stream_error():
|
||||||
"""Streaming should return error on non-200 status."""
|
"""Streaming should return error when SDK raises."""
|
||||||
provider = AzureOpenAIProvider(
|
provider = AzureOpenAIProvider(
|
||||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||||
)
|
)
|
||||||
|
provider._client.responses = MagicMock()
|
||||||
|
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
|
||||||
|
|
||||||
mock_response = AsyncMock()
|
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
|
||||||
mock_response.status_code = 401
|
|
||||||
mock_response.aread = AsyncMock(return_value=b"Unauthorized")
|
|
||||||
|
|
||||||
with patch("httpx.AsyncClient") as mock_client:
|
assert "Connection failed" in result.content
|
||||||
mock_ctx = AsyncMock()
|
|
||||||
mock_stream_ctx = AsyncMock()
|
|
||||||
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
|
|
||||||
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
|
|
||||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
|
|
||||||
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
|
|
||||||
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
|
|
||||||
|
|
||||||
assert "401" in result.content
|
|
||||||
assert result.finish_reason == "error"
|
assert result.finish_reason == "error"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user