Add tests and handle json

This commit is contained in:
Kunal Karmakar 2026-03-31 08:37:41 +00:00 committed by Xubin Ren
parent ac2ee58791
commit e206cffd7a
2 changed files with 44 additions and 23 deletions

View File

@ -7,6 +7,7 @@ from collections.abc import Awaitable, Callable
from typing import Any, AsyncGenerator
import httpx
import json_repair
from loguru import logger
from nanobot.providers.base import LLMResponse, ToolCallRequest
@ -27,24 +28,36 @@ def map_finish_reason(status: str | None) -> str:
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
"""Yield parsed JSON events from a Responses API SSE stream."""
buffer: list[str] = []
def _flush() -> dict[str, Any] | None:
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
buffer.clear()
if not data_lines:
return None
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
return None
try:
return json.loads(data)
except Exception:
logger.warning("Failed to parse SSE event JSON: {}", data[:200])
return None
async for line in response.aiter_lines():
if line == "":
if buffer:
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
buffer = []
if not data_lines:
continue
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
continue
try:
yield json.loads(data)
except Exception:
logger.warning("Failed to parse SSE event JSON: {}", data[:200])
continue
event = _flush()
if event is not None:
yield event
continue
buffer.append(line)
# Flush any remaining buffer at EOF (#10)
if buffer:
event = _flush()
if event is not None:
yield event
async def consume_sse(
response: httpx.Response,
@ -95,11 +108,13 @@ async def consume_sse(
except Exception:
logger.warning("Failed to parse tool call arguments for '{}': {}",
buf.get("name") or item.get("name"), args_raw[:200])
args = {"raw": args_raw}
args = json_repair.loads(args_raw)
if not isinstance(args, dict):
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
name=buf.get("name") or item.get("name"),
name=buf.get("name") or item.get("name") or "",
arguments=args,
)
)
@ -107,7 +122,8 @@ async def consume_sse(
status = (event.get("response") or {}).get("status")
finish_reason = map_finish_reason(status)
elif event_type in {"error", "response.failed"}:
raise RuntimeError("Response failed")
detail = event.get("error") or event.get("message") or event
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
return content, tool_calls, finish_reason
@ -158,7 +174,9 @@ def parse_response_output(response: Any) -> LLMResponse:
except Exception:
logger.warning("Failed to parse tool call arguments for '{}': {}",
item.get("name"), str(args_raw)[:200])
args = {"raw": args_raw}
args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw
if not isinstance(args, dict):
args = {"raw": args_raw}
tool_calls.append(ToolCallRequest(
id=f"{call_id}|{item_id}",
name=item.get("name") or "",
@ -246,11 +264,13 @@ async def consume_sdk_stream(
logger.warning("Failed to parse tool call arguments for '{}': {}",
buf.get("name") or getattr(item, "name", None),
str(args_raw)[:200])
args = {"raw": args_raw}
args = json_repair.loads(args_raw)
if not isinstance(args, dict):
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
name=buf.get("name") or getattr(item, "name", None),
name=buf.get("name") or getattr(item, "name", None) or "",
arguments=args,
)
)
@ -276,6 +296,7 @@ async def consume_sdk_stream(
if text:
reasoning_content = (reasoning_content or "") + text
elif event_type in {"error", "response.failed"}:
raise RuntimeError("Response failed")
detail = getattr(event, "error", None) or getattr(event, "message", None) or event
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
return content, tool_calls, finish_reason, usage, reasoning_content

View File

@ -492,22 +492,22 @@ class TestConsumeSdkStream:
@pytest.mark.asyncio
async def test_error_event_raises(self):
ev = MagicMock(type="error")
ev = MagicMock(type="error", error="rate_limit_exceeded")
async def stream():
yield ev
with pytest.raises(RuntimeError, match="Response failed"):
with pytest.raises(RuntimeError, match="Response failed.*rate_limit_exceeded"):
await consume_sdk_stream(stream())
@pytest.mark.asyncio
async def test_failed_event_raises(self):
ev = MagicMock(type="response.failed")
ev = MagicMock(type="response.failed", error="server_error")
async def stream():
yield ev
with pytest.raises(RuntimeError, match="Response failed"):
with pytest.raises(RuntimeError, match="Response failed.*server_error"):
await consume_sdk_stream(stream())
@pytest.mark.asyncio