Fix API stream lifecycle for tool-backed requests

This commit is contained in:
hanyuanling 2026-04-30 17:53:18 +08:00 committed by Xubin Ren
parent 73840b0af6
commit 1040124ede
3 changed files with 135 additions and 12 deletions

View File

@ -43,6 +43,26 @@ We use a two-branch model to balance stability and exploration:
**When in doubt, target `nightly`.** It is easier to move a stable idea from `nightly`
to `main` than to undo a risky change after it lands in the stable branch.
### Starting Work
Before making changes, sync the target branch and create a topic branch from it.
For stable bug fixes and documentation-only changes, start from the latest `main`.
For experimental work, start from the latest `nightly`.
```bash
git fetch upstream
git switch main
git pull --ff-only upstream main
git switch -c your-topic-branch
```
Use your primary HKUDS/nanobot remote in place of `upstream` if your checkout
uses a different remote name.
Keep unrelated local changes out of the topic branch. If your checkout already has
work in progress, use a separate worktree or finish that work before starting a
new branch.
### How Does Nightly Get Merged to Main?
We don't merge the entire `nightly` branch. Instead, stable features are **cherry-picked** from `nightly` into individual PRs targeting `main`:

View File

@ -7,6 +7,7 @@ All requests route to a single persistent API session.
from __future__ import annotations
import asyncio
import contextlib
import json as _json
import time
import uuid
@ -18,8 +19,12 @@ from loguru import logger
from nanobot.config.paths import get_media_dir
from nanobot.utils.helpers import safe_filename
from nanobot.utils.media_decode import (
FileSizeExceeded as _FileSizeExceeded,
MAX_FILE_SIZE,
)
from nanobot.utils.media_decode import (
FileSizeExceeded as _FileSizeExceeded,
)
from nanobot.utils.media_decode import (
save_base64_data_url as _save_base64_data_url,
)
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
@ -240,18 +245,25 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
queue: asyncio.Queue[str | None] = asyncio.Queue()
stream_failed = False
emitted_content = False
async def _on_stream(token: str) -> None:
nonlocal emitted_content
if token:
emitted_content = True
await queue.put(token)
async def _on_stream_end(*_a: Any, **_kw: Any) -> None:
await queue.put(None)
# Agent stream-end callbacks mark generation segment boundaries.
# Tool-backed requests may continue after a segment ends, so the
# HTTP SSE stream is closed only when process_direct returns.
return None
async def _run() -> None:
nonlocal stream_failed
try:
async with session_lock:
await asyncio.wait_for(
response = await asyncio.wait_for(
agent_loop.process_direct(
content=text,
media=media_paths if media_paths else None,
@ -263,9 +275,14 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
),
timeout=timeout_s,
)
if not emitted_content:
response_text = _response_text(response)
if response_text.strip():
await queue.put(response_text)
except Exception:
stream_failed = True
logger.exception("Streaming error for session {}", session_key)
finally:
await queue.put(None)
task = asyncio.create_task(_run())
@ -276,7 +293,10 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
break
await resp.write(_sse_chunk(token, model_name, chunk_id))
finally:
task.cancel()
if not task.done():
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
if not stream_failed:
await resp.write(_sse_chunk("", model_name, chunk_id, finish_reason="stop"))
@ -284,7 +304,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
return resp
# -- non-streaming path (original logic) --
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
fallback = EMPTY_FINAL_RESPONSE_MESSAGE
try:
async with session_lock:
@ -316,7 +336,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
response_text = _response_text(retry_response)
if not response_text or not response_text.strip():
logger.warning("Empty response after retry, using fallback")
response_text = _FALLBACK
response_text = fallback
except asyncio.TimeoutError:
return _error_json(504, f"Request timed out after {timeout_s}s")

View File

@ -10,8 +10,8 @@ import pytest
import pytest_asyncio
from nanobot.api.server import (
_sse_chunk,
_SSE_DONE,
_sse_chunk,
create_app,
)
@ -111,13 +111,13 @@ async def test_stream_true_returns_sse(aiohttp_client) -> None:
assert resp.content_type == "text/event-stream"
body = await resp.text()
lines = [l for l in body.split("\n") if l.startswith("data: ")]
lines = [line for line in body.split("\n") if line.startswith("data: ")]
# Should have: 2 token chunks + 1 finish chunk + [DONE]
data_lines = [l[len("data: "):] for l in lines]
data_lines = [line[len("data: "):] for line in lines]
assert data_lines[-1] == "[DONE]"
chunks = [json.loads(l) for l in data_lines[:-1]]
chunks = [json.loads(line) for line in data_lines[:-1]]
assert chunks[0]["choices"][0]["delta"]["content"] == "Hello"
assert chunks[1]["choices"][0]["delta"]["content"] == " world"
# Last chunk before [DONE] should have finish_reason=stop
@ -181,8 +181,12 @@ async def test_stream_sse_chunk_ids_are_consistent(aiohttp_client) -> None:
json={"messages": [{"role": "user", "content": "go"}], "stream": True},
)
body = await resp.text()
data_lines = [l[len("data: "):] for l in body.split("\n") if l.startswith("data: ") and l != "data: [DONE]"]
chunks = [json.loads(l) for l in data_lines]
data_lines = [
line[len("data: "):]
for line in body.split("\n")
if line.startswith("data: ") and line != "data: [DONE]"
]
chunks = [json.loads(line) for line in data_lines]
chunk_ids = {c["id"] for c in chunks}
assert len(chunk_ids) == 1, f"Expected single chunk id, got {chunk_ids}"
@ -218,6 +222,85 @@ async def test_stream_passes_on_stream_callbacks(aiohttp_client) -> None:
assert captured_kwargs.get("on_stream_end") is not None
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_stream_segment_end_does_not_close_sse(aiohttp_client) -> None:
"""Intermediate stream-end callbacks should not terminate the HTTP stream."""
agent = MagicMock()
async def fake_process_direct(*, on_stream=None, on_stream_end=None, **kwargs):
assert on_stream is not None
assert on_stream_end is not None
await on_stream("planning")
await on_stream_end(resuming=True)
await asyncio.sleep(0.05)
await on_stream(" final")
await on_stream_end(resuming=False)
return "planning final"
agent.process_direct = fake_process_direct
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "use a tool"}], "stream": True},
)
assert resp.status == 200
body = await resp.text()
data_lines = [
line[len("data: "):] for line in body.split("\n") if line.startswith("data: ")
]
assert data_lines[-1] == "[DONE]"
chunks = [json.loads(line) for line in data_lines[:-1]]
deltas = [c["choices"][0]["delta"].get("content", "") for c in chunks]
assert "planning" in deltas
assert " final" in deltas
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_stream_uses_final_response_when_no_deltas(aiohttp_client) -> None:
"""stream=true should not return an empty stream when the agent returns content."""
agent = MagicMock()
async def fake_process_direct(*, on_stream=None, on_stream_end=None, **kwargs):
assert on_stream is not None
assert on_stream_end is not None
await on_stream_end(resuming=False)
return "plain final"
agent.process_direct = fake_process_direct
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
)
assert resp.status == 200
body = await resp.text()
data_lines = [
line[len("data: "):] for line in body.split("\n") if line.startswith("data: ")
]
chunks = [json.loads(line) for line in data_lines[:-1]]
deltas = [c["choices"][0]["delta"].get("content", "") for c in chunks]
assert "plain final" in deltas
assert data_lines[-1] == "[DONE]"
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_stream_with_session_id(aiohttp_client) -> None: