Fill up gaps

This commit is contained in:
Kunal Karmakar 2026-03-31 02:29:40 +00:00 committed by Xubin Ren
parent 8c0607e079
commit 7c44aa92ca
2 changed files with 37 additions and 5 deletions

View File

@ -160,13 +160,15 @@ class AzureOpenAIProvider(LLMProvider):
try: try:
stream = await self._client.responses.create(**body) stream = await self._client.responses.create(**body)
content, tool_calls, finish_reason = await consume_sdk_stream( content, tool_calls, finish_reason, usage, reasoning_content = (
stream, on_content_delta, await consume_sdk_stream(stream, on_content_delta)
) )
return LLMResponse( return LLMResponse(
content=content or None, content=content or None,
tool_calls=tool_calls, tool_calls=tool_calls,
finish_reason=finish_reason, finish_reason=finish_reason,
usage=usage,
reasoning_content=reasoning_content,
) )
except Exception as e: except Exception as e:
return self._handle_error(e) return self._handle_error(e)

View File

@ -122,6 +122,7 @@ def parse_response_output(response: Any) -> LLMResponse:
output = response.get("output") or [] output = response.get("output") or []
content_parts: list[str] = [] content_parts: list[str] = []
tool_calls: list[ToolCallRequest] = [] tool_calls: list[ToolCallRequest] = []
reasoning_content: str | None = None
for item in output: for item in output:
if not isinstance(item, dict): if not isinstance(item, dict):
@ -136,6 +137,14 @@ def parse_response_output(response: Any) -> LLMResponse:
block = dump() if callable(dump) else vars(block) block = dump() if callable(dump) else vars(block)
if block.get("type") == "output_text": if block.get("type") == "output_text":
content_parts.append(block.get("text") or "") content_parts.append(block.get("text") or "")
elif item_type == "reasoning":
# Reasoning items may have a summary list with text blocks
for s in item.get("summary") or []:
if not isinstance(s, dict):
dump = getattr(s, "model_dump", None)
s = dump() if callable(dump) else vars(s)
if s.get("type") == "summary_text" and s.get("text"):
reasoning_content = (reasoning_content or "") + s["text"]
elif item_type == "function_call": elif item_type == "function_call":
call_id = item.get("call_id") or "" call_id = item.get("call_id") or ""
item_id = item.get("id") or "fc_0" item_id = item.get("id") or "fc_0"
@ -170,22 +179,26 @@ def parse_response_output(response: Any) -> LLMResponse:
tool_calls=tool_calls, tool_calls=tool_calls,
finish_reason=finish_reason, finish_reason=finish_reason,
usage=usage, usage=usage,
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
) )
async def consume_sdk_stream( async def consume_sdk_stream(
stream: Any, stream: Any,
on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]: ) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]:
"""Consume an SDK async stream from ``client.responses.create(stream=True)``. """Consume an SDK async stream from ``client.responses.create(stream=True)``.
The SDK yields typed event objects with a ``.type`` attribute and The SDK yields typed event objects with a ``.type`` attribute and
event-specific fields. Returns ``(content, tool_calls, finish_reason)``. event-specific fields. Returns
``(content, tool_calls, finish_reason, usage, reasoning_content)``.
""" """
content = "" content = ""
tool_calls: list[ToolCallRequest] = [] tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {} tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop" finish_reason = "stop"
usage: dict[str, int] = {}
reasoning_content: str | None = None
async for event in stream: async for event in stream:
event_type = getattr(event, "type", None) event_type = getattr(event, "type", None)
@ -236,7 +249,24 @@ async def consume_sdk_stream(
resp = getattr(event, "response", None) resp = getattr(event, "response", None)
status = getattr(resp, "status", None) if resp else None status = getattr(resp, "status", None) if resp else None
finish_reason = map_finish_reason(status) finish_reason = map_finish_reason(status)
# Extract usage from the completed response
if resp:
usage_obj = getattr(resp, "usage", None)
if usage_obj:
usage = {
"prompt_tokens": int(getattr(usage_obj, "input_tokens", 0) or 0),
"completion_tokens": int(getattr(usage_obj, "output_tokens", 0) or 0),
"total_tokens": int(getattr(usage_obj, "total_tokens", 0) or 0),
}
# Extract reasoning_content from completed output items
for out_item in getattr(resp, "output", None) or []:
if getattr(out_item, "type", None) == "reasoning":
for s in getattr(out_item, "summary", None) or []:
if getattr(s, "type", None) == "summary_text":
text = getattr(s, "text", None)
if text:
reasoning_content = (reasoning_content or "") + text
elif event_type in {"error", "response.failed"}: elif event_type in {"error", "response.failed"}:
raise RuntimeError("Response failed") raise RuntimeError("Response failed")
return content, tool_calls, finish_reason return content, tool_calls, finish_reason, usage, reasoning_content