mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
Add tests and handle json
This commit is contained in:
parent
ac2ee58791
commit
e206cffd7a
@ -7,6 +7,7 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
@ -27,24 +28,36 @@ def map_finish_reason(status: str | None) -> str:
|
||||
async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""Yield parsed JSON events from a Responses API SSE stream."""
|
||||
buffer: list[str] = []
|
||||
|
||||
def _flush() -> dict[str, Any] | None:
|
||||
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
||||
buffer.clear()
|
||||
if not data_lines:
|
||||
return None
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data or data == "[DONE]":
|
||||
return None
|
||||
try:
|
||||
return json.loads(data)
|
||||
except Exception:
|
||||
logger.warning("Failed to parse SSE event JSON: {}", data[:200])
|
||||
return None
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
if buffer:
|
||||
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
||||
buffer = []
|
||||
if not data_lines:
|
||||
continue
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data or data == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
yield json.loads(data)
|
||||
except Exception:
|
||||
logger.warning("Failed to parse SSE event JSON: {}", data[:200])
|
||||
continue
|
||||
event = _flush()
|
||||
if event is not None:
|
||||
yield event
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
# Flush any remaining buffer at EOF (#10)
|
||||
if buffer:
|
||||
event = _flush()
|
||||
if event is not None:
|
||||
yield event
|
||||
|
||||
|
||||
async def consume_sse(
|
||||
response: httpx.Response,
|
||||
@ -95,11 +108,13 @@ async def consume_sse(
|
||||
except Exception:
|
||||
logger.warning("Failed to parse tool call arguments for '{}': {}",
|
||||
buf.get("name") or item.get("name"), args_raw[:200])
|
||||
args = {"raw": args_raw}
|
||||
args = json_repair.loads(args_raw)
|
||||
if not isinstance(args, dict):
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||
name=buf.get("name") or item.get("name"),
|
||||
name=buf.get("name") or item.get("name") or "",
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
@ -107,7 +122,8 @@ async def consume_sse(
|
||||
status = (event.get("response") or {}).get("status")
|
||||
finish_reason = map_finish_reason(status)
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
raise RuntimeError("Response failed")
|
||||
detail = event.get("error") or event.get("message") or event
|
||||
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
|
||||
|
||||
return content, tool_calls, finish_reason
|
||||
|
||||
@ -158,7 +174,9 @@ def parse_response_output(response: Any) -> LLMResponse:
|
||||
except Exception:
|
||||
logger.warning("Failed to parse tool call arguments for '{}': {}",
|
||||
item.get("name"), str(args_raw)[:200])
|
||||
args = {"raw": args_raw}
|
||||
args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
if not isinstance(args, dict):
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=f"{call_id}|{item_id}",
|
||||
name=item.get("name") or "",
|
||||
@ -246,11 +264,13 @@ async def consume_sdk_stream(
|
||||
logger.warning("Failed to parse tool call arguments for '{}': {}",
|
||||
buf.get("name") or getattr(item, "name", None),
|
||||
str(args_raw)[:200])
|
||||
args = {"raw": args_raw}
|
||||
args = json_repair.loads(args_raw)
|
||||
if not isinstance(args, dict):
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}",
|
||||
name=buf.get("name") or getattr(item, "name", None),
|
||||
name=buf.get("name") or getattr(item, "name", None) or "",
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
@ -276,6 +296,7 @@ async def consume_sdk_stream(
|
||||
if text:
|
||||
reasoning_content = (reasoning_content or "") + text
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
raise RuntimeError("Response failed")
|
||||
detail = getattr(event, "error", None) or getattr(event, "message", None) or event
|
||||
raise RuntimeError(f"Response failed: {str(detail)[:500]}")
|
||||
|
||||
return content, tool_calls, finish_reason, usage, reasoning_content
|
||||
|
||||
@ -492,22 +492,22 @@ class TestConsumeSdkStream:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_event_raises(self):
|
||||
ev = MagicMock(type="error")
|
||||
ev = MagicMock(type="error", error="rate_limit_exceeded")
|
||||
|
||||
async def stream():
|
||||
yield ev
|
||||
|
||||
with pytest.raises(RuntimeError, match="Response failed"):
|
||||
with pytest.raises(RuntimeError, match="Response failed.*rate_limit_exceeded"):
|
||||
await consume_sdk_stream(stream())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_event_raises(self):
|
||||
ev = MagicMock(type="response.failed")
|
||||
ev = MagicMock(type="response.failed", error="server_error")
|
||||
|
||||
async def stream():
|
||||
yield ev
|
||||
|
||||
with pytest.raises(RuntimeError, match="Response failed"):
|
||||
with pytest.raises(RuntimeError, match="Response failed.*server_error"):
|
||||
await consume_sdk_stream(stream())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user