mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-26 11:32:25 +00:00
Fix API stream lifecycle for tool-backed requests
This commit is contained in:
parent
73840b0af6
commit
1040124ede
@ -43,6 +43,26 @@ We use a two-branch model to balance stability and exploration:
|
||||
**When in doubt, target `nightly`.** It is easier to move a stable idea from `nightly`
|
||||
to `main` than to undo a risky change after it lands in the stable branch.
|
||||
|
||||
### Starting Work
|
||||
|
||||
Before making changes, sync the target branch and create a topic branch from it.
|
||||
For stable bug fixes and documentation-only changes, start from the latest `main`.
|
||||
For experimental work, start from the latest `nightly`.
|
||||
|
||||
```bash
|
||||
git fetch upstream
|
||||
git switch main
|
||||
git pull --ff-only upstream main
|
||||
git switch -c your-topic-branch
|
||||
```
|
||||
|
||||
Use your primary HKUDS/nanobot remote in place of `upstream` if your checkout
|
||||
uses a different remote name.
|
||||
|
||||
Keep unrelated local changes out of the topic branch. If your checkout already has
|
||||
work in progress, use a separate worktree or finish that work before starting a
|
||||
new branch.
|
||||
|
||||
### How Does Nightly Get Merged to Main?
|
||||
|
||||
We don't merge the entire `nightly` branch. Instead, stable features are **cherry-picked** from `nightly` into individual PRs targeting `main`:
|
||||
|
||||
@ -7,6 +7,7 @@ All requests route to a single persistent API session.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json as _json
|
||||
import time
|
||||
import uuid
|
||||
@ -18,8 +19,12 @@ from loguru import logger
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
from nanobot.utils.media_decode import (
|
||||
FileSizeExceeded as _FileSizeExceeded,
|
||||
MAX_FILE_SIZE,
|
||||
)
|
||||
from nanobot.utils.media_decode import (
|
||||
FileSizeExceeded as _FileSizeExceeded,
|
||||
)
|
||||
from nanobot.utils.media_decode import (
|
||||
save_base64_data_url as _save_base64_data_url,
|
||||
)
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
@ -240,18 +245,25 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
stream_failed = False
|
||||
emitted_content = False
|
||||
|
||||
async def _on_stream(token: str) -> None:
|
||||
nonlocal emitted_content
|
||||
if token:
|
||||
emitted_content = True
|
||||
await queue.put(token)
|
||||
|
||||
async def _on_stream_end(*_a: Any, **_kw: Any) -> None:
|
||||
await queue.put(None)
|
||||
# Agent stream-end callbacks mark generation segment boundaries.
|
||||
# Tool-backed requests may continue after a segment ends, so the
|
||||
# HTTP SSE stream is closed only when process_direct returns.
|
||||
return None
|
||||
|
||||
async def _run() -> None:
|
||||
nonlocal stream_failed
|
||||
try:
|
||||
async with session_lock:
|
||||
await asyncio.wait_for(
|
||||
response = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=text,
|
||||
media=media_paths if media_paths else None,
|
||||
@ -263,9 +275,14 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
if not emitted_content:
|
||||
response_text = _response_text(response)
|
||||
if response_text.strip():
|
||||
await queue.put(response_text)
|
||||
except Exception:
|
||||
stream_failed = True
|
||||
logger.exception("Streaming error for session {}", session_key)
|
||||
finally:
|
||||
await queue.put(None)
|
||||
|
||||
task = asyncio.create_task(_run())
|
||||
@ -276,7 +293,10 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
break
|
||||
await resp.write(_sse_chunk(token, model_name, chunk_id))
|
||||
finally:
|
||||
task.cancel()
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
if not stream_failed:
|
||||
await resp.write(_sse_chunk("", model_name, chunk_id, finish_reason="stop"))
|
||||
@ -284,7 +304,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
return resp
|
||||
|
||||
# -- non-streaming path (original logic) --
|
||||
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
fallback = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
try:
|
||||
async with session_lock:
|
||||
@ -316,7 +336,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
response_text = _response_text(retry_response)
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning("Empty response after retry, using fallback")
|
||||
response_text = _FALLBACK
|
||||
response_text = fallback
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return _error_json(504, f"Request timed out after {timeout_s}s")
|
||||
|
||||
@ -10,8 +10,8 @@ import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from nanobot.api.server import (
|
||||
_sse_chunk,
|
||||
_SSE_DONE,
|
||||
_sse_chunk,
|
||||
create_app,
|
||||
)
|
||||
|
||||
@ -111,13 +111,13 @@ async def test_stream_true_returns_sse(aiohttp_client) -> None:
|
||||
assert resp.content_type == "text/event-stream"
|
||||
|
||||
body = await resp.text()
|
||||
lines = [l for l in body.split("\n") if l.startswith("data: ")]
|
||||
lines = [line for line in body.split("\n") if line.startswith("data: ")]
|
||||
|
||||
# Should have: 2 token chunks + 1 finish chunk + [DONE]
|
||||
data_lines = [l[len("data: "):] for l in lines]
|
||||
data_lines = [line[len("data: "):] for line in lines]
|
||||
assert data_lines[-1] == "[DONE]"
|
||||
|
||||
chunks = [json.loads(l) for l in data_lines[:-1]]
|
||||
chunks = [json.loads(line) for line in data_lines[:-1]]
|
||||
assert chunks[0]["choices"][0]["delta"]["content"] == "Hello"
|
||||
assert chunks[1]["choices"][0]["delta"]["content"] == " world"
|
||||
# Last chunk before [DONE] should have finish_reason=stop
|
||||
@ -181,8 +181,12 @@ async def test_stream_sse_chunk_ids_are_consistent(aiohttp_client) -> None:
|
||||
json={"messages": [{"role": "user", "content": "go"}], "stream": True},
|
||||
)
|
||||
body = await resp.text()
|
||||
data_lines = [l[len("data: "):] for l in body.split("\n") if l.startswith("data: ") and l != "data: [DONE]"]
|
||||
chunks = [json.loads(l) for l in data_lines]
|
||||
data_lines = [
|
||||
line[len("data: "):]
|
||||
for line in body.split("\n")
|
||||
if line.startswith("data: ") and line != "data: [DONE]"
|
||||
]
|
||||
chunks = [json.loads(line) for line in data_lines]
|
||||
|
||||
chunk_ids = {c["id"] for c in chunks}
|
||||
assert len(chunk_ids) == 1, f"Expected single chunk id, got {chunk_ids}"
|
||||
@ -218,6 +222,85 @@ async def test_stream_passes_on_stream_callbacks(aiohttp_client) -> None:
|
||||
assert captured_kwargs.get("on_stream_end") is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_segment_end_does_not_close_sse(aiohttp_client) -> None:
|
||||
"""Intermediate stream-end callbacks should not terminate the HTTP stream."""
|
||||
agent = MagicMock()
|
||||
|
||||
async def fake_process_direct(*, on_stream=None, on_stream_end=None, **kwargs):
|
||||
assert on_stream is not None
|
||||
assert on_stream_end is not None
|
||||
await on_stream("planning")
|
||||
await on_stream_end(resuming=True)
|
||||
await asyncio.sleep(0.05)
|
||||
await on_stream(" final")
|
||||
await on_stream_end(resuming=False)
|
||||
return "planning final"
|
||||
|
||||
agent.process_direct = fake_process_direct
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "use a tool"}], "stream": True},
|
||||
)
|
||||
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
data_lines = [
|
||||
line[len("data: "):] for line in body.split("\n") if line.startswith("data: ")
|
||||
]
|
||||
assert data_lines[-1] == "[DONE]"
|
||||
|
||||
chunks = [json.loads(line) for line in data_lines[:-1]]
|
||||
deltas = [c["choices"][0]["delta"].get("content", "") for c in chunks]
|
||||
assert "planning" in deltas
|
||||
assert " final" in deltas
|
||||
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_uses_final_response_when_no_deltas(aiohttp_client) -> None:
|
||||
"""stream=true should not return an empty stream when the agent returns content."""
|
||||
agent = MagicMock()
|
||||
|
||||
async def fake_process_direct(*, on_stream=None, on_stream_end=None, **kwargs):
|
||||
assert on_stream is not None
|
||||
assert on_stream_end is not None
|
||||
await on_stream_end(resuming=False)
|
||||
return "plain final"
|
||||
|
||||
agent.process_direct = fake_process_direct
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
|
||||
)
|
||||
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
data_lines = [
|
||||
line[len("data: "):] for line in body.split("\n") if line.startswith("data: ")
|
||||
]
|
||||
chunks = [json.loads(line) for line in data_lines[:-1]]
|
||||
deltas = [c["choices"][0]["delta"].get("content", "") for c in chunks]
|
||||
|
||||
assert "plain final" in deltas
|
||||
assert data_lines[-1] == "[DONE]"
|
||||
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_with_session_id(aiohttp_client) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user