feat(api): add SSE streaming for /v1/chat/completions Wire up the existing on_stream/on_stream_end callbacks from process_direct() to emit OpenAI-compatible SSE chunks when stream=true. Non-streaming path is untouched.

This commit is contained in:
whs 2026-04-17 00:41:39 +08:00 committed by Xubin Ren
parent db78574cb8
commit 4fce8d8b8d
3 changed files with 336 additions and 9 deletions

View File

@ -8,6 +8,7 @@ from __future__ import annotations
import asyncio import asyncio
import base64 import base64
import json as _json
import mimetypes import mimetypes
import re import re
import time import time
@ -71,6 +72,30 @@ def _response_text(value: Any) -> str:
return str(getattr(value, "content") or "") return str(getattr(value, "content") or "")
return str(value) return str(value)
# ---------------------------------------------------------------------------
# SSE helpers
# ---------------------------------------------------------------------------
def _sse_chunk(delta: str, model: str, chunk_id: str, finish_reason: str | None = None) -> bytes:
"""Format a single OpenAI-compatible SSE chunk."""
payload = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": delta} if delta else {},
"finish_reason": finish_reason,
}
],
}
return f"data: {_json.dumps(payload)}\n\n".encode()
_SSE_DONE = b"data: [DONE]\n\n"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Upload helpers # Upload helpers
@ -188,6 +213,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
timeout_s: float = request.app.get("request_timeout", 120.0) timeout_s: float = request.app.get("request_timeout", 120.0)
model_name: str = request.app.get("model_name", "nanobot") model_name: str = request.app.get("model_name", "nanobot")
stream = False
try: try:
if content_type.startswith("multipart/"): if content_type.startswith("multipart/"):
text, media_paths, session_id, requested_model = await _parse_multipart(request) text, media_paths, session_id, requested_model = await _parse_multipart(request)
@ -196,10 +222,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
body = await request.json() body = await request.json()
except Exception: except Exception:
return _error_json(400, "Invalid JSON body") return _error_json(400, "Invalid JSON body")
if body.get("stream", False): stream = body.get("stream", False)
return _error_json(
400, "stream=true is not supported yet. Set stream=false or omit it."
)
requested_model = body.get("model") requested_model = body.get("model")
text, media_paths = _parse_json_content(body) text, media_paths = _parse_json_content(body)
session_id = body.get("session_id") session_id = body.get("session_id")
@ -219,9 +242,61 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
session_lock = session_locks.setdefault(session_key, asyncio.Lock()) session_lock = session_locks.setdefault(session_key, asyncio.Lock())
logger.info( logger.info(
"API request session_key={} media={} text={}", session_key, len(media_paths), text[:80] "API request session_key={} media={} text={} stream={}",
session_key, len(media_paths), text[:80], stream,
) )
# -- streaming path --
if stream:
resp = web.StreamResponse()
resp.content_type = "text/event-stream"
resp.headers["Cache-Control"] = "no-cache"
resp.headers["Connection"] = "keep-alive"
resp.enable_compression()
await resp.prepare(request)
chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
queue: asyncio.Queue[str | None] = asyncio.Queue()
async def _on_stream(token: str) -> None:
await queue.put(token)
async def _on_stream_end(*_a: Any, **_kw: Any) -> None:
await queue.put(None)
async def _run() -> None:
try:
async with session_lock:
await asyncio.wait_for(
agent_loop.process_direct(
content=text,
media=media_paths if media_paths else None,
session_key=session_key,
channel="api",
chat_id=API_CHAT_ID,
on_stream=_on_stream,
on_stream_end=_on_stream_end,
),
timeout=timeout_s,
)
except Exception:
logger.exception("Streaming error for session {}", session_key)
await queue.put(None)
task = asyncio.create_task(_run())
try:
while True:
token = await queue.get()
if token is None:
break
await resp.write(_sse_chunk(token, model_name, chunk_id))
finally:
task.cancel()
await resp.write(_sse_chunk("", model_name, chunk_id, finish_reason="stop"))
await resp.write(_SSE_DONE)
return resp
# -- non-streaming path (original logic) --
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE _FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
try: try:

253
tests/test_api_stream.py Normal file
View File

@ -0,0 +1,253 @@
"""Tests for SSE streaming support in /v1/chat/completions."""
from __future__ import annotations
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
from nanobot.api.server import (
_sse_chunk,
_SSE_DONE,
create_app,
)
try:
from aiohttp.test_utils import TestClient, TestServer
HAS_AIOHTTP = True
except ImportError:
HAS_AIOHTTP = False
pytest_plugins = ("pytest_asyncio",)
# ---------------------------------------------------------------------------
# Unit tests for SSE helpers
# ---------------------------------------------------------------------------
def test_sse_chunk_with_delta() -> None:
raw = _sse_chunk("hello", "test-model", "chatcmpl-abc123")
line = raw.decode()
assert line.startswith("data: ")
payload = json.loads(line[len("data: "):])
assert payload["id"] == "chatcmpl-abc123"
assert payload["object"] == "chat.completion.chunk"
assert payload["model"] == "test-model"
assert payload["choices"][0]["delta"]["content"] == "hello"
assert payload["choices"][0]["finish_reason"] is None
def test_sse_chunk_finish_reason() -> None:
raw = _sse_chunk("", "m", "id1", finish_reason="stop")
payload = json.loads(raw.decode().split("data: ", 1)[1])
assert payload["choices"][0]["delta"] == {}
assert payload["choices"][0]["finish_reason"] == "stop"
def test_sse_done_format() -> None:
assert _SSE_DONE == b"data: [DONE]\n\n"
# ---------------------------------------------------------------------------
# Integration tests with aiohttp TestClient
# ---------------------------------------------------------------------------
def _make_streaming_agent(tokens: list[str]) -> MagicMock:
"""Create a mock agent that streams tokens via on_stream callback."""
agent = MagicMock()
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
async def fake_process_direct(*, content="", media=None, session_key="",
channel="", chat_id="", on_stream=None,
on_stream_end=None, **kwargs):
if on_stream:
for token in tokens:
await on_stream(token)
if on_stream_end:
await on_stream_end()
return " ".join(tokens)
agent.process_direct = fake_process_direct
return agent
@pytest_asyncio.fixture
async def aiohttp_client():
clients: list[TestClient] = []
async def _make_client(app):
client = TestClient(TestServer(app))
await client.start_server()
clients.append(client)
return client
try:
yield _make_client
finally:
for client in clients:
await client.close()
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_stream_true_returns_sse(aiohttp_client) -> None:
"""stream=true should return text/event-stream with SSE chunks."""
agent = _make_streaming_agent(["Hello", " world"])
app = create_app(agent, model_name="test-model")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
)
assert resp.status == 200
assert resp.content_type == "text/event-stream"
body = await resp.text()
lines = [l for l in body.split("\n") if l.startswith("data: ")]
# Should have: 2 token chunks + 1 finish chunk + [DONE]
data_lines = [l[len("data: "):] for l in lines]
assert data_lines[-1] == "[DONE]"
chunks = [json.loads(l) for l in data_lines[:-1]]
assert chunks[0]["choices"][0]["delta"]["content"] == "Hello"
assert chunks[1]["choices"][0]["delta"]["content"] == " world"
# Last chunk before [DONE] should have finish_reason=stop
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
assert chunks[-1]["choices"][0]["delta"] == {}
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_stream_false_returns_json(aiohttp_client) -> None:
"""stream=false should still return regular JSON response."""
agent = MagicMock()
agent.process_direct = AsyncMock(return_value="normal reply")
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hi"}], "stream": False},
)
assert resp.status == 200
body = await resp.json()
assert body["object"] == "chat.completion"
assert body["choices"][0]["message"]["content"] == "normal reply"
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_stream_default_is_false(aiohttp_client) -> None:
"""Omitting stream should behave like stream=false."""
agent = MagicMock()
agent.process_direct = AsyncMock(return_value="default reply")
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hi"}]},
)
assert resp.status == 200
body = await resp.json()
assert body["object"] == "chat.completion"
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_stream_sse_chunk_ids_are_consistent(aiohttp_client) -> None:
"""All SSE chunks in a single stream should share the same id."""
agent = _make_streaming_agent(["A", "B", "C"])
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "go"}], "stream": True},
)
body = await resp.text()
data_lines = [l[len("data: "):] for l in body.split("\n") if l.startswith("data: ") and l != "data: [DONE]"]
chunks = [json.loads(l) for l in data_lines]
chunk_ids = {c["id"] for c in chunks}
assert len(chunk_ids) == 1, f"Expected single chunk id, got {chunk_ids}"
assert chunk_ids.pop().startswith("chatcmpl-")
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_stream_passes_on_stream_callbacks(aiohttp_client) -> None:
"""process_direct should be called with on_stream and on_stream_end when streaming."""
captured_kwargs: dict = {}
async def fake_process_direct(**kwargs):
captured_kwargs.update(kwargs)
if kwargs.get("on_stream_end"):
await kwargs["on_stream_end"]()
return "done"
agent = MagicMock()
agent.process_direct = fake_process_direct
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
)
assert resp.status == 200
assert captured_kwargs.get("on_stream") is not None
assert captured_kwargs.get("on_stream_end") is not None
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_stream_with_session_id(aiohttp_client) -> None:
"""Streaming should respect session_id for session key routing."""
captured_key: str = ""
async def fake_process_direct(*, session_key="", on_stream=None, on_stream_end=None, **kwargs):
nonlocal captured_key
captured_key = session_key
if on_stream:
await on_stream("ok")
if on_stream_end:
await on_stream_end()
return "ok"
agent = MagicMock()
agent.process_direct = fake_process_direct
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "hi"}],
"stream": True,
"session_id": "my-session",
},
)
assert resp.status == 200
assert captured_key == "api:my-session"

View File

@ -101,15 +101,14 @@ async def test_no_user_message_returns_400(aiohttp_client, app) -> None:
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") @pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stream_true_returns_400(aiohttp_client, app) -> None: async def test_stream_true_returns_sse(aiohttp_client, app) -> None:
client = await aiohttp_client(app) client = await aiohttp_client(app)
resp = await client.post( resp = await client.post(
"/v1/chat/completions", "/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}], "stream": True}, json={"messages": [{"role": "user", "content": "hello"}], "stream": True},
) )
assert resp.status == 400 assert resp.status == 200
body = await resp.json() assert resp.content_type == "text/event-stream"
assert "stream" in body["error"]["message"].lower()
@pytest.mark.asyncio @pytest.mark.asyncio