mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-09 11:15:55 +00:00
feat(api): add SSE streaming for /v1/chat/completions Wire up the existing on_stream/on_stream_end callbacks from process_direct() to emit OpenAI-compatible SSE chunks when stream=true. Non-streaming path is untouched.
This commit is contained in:
parent
db78574cb8
commit
4fce8d8b8d
@ -8,6 +8,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import json as _json
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -71,6 +72,30 @@ def _response_text(value: Any) -> str:
|
|||||||
return str(getattr(value, "content") or "")
|
return str(getattr(value, "content") or "")
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SSE helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _sse_chunk(delta: str, model: str, chunk_id: str, finish_reason: str | None = None) -> bytes:
|
||||||
|
"""Format a single OpenAI-compatible SSE chunk."""
|
||||||
|
payload = {
|
||||||
|
"id": chunk_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": delta} if delta else {},
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return f"data: {_json.dumps(payload)}\n\n".encode()
|
||||||
|
|
||||||
|
|
||||||
|
_SSE_DONE = b"data: [DONE]\n\n"
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Upload helpers
|
# Upload helpers
|
||||||
@ -188,6 +213,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
timeout_s: float = request.app.get("request_timeout", 120.0)
|
timeout_s: float = request.app.get("request_timeout", 120.0)
|
||||||
model_name: str = request.app.get("model_name", "nanobot")
|
model_name: str = request.app.get("model_name", "nanobot")
|
||||||
|
|
||||||
|
stream = False
|
||||||
try:
|
try:
|
||||||
if content_type.startswith("multipart/"):
|
if content_type.startswith("multipart/"):
|
||||||
text, media_paths, session_id, requested_model = await _parse_multipart(request)
|
text, media_paths, session_id, requested_model = await _parse_multipart(request)
|
||||||
@ -196,10 +222,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
body = await request.json()
|
body = await request.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
return _error_json(400, "Invalid JSON body")
|
return _error_json(400, "Invalid JSON body")
|
||||||
if body.get("stream", False):
|
stream = body.get("stream", False)
|
||||||
return _error_json(
|
|
||||||
400, "stream=true is not supported yet. Set stream=false or omit it."
|
|
||||||
)
|
|
||||||
requested_model = body.get("model")
|
requested_model = body.get("model")
|
||||||
text, media_paths = _parse_json_content(body)
|
text, media_paths = _parse_json_content(body)
|
||||||
session_id = body.get("session_id")
|
session_id = body.get("session_id")
|
||||||
@ -219,9 +242,61 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
|
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"API request session_key={} media={} text={}", session_key, len(media_paths), text[:80]
|
"API request session_key={} media={} text={} stream={}",
|
||||||
|
session_key, len(media_paths), text[:80], stream,
|
||||||
)
|
)
|
||||||
|
# -- streaming path --
|
||||||
|
if stream:
|
||||||
|
resp = web.StreamResponse()
|
||||||
|
resp.content_type = "text/event-stream"
|
||||||
|
resp.headers["Cache-Control"] = "no-cache"
|
||||||
|
resp.headers["Connection"] = "keep-alive"
|
||||||
|
resp.enable_compression()
|
||||||
|
await resp.prepare(request)
|
||||||
|
|
||||||
|
chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||||
|
|
||||||
|
async def _on_stream(token: str) -> None:
|
||||||
|
await queue.put(token)
|
||||||
|
|
||||||
|
async def _on_stream_end(*_a: Any, **_kw: Any) -> None:
|
||||||
|
await queue.put(None)
|
||||||
|
|
||||||
|
async def _run() -> None:
|
||||||
|
try:
|
||||||
|
async with session_lock:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
agent_loop.process_direct(
|
||||||
|
content=text,
|
||||||
|
media=media_paths if media_paths else None,
|
||||||
|
session_key=session_key,
|
||||||
|
channel="api",
|
||||||
|
chat_id=API_CHAT_ID,
|
||||||
|
on_stream=_on_stream,
|
||||||
|
on_stream_end=_on_stream_end,
|
||||||
|
),
|
||||||
|
timeout=timeout_s,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Streaming error for session {}", session_key)
|
||||||
|
await queue.put(None)
|
||||||
|
|
||||||
|
task = asyncio.create_task(_run())
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
token = await queue.get()
|
||||||
|
if token is None:
|
||||||
|
break
|
||||||
|
await resp.write(_sse_chunk(token, model_name, chunk_id))
|
||||||
|
finally:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
await resp.write(_sse_chunk("", model_name, chunk_id, finish_reason="stop"))
|
||||||
|
await resp.write(_SSE_DONE)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
# -- non-streaming path (original logic) --
|
||||||
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
|
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
253
tests/test_api_stream.py
Normal file
253
tests/test_api_stream.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
"""Tests for SSE streaming support in /v1/chat/completions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from nanobot.api.server import (
|
||||||
|
_sse_chunk,
|
||||||
|
_SSE_DONE,
|
||||||
|
create_app,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from aiohttp.test_utils import TestClient, TestServer
|
||||||
|
|
||||||
|
HAS_AIOHTTP = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_AIOHTTP = False
|
||||||
|
|
||||||
|
pytest_plugins = ("pytest_asyncio",)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit tests for SSE helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_sse_chunk_with_delta() -> None:
|
||||||
|
raw = _sse_chunk("hello", "test-model", "chatcmpl-abc123")
|
||||||
|
line = raw.decode()
|
||||||
|
assert line.startswith("data: ")
|
||||||
|
payload = json.loads(line[len("data: "):])
|
||||||
|
assert payload["id"] == "chatcmpl-abc123"
|
||||||
|
assert payload["object"] == "chat.completion.chunk"
|
||||||
|
assert payload["model"] == "test-model"
|
||||||
|
assert payload["choices"][0]["delta"]["content"] == "hello"
|
||||||
|
assert payload["choices"][0]["finish_reason"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_sse_chunk_finish_reason() -> None:
|
||||||
|
raw = _sse_chunk("", "m", "id1", finish_reason="stop")
|
||||||
|
payload = json.loads(raw.decode().split("data: ", 1)[1])
|
||||||
|
assert payload["choices"][0]["delta"] == {}
|
||||||
|
assert payload["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sse_done_format() -> None:
|
||||||
|
assert _SSE_DONE == b"data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration tests with aiohttp TestClient
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_streaming_agent(tokens: list[str]) -> MagicMock:
|
||||||
|
"""Create a mock agent that streams tokens via on_stream callback."""
|
||||||
|
agent = MagicMock()
|
||||||
|
agent._connect_mcp = AsyncMock()
|
||||||
|
agent.close_mcp = AsyncMock()
|
||||||
|
|
||||||
|
async def fake_process_direct(*, content="", media=None, session_key="",
|
||||||
|
channel="", chat_id="", on_stream=None,
|
||||||
|
on_stream_end=None, **kwargs):
|
||||||
|
if on_stream:
|
||||||
|
for token in tokens:
|
||||||
|
await on_stream(token)
|
||||||
|
if on_stream_end:
|
||||||
|
await on_stream_end()
|
||||||
|
return " ".join(tokens)
|
||||||
|
|
||||||
|
agent.process_direct = fake_process_direct
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def aiohttp_client():
|
||||||
|
clients: list[TestClient] = []
|
||||||
|
|
||||||
|
async def _make_client(app):
|
||||||
|
client = TestClient(TestServer(app))
|
||||||
|
await client.start_server()
|
||||||
|
clients.append(client)
|
||||||
|
return client
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield _make_client
|
||||||
|
finally:
|
||||||
|
for client in clients:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_true_returns_sse(aiohttp_client) -> None:
|
||||||
|
"""stream=true should return text/event-stream with SSE chunks."""
|
||||||
|
agent = _make_streaming_agent(["Hello", " world"])
|
||||||
|
app = create_app(agent, model_name="test-model")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
assert resp.content_type == "text/event-stream"
|
||||||
|
|
||||||
|
body = await resp.text()
|
||||||
|
lines = [l for l in body.split("\n") if l.startswith("data: ")]
|
||||||
|
|
||||||
|
# Should have: 2 token chunks + 1 finish chunk + [DONE]
|
||||||
|
data_lines = [l[len("data: "):] for l in lines]
|
||||||
|
assert data_lines[-1] == "[DONE]"
|
||||||
|
|
||||||
|
chunks = [json.loads(l) for l in data_lines[:-1]]
|
||||||
|
assert chunks[0]["choices"][0]["delta"]["content"] == "Hello"
|
||||||
|
assert chunks[1]["choices"][0]["delta"]["content"] == " world"
|
||||||
|
# Last chunk before [DONE] should have finish_reason=stop
|
||||||
|
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
|
||||||
|
assert chunks[-1]["choices"][0]["delta"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_false_returns_json(aiohttp_client) -> None:
|
||||||
|
"""stream=false should still return regular JSON response."""
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.process_direct = AsyncMock(return_value="normal reply")
|
||||||
|
agent._connect_mcp = AsyncMock()
|
||||||
|
agent.close_mcp = AsyncMock()
|
||||||
|
|
||||||
|
app = create_app(agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}], "stream": False},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
body = await resp.json()
|
||||||
|
assert body["object"] == "chat.completion"
|
||||||
|
assert body["choices"][0]["message"]["content"] == "normal reply"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_default_is_false(aiohttp_client) -> None:
|
||||||
|
"""Omitting stream should behave like stream=false."""
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.process_direct = AsyncMock(return_value="default reply")
|
||||||
|
agent._connect_mcp = AsyncMock()
|
||||||
|
agent.close_mcp = AsyncMock()
|
||||||
|
|
||||||
|
app = create_app(agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}]},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
body = await resp.json()
|
||||||
|
assert body["object"] == "chat.completion"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_sse_chunk_ids_are_consistent(aiohttp_client) -> None:
|
||||||
|
"""All SSE chunks in a single stream should share the same id."""
|
||||||
|
agent = _make_streaming_agent(["A", "B", "C"])
|
||||||
|
app = create_app(agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "go"}], "stream": True},
|
||||||
|
)
|
||||||
|
body = await resp.text()
|
||||||
|
data_lines = [l[len("data: "):] for l in body.split("\n") if l.startswith("data: ") and l != "data: [DONE]"]
|
||||||
|
chunks = [json.loads(l) for l in data_lines]
|
||||||
|
|
||||||
|
chunk_ids = {c["id"] for c in chunks}
|
||||||
|
assert len(chunk_ids) == 1, f"Expected single chunk id, got {chunk_ids}"
|
||||||
|
assert chunk_ids.pop().startswith("chatcmpl-")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_passes_on_stream_callbacks(aiohttp_client) -> None:
|
||||||
|
"""process_direct should be called with on_stream and on_stream_end when streaming."""
|
||||||
|
captured_kwargs: dict = {}
|
||||||
|
|
||||||
|
async def fake_process_direct(**kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
if kwargs.get("on_stream_end"):
|
||||||
|
await kwargs["on_stream_end"]()
|
||||||
|
return "done"
|
||||||
|
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.process_direct = fake_process_direct
|
||||||
|
agent._connect_mcp = AsyncMock()
|
||||||
|
agent.close_mcp = AsyncMock()
|
||||||
|
|
||||||
|
app = create_app(agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
assert captured_kwargs.get("on_stream") is not None
|
||||||
|
assert captured_kwargs.get("on_stream_end") is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_with_session_id(aiohttp_client) -> None:
|
||||||
|
"""Streaming should respect session_id for session key routing."""
|
||||||
|
captured_key: str = ""
|
||||||
|
|
||||||
|
async def fake_process_direct(*, session_key="", on_stream=None, on_stream_end=None, **kwargs):
|
||||||
|
nonlocal captured_key
|
||||||
|
captured_key = session_key
|
||||||
|
if on_stream:
|
||||||
|
await on_stream("ok")
|
||||||
|
if on_stream_end:
|
||||||
|
await on_stream_end()
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.process_direct = fake_process_direct
|
||||||
|
agent._connect_mcp = AsyncMock()
|
||||||
|
agent.close_mcp = AsyncMock()
|
||||||
|
|
||||||
|
app = create_app(agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"messages": [{"role": "user", "content": "hi"}],
|
||||||
|
"stream": True,
|
||||||
|
"session_id": "my-session",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
assert captured_key == "api:my-session"
|
||||||
@ -101,15 +101,14 @@ async def test_no_user_message_returns_400(aiohttp_client, app) -> None:
|
|||||||
|
|
||||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_true_returns_400(aiohttp_client, app) -> None:
|
async def test_stream_true_returns_sse(aiohttp_client, app) -> None:
|
||||||
client = await aiohttp_client(app)
|
client = await aiohttp_client(app)
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
json={"messages": [{"role": "user", "content": "hello"}], "stream": True},
|
json={"messages": [{"role": "user", "content": "hello"}], "stream": True},
|
||||||
)
|
)
|
||||||
assert resp.status == 400
|
assert resp.status == 200
|
||||||
body = await resp.json()
|
assert resp.content_type == "text/event-stream"
|
||||||
assert "stream" in body["error"]["message"].lower()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user