diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py index ab4d187ae..b97743ab2 100644 --- a/nanobot/providers/azure_openai_provider.py +++ b/nanobot/providers/azure_openai_provider.py @@ -11,12 +11,11 @@ import uuid from collections.abc import Awaitable, Callable from typing import Any -import httpx from openai import AsyncOpenAI from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.openai_responses_common import ( - consume_sse, + consume_sdk_stream, convert_messages, convert_tools, parse_response_output, @@ -94,6 +93,7 @@ class AzureOpenAIProvider(LLMProvider): "model": deployment, "instructions": instructions or None, "input": input_items, + "max_output_tokens": max(1, max_tokens), "store": False, "stream": False, } @@ -159,31 +159,15 @@ class AzureOpenAIProvider(LLMProvider): body["stream"] = True try: - # Use raw httpx stream via the SDK's base URL so we can reuse - # the shared Responses-API SSE parser (same as Codex provider). - base_url = str(self._client.base_url).rstrip("/") - url = f"{base_url}/responses" - headers = { - "Authorization": f"Bearer {self._client.api_key}", - "Content-Type": "application/json", - **(self._client._custom_headers or {}), - } - async with httpx.AsyncClient(timeout=60.0, verify=True) as http: - async with http.stream("POST", url, headers=headers, json=body) as response: - if response.status_code != 200: - text = await response.aread() - return LLMResponse( - content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}", - finish_reason="error", - ) - content, tool_calls, finish_reason = await consume_sse( - response, on_content_delta, - ) - return LLMResponse( - content=content or None, - tool_calls=tool_calls, - finish_reason=finish_reason, - ) + stream = await self._client.responses.create(**body) + content, tool_calls, finish_reason = await consume_sdk_stream( + stream, on_content_delta, + ) + return LLMResponse( + content=content or None, + tool_calls=tool_calls, + finish_reason=finish_reason, + ) except Exception as e: return self._handle_error(e) diff --git a/nanobot/providers/openai_responses_common/__init__.py b/nanobot/providers/openai_responses_common/__init__.py index cfc327bdb..80a03e43a 100644 --- a/nanobot/providers/openai_responses_common/__init__.py +++ b/nanobot/providers/openai_responses_common/__init__.py @@ -8,6 +8,7 @@ from nanobot.providers.openai_responses_common.converters import ( ) from nanobot.providers.openai_responses_common.parsing import ( FINISH_REASON_MAP, + consume_sdk_stream, consume_sse, iter_sse, map_finish_reason, @@ -21,6 +22,7 @@ __all__ = [ "split_tool_call_id", "iter_sse", "consume_sse", + "consume_sdk_stream", "map_finish_reason", "parse_response_output", "FINISH_REASON_MAP", diff --git a/nanobot/providers/openai_responses_common/parsing.py b/nanobot/providers/openai_responses_common/parsing.py index e0d5f4462..5de895534 100644 --- a/nanobot/providers/openai_responses_common/parsing.py +++ b/nanobot/providers/openai_responses_common/parsing.py @@ -171,3 +171,72 @@ def parse_response_output(response: Any) -> LLMResponse: finish_reason=finish_reason, usage=usage, ) + + +async def consume_sdk_stream( + stream: Any, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, +) -> tuple[str, list[ToolCallRequest], str]: + """Consume an SDK async stream from ``client.responses.create(stream=True)``. + + The SDK yields typed event objects with a ``.type`` attribute and + event-specific fields. Returns ``(content, tool_calls, finish_reason)``. + """ + content = "" + tool_calls: list[ToolCallRequest] = [] + tool_call_buffers: dict[str, dict[str, Any]] = {} + finish_reason = "stop" + + async for event in stream: + event_type = getattr(event, "type", None) + if event_type == "response.output_item.added": + item = getattr(event, "item", None) + if item and getattr(item, "type", None) == "function_call": + call_id = getattr(item, "call_id", None) + if not call_id: + continue + tool_call_buffers[call_id] = { + "id": getattr(item, "id", None) or "fc_0", + "name": getattr(item, "name", None), + "arguments": getattr(item, "arguments", None) or "", + } + elif event_type == "response.output_text.delta": + delta_text = getattr(event, "delta", "") or "" + content += delta_text + if on_content_delta and delta_text: + await on_content_delta(delta_text) + elif event_type == "response.function_call_arguments.delta": + call_id = getattr(event, "call_id", None) + if call_id and call_id in tool_call_buffers: + tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or "" + elif event_type == "response.function_call_arguments.done": + call_id = getattr(event, "call_id", None) + if call_id and call_id in tool_call_buffers: + tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or "" + elif event_type == "response.output_item.done": + item = getattr(event, "item", None) + if item and getattr(item, "type", None) == "function_call": + call_id = getattr(item, "call_id", None) + if not call_id: + continue + buf = tool_call_buffers.get(call_id) or {} + args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}" + try: + args = json.loads(args_raw) + except Exception: + args = {"raw": args_raw} + tool_calls.append( + ToolCallRequest( + id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}", + name=buf.get("name") or getattr(item, "name", None), + arguments=args, + ) + ) + elif event_type == "response.completed": + resp = getattr(event, "response", None) + status = getattr(resp, "status", None) if resp else None + finish_reason = map_finish_reason(status) + elif event_type in {"error", "response.failed"}: + raise RuntimeError("Response failed") + + return content, tool_calls, finish_reason diff --git a/tests/providers/test_azure_openai_provider.py b/tests/providers/test_azure_openai_provider.py index 9a95cae5d..4a18f3bf9 100644 --- a/tests/providers/test_azure_openai_provider.py +++ b/tests/providers/test_azure_openai_provider.py @@ -1,6 +1,6 @@ """Test Azure OpenAI provider (Responses API via OpenAI SDK).""" -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest @@ -93,6 +93,7 @@ def test_build_body_basic(): assert body["model"] == "gpt-4o" assert body["instructions"] == "You are helpful." assert body["temperature"] == 0.7 + assert body["max_output_tokens"] == 4096 assert body["store"] is False assert "reasoning" not in body # input should contain the converted user message only (system extracted) @@ -102,6 +103,13 @@ def test_build_body_basic(): ) +def test_build_body_max_tokens_minimum(): + """max_output_tokens should never be less than 1.""" + provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o") + body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None) + assert body["max_output_tokens"] == 1 + + def test_build_body_with_tools(): provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o") tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] @@ -290,46 +298,29 @@ async def test_chat_stream_success(): api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o", ) - # Build SSE lines for the mock httpx stream - sse_events = [ - 'event: response.output_text.delta', - 'data: {"type":"response.output_text.delta","delta":"Hello"}', - '', - 'event: response.output_text.delta', - 'data: {"type":"response.output_text.delta","delta":" world"}', - '', - 'event: response.completed', - 'data: {"type":"response.completed","response":{"status":"completed"}}', - '', - ] + # Build mock SDK stream events + events = [] + ev1 = MagicMock(type="response.output_text.delta", delta="Hello") + ev2 = MagicMock(type="response.output_text.delta", delta=" world") + resp_obj = MagicMock(status="completed") + ev3 = MagicMock(type="response.completed", response=resp_obj) + events = [ev1, ev2, ev3] + + async def mock_stream(): + for e in events: + yield e + + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_stream()) deltas: list[str] = [] async def on_delta(text: str) -> None: deltas.append(text) - # Mock httpx stream - mock_response = AsyncMock() - mock_response.status_code = 200 - - async def aiter_lines(): - for line in sse_events: - yield line - - mock_response.aiter_lines = aiter_lines - - with patch("httpx.AsyncClient") as mock_client: - mock_ctx = AsyncMock() - mock_stream_ctx = AsyncMock() - mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response) - mock_stream_ctx.__aexit__ = AsyncMock(return_value=False) - mock_ctx.stream = MagicMock(return_value=mock_stream_ctx) - mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_client.return_value.__aexit__ = AsyncMock(return_value=False) - - result = await provider.chat_stream( - [{"role": "user", "content": "Hi"}], on_content_delta=on_delta, - ) + result = await provider.chat_stream( + [{"role": "user", "content": "Hi"}], on_content_delta=on_delta, + ) assert result.content == "Hello world" assert result.finish_reason == "stop" @@ -343,41 +334,34 @@ async def test_chat_stream_with_tool_calls(): api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", ) - sse_events = [ - 'data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":""}}', - '', - 'data: {"type":"response.function_call_arguments.delta","call_id":"call_1","delta":"{\\"loc"}', - '', - 'data: {"type":"response.function_call_arguments.done","call_id":"call_1","arguments":"{\\"location\\":\\"SF\\"}"}', - '', - 'data: {"type":"response.output_item.done","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":"{\\"location\\":\\"SF\\"}"}}', - '', - 'data: {"type":"response.completed","response":{"status":"completed"}}', - '', - ] + item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="") + item_added.name = "get_weather" + ev_added = MagicMock(type="response.output_item.added", item=item_added) + ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc') + ev_args_done = MagicMock( + type="response.function_call_arguments.done", + call_id="call_1", arguments='{"location":"SF"}', + ) + item_done = MagicMock( + type="function_call", call_id="call_1", id="fc_1", + arguments='{"location":"SF"}', + ) + item_done.name = "get_weather" + ev_item_done = MagicMock(type="response.output_item.done", item=item_done) + resp_obj = MagicMock(status="completed") + ev_completed = MagicMock(type="response.completed", response=resp_obj) - mock_response = AsyncMock() - mock_response.status_code = 200 + async def mock_stream(): + for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]: + yield e - async def aiter_lines(): - for line in sse_events: - yield line + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_stream()) - mock_response.aiter_lines = aiter_lines - - with patch("httpx.AsyncClient") as mock_client: - mock_ctx = AsyncMock() - mock_stream_ctx = AsyncMock() - mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response) - mock_stream_ctx.__aexit__ = AsyncMock(return_value=False) - mock_ctx.stream = MagicMock(return_value=mock_stream_ctx) - mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_client.return_value.__aexit__ = AsyncMock(return_value=False) - - result = await provider.chat_stream( - [{"role": "user", "content": "weather?"}], - tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}], - ) + result = await provider.chat_stream( + [{"role": "user", "content": "weather?"}], + tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}], + ) assert len(result.tool_calls) == 1 assert result.tool_calls[0].name == "get_weather" @@ -385,28 +369,17 @@ async def test_chat_stream_with_tool_calls(): @pytest.mark.asyncio -async def test_chat_stream_http_error(): - """Streaming should return error on non-200 status.""" +async def test_chat_stream_error(): + """Streaming should return error when SDK raises.""" provider = AzureOpenAIProvider( api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", ) + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed")) - mock_response = AsyncMock() - mock_response.status_code = 401 - mock_response.aread = AsyncMock(return_value=b"Unauthorized") + result = await provider.chat_stream([{"role": "user", "content": "Hi"}]) - with patch("httpx.AsyncClient") as mock_client: - mock_ctx = AsyncMock() - mock_stream_ctx = AsyncMock() - mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response) - mock_stream_ctx.__aexit__ = AsyncMock(return_value=False) - mock_ctx.stream = MagicMock(return_value=mock_stream_ctx) - mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_client.return_value.__aexit__ = AsyncMock(return_value=False) - - result = await provider.chat_stream([{"role": "user", "content": "Hi"}]) - - assert "401" in result.content + assert "Connection failed" in result.content assert result.finish_reason == "error"