From 4fce8d8b8d7cef80a098ff2fdadd1e99415549a4 Mon Sep 17 00:00:00 2001 From: whs Date: Fri, 17 Apr 2026 00:41:39 +0800 Subject: [PATCH] feat(api): add SSE streaming for /v1/chat/completions Wire up the existing on_stream/on_stream_end callbacks from process_direct() to emit OpenAI-compatible SSE chunks when stream=true. Non-streaming path is untouched. --- nanobot/api/server.py | 85 ++++++++++++- tests/test_api_stream.py | 253 +++++++++++++++++++++++++++++++++++++++ tests/test_openai_api.py | 7 +- 3 files changed, 336 insertions(+), 9 deletions(-) create mode 100644 tests/test_api_stream.py diff --git a/nanobot/api/server.py b/nanobot/api/server.py index d8a230340..e384eabcb 100644 --- a/nanobot/api/server.py +++ b/nanobot/api/server.py @@ -8,6 +8,7 @@ from __future__ import annotations import asyncio import base64 +import json as _json import mimetypes import re import time @@ -71,6 +72,30 @@ def _response_text(value: Any) -> str: return str(getattr(value, "content") or "") return str(value) +# --------------------------------------------------------------------------- +# SSE helpers +# --------------------------------------------------------------------------- + + +def _sse_chunk(delta: str, model: str, chunk_id: str, finish_reason: str | None = None) -> bytes: + """Format a single OpenAI-compatible SSE chunk.""" + payload = { + "id": chunk_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": delta} if delta else {}, + "finish_reason": finish_reason, + } + ], + } + return f"data: {_json.dumps(payload)}\n\n".encode() + + +_SSE_DONE = b"data: [DONE]\n\n" # --------------------------------------------------------------------------- # Upload helpers @@ -188,6 +213,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response: timeout_s: float = request.app.get("request_timeout", 120.0) model_name: str = request.app.get("model_name", "nanobot") + stream = False try: if content_type.startswith("multipart/"): text, media_paths, session_id, requested_model = await _parse_multipart(request) @@ -196,10 +222,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response: body = await request.json() except Exception: return _error_json(400, "Invalid JSON body") - if body.get("stream", False): - return _error_json( - 400, "stream=true is not supported yet. Set stream=false or omit it." - ) + stream = body.get("stream", False) requested_model = body.get("model") text, media_paths = _parse_json_content(body) session_id = body.get("session_id") @@ -219,9 +242,61 @@ async def handle_chat_completions(request: web.Request) -> web.Response: session_lock = session_locks.setdefault(session_key, asyncio.Lock()) logger.info( - "API request session_key={} media={} text={}", session_key, len(media_paths), text[:80] + "API request session_key={} media={} text={} stream={}", + session_key, len(media_paths), text[:80], stream, ) + # -- streaming path -- + if stream: + resp = web.StreamResponse() + resp.content_type = "text/event-stream" + resp.headers["Cache-Control"] = "no-cache" + resp.headers["Connection"] = "keep-alive" + resp.enable_compression() + await resp.prepare(request) + chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" + queue: asyncio.Queue[str | None] = asyncio.Queue() + + async def _on_stream(token: str) -> None: + await queue.put(token) + + async def _on_stream_end(*_a: Any, **_kw: Any) -> None: + await queue.put(None) + + async def _run() -> None: + try: + async with session_lock: + await asyncio.wait_for( + agent_loop.process_direct( + content=text, + media=media_paths if media_paths else None, + session_key=session_key, + channel="api", + chat_id=API_CHAT_ID, + on_stream=_on_stream, + on_stream_end=_on_stream_end, + ), + timeout=timeout_s, + ) + except Exception: + logger.exception("Streaming error for session {}", session_key) + await queue.put(None) + + task = asyncio.create_task(_run()) + try: + while True: + token = await queue.get() + if token is None: + break + await resp.write(_sse_chunk(token, model_name, chunk_id)) + finally: + task.cancel() + + await resp.write(_sse_chunk("", model_name, chunk_id, finish_reason="stop")) + await resp.write(_SSE_DONE) + return resp + + # -- non-streaming path (original logic) -- _FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE try: diff --git a/tests/test_api_stream.py b/tests/test_api_stream.py new file mode 100644 index 000000000..cb9fa484f --- /dev/null +++ b/tests/test_api_stream.py @@ -0,0 +1,253 @@ +"""Tests for SSE streaming support in /v1/chat/completions.""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest +import pytest_asyncio + +from nanobot.api.server import ( + _sse_chunk, + _SSE_DONE, + create_app, +) + +try: + from aiohttp.test_utils import TestClient, TestServer + + HAS_AIOHTTP = True +except ImportError: + HAS_AIOHTTP = False + +pytest_plugins = ("pytest_asyncio",) + + +# --------------------------------------------------------------------------- +# Unit tests for SSE helpers +# --------------------------------------------------------------------------- + + +def test_sse_chunk_with_delta() -> None: + raw = _sse_chunk("hello", "test-model", "chatcmpl-abc123") + line = raw.decode() + assert line.startswith("data: ") + payload = json.loads(line[len("data: "):]) + assert payload["id"] == "chatcmpl-abc123" + assert payload["object"] == "chat.completion.chunk" + assert payload["model"] == "test-model" + assert payload["choices"][0]["delta"]["content"] == "hello" + assert payload["choices"][0]["finish_reason"] is None + + +def test_sse_chunk_finish_reason() -> None: + raw = _sse_chunk("", "m", "id1", finish_reason="stop") + payload = json.loads(raw.decode().split("data: ", 1)[1]) + assert payload["choices"][0]["delta"] == {} + assert payload["choices"][0]["finish_reason"] == "stop" + + +def test_sse_done_format() -> None: + assert _SSE_DONE == b"data: [DONE]\n\n" + + +# --------------------------------------------------------------------------- +# Integration tests with aiohttp TestClient +# --------------------------------------------------------------------------- + + +def _make_streaming_agent(tokens: list[str]) -> MagicMock: + """Create a mock agent that streams tokens via on_stream callback.""" + agent = MagicMock() + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + async def fake_process_direct(*, content="", media=None, session_key="", + channel="", chat_id="", on_stream=None, + on_stream_end=None, **kwargs): + if on_stream: + for token in tokens: + await on_stream(token) + if on_stream_end: + await on_stream_end() + return " ".join(tokens) + + agent.process_direct = fake_process_direct + return agent + + +@pytest_asyncio.fixture +async def aiohttp_client(): + clients: list[TestClient] = [] + + async def _make_client(app): + client = TestClient(TestServer(app)) + await client.start_server() + clients.append(client) + return client + + try: + yield _make_client + finally: + for client in clients: + await client.close() + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_stream_true_returns_sse(aiohttp_client) -> None: + """stream=true should return text/event-stream with SSE chunks.""" + agent = _make_streaming_agent(["Hello", " world"]) + app = create_app(agent, model_name="test-model") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}], "stream": True}, + ) + assert resp.status == 200 + assert resp.content_type == "text/event-stream" + + body = await resp.text() + lines = [l for l in body.split("\n") if l.startswith("data: ")] + + # Should have: 2 token chunks + 1 finish chunk + [DONE] + data_lines = [l[len("data: "):] for l in lines] + assert data_lines[-1] == "[DONE]" + + chunks = [json.loads(l) for l in data_lines[:-1]] + assert chunks[0]["choices"][0]["delta"]["content"] == "Hello" + assert chunks[1]["choices"][0]["delta"]["content"] == " world" + # Last chunk before [DONE] should have finish_reason=stop + assert chunks[-1]["choices"][0]["finish_reason"] == "stop" + assert chunks[-1]["choices"][0]["delta"] == {} + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_stream_false_returns_json(aiohttp_client) -> None: + """stream=false should still return regular JSON response.""" + agent = MagicMock() + agent.process_direct = AsyncMock(return_value="normal reply") + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}], "stream": False}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["object"] == "chat.completion" + assert body["choices"][0]["message"]["content"] == "normal reply" + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_stream_default_is_false(aiohttp_client) -> None: + """Omitting stream should behave like stream=false.""" + agent = MagicMock() + agent.process_direct = AsyncMock(return_value="default reply") + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["object"] == "chat.completion" + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_stream_sse_chunk_ids_are_consistent(aiohttp_client) -> None: + """All SSE chunks in a single stream should share the same id.""" + agent = _make_streaming_agent(["A", "B", "C"]) + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "go"}], "stream": True}, + ) + body = await resp.text() + data_lines = [l[len("data: "):] for l in body.split("\n") if l.startswith("data: ") and l != "data: [DONE]"] + chunks = [json.loads(l) for l in data_lines] + + chunk_ids = {c["id"] for c in chunks} + assert len(chunk_ids) == 1, f"Expected single chunk id, got {chunk_ids}" + assert chunk_ids.pop().startswith("chatcmpl-") + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_stream_passes_on_stream_callbacks(aiohttp_client) -> None: + """process_direct should be called with on_stream and on_stream_end when streaming.""" + captured_kwargs: dict = {} + + async def fake_process_direct(**kwargs): + captured_kwargs.update(kwargs) + if kwargs.get("on_stream_end"): + await kwargs["on_stream_end"]() + return "done" + + agent = MagicMock() + agent.process_direct = fake_process_direct + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}], "stream": True}, + ) + assert resp.status == 200 + assert captured_kwargs.get("on_stream") is not None + assert captured_kwargs.get("on_stream_end") is not None + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_stream_with_session_id(aiohttp_client) -> None: + """Streaming should respect session_id for session key routing.""" + captured_key: str = "" + + async def fake_process_direct(*, session_key="", on_stream=None, on_stream_end=None, **kwargs): + nonlocal captured_key + captured_key = session_key + if on_stream: + await on_stream("ok") + if on_stream_end: + await on_stream_end() + return "ok" + + agent = MagicMock() + agent.process_direct = fake_process_direct + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + "session_id": "my-session", + }, + ) + assert resp.status == 200 + assert captured_key == "api:my-session" diff --git a/tests/test_openai_api.py b/tests/test_openai_api.py index 50607de44..59b52b191 100644 --- a/tests/test_openai_api.py +++ b/tests/test_openai_api.py @@ -101,15 +101,14 @@ async def test_no_user_message_returns_400(aiohttp_client, app) -> None: @pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") @pytest.mark.asyncio -async def test_stream_true_returns_400(aiohttp_client, app) -> None: +async def test_stream_true_returns_sse(aiohttp_client, app) -> None: client = await aiohttp_client(app) resp = await client.post( "/v1/chat/completions", json={"messages": [{"role": "user", "content": "hello"}], "stream": True}, ) - assert resp.status == 400 - body = await resp.json() - assert "stream" in body["error"]["message"].lower() + assert resp.status == 200 + assert resp.content_type == "text/event-stream" @pytest.mark.asyncio