mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
Use OpenAI responses API
This commit is contained in:
parent
9ba413c82e
commit
0417c3f03b
@ -1,31 +1,37 @@
|
||||
"""Azure OpenAI provider implementation with API version 2024-10-21."""
|
||||
"""Azure OpenAI provider using the OpenAI SDK Responses API.
|
||||
|
||||
Uses ``AsyncOpenAI`` pointed at ``https://{endpoint}/openai/v1/`` which
|
||||
routes to the Responses API (``/responses``). Reuses shared conversion
|
||||
helpers from :mod:`nanobot.providers.openai_responses_common`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
from nanobot.providers.openai_responses_common import (
|
||||
consume_sse,
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
parse_response_output,
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAIProvider(LLMProvider):
|
||||
"""
|
||||
Azure OpenAI provider with API version 2024-10-21 compliance.
|
||||
|
||||
"""Azure OpenAI provider backed by the Responses API.
|
||||
|
||||
Features:
|
||||
- Hardcoded API version 2024-10-21
|
||||
- Uses model field as Azure deployment name in URL path
|
||||
- Uses api-key header instead of Authorization Bearer
|
||||
- Uses max_completion_tokens instead of max_tokens
|
||||
- Direct HTTP calls, bypasses LiteLLM
|
||||
- Uses the OpenAI Python SDK (``AsyncOpenAI``) with
|
||||
``base_url = {endpoint}/openai/v1/``
|
||||
- Calls ``client.responses.create()`` (Responses API)
|
||||
- Reuses shared message/tool/SSE conversion from
|
||||
``openai_responses_common``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -36,40 +42,28 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self.api_version = "2024-10-21"
|
||||
|
||||
# Validate required parameters
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("Azure OpenAI api_key is required")
|
||||
if not api_base:
|
||||
raise ValueError("Azure OpenAI api_base is required")
|
||||
|
||||
# Ensure api_base ends with /
|
||||
if not api_base.endswith('/'):
|
||||
api_base += '/'
|
||||
|
||||
# Normalise: ensure trailing slash
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
self.api_base = api_base
|
||||
|
||||
def _build_chat_url(self, deployment_name: str) -> str:
|
||||
"""Build the Azure OpenAI chat completions URL."""
|
||||
# Azure OpenAI URL format:
|
||||
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
|
||||
base_url = self.api_base
|
||||
if not base_url.endswith('/'):
|
||||
base_url += '/'
|
||||
|
||||
url = urljoin(
|
||||
base_url,
|
||||
f"openai/deployments/{deployment_name}/chat/completions"
|
||||
# SDK client targeting the Azure Responses API endpoint
|
||||
base_url = f"{api_base.rstrip('/')}/openai/v1/"
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
default_headers={"x-session-affinity": uuid.uuid4().hex},
|
||||
)
|
||||
return f"{url}?api-version={self.api_version}"
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
"""Build headers for Azure OpenAI API with api-key header."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
|
||||
"x-session-affinity": uuid.uuid4().hex, # For cache locality
|
||||
}
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _supports_temperature(
|
||||
@ -82,36 +76,50 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
name = deployment_name.lower()
|
||||
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||
|
||||
def _prepare_request_payload(
|
||||
def _build_body(
|
||||
self,
|
||||
deployment_name: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
||||
payload: dict[str, Any] = {
|
||||
"messages": self._sanitize_request_messages(
|
||||
self._sanitize_empty_content(messages),
|
||||
_AZURE_MSG_KEYS,
|
||||
),
|
||||
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
||||
"""Build the Responses API request body from Chat-Completions-style args."""
|
||||
deployment = model or self.default_model
|
||||
instructions, input_items = convert_messages(messages)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"model": deployment,
|
||||
"instructions": instructions or None,
|
||||
"input": input_items,
|
||||
"store": False,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
if self._supports_temperature(deployment_name, reasoning_effort):
|
||||
payload["temperature"] = temperature
|
||||
if self._supports_temperature(deployment, reasoning_effort):
|
||||
body["temperature"] = temperature
|
||||
|
||||
if reasoning_effort:
|
||||
payload["reasoning_effort"] = reasoning_effort
|
||||
body["reasoning"] = {"effort": reasoning_effort}
|
||||
body["include"] = ["reasoning.encrypted_content"]
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = tool_choice or "auto"
|
||||
body["tools"] = convert_tools(tools)
|
||||
body["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
return payload
|
||||
return body
|
||||
|
||||
@staticmethod
|
||||
def _handle_error(e: Exception) -> LLMResponse:
|
||||
body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None)
|
||||
msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}"
|
||||
return LLMResponse(content=msg, finish_reason="error")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@ -123,92 +131,15 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request to Azure OpenAI.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions in OpenAI format.
|
||||
model: Model identifier (used as deployment name).
|
||||
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
|
||||
temperature: Sampling temperature.
|
||||
reasoning_effort: Optional reasoning effort parameter.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
deployment_name = model or self.default_model
|
||||
url = self._build_chat_url(deployment_name)
|
||||
headers = self._build_headers()
|
||||
payload = self._prepare_request_payload(
|
||||
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
body = self._build_body(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
return LLMResponse(
|
||||
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
return self._parse_response(response_data)
|
||||
|
||||
response = await self._client.responses.create(**body)
|
||||
return parse_response_output(response)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Azure OpenAI: {repr(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
|
||||
"""Parse Azure OpenAI response into our standard format."""
|
||||
try:
|
||||
choice = response["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
tool_calls = []
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
usage = {}
|
||||
if response.get("usage"):
|
||||
usage_data = response["usage"]
|
||||
usage = {
|
||||
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage_data.get("completion_tokens", 0),
|
||||
"total_tokens": usage_data.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
reasoning_content = message.get("reasoning_content") or None
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content"),
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.get("finish_reason", "stop"),
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
return LLMResponse(
|
||||
content=f"Error parsing Azure OpenAI response: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
return self._handle_error(e)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
@ -221,89 +152,40 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Stream a chat completion via Azure OpenAI SSE."""
|
||||
deployment_name = model or self.default_model
|
||||
url = self._build_chat_url(deployment_name)
|
||||
headers = self._build_headers()
|
||||
payload = self._prepare_request_payload(
|
||||
deployment_name, messages, tools, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice=tool_choice,
|
||||
body = self._build_body(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
payload["stream"] = True
|
||||
body["stream"] = True
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
# Use raw httpx stream via the SDK's base URL so we can reuse
|
||||
# the shared Responses-API SSE parser (same as Codex provider).
|
||||
base_url = str(self._client.base_url).rstrip("/")
|
||||
url = f"{base_url}/responses"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._client.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**(self._client._custom_headers or {}),
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as http:
|
||||
async with http.stream("POST", url, headers=headers, json=body) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
return LLMResponse(
|
||||
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
|
||||
finish_reason="error",
|
||||
)
|
||||
return await self._consume_stream(response, on_content_delta)
|
||||
content, tool_calls, finish_reason = await consume_sse(
|
||||
response, on_content_delta,
|
||||
)
|
||||
return LLMResponse(
|
||||
content=content or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error")
|
||||
|
||||
async def _consume_stream(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None,
|
||||
) -> LLMResponse:
|
||||
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
|
||||
content_parts: list[str] = []
|
||||
tool_call_buffers: dict[int, dict[str, str]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:].strip()
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
choices = chunk.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
choice = choices[0]
|
||||
if choice.get("finish_reason"):
|
||||
finish_reason = choice["finish_reason"]
|
||||
delta = choice.get("delta") or {}
|
||||
|
||||
text = delta.get("content")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
if on_content_delta:
|
||||
await on_content_delta(text)
|
||||
|
||||
for tc in delta.get("tool_calls") or []:
|
||||
idx = tc.get("index", 0)
|
||||
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
|
||||
if tc.get("id"):
|
||||
buf["id"] = tc["id"]
|
||||
fn = tc.get("function") or {}
|
||||
if fn.get("name"):
|
||||
buf["name"] = fn["name"]
|
||||
if fn.get("arguments"):
|
||||
buf["arguments"] += fn["arguments"]
|
||||
|
||||
tool_calls = [
|
||||
ToolCallRequest(
|
||||
id=buf["id"], name=buf["name"],
|
||||
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
|
||||
)
|
||||
for buf in tool_call_buffers.values()
|
||||
]
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
return self._handle_error(e)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model (also used as default deployment name)."""
|
||||
return self.default_model
|
||||
@ -6,13 +6,18 @@ import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from oauth_cli_kit import get_token as get_codex_token
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.openai_responses_common import (
|
||||
consume_sse,
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
)
|
||||
|
||||
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
DEFAULT_ORIGINATOR = "nanobot"
|
||||
@ -36,7 +41,7 @@ class OpenAICodexProvider(LLMProvider):
|
||||
) -> LLMResponse:
|
||||
"""Shared request logic for both chat() and chat_stream()."""
|
||||
model = model or self.default_model
|
||||
system_prompt, input_items = _convert_messages(messages)
|
||||
system_prompt, input_items = convert_messages(messages)
|
||||
|
||||
token = await asyncio.to_thread(get_codex_token)
|
||||
headers = _build_headers(token.account_id, token.access)
|
||||
@ -56,7 +61,7 @@ class OpenAICodexProvider(LLMProvider):
|
||||
if reasoning_effort:
|
||||
body["reasoning"] = {"effort": reasoning_effort}
|
||||
if tools:
|
||||
body["tools"] = _convert_tools(tools)
|
||||
body["tools"] = convert_tools(tools)
|
||||
|
||||
try:
|
||||
try:
|
||||
@ -127,96 +132,7 @@ async def _request_codex(
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
||||
return await _consume_sse(response, on_content_delta)
|
||||
|
||||
|
||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI function-calling schema to Codex flat format."""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
params = fn.get("parameters") or {}
|
||||
converted.append({
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
})
|
||||
return converted
|
||||
|
||||
|
||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
system_prompt = ""
|
||||
input_items: list[dict[str, Any]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "system":
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
|
||||
if role == "user":
|
||||
input_items.append(_convert_user_message(content))
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
if isinstance(content, str) and content:
|
||||
input_items.append({
|
||||
"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed", "id": f"msg_{idx}",
|
||||
})
|
||||
for tool_call in msg.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
||||
input_items.append({
|
||||
"type": "function_call",
|
||||
"id": item_id or f"fc_{idx}",
|
||||
"call_id": call_id or f"call_{idx}",
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
})
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
|
||||
|
||||
return system_prompt, input_items
|
||||
|
||||
|
||||
def _convert_user_message(content: Any) -> dict[str, Any]:
|
||||
if isinstance(content, str):
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
if isinstance(content, list):
|
||||
converted: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||
elif item.get("type") == "image_url":
|
||||
url = (item.get("image_url") or {}).get("url")
|
||||
if url:
|
||||
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
|
||||
if converted:
|
||||
return {"role": "user", "content": converted}
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
|
||||
|
||||
|
||||
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
if "|" in tool_call_id:
|
||||
call_id, item_id = tool_call_id.split("|", 1)
|
||||
return call_id, item_id or None
|
||||
return tool_call_id, None
|
||||
return "call_0", None
|
||||
return await consume_sse(response, on_content_delta)
|
||||
|
||||
|
||||
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||
@ -224,96 +140,6 @@ def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||
buffer: list[str] = []
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
if buffer:
|
||||
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
||||
buffer = []
|
||||
if not data_lines:
|
||||
continue
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data or data == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
yield json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
|
||||
async def _consume_sse(
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
content = ""
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for event in _iter_sse(response):
|
||||
event_type = event.get("type")
|
||||
if event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": item.get("id") or "fc_0",
|
||||
"name": item.get("name"),
|
||||
"arguments": item.get("arguments") or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
delta_text = event.get("delta") or ""
|
||||
content += delta_text
|
||||
if on_content_delta and delta_text:
|
||||
await on_content_delta(delta_text)
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||
name=buf.get("name") or item.get("name"),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
status = (event.get("response") or {}).get("status")
|
||||
finish_reason = _map_finish_reason(status)
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
raise RuntimeError("Codex response failed")
|
||||
|
||||
return content, tool_calls, finish_reason
|
||||
|
||||
|
||||
_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"}
|
||||
|
||||
|
||||
def _map_finish_reason(status: str | None) -> str:
|
||||
return _FINISH_REASON_MAP.get(status or "completed", "stop")
|
||||
|
||||
|
||||
def _friendly_error(status_code: int, raw: str) -> str:
|
||||
if status_code == 429:
|
||||
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."
|
||||
|
||||
27
nanobot/providers/openai_responses_common/__init__.py
Normal file
27
nanobot/providers/openai_responses_common/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Shared helpers for OpenAI Responses API providers (Codex, Azure OpenAI)."""
|
||||
|
||||
from nanobot.providers.openai_responses_common.converters import (
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
convert_user_message,
|
||||
split_tool_call_id,
|
||||
)
|
||||
from nanobot.providers.openai_responses_common.parsing import (
|
||||
FINISH_REASON_MAP,
|
||||
consume_sse,
|
||||
iter_sse,
|
||||
map_finish_reason,
|
||||
parse_response_output,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"convert_messages",
|
||||
"convert_tools",
|
||||
"convert_user_message",
|
||||
"split_tool_call_id",
|
||||
"iter_sse",
|
||||
"consume_sse",
|
||||
"map_finish_reason",
|
||||
"parse_response_output",
|
||||
"FINISH_REASON_MAP",
|
||||
]
|
||||
110
nanobot/providers/openai_responses_common/converters.py
Normal file
110
nanobot/providers/openai_responses_common/converters.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Convert Chat Completions messages/tools to Responses API format."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""Convert Chat Completions messages to Responses API input items.
|
||||
|
||||
Returns ``(system_prompt, input_items)`` where *system_prompt* is extracted
|
||||
from any ``system`` role message and *input_items* is the Responses API
|
||||
``input`` array.
|
||||
"""
|
||||
system_prompt = ""
|
||||
input_items: list[dict[str, Any]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "system":
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
|
||||
if role == "user":
|
||||
input_items.append(convert_user_message(content))
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
if isinstance(content, str) and content:
|
||||
input_items.append({
|
||||
"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed", "id": f"msg_{idx}",
|
||||
})
|
||||
for tool_call in msg.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = split_tool_call_id(tool_call.get("id"))
|
||||
input_items.append({
|
||||
"type": "function_call",
|
||||
"id": item_id or f"fc_{idx}",
|
||||
"call_id": call_id or f"call_{idx}",
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
})
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
call_id, _ = split_tool_call_id(msg.get("tool_call_id"))
|
||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
|
||||
|
||||
return system_prompt, input_items
|
||||
|
||||
|
||||
def convert_user_message(content: Any) -> dict[str, Any]:
|
||||
"""Convert a user message's content to Responses API format.
|
||||
|
||||
Handles plain strings, ``text`` blocks → ``input_text``, and
|
||||
``image_url`` blocks → ``input_image``.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
if isinstance(content, list):
|
||||
converted: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||
elif item.get("type") == "image_url":
|
||||
url = (item.get("image_url") or {}).get("url")
|
||||
if url:
|
||||
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
|
||||
if converted:
|
||||
return {"role": "user", "content": converted}
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
|
||||
|
||||
|
||||
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI function-calling tool schema to Responses API flat format."""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
params = fn.get("parameters") or {}
|
||||
converted.append({
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
})
|
||||
return converted
|
||||
|
||||
|
||||
def split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
||||
"""Split a compound ``call_id|item_id`` string.
|
||||
|
||||
Returns ``(call_id, item_id)`` where *item_id* may be ``None``.
|
||||
"""
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
if "|" in tool_call_id:
|
||||
call_id, item_id = tool_call_id.split("|", 1)
|
||||
return call_id, item_id or None
|
||||
return tool_call_id, None
|
||||
return "call_0", None
|
||||
173
nanobot/providers/openai_responses_common/parsing.py
Normal file
173
nanobot/providers/openai_responses_common/parsing.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""Parse Responses API SSE streams and SDK response objects."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
FINISH_REASON_MAP = {
|
||||
"completed": "stop",
|
||||
"incomplete": "length",
|
||||
"failed": "error",
|
||||
"cancelled": "error",
|
||||
}
|
||||
|
||||
|
||||
def map_finish_reason(status: str | None) -> str:
|
||||
"""Map a Responses API status string to a Chat-Completions-style finish_reason."""
|
||||
return FINISH_REASON_MAP.get(status or "completed", "stop")
|
||||
|
||||
|
||||
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""Yield parsed JSON events from a Responses API SSE stream."""
|
||||
buffer: list[str] = []
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
if buffer:
|
||||
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
||||
buffer = []
|
||||
if not data_lines:
|
||||
continue
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data or data == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
yield json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
|
||||
async def consume_sse(
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
"""Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``."""
|
||||
content = ""
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for event in iter_sse(response):
|
||||
event_type = event.get("type")
|
||||
if event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": item.get("id") or "fc_0",
|
||||
"name": item.get("name"),
|
||||
"arguments": item.get("arguments") or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
delta_text = event.get("delta") or ""
|
||||
content += delta_text
|
||||
if on_content_delta and delta_text:
|
||||
await on_content_delta(delta_text)
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||
name=buf.get("name") or item.get("name"),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
status = (event.get("response") or {}).get("status")
|
||||
finish_reason = map_finish_reason(status)
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
raise RuntimeError("Response failed")
|
||||
|
||||
return content, tool_calls, finish_reason
|
||||
|
||||
|
||||
def parse_response_output(response: Any) -> LLMResponse:
|
||||
"""Parse an SDK ``Response`` object (from ``client.responses.create()``)
|
||||
into an ``LLMResponse``.
|
||||
|
||||
Works with both Pydantic model objects and plain dicts.
|
||||
"""
|
||||
# Normalise to dict
|
||||
if not isinstance(response, dict):
|
||||
dump = getattr(response, "model_dump", None)
|
||||
response = dump() if callable(dump) else vars(response)
|
||||
|
||||
output = response.get("output") or []
|
||||
content_parts: list[str] = []
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
|
||||
for item in output:
|
||||
if not isinstance(item, dict):
|
||||
dump = getattr(item, "model_dump", None)
|
||||
item = dump() if callable(dump) else vars(item)
|
||||
|
||||
item_type = item.get("type")
|
||||
if item_type == "message":
|
||||
for block in item.get("content") or []:
|
||||
if not isinstance(block, dict):
|
||||
dump = getattr(block, "model_dump", None)
|
||||
block = dump() if callable(dump) else vars(block)
|
||||
if block.get("type") == "output_text":
|
||||
content_parts.append(block.get("text") or "")
|
||||
elif item_type == "function_call":
|
||||
call_id = item.get("call_id") or ""
|
||||
item_id = item.get("id") or "fc_0"
|
||||
args_raw = item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
except Exception:
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=f"{call_id}|{item_id}",
|
||||
name=item.get("name") or "",
|
||||
arguments=args if isinstance(args, dict) else {},
|
||||
))
|
||||
|
||||
usage_raw = response.get("usage") or {}
|
||||
if not isinstance(usage_raw, dict):
|
||||
dump = getattr(usage_raw, "model_dump", None)
|
||||
usage_raw = dump() if callable(dump) else vars(usage_raw)
|
||||
usage = {}
|
||||
if usage_raw:
|
||||
usage = {
|
||||
"prompt_tokens": int(usage_raw.get("input_tokens") or 0),
|
||||
"completion_tokens": int(usage_raw.get("output_tokens") or 0),
|
||||
"total_tokens": int(usage_raw.get("total_tokens") or 0),
|
||||
}
|
||||
|
||||
status = response.get("status")
|
||||
finish_reason = map_finish_reason(status)
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
@ -1,6 +1,6 @@
|
||||
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
||||
"""Test Azure OpenAI provider (Responses API via OpenAI SDK)."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -8,392 +8,415 @@ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def test_azure_openai_provider_init():
|
||||
"""Test AzureOpenAIProvider initialization without deployment_name."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# Init & validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_init_creates_sdk_client():
|
||||
"""Provider creates an AsyncOpenAI client with correct base_url."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.api_base == "https://test-resource.openai.azure.com/"
|
||||
assert provider.default_model == "gpt-4o-deployment"
|
||||
assert provider.api_version == "2024-10-21"
|
||||
# SDK client base_url ends with /openai/v1/
|
||||
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
|
||||
|
||||
|
||||
def test_azure_openai_provider_init_validation():
|
||||
"""Test AzureOpenAIProvider initialization validation."""
|
||||
# Missing api_key
|
||||
def test_init_base_url_no_trailing_slash():
|
||||
"""Trailing slashes are normalised before building base_url."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://res.openai.azure.com",
|
||||
)
|
||||
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
|
||||
|
||||
|
||||
def test_init_base_url_with_trailing_slash():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://res.openai.azure.com/",
|
||||
)
|
||||
assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1")
|
||||
|
||||
|
||||
def test_init_validation_missing_key():
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
|
||||
AzureOpenAIProvider(api_key="", api_base="https://test.com")
|
||||
|
||||
# Missing api_base
|
||||
|
||||
|
||||
def test_init_validation_missing_base():
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
|
||||
AzureOpenAIProvider(api_key="test", api_base="")
|
||||
|
||||
|
||||
def test_build_chat_url():
|
||||
"""Test Azure OpenAI URL building with different deployment names."""
|
||||
def test_no_api_version_in_base_url():
|
||||
"""The /openai/v1/ path should NOT contain an api-version query param."""
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://res.openai.azure.com")
|
||||
base = str(provider._client.base_url)
|
||||
assert "api-version" not in base
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _supports_temperature
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_supports_temperature_standard_model():
|
||||
assert AzureOpenAIProvider._supports_temperature("gpt-4o") is True
|
||||
|
||||
|
||||
def test_supports_temperature_reasoning_model():
|
||||
assert AzureOpenAIProvider._supports_temperature("o3-mini") is False
|
||||
assert AzureOpenAIProvider._supports_temperature("gpt-5-chat") is False
|
||||
assert AzureOpenAIProvider._supports_temperature("o4-mini") is False
|
||||
|
||||
|
||||
def test_supports_temperature_with_reasoning_effort():
|
||||
assert AzureOpenAIProvider._supports_temperature("gpt-4o", reasoning_effort="medium") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_body — Responses API body construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_body_basic():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
api_key="k", api_base="https://res.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test various deployment names
|
||||
test_cases = [
|
||||
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
|
||||
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
|
||||
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
|
||||
]
|
||||
|
||||
for deployment_name, expected_url in test_cases:
|
||||
url = provider._build_chat_url(deployment_name)
|
||||
assert url == expected_url
|
||||
messages = [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}]
|
||||
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
|
||||
|
||||
|
||||
def test_build_chat_url_api_base_without_slash():
|
||||
"""Test URL building when api_base doesn't end with slash."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com", # No trailing slash
|
||||
default_model="gpt-4o",
|
||||
assert body["model"] == "gpt-4o"
|
||||
assert body["instructions"] == "You are helpful."
|
||||
assert body["temperature"] == 0.7
|
||||
assert body["store"] is False
|
||||
assert "reasoning" not in body
|
||||
# input should contain the converted user message only (system extracted)
|
||||
assert any(
|
||||
item.get("role") == "user"
|
||||
for item in body["input"]
|
||||
)
|
||||
|
||||
url = provider._build_chat_url("test-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
|
||||
|
||||
def test_build_headers():
|
||||
"""Test Azure OpenAI header building with api-key authentication."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-api-key-123",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
headers = provider._build_headers()
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
|
||||
assert "x-session-affinity" in headers
|
||||
|
||||
|
||||
def test_prepare_request_payload():
|
||||
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
|
||||
|
||||
assert payload["messages"] == messages
|
||||
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
||||
assert payload["temperature"] == 0.8
|
||||
assert "tools" not in payload
|
||||
|
||||
# Test with tools
|
||||
def test_build_body_with_tools():
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
|
||||
assert payload_with_tools["tools"] == tools
|
||||
assert payload_with_tools["tool_choice"] == "auto"
|
||||
|
||||
# Test with reasoning_effort
|
||||
payload_with_reasoning = provider._prepare_request_payload(
|
||||
"gpt-5-chat", messages, reasoning_effort="medium"
|
||||
body = provider._build_body(
|
||||
[{"role": "user", "content": "weather?"}], tools, None, 4096, 0.7, None, None,
|
||||
)
|
||||
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
||||
assert "temperature" not in payload_with_reasoning
|
||||
assert body["tools"] == [{"type": "function", "name": "get_weather", "description": "", "parameters": {}}]
|
||||
assert body["tool_choice"] == "auto"
|
||||
|
||||
|
||||
def test_prepare_request_payload_sanitizes_messages():
|
||||
"""Test Azure payload strips non-standard message keys before sending."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
def test_build_body_with_reasoning():
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-5-chat")
|
||||
body = provider._build_body(
|
||||
[{"role": "user", "content": "think"}], None, "gpt-5-chat", 4096, 0.7, "medium", None,
|
||||
)
|
||||
assert body["reasoning"] == {"effort": "medium"}
|
||||
assert "reasoning.encrypted_content" in body.get("include", [])
|
||||
# temperature omitted for reasoning models
|
||||
assert "temperature" not in body
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
"reasoning_content": "hidden chain-of-thought",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
"extra_field": "should be removed",
|
||||
},
|
||||
]
|
||||
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages)
|
||||
def test_build_body_image_conversion():
|
||||
"""image_url content blocks should be converted to input_image."""
|
||||
provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o")
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/img.png"}},
|
||||
],
|
||||
}]
|
||||
body = provider._build_body(messages, None, None, 4096, 0.7, None, None)
|
||||
user_item = body["input"][0]
|
||||
content_types = [b["type"] for b in user_item["content"]]
|
||||
assert "input_text" in content_types
|
||||
assert "input_image" in content_types
|
||||
image_block = next(b for b in user_item["content"] if b["type"] == "input_image")
|
||||
assert image_block["image_url"] == "https://example.com/img.png"
|
||||
|
||||
assert payload["messages"] == [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# chat() — non-streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_sdk_response(
|
||||
content="Hello!", tool_calls=None, status="completed",
|
||||
usage=None,
|
||||
):
|
||||
"""Build a mock that quacks like an openai Response object."""
|
||||
resp = MagicMock()
|
||||
resp.model_dump = MagicMock(return_value={
|
||||
"output": [
|
||||
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]},
|
||||
*([{
|
||||
"type": "function_call",
|
||||
"call_id": tc["call_id"], "id": tc["id"],
|
||||
"name": tc["name"], "arguments": tc["arguments"],
|
||||
} for tc in (tool_calls or [])]),
|
||||
],
|
||||
"status": status,
|
||||
"usage": {
|
||||
"input_tokens": (usage or {}).get("input_tokens", 10),
|
||||
"output_tokens": (usage or {}).get("output_tokens", 5),
|
||||
"total_tokens": (usage or {}).get("total_tokens", 15),
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
},
|
||||
]
|
||||
})
|
||||
return resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_success():
|
||||
"""Test successful chat request using model as deployment name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Mock response data
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "Hello! How can I help you today?",
|
||||
"role": "assistant"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 18,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
# Test with specific model (deployment name)
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages, model="custom-deployment")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content == "Hello! How can I help you today?"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage["prompt_tokens"] == 12
|
||||
assert result.usage["completion_tokens"] == 18
|
||||
assert result.usage["total_tokens"] == 30
|
||||
|
||||
# Verify URL was built with the provided model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
mock_resp = _make_sdk_response(content="Hello!")
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
result = await provider.chat([{"role": "user", "content": "Hi"}])
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content == "Hello!"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage["prompt_tokens"] == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_uses_default_model_when_no_model_provided():
|
||||
"""Test that chat uses default_model when no model is specified."""
|
||||
async def test_chat_uses_default_model():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="default-deployment",
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="my-deployment",
|
||||
)
|
||||
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {"content": "Response", "role": "assistant"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
await provider.chat(messages) # No model specified
|
||||
|
||||
# Verify URL was built with default model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
mock_resp = _make_sdk_response(content="ok")
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
await provider.chat([{"role": "user", "content": "test"}])
|
||||
|
||||
call_kwargs = provider._client.responses.create.call_args[1]
|
||||
assert call_kwargs["model"] == "my-deployment"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_custom_model():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
mock_resp = _make_sdk_response(content="ok")
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
await provider.chat([{"role": "user", "content": "test"}], model="custom-deploy")
|
||||
|
||||
call_kwargs = provider._client.responses.create.call_args[1]
|
||||
assert call_kwargs["model"] == "custom-deploy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_tool_calls():
|
||||
"""Test chat request with tool calls in response."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Mock response with tool calls
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": None,
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_12345",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}'
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
mock_resp = _make_sdk_response(
|
||||
content=None,
|
||||
tool_calls=[{
|
||||
"call_id": "call_123", "id": "fc_1",
|
||||
"name": "get_weather", "arguments": '{"location": "SF"}',
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
result = await provider.chat(messages, tools=tools, model="weather-model")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content is None
|
||||
assert result.finish_reason == "tool_calls"
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
||||
)
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
result = await provider.chat(
|
||||
[{"role": "user", "content": "Weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "SF"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_api_error():
|
||||
"""Test chat request API error handling."""
|
||||
async def test_chat_error_handling():
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Invalid authentication credentials"
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Azure OpenAI API Error 401" in result.content
|
||||
assert "Invalid authentication credentials" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
|
||||
result = await provider.chat([{"role": "user", "content": "Hi"}])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_connection_error():
|
||||
"""Test chat request connection error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_parse_response_malformed():
|
||||
"""Test response parsing with malformed data."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test with missing choices
|
||||
malformed_response = {"usage": {"prompt_tokens": 10}}
|
||||
result = provider._parse_response(malformed_response)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error parsing Azure OpenAI response" in result.content
|
||||
assert "Connection failed" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_reasoning_param_format():
|
||||
"""reasoning_effort should be sent as reasoning={effort: ...} not a flat string."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-5-chat",
|
||||
)
|
||||
mock_resp = _make_sdk_response(content="thought")
|
||||
provider._client.responses = MagicMock()
|
||||
provider._client.responses.create = AsyncMock(return_value=mock_resp)
|
||||
|
||||
await provider.chat(
|
||||
[{"role": "user", "content": "think"}], reasoning_effort="medium",
|
||||
)
|
||||
|
||||
call_kwargs = provider._client.responses.create.call_args[1]
|
||||
assert call_kwargs["reasoning"] == {"effort": "medium"}
|
||||
assert "reasoning_effort" not in call_kwargs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# chat_stream()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_success():
|
||||
"""Streaming should call on_content_delta and return combined response."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Build SSE lines for the mock httpx stream
|
||||
sse_events = [
|
||||
'event: response.output_text.delta',
|
||||
'data: {"type":"response.output_text.delta","delta":"Hello"}',
|
||||
'',
|
||||
'event: response.output_text.delta',
|
||||
'data: {"type":"response.output_text.delta","delta":" world"}',
|
||||
'',
|
||||
'event: response.completed',
|
||||
'data: {"type":"response.completed","response":{"status":"completed"}}',
|
||||
'',
|
||||
]
|
||||
|
||||
deltas: list[str] = []
|
||||
|
||||
async def on_delta(text: str) -> None:
|
||||
deltas.append(text)
|
||||
|
||||
# Mock httpx stream
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
async def aiter_lines():
|
||||
for line in sse_events:
|
||||
yield line
|
||||
|
||||
mock_response.aiter_lines = aiter_lines
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_stream_ctx = AsyncMock()
|
||||
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
result = await provider.chat_stream(
|
||||
[{"role": "user", "content": "Hi"}], on_content_delta=on_delta,
|
||||
)
|
||||
|
||||
assert result.content == "Hello world"
|
||||
assert result.finish_reason == "stop"
|
||||
assert deltas == ["Hello", " world"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_with_tool_calls():
|
||||
"""Streaming tool calls should be accumulated correctly."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
sse_events = [
|
||||
'data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":""}}',
|
||||
'',
|
||||
'data: {"type":"response.function_call_arguments.delta","call_id":"call_1","delta":"{\\"loc"}',
|
||||
'',
|
||||
'data: {"type":"response.function_call_arguments.done","call_id":"call_1","arguments":"{\\"location\\":\\"SF\\"}"}',
|
||||
'',
|
||||
'data: {"type":"response.output_item.done","item":{"type":"function_call","call_id":"call_1","id":"fc_1","name":"get_weather","arguments":"{\\"location\\":\\"SF\\"}"}}',
|
||||
'',
|
||||
'data: {"type":"response.completed","response":{"status":"completed"}}',
|
||||
'',
|
||||
]
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
async def aiter_lines():
|
||||
for line in sse_events:
|
||||
yield line
|
||||
|
||||
mock_response.aiter_lines = aiter_lines
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_stream_ctx = AsyncMock()
|
||||
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
result = await provider.chat_stream(
|
||||
[{"role": "user", "content": "weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "SF"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_http_error():
|
||||
"""Streaming should return error on non-200 status."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o",
|
||||
)
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.aread = AsyncMock(return_value=b"Unauthorized")
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_stream_ctx = AsyncMock()
|
||||
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_stream_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_ctx.stream = MagicMock(return_value=mock_stream_ctx)
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
result = await provider.chat_stream([{"role": "user", "content": "Hi"}])
|
||||
|
||||
assert "401" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_default_model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_default_model():
|
||||
"""Test get_default_model method."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="my-custom-deployment",
|
||||
api_key="k", api_base="https://r.com", default_model="my-deploy",
|
||||
)
|
||||
|
||||
assert provider.get_default_model() == "my-custom-deployment"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run basic tests
|
||||
print("Running basic Azure OpenAI provider tests...")
|
||||
|
||||
# Test initialization
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
print("✅ Provider initialization successful")
|
||||
|
||||
# Test URL building
|
||||
url = provider._build_chat_url("my-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
print("✅ URL building works correctly")
|
||||
|
||||
# Test headers
|
||||
headers = provider._build_headers()
|
||||
assert headers["api-key"] == "test-key"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
print("✅ Header building works correctly")
|
||||
|
||||
# Test payload preparation
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
|
||||
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
||||
print("✅ Payload preparation works correctly")
|
||||
|
||||
print("✅ All basic tests passed! Updated test file is working correctly.")
|
||||
assert provider.get_default_model() == "my-deploy"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user