Fill up gaps

This commit is contained in:
Kunal Karmakar 2026-03-31 02:29:40 +00:00 committed by Xubin Ren
parent 8c0607e079
commit 7c44aa92ca
2 changed files with 37 additions and 5 deletions

View File

@ -160,13 +160,15 @@ class AzureOpenAIProvider(LLMProvider):
try:
stream = await self._client.responses.create(**body)
content, tool_calls, finish_reason = await consume_sdk_stream(
stream, on_content_delta,
content, tool_calls, finish_reason, usage, reasoning_content = (
await consume_sdk_stream(stream, on_content_delta)
)
return LLMResponse(
content=content or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
usage=usage,
reasoning_content=reasoning_content,
)
except Exception as e:
return self._handle_error(e)

View File

@ -122,6 +122,7 @@ def parse_response_output(response: Any) -> LLMResponse:
output = response.get("output") or []
content_parts: list[str] = []
tool_calls: list[ToolCallRequest] = []
reasoning_content: str | None = None
for item in output:
if not isinstance(item, dict):
@ -136,6 +137,14 @@ def parse_response_output(response: Any) -> LLMResponse:
block = dump() if callable(dump) else vars(block)
if block.get("type") == "output_text":
content_parts.append(block.get("text") or "")
elif item_type == "reasoning":
# Reasoning items may have a summary list with text blocks
for s in item.get("summary") or []:
if not isinstance(s, dict):
dump = getattr(s, "model_dump", None)
s = dump() if callable(dump) else vars(s)
if s.get("type") == "summary_text" and s.get("text"):
reasoning_content = (reasoning_content or "") + s["text"]
elif item_type == "function_call":
call_id = item.get("call_id") or ""
item_id = item.get("id") or "fc_0"
@ -170,22 +179,26 @@ def parse_response_output(response: Any) -> LLMResponse:
tool_calls=tool_calls,
finish_reason=finish_reason,
usage=usage,
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
)
async def consume_sdk_stream(
stream: Any,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]:
) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]:
"""Consume an SDK async stream from ``client.responses.create(stream=True)``.
The SDK yields typed event objects with a ``.type`` attribute and
event-specific fields. Returns ``(content, tool_calls, finish_reason)``.
event-specific fields. Returns
``(content, tool_calls, finish_reason, usage, reasoning_content)``.
"""
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
usage: dict[str, int] = {}
reasoning_content: str | None = None
async for event in stream:
event_type = getattr(event, "type", None)
@ -236,7 +249,24 @@ async def consume_sdk_stream(
resp = getattr(event, "response", None)
status = getattr(resp, "status", None) if resp else None
finish_reason = map_finish_reason(status)
# Extract usage from the completed response
if resp:
usage_obj = getattr(resp, "usage", None)
if usage_obj:
usage = {
"prompt_tokens": int(getattr(usage_obj, "input_tokens", 0) or 0),
"completion_tokens": int(getattr(usage_obj, "output_tokens", 0) or 0),
"total_tokens": int(getattr(usage_obj, "total_tokens", 0) or 0),
}
# Extract reasoning_content from completed output items
for out_item in getattr(resp, "output", None) or []:
if getattr(out_item, "type", None) == "reasoning":
for s in getattr(out_item, "summary", None) or []:
if getattr(s, "type", None) == "summary_text":
text = getattr(s, "text", None)
if text:
reasoning_content = (reasoning_content or "") + text
elif event_type in {"error", "response.failed"}:
raise RuntimeError("Response failed")
return content, tool_calls, finish_reason
return content, tool_calls, finish_reason, usage, reasoning_content