Use OpenAI responses API

This commit is contained in:
Kunal Karmakar 2026-03-31 02:05:59 +00:00 committed by Xubin Ren
parent 9ba413c82e
commit 0417c3f03b
6 changed files with 769 additions and 728 deletions

View File

@ -1,31 +1,37 @@
"""Azure OpenAI provider implementation with API version 2024-10-21.""" """Azure OpenAI provider using the OpenAI SDK Responses API.
Uses ``AsyncOpenAI`` pointed at ``https://{endpoint}/openai/v1/`` which
routes to the Responses API (``/responses``). Reuses shared conversion
helpers from :mod:`nanobot.providers.openai_responses_common`.
"""
from __future__ import annotations from __future__ import annotations
import json
import uuid import uuid
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
from urllib.parse import urljoin
import httpx import httpx
import json_repair from openai import AsyncOpenAI
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.openai_responses_common import (
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"}) consume_sse,
convert_messages,
convert_tools,
parse_response_output,
)
class AzureOpenAIProvider(LLMProvider): class AzureOpenAIProvider(LLMProvider):
""" """Azure OpenAI provider backed by the Responses API.
Azure OpenAI provider with API version 2024-10-21 compliance.
Features: Features:
- Hardcoded API version 2024-10-21 - Uses the OpenAI Python SDK (``AsyncOpenAI``) with
- Uses model field as Azure deployment name in URL path ``base_url = {endpoint}/openai/v1/``
- Uses api-key header instead of Authorization Bearer - Calls ``client.responses.create()`` (Responses API)
- Uses max_completion_tokens instead of max_tokens - Reuses shared message/tool/SSE conversion from
- Direct HTTP calls, bypasses LiteLLM ``openai_responses_common``
""" """
def __init__( def __init__(
@ -36,40 +42,28 @@ class AzureOpenAIProvider(LLMProvider):
): ):
super().__init__(api_key, api_base) super().__init__(api_key, api_base)
self.default_model = default_model self.default_model = default_model
self.api_version = "2024-10-21"
# Validate required parameters
if not api_key: if not api_key:
raise ValueError("Azure OpenAI api_key is required") raise ValueError("Azure OpenAI api_key is required")
if not api_base: if not api_base:
raise ValueError("Azure OpenAI api_base is required") raise ValueError("Azure OpenAI api_base is required")
# Ensure api_base ends with / # Normalise: ensure trailing slash
if not api_base.endswith('/'): if not api_base.endswith("/"):
api_base += '/' api_base += "/"
self.api_base = api_base self.api_base = api_base
def _build_chat_url(self, deployment_name: str) -> str: # SDK client targeting the Azure Responses API endpoint
"""Build the Azure OpenAI chat completions URL.""" base_url = f"{api_base.rstrip('/')}/openai/v1/"
# Azure OpenAI URL format: self._client = AsyncOpenAI(
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version} api_key=api_key,
base_url = self.api_base base_url=base_url,
if not base_url.endswith('/'): default_headers={"x-session-affinity": uuid.uuid4().hex},
base_url += '/'
url = urljoin(
base_url,
f"openai/deployments/{deployment_name}/chat/completions"
) )
return f"{url}?api-version={self.api_version}"
def _build_headers(self) -> dict[str, str]: # ------------------------------------------------------------------
"""Build headers for Azure OpenAI API with api-key header.""" # Helpers
return { # ------------------------------------------------------------------
"Content-Type": "application/json",
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
"x-session-affinity": uuid.uuid4().hex, # For cache locality
}
@staticmethod @staticmethod
def _supports_temperature( def _supports_temperature(
@ -82,36 +76,50 @@ class AzureOpenAIProvider(LLMProvider):
name = deployment_name.lower() name = deployment_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4")) return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _prepare_request_payload( def _build_body(
self, self,
deployment_name: str,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None,
max_tokens: int = 4096, model: str | None,
temperature: float = 0.7, max_tokens: int,
reasoning_effort: str | None = None, temperature: float,
tool_choice: str | dict[str, Any] | None = None, reasoning_effort: str | None,
tool_choice: str | dict[str, Any] | None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance.""" """Build the Responses API request body from Chat-Completions-style args."""
payload: dict[str, Any] = { deployment = model or self.default_model
"messages": self._sanitize_request_messages( instructions, input_items = convert_messages(messages)
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS, body: dict[str, Any] = {
), "model": deployment,
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens "instructions": instructions or None,
"input": input_items,
"store": False,
"stream": False,
} }
if self._supports_temperature(deployment_name, reasoning_effort): if self._supports_temperature(deployment, reasoning_effort):
payload["temperature"] = temperature body["temperature"] = temperature
if reasoning_effort: if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort body["reasoning"] = {"effort": reasoning_effort}
body["include"] = ["reasoning.encrypted_content"]
if tools: if tools:
payload["tools"] = tools body["tools"] = convert_tools(tools)
payload["tool_choice"] = tool_choice or "auto" body["tool_choice"] = tool_choice or "auto"
return payload return body
@staticmethod
def _handle_error(e: Exception) -> LLMResponse:
body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None)
msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}"
return LLMResponse(content=msg, finish_reason="error")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def chat( async def chat(
self, self,
@ -123,92 +131,15 @@ class AzureOpenAIProvider(LLMProvider):
reasoning_effort: str | None = None, reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None, tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse: ) -> LLMResponse:
""" body = self._build_body(
Send a chat completion request to Azure OpenAI. messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (used as deployment name).
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
temperature: Sampling temperature.
reasoning_effort: Optional reasoning effort parameter.
Returns:
LLMResponse with content and/or tool calls.
"""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
tool_choice=tool_choice,
) )
try: try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client: response = await self._client.responses.create(**body)
response = await client.post(url, headers=headers, json=payload) return parse_response_output(response)
if response.status_code != 200:
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
finish_reason="error",
)
response_data = response.json()
return self._parse_response(response_data)
except Exception as e: except Exception as e:
return LLMResponse( return self._handle_error(e)
content=f"Error calling Azure OpenAI: {repr(e)}",
finish_reason="error",
)
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
"""Parse Azure OpenAI response into our standard format."""
try:
choice = response["choices"][0]
message = choice["message"]
tool_calls = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
# Parse arguments from JSON string if needed
args = tc["function"]["arguments"]
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(
ToolCallRequest(
id=tc["id"],
name=tc["function"]["name"],
arguments=args,
)
)
usage = {}
if response.get("usage"):
usage_data = response["usage"]
usage = {
"prompt_tokens": usage_data.get("prompt_tokens", 0),
"completion_tokens": usage_data.get("completion_tokens", 0),
"total_tokens": usage_data.get("total_tokens", 0),
}
reasoning_content = message.get("reasoning_content") or None
return LLMResponse(
content=message.get("content"),
tool_calls=tool_calls,
finish_reason=choice.get("finish_reason", "stop"),
usage=usage,
reasoning_content=reasoning_content,
)
except (KeyError, IndexError) as e:
return LLMResponse(
content=f"Error parsing Azure OpenAI response: {str(e)}",
finish_reason="error",
)
async def chat_stream( async def chat_stream(
self, self,
@ -221,89 +152,40 @@ class AzureOpenAIProvider(LLMProvider):
tool_choice: str | dict[str, Any] | None = None, tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse: ) -> LLMResponse:
"""Stream a chat completion via Azure OpenAI SSE.""" body = self._build_body(
deployment_name = model or self.default_model messages, tools, model, max_tokens, temperature,
url = self._build_chat_url(deployment_name) reasoning_effort, tool_choice,
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature,
reasoning_effort, tool_choice=tool_choice,
) )
payload["stream"] = True body["stream"] = True
try: try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client: # Use raw httpx stream via the SDK's base URL so we can reuse
async with client.stream("POST", url, headers=headers, json=payload) as response: # the shared Responses-API SSE parser (same as Codex provider).
base_url = str(self._client.base_url).rstrip("/")
url = f"{base_url}/responses"
headers = {
"Authorization": f"Bearer {self._client.api_key}",
"Content-Type": "application/json",
**(self._client._custom_headers or {}),
}
async with httpx.AsyncClient(timeout=60.0, verify=True) as http:
async with http.stream("POST", url, headers=headers, json=body) as response:
if response.status_code != 200: if response.status_code != 200:
text = await response.aread() text = await response.aread()
return LLMResponse( return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}", content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
finish_reason="error", finish_reason="error",
) )
return await self._consume_stream(response, on_content_delta) content, tool_calls, finish_reason = await consume_sse(
response, on_content_delta,
)
return LLMResponse(
content=content or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
)
except Exception as e: except Exception as e:
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error") return self._handle_error(e)
async def _consume_stream(
self,
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
content_parts: list[str] = []
tool_call_buffers: dict[int, dict[str, str]] = {}
finish_reason = "stop"
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:].strip()
if data == "[DONE]":
break
try:
chunk = json.loads(data)
except Exception:
continue
choices = chunk.get("choices") or []
if not choices:
continue
choice = choices[0]
if choice.get("finish_reason"):
finish_reason = choice["finish_reason"]
delta = choice.get("delta") or {}
text = delta.get("content")
if text:
content_parts.append(text)
if on_content_delta:
await on_content_delta(text)
for tc in delta.get("tool_calls") or []:
idx = tc.get("index", 0)
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
if tc.get("id"):
buf["id"] = tc["id"]
fn = tc.get("function") or {}
if fn.get("name"):
buf["name"] = fn["name"]
if fn.get("arguments"):
buf["arguments"] += fn["arguments"]
tool_calls = [
ToolCallRequest(
id=buf["id"], name=buf["name"],
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
)
for buf in tool_call_buffers.values()
]
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
)
def get_default_model(self) -> str: def get_default_model(self) -> str:
"""Get the default model (also used as default deployment name)."""
return self.default_model return self.default_model

View File

@ -6,13 +6,18 @@ import asyncio
import hashlib import hashlib
import json import json
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any, AsyncGenerator from typing import Any
import httpx import httpx
from loguru import logger from loguru import logger
from oauth_cli_kit import get_token as get_codex_token from oauth_cli_kit import get_token as get_codex_token
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.openai_responses_common import (
consume_sse,
convert_messages,
convert_tools,
)
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses" DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
DEFAULT_ORIGINATOR = "nanobot" DEFAULT_ORIGINATOR = "nanobot"
@ -36,7 +41,7 @@ class OpenAICodexProvider(LLMProvider):
) -> LLMResponse: ) -> LLMResponse:
"""Shared request logic for both chat() and chat_stream().""" """Shared request logic for both chat() and chat_stream()."""
model = model or self.default_model model = model or self.default_model
system_prompt, input_items = _convert_messages(messages) system_prompt, input_items = convert_messages(messages)
token = await asyncio.to_thread(get_codex_token) token = await asyncio.to_thread(get_codex_token)
headers = _build_headers(token.account_id, token.access) headers = _build_headers(token.account_id, token.access)
@ -56,7 +61,7 @@ class OpenAICodexProvider(LLMProvider):
if reasoning_effort: if reasoning_effort:
body["reasoning"] = {"effort": reasoning_effort} body["reasoning"] = {"effort": reasoning_effort}
if tools: if tools:
body["tools"] = _convert_tools(tools) body["tools"] = convert_tools(tools)
try: try:
try: try:
@ -127,96 +132,7 @@ async def _request_codex(
if response.status_code != 200: if response.status_code != 200:
text = await response.aread() text = await response.aread()
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore"))) raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
return await _consume_sse(response, on_content_delta) return await consume_sse(response, on_content_delta)
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert OpenAI function-calling schema to Codex flat format."""
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
name = fn.get("name")
if not name:
continue
params = fn.get("parameters") or {}
converted.append({
"type": "function",
"name": name,
"description": fn.get("description") or "",
"parameters": params if isinstance(params, dict) else {},
})
return converted
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
system_prompt = ""
input_items: list[dict[str, Any]] = []
for idx, msg in enumerate(messages):
role = msg.get("role")
content = msg.get("content")
if role == "system":
system_prompt = content if isinstance(content, str) else ""
continue
if role == "user":
input_items.append(_convert_user_message(content))
continue
if role == "assistant":
if isinstance(content, str) and content:
input_items.append({
"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": content}],
"status": "completed", "id": f"msg_{idx}",
})
for tool_call in msg.get("tool_calls", []) or []:
fn = tool_call.get("function") or {}
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
input_items.append({
"type": "function_call",
"id": item_id or f"fc_{idx}",
"call_id": call_id or f"call_{idx}",
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
})
continue
if role == "tool":
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
return system_prompt, input_items
def _convert_user_message(content: Any) -> dict[str, Any]:
if isinstance(content, str):
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
if isinstance(content, list):
converted: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url")
if url:
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
if converted:
return {"role": "user", "content": converted}
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
if isinstance(tool_call_id, str) and tool_call_id:
if "|" in tool_call_id:
call_id, item_id = tool_call_id.split("|", 1)
return call_id, item_id or None
return tool_call_id, None
return "call_0", None
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str: def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
@ -224,96 +140,6 @@ def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
return hashlib.sha256(raw.encode("utf-8")).hexdigest() return hashlib.sha256(raw.encode("utf-8")).hexdigest()
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
buffer: list[str] = []
async for line in response.aiter_lines():
if line == "":
if buffer:
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
buffer = []
if not data_lines:
continue
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
continue
try:
yield json.loads(data)
except Exception:
continue
continue
buffer.append(line)
async def _consume_sse(
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]:
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
async for event in _iter_sse(response):
event_type = event.get("type")
if event_type == "response.output_item.added":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": item.get("id") or "fc_0",
"name": item.get("name"),
"arguments": item.get("arguments") or "",
}
elif event_type == "response.output_text.delta":
delta_text = event.get("delta") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
elif event_type == "response.function_call_arguments.done":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
elif event_type == "response.output_item.done":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
try:
args = json.loads(args_raw)
except Exception:
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
name=buf.get("name") or item.get("name"),
arguments=args,
)
)
elif event_type == "response.completed":
status = (event.get("response") or {}).get("status")
finish_reason = _map_finish_reason(status)
elif event_type in {"error", "response.failed"}:
raise RuntimeError("Codex response failed")
return content, tool_calls, finish_reason
_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"}
def _map_finish_reason(status: str | None) -> str:
return _FINISH_REASON_MAP.get(status or "completed", "stop")
def _friendly_error(status_code: int, raw: str) -> str: def _friendly_error(status_code: int, raw: str) -> str:
if status_code == 429: if status_code == 429:
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later." return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."

View File

@ -0,0 +1,27 @@
"""Shared helpers for OpenAI Responses API providers (Codex, Azure OpenAI)."""
from nanobot.providers.openai_responses_common.converters import (
convert_messages,
convert_tools,
convert_user_message,
split_tool_call_id,
)
from nanobot.providers.openai_responses_common.parsing import (
FINISH_REASON_MAP,
consume_sse,
iter_sse,
map_finish_reason,
parse_response_output,
)
__all__ = [
"convert_messages",
"convert_tools",
"convert_user_message",
"split_tool_call_id",
"iter_sse",
"consume_sse",
"map_finish_reason",
"parse_response_output",
"FINISH_REASON_MAP",
]

View File

@ -0,0 +1,110 @@
"""Convert Chat Completions messages/tools to Responses API format."""
from __future__ import annotations
import json
from typing import Any
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
"""Convert Chat Completions messages to Responses API input items.
Returns ``(system_prompt, input_items)`` where *system_prompt* is extracted
from any ``system`` role message and *input_items* is the Responses API
``input`` array.
"""
system_prompt = ""
input_items: list[dict[str, Any]] = []
for idx, msg in enumerate(messages):
role = msg.get("role")
content = msg.get("content")
if role == "system":
system_prompt = content if isinstance(content, str) else ""
continue
if role == "user":
input_items.append(convert_user_message(content))
continue
if role == "assistant":
if isinstance(content, str) and content:
input_items.append({
"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": content}],
"status": "completed", "id": f"msg_{idx}",
})
for tool_call in msg.get("tool_calls", []) or []:
fn = tool_call.get("function") or {}
call_id, item_id = split_tool_call_id(tool_call.get("id"))
input_items.append({
"type": "function_call",
"id": item_id or f"fc_{idx}",
"call_id": call_id or f"call_{idx}",
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
})
continue
if role == "tool":
call_id, _ = split_tool_call_id(msg.get("tool_call_id"))
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
return system_prompt, input_items
def convert_user_message(content: Any) -> dict[str, Any]:
"""Convert a user message's content to Responses API format.
Handles plain strings, ``text`` blocks ``input_text``, and
``image_url`` blocks ``input_image``.
"""
if isinstance(content, str):
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
if isinstance(content, list):
converted: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url")
if url:
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
if converted:
return {"role": "user", "content": converted}
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert OpenAI function-calling tool schema to Responses API flat format."""
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
name = fn.get("name")
if not name:
continue
params = fn.get("parameters") or {}
converted.append({
"type": "function",
"name": name,
"description": fn.get("description") or "",
"parameters": params if isinstance(params, dict) else {},
})
return converted
def split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
"""Split a compound ``call_id|item_id`` string.
Returns ``(call_id, item_id)`` where *item_id* may be ``None``.
"""
if isinstance(tool_call_id, str) and tool_call_id:
if "|" in tool_call_id:
call_id, item_id = tool_call_id.split("|", 1)
return call_id, item_id or None
return tool_call_id, None
return "call_0", None

View File

@ -0,0 +1,173 @@
"""Parse Responses API SSE streams and SDK response objects."""
from __future__ import annotations
import json
from collections.abc import Awaitable, Callable
from typing import Any, AsyncGenerator
import httpx
from nanobot.providers.base import LLMResponse, ToolCallRequest
FINISH_REASON_MAP = {
"completed": "stop",
"incomplete": "length",
"failed": "error",
"cancelled": "error",
}
def map_finish_reason(status: str | None) -> str:
"""Map a Responses API status string to a Chat-Completions-style finish_reason."""
return FINISH_REASON_MAP.get(status or "completed", "stop")
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
"""Yield parsed JSON events from a Responses API SSE stream."""
buffer: list[str] = []
async for line in response.aiter_lines():
if line == "":
if buffer:
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
buffer = []
if not data_lines:
continue
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
continue
try:
yield json.loads(data)
except Exception:
continue
continue
buffer.append(line)
async def consume_sse(
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]:
"""Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``."""
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
async for event in iter_sse(response):
event_type = event.get("type")
if event_type == "response.output_item.added":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": item.get("id") or "fc_0",
"name": item.get("name"),
"arguments": item.get("arguments") or "",
}
elif event_type == "response.output_text.delta":
delta_text = event.get("delta") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
elif event_type == "response.function_call_arguments.done":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
elif event_type == "response.output_item.done":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
try:
args = json.loads(args_raw)
except Exception:
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
name=buf.get("name") or item.get("name"),
arguments=args,
)
)
elif event_type == "response.completed":
status = (event.get("response") or {}).get("status")
finish_reason = map_finish_reason(status)
elif event_type in {"error", "response.failed"}:
raise RuntimeError("Response failed")
return content, tool_calls, finish_reason
def parse_response_output(response: Any) -> LLMResponse:
"""Parse an SDK ``Response`` object (from ``client.responses.create()``)
into an ``LLMResponse``.
Works with both Pydantic model objects and plain dicts.
"""
# Normalise to dict
if not isinstance(response, dict):
dump = getattr(response, "model_dump", None)
response = dump() if callable(dump) else vars(response)
output = response.get("output") or []
content_parts: list[str] = []
tool_calls: list[ToolCallRequest] = []
for item in output:
if not isinstance(item, dict):
dump = getattr(item, "model_dump", None)
item = dump() if callable(dump) else vars(item)
item_type = item.get("type")
if item_type == "message":
for block in item.get("content") or []:
if not isinstance(block, dict):
dump = getattr(block, "model_dump", None)
block = dump() if callable(dump) else vars(block)
if block.get("type") == "output_text":
content_parts.append(block.get("text") or "")
elif item_type == "function_call":
call_id = item.get("call_id") or ""
item_id = item.get("id") or "fc_0"
args_raw = item.get("arguments") or "{}"
try:
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
except Exception:
args = {"raw": args_raw}
tool_calls.append(ToolCallRequest(
id=f"{call_id}|{item_id}",
name=item.get("name") or "",
arguments=args if isinstance(args, dict) else {},
))
usage_raw = response.get("usage") or {}
if not isinstance(usage_raw, dict):
dump = getattr(usage_raw, "model_dump", None)
usage_raw = dump() if callable(dump) else vars(usage_raw)
usage = {}
if usage_raw:
usage = {
"prompt_tokens": int(usage_raw.get("input_tokens") or 0),
"completion_tokens": int(usage_raw.get("output_tokens") or 0),
"total_tokens": int(usage_raw.get("total_tokens") or 0),
}
status = response.get("status")
finish_reason = map_finish_reason(status)
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
usage=usage,
)

View File

@ -1,6 +1,6 @@
"""Test Azure OpenAI provider implementation (updated for model-based deployment names).""" """Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -8,392 +8,415 @@ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import LLMResponse from nanobot.providers.base import LLMResponse
def test_azure_openai_provider_init(): # ---------------------------------------------------------------------------
"""Test AzureOpenAIProvider initialization without deployment_name.""" # Init & validation
# ---------------------------------------------------------------------------
def test_init_creates_sdk_client():
"""Provider creates an AsyncOpenAI client with correct base_url."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="test-key",
api_base="https://test-resource.openai.azure.com", api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment", default_model="gpt-4o-deployment",
) )
assert provider.api_key == "test-key" assert provider.api_key == "test-key"
assert provider.api_base == "https://test-resource.openai.azure.com/" assert provider.api_base == "https://test-resource.openai.azure.com/"
assert provider.default_model == "gpt-4o-deployment" assert provider.default_model == "gpt-4o-deployment"
assert provider.api_version == "2024-10-21" # SDK client base_url ends with /openai/v1/
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
def test_azure_openai_provider_init_validation(): def test_init_base_url_no_trailing_slash():
"""Test AzureOpenAIProvider initialization validation.""" """Trailing slashes are normalised before building base_url."""
# Missing api_key provider = AzureOpenAIProvider(
api_key="k", api_base="https://res.openai.azure.com",
)
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
def test_init_base_url_with_trailing_slash():
provider = AzureOpenAIProvider(
api_key="k", api_base="https://res.openai.azure.com/",
)
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
def test_init_validation_missing_key():
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"): with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
AzureOpenAIProvider(api_key="", api_base="https://test.com") AzureOpenAIProvider(api_key="", api_base="https://test.com")
# Missing api_base
def test_init_validation_missing_base():
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"): with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
AzureOpenAIProvider(api_key="test", api_base="") AzureOpenAIProvider(api_key="test", api_base="")
def test_build_chat_url(): def test_no_api_version_in_base_url():
"""Test Azure OpenAI URL building with different deployment names.""" """The /openai/v1/ path should NOT contain an api-version query param."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://res.openai.azure.com")
base = str(provider._client.base_url)
assert "api-version" not in base
# ---------------------------------------------------------------------------
# _supports_temperature
# ---------------------------------------------------------------------------
def test_supports_temperature_standard_model():
assert AzureOpenAIProvider._supports_temperature("gpt-4o") is True
def test_supports_temperature_reasoning_model():
assert AzureOpenAIProvider._supports_temperature("o3-mini") is False
assert AzureOpenAIProvider._supports_temperature("gpt-5-chat") is False
assert AzureOpenAIProvider._supports_temperature("o4-mini") is False
def test_supports_temperature_with_reasoning_effort():
assert AzureOpenAIProvider._supports_temperature("gpt-4o", reasoning_effort="medium") is False
# ---------------------------------------------------------------------------
# _build_body — Responses API body construction
# ---------------------------------------------------------------------------
def test_build_body_basic():
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="k", api_base="https://res.openai.azure.com", default_model="gpt-4o",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
) )
messages = [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}]
# Test various deployment names body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
test_cases = [
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
]
for deployment_name, expected_url in test_cases:
url = provider._build_chat_url(deployment_name)
assert url == expected_url
assert body["model"] == "gpt-4o"
def test_build_chat_url_api_base_without_slash(): assert body["instructions"] == "You are helpful."
"""Test URL building when api_base doesn't end with slash.""" assert body["temperature"] == 0.7
provider = AzureOpenAIProvider( assert body["store"] is False
api_key="test-key", assert "reasoning" not in body
api_base="https://test-resource.openai.azure.com", # No trailing slash # input should contain the converted user message only (system extracted)
default_model="gpt-4o", assert any(
item.get("role") == "user"
for item in body["input"]
) )
url = provider._build_chat_url("test-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
def test_build_headers(): def test_build_body_with_tools():
"""Test Azure OpenAI header building with api-key authentication.""" provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
provider = AzureOpenAIProvider(
api_key="test-api-key-123",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
headers = provider._build_headers()
assert headers["Content-Type"] == "application/json"
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
assert "x-session-affinity" in headers
def test_prepare_request_payload():
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [{"role": "user", "content": "Hello"}]
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
assert payload["messages"] == messages
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
assert payload["temperature"] == 0.8
assert "tools" not in payload
# Test with tools
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools) body = provider._build_body(
assert payload_with_tools["tools"] == tools [{"role": "user", "content": "weather?"}], tools, None, 4096, 0.7, None, None,
assert payload_with_tools["tool_choice"] == "auto"
# Test with reasoning_effort
payload_with_reasoning = provider._prepare_request_payload(
"gpt-5-chat", messages, reasoning_effort="medium"
) )
assert payload_with_reasoning["reasoning_effort"] == "medium" assert body["tools"] == [{"type": "function", "name": "get_weather", "description": "", "parameters": {}}]
assert "temperature" not in payload_with_reasoning assert body["tool_choice"] == "auto"
def test_prepare_request_payload_sanitizes_messages(): def test_build_body_with_reasoning():
"""Test Azure payload strips non-standard message keys before sending.""" provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-5-chat")
provider = AzureOpenAIProvider( body = provider._build_body(
api_key="test-key", [{"role": "user", "content": "think"}], None, "gpt-5-chat", 4096, 0.7, "medium", None,
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
) )
assert body["reasoning"] == {"effort": "medium"}
assert "reasoning.encrypted_content" in body.get("include", [])
# temperature omitted for reasoning models
assert "temperature" not in body
messages = [
{
"role": "assistant",
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
"reasoning_content": "hidden chain-of-thought",
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
"extra_field": "should be removed",
},
]
payload = provider._prepare_request_payload("gpt-4o", messages) def test_build_body_image_conversion():
"""image_url content blocks should be converted to input_image."""
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
messages = [{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "https://example.com/img.png"}},
],
}]
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
user_item = body["input"][0]
content_types = [b["type"] for b in user_item["content"]]
assert "input_text" in content_types
assert "input_image" in content_types
image_block = next(b for b in user_item["content"] if b["type"] == "input_image")
assert image_block["image_url"] == "https://example.com/img.png"
assert payload["messages"] == [
{ # ---------------------------------------------------------------------------
"role": "assistant", # chat() — non-streaming
"content": None, # ---------------------------------------------------------------------------
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
def _make_sdk_response(
content="Hello!", tool_calls=None, status="completed",
usage=None,
):
"""Build a mock that quacks like an openai Response object."""
resp = MagicMock()
resp.model_dump = MagicMock(return_value={
"output": [
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]},
*([{
"type": "function_call",
"call_id": tc["call_id"], "id": tc["id"],
"name": tc["name"], "arguments": tc["arguments"],
} for tc in (tool_calls or [])]),
],
"status": status,
"usage": {
"input_tokens": (usage or {}).get("input_tokens", 10),
"output_tokens": (usage or {}).get("output_tokens", 5),
"total_tokens": (usage or {}).get("total_tokens", 15),
}, },
{ })
"role": "tool", return resp
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
},
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_success(): async def test_chat_success():
"""Test successful chat request using model as deployment name."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
) )
mock_resp = _make_sdk_response(content="Hello!")
# Mock response data provider._client.responses = MagicMock()
mock_response_data = { provider._client.responses.create = AsyncMock(return_value=mock_resp)
"choices": [{
"message": { result = await provider.chat([{"role": "user", "content": "Hi"}])
"content": "Hello! How can I help you today?",
"role": "assistant" assert isinstance(result, LLMResponse)
}, assert result.content == "Hello!"
"finish_reason": "stop" assert result.finish_reason == "stop"
}], assert result.usage["prompt_tokens"] == 10
"usage": {
"prompt_tokens": 12,
"completion_tokens": 18,
"total_tokens": 30
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
# Test with specific model (deployment name)
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages, model="custom-deployment")
assert isinstance(result, LLMResponse)
assert result.content == "Hello! How can I help you today?"
assert result.finish_reason == "stop"
assert result.usage["prompt_tokens"] == 12
assert result.usage["completion_tokens"] == 18
assert result.usage["total_tokens"] == 30
# Verify URL was built with the provided model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_uses_default_model_when_no_model_provided(): async def test_chat_uses_default_model():
"""Test that chat uses default_model when no model is specified."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="k", api_base="https://test.openai.azure.com", default_model="my-deployment",
api_base="https://test-resource.openai.azure.com",
default_model="default-deployment",
) )
mock_resp = _make_sdk_response(content="ok")
mock_response_data = { provider._client.responses = MagicMock()
"choices": [{ provider._client.responses.create = AsyncMock(return_value=mock_resp)
"message": {"content": "Response", "role": "assistant"},
"finish_reason": "stop" await provider.chat([{"role": "user", "content": "test"}])
}],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10} call_kwargs = provider._client.responses.create.call_args[1]
} assert call_kwargs["model"] == "my-deployment"
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock() @pytest.mark.asyncio
mock_response.status_code = 200 async def test_chat_custom_model():
mock_response.json = Mock(return_value=mock_response_data) provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
mock_context = AsyncMock() )
mock_context.post = AsyncMock(return_value=mock_response) mock_resp = _make_sdk_response(content="ok")
mock_client.return_value.__aenter__.return_value = mock_context provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
messages = [{"role": "user", "content": "Test"}]
await provider.chat(messages) # No model specified await provider.chat([{"role": "user", "content": "test"}], model="custom-deploy")
# Verify URL was built with default model as deployment name call_kwargs = provider._client.responses.create.call_args[1]
call_args = mock_context.post.call_args assert call_kwargs["model"] == "custom-deploy"
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_with_tool_calls(): async def test_chat_with_tool_calls():
"""Test chat request with tool calls in response."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
) )
mock_resp = _make_sdk_response(
# Mock response with tool calls content=None,
mock_response_data = { tool_calls=[{
"choices": [{ "call_id": "call_123", "id": "fc_1",
"message": { "name": "get_weather", "arguments": '{"location": "SF"}',
"content": None,
"role": "assistant",
"tool_calls": [{
"id": "call_12345",
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}'
}
}]
},
"finish_reason": "tool_calls"
}], }],
"usage": { )
"prompt_tokens": 20, provider._client.responses = MagicMock()
"completion_tokens": 15, provider._client.responses.create = AsyncMock(return_value=mock_resp)
"total_tokens": 35
} result = await provider.chat(
} [{"role": "user", "content": "Weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
with patch("httpx.AsyncClient") as mock_client: )
mock_response = AsyncMock()
mock_response.status_code = 200 assert len(result.tool_calls) == 1
mock_response.json = Mock(return_value=mock_response_data) assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "SF"}
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
result = await provider.chat(messages, tools=tools, model="weather-model")
assert isinstance(result, LLMResponse)
assert result.content is None
assert result.finish_reason == "tool_calls"
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_api_error(): async def test_chat_error_handling():
"""Test chat request API error handling."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
) )
provider._client.responses = MagicMock()
with patch("httpx.AsyncClient") as mock_client: provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
mock_response = AsyncMock()
mock_response.status_code = 401
mock_response.text = "Invalid authentication credentials"
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Azure OpenAI API Error 401" in result.content
assert "Invalid authentication credentials" in result.content
assert result.finish_reason == "error"
result = await provider.chat([{"role": "user", "content": "Hi"}])
@pytest.mark.asyncio
async def test_chat_connection_error():
"""Test chat request connection error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_context = AsyncMock()
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
assert result.finish_reason == "error"
def test_parse_response_malformed():
"""Test response parsing with malformed data."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test with missing choices
malformed_response = {"usage": {"prompt_tokens": 10}}
result = provider._parse_response(malformed_response)
assert isinstance(result, LLMResponse) assert isinstance(result, LLMResponse)
assert "Error parsing Azure OpenAI response" in result.content assert "Connection failed" in result.content
assert result.finish_reason == "error" assert result.finish_reason == "error"
@pytest.mark.asyncio
async def test_chat_reasoning_param_format():
"""reasoning_effort should be sent as reasoning={effort: ...} not a flat string."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-5-chat",
)
mock_resp = _make_sdk_response(content="thought")
provider._client.responses = MagicMock()
provider._client.responses.create = AsyncMock(return_value=mock_resp)
await provider.chat(
[{"role": "user", "content": "think"}], reasoning_effort="medium",
)
call_kwargs = provider._client.responses.create.call_args[1]
assert call_kwargs["reasoning"] == {"effort": "medium"}
assert "reasoning_effort" not in call_kwargs
# ---------------------------------------------------------------------------
# chat_stream()
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_stream_success():
"""Streaming should call on_content_delta and return combined response."""
provider = AzureOpenAIProvider(
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
# Build SSE lines for the mock httpx stream
sse_events = [
'event: response.output_text.delta',
'data: {"type":"response.output_text.delta","delta":"Hello"}',
'',
'event: response.output_text.delta',
'data: {"type":"response.output_text.delta","delta":" world"}',
'',
'event: response.completed',
'data: {"type":"response.completed","response":{"status":"completed"}}',
'',
]
deltas: list[str] = []
async def on_delta(text: str) -> None:
deltas.append(text)
# Mock httpx stream
mock_response = AsyncMock()
mock_response.status_code = 200
async def aiter_lines():
for line in sse_events:
yield line
mock_response.aiter_lines = aiter_lines
with patch("httpx.AsyncClient") as mock_client:
mock_ctx = AsyncMock()
mock_stream_ctx = AsyncMock()
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
result = await provider.chat_stream(
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
)
assert result.content == "Hello world"
assert result.finish_reason == "stop"
assert deltas == ["Hello", " world"]
@pytest.mark.asyncio
async def test_chat_stream_with_tool_calls():
"""Streaming tool calls should be accumulated correctly."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
sse_events = [
'data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":""}}',
'',
'data: {"type":"response.function_call_arguments.delta","call_id":"call_1","delta":"{\\"loc"}',
'',
'data: {"type":"response.function_call_arguments.done","call_id":"call_1","arguments":"{\\"location\\":\\"SF\\"}"}',
'',
'data: {"type":"response.output_item.done","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":"{\\"location\\":\\"SF\\"}"}}',
'',
'data: {"type":"response.completed","response":{"status":"completed"}}',
'',
]
mock_response = AsyncMock()
mock_response.status_code = 200
async def aiter_lines():
for line in sse_events:
yield line
mock_response.aiter_lines = aiter_lines
with patch("httpx.AsyncClient") as mock_client:
mock_ctx = AsyncMock()
mock_stream_ctx = AsyncMock()
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
result = await provider.chat_stream(
[{"role": "user", "content": "weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "SF"}
@pytest.mark.asyncio
async def test_chat_stream_http_error():
"""Streaming should return error on non-200 status."""
provider = AzureOpenAIProvider(
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
)
mock_response = AsyncMock()
mock_response.status_code = 401
mock_response.aread = AsyncMock(return_value=b"Unauthorized")
with patch("httpx.AsyncClient") as mock_client:
mock_ctx = AsyncMock()
mock_stream_ctx = AsyncMock()
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
assert "401" in result.content
assert result.finish_reason == "error"
# ---------------------------------------------------------------------------
# get_default_model
# ---------------------------------------------------------------------------
def test_get_default_model(): def test_get_default_model():
"""Test get_default_model method."""
provider = AzureOpenAIProvider( provider = AzureOpenAIProvider(
api_key="test-key", api_key="k", api_base="https://r.com", default_model="my-deploy",
api_base="https://test-resource.openai.azure.com",
default_model="my-custom-deployment",
) )
assert provider.get_default_model() == "my-deploy"
assert provider.get_default_model() == "my-custom-deployment"
if __name__ == "__main__":
# Run basic tests
print("Running basic Azure OpenAI provider tests...")
# Test initialization
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
print("✅ Provider initialization successful")
# Test URL building
url = provider._build_chat_url("my-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
print("✅ URL building works correctly")
# Test headers
headers = provider._build_headers()
assert headers["api-key"] == "test-key"
assert headers["Content-Type"] == "application/json"
print("✅ Header building works correctly")
# Test payload preparation
messages = [{"role": "user", "content": "Test"}]
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
print("✅ Payload preparation works correctly")
print("✅ All basic tests passed! Updated test file is working correctly.")