mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-26 03:22:38 +00:00
Fix API stream lifecycle for tool-backed requests
This commit is contained in:
parent
73840b0af6
commit
1040124ede
@ -43,6 +43,26 @@ We use a two-branch model to balance stability and exploration:
|
|||||||
**When in doubt, target `nightly`.** It is easier to move a stable idea from `nightly`
|
**When in doubt, target `nightly`.** It is easier to move a stable idea from `nightly`
|
||||||
to `main` than to undo a risky change after it lands in the stable branch.
|
to `main` than to undo a risky change after it lands in the stable branch.
|
||||||
|
|
||||||
|
### Starting Work
|
||||||
|
|
||||||
|
Before making changes, sync the target branch and create a topic branch from it.
|
||||||
|
For stable bug fixes and documentation-only changes, start from the latest `main`.
|
||||||
|
For experimental work, start from the latest `nightly`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git fetch upstream
|
||||||
|
git switch main
|
||||||
|
git pull --ff-only upstream main
|
||||||
|
git switch -c your-topic-branch
|
||||||
|
```
|
||||||
|
|
||||||
|
Use your primary HKUDS/nanobot remote in place of `upstream` if your checkout
|
||||||
|
uses a different remote name.
|
||||||
|
|
||||||
|
Keep unrelated local changes out of the topic branch. If your checkout already has
|
||||||
|
work in progress, use a separate worktree or finish that work before starting a
|
||||||
|
new branch.
|
||||||
|
|
||||||
### How Does Nightly Get Merged to Main?
|
### How Does Nightly Get Merged to Main?
|
||||||
|
|
||||||
We don't merge the entire `nightly` branch. Instead, stable features are **cherry-picked** from `nightly` into individual PRs targeting `main`:
|
We don't merge the entire `nightly` branch. Instead, stable features are **cherry-picked** from `nightly` into individual PRs targeting `main`:
|
||||||
|
|||||||
@ -7,6 +7,7 @@ All requests route to a single persistent API session.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import json as _json
|
import json as _json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@ -18,8 +19,12 @@ from loguru import logger
|
|||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
from nanobot.utils.helpers import safe_filename
|
from nanobot.utils.helpers import safe_filename
|
||||||
from nanobot.utils.media_decode import (
|
from nanobot.utils.media_decode import (
|
||||||
FileSizeExceeded as _FileSizeExceeded,
|
|
||||||
MAX_FILE_SIZE,
|
MAX_FILE_SIZE,
|
||||||
|
)
|
||||||
|
from nanobot.utils.media_decode import (
|
||||||
|
FileSizeExceeded as _FileSizeExceeded,
|
||||||
|
)
|
||||||
|
from nanobot.utils.media_decode import (
|
||||||
save_base64_data_url as _save_base64_data_url,
|
save_base64_data_url as _save_base64_data_url,
|
||||||
)
|
)
|
||||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
@ -240,18 +245,25 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||||
stream_failed = False
|
stream_failed = False
|
||||||
|
emitted_content = False
|
||||||
|
|
||||||
async def _on_stream(token: str) -> None:
|
async def _on_stream(token: str) -> None:
|
||||||
|
nonlocal emitted_content
|
||||||
|
if token:
|
||||||
|
emitted_content = True
|
||||||
await queue.put(token)
|
await queue.put(token)
|
||||||
|
|
||||||
async def _on_stream_end(*_a: Any, **_kw: Any) -> None:
|
async def _on_stream_end(*_a: Any, **_kw: Any) -> None:
|
||||||
await queue.put(None)
|
# Agent stream-end callbacks mark generation segment boundaries.
|
||||||
|
# Tool-backed requests may continue after a segment ends, so the
|
||||||
|
# HTTP SSE stream is closed only when process_direct returns.
|
||||||
|
return None
|
||||||
|
|
||||||
async def _run() -> None:
|
async def _run() -> None:
|
||||||
nonlocal stream_failed
|
nonlocal stream_failed
|
||||||
try:
|
try:
|
||||||
async with session_lock:
|
async with session_lock:
|
||||||
await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
agent_loop.process_direct(
|
agent_loop.process_direct(
|
||||||
content=text,
|
content=text,
|
||||||
media=media_paths if media_paths else None,
|
media=media_paths if media_paths else None,
|
||||||
@ -263,9 +275,14 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
),
|
),
|
||||||
timeout=timeout_s,
|
timeout=timeout_s,
|
||||||
)
|
)
|
||||||
|
if not emitted_content:
|
||||||
|
response_text = _response_text(response)
|
||||||
|
if response_text.strip():
|
||||||
|
await queue.put(response_text)
|
||||||
except Exception:
|
except Exception:
|
||||||
stream_failed = True
|
stream_failed = True
|
||||||
logger.exception("Streaming error for session {}", session_key)
|
logger.exception("Streaming error for session {}", session_key)
|
||||||
|
finally:
|
||||||
await queue.put(None)
|
await queue.put(None)
|
||||||
|
|
||||||
task = asyncio.create_task(_run())
|
task = asyncio.create_task(_run())
|
||||||
@ -276,7 +293,10 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
break
|
break
|
||||||
await resp.write(_sse_chunk(token, model_name, chunk_id))
|
await resp.write(_sse_chunk(token, model_name, chunk_id))
|
||||||
finally:
|
finally:
|
||||||
|
if not task.done():
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
|
||||||
if not stream_failed:
|
if not stream_failed:
|
||||||
await resp.write(_sse_chunk("", model_name, chunk_id, finish_reason="stop"))
|
await resp.write(_sse_chunk("", model_name, chunk_id, finish_reason="stop"))
|
||||||
@ -284,7 +304,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
# -- non-streaming path (original logic) --
|
# -- non-streaming path (original logic) --
|
||||||
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
|
fallback = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with session_lock:
|
async with session_lock:
|
||||||
@ -316,7 +336,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
response_text = _response_text(retry_response)
|
response_text = _response_text(retry_response)
|
||||||
if not response_text or not response_text.strip():
|
if not response_text or not response_text.strip():
|
||||||
logger.warning("Empty response after retry, using fallback")
|
logger.warning("Empty response after retry, using fallback")
|
||||||
response_text = _FALLBACK
|
response_text = fallback
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
return _error_json(504, f"Request timed out after {timeout_s}s")
|
return _error_json(504, f"Request timed out after {timeout_s}s")
|
||||||
|
|||||||
@ -10,8 +10,8 @@ import pytest
|
|||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from nanobot.api.server import (
|
from nanobot.api.server import (
|
||||||
_sse_chunk,
|
|
||||||
_SSE_DONE,
|
_SSE_DONE,
|
||||||
|
_sse_chunk,
|
||||||
create_app,
|
create_app,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -111,13 +111,13 @@ async def test_stream_true_returns_sse(aiohttp_client) -> None:
|
|||||||
assert resp.content_type == "text/event-stream"
|
assert resp.content_type == "text/event-stream"
|
||||||
|
|
||||||
body = await resp.text()
|
body = await resp.text()
|
||||||
lines = [l for l in body.split("\n") if l.startswith("data: ")]
|
lines = [line for line in body.split("\n") if line.startswith("data: ")]
|
||||||
|
|
||||||
# Should have: 2 token chunks + 1 finish chunk + [DONE]
|
# Should have: 2 token chunks + 1 finish chunk + [DONE]
|
||||||
data_lines = [l[len("data: "):] for l in lines]
|
data_lines = [line[len("data: "):] for line in lines]
|
||||||
assert data_lines[-1] == "[DONE]"
|
assert data_lines[-1] == "[DONE]"
|
||||||
|
|
||||||
chunks = [json.loads(l) for l in data_lines[:-1]]
|
chunks = [json.loads(line) for line in data_lines[:-1]]
|
||||||
assert chunks[0]["choices"][0]["delta"]["content"] == "Hello"
|
assert chunks[0]["choices"][0]["delta"]["content"] == "Hello"
|
||||||
assert chunks[1]["choices"][0]["delta"]["content"] == " world"
|
assert chunks[1]["choices"][0]["delta"]["content"] == " world"
|
||||||
# Last chunk before [DONE] should have finish_reason=stop
|
# Last chunk before [DONE] should have finish_reason=stop
|
||||||
@ -181,8 +181,12 @@ async def test_stream_sse_chunk_ids_are_consistent(aiohttp_client) -> None:
|
|||||||
json={"messages": [{"role": "user", "content": "go"}], "stream": True},
|
json={"messages": [{"role": "user", "content": "go"}], "stream": True},
|
||||||
)
|
)
|
||||||
body = await resp.text()
|
body = await resp.text()
|
||||||
data_lines = [l[len("data: "):] for l in body.split("\n") if l.startswith("data: ") and l != "data: [DONE]"]
|
data_lines = [
|
||||||
chunks = [json.loads(l) for l in data_lines]
|
line[len("data: "):]
|
||||||
|
for line in body.split("\n")
|
||||||
|
if line.startswith("data: ") and line != "data: [DONE]"
|
||||||
|
]
|
||||||
|
chunks = [json.loads(line) for line in data_lines]
|
||||||
|
|
||||||
chunk_ids = {c["id"] for c in chunks}
|
chunk_ids = {c["id"] for c in chunks}
|
||||||
assert len(chunk_ids) == 1, f"Expected single chunk id, got {chunk_ids}"
|
assert len(chunk_ids) == 1, f"Expected single chunk id, got {chunk_ids}"
|
||||||
@ -218,6 +222,85 @@ async def test_stream_passes_on_stream_callbacks(aiohttp_client) -> None:
|
|||||||
assert captured_kwargs.get("on_stream_end") is not None
|
assert captured_kwargs.get("on_stream_end") is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_segment_end_does_not_close_sse(aiohttp_client) -> None:
|
||||||
|
"""Intermediate stream-end callbacks should not terminate the HTTP stream."""
|
||||||
|
agent = MagicMock()
|
||||||
|
|
||||||
|
async def fake_process_direct(*, on_stream=None, on_stream_end=None, **kwargs):
|
||||||
|
assert on_stream is not None
|
||||||
|
assert on_stream_end is not None
|
||||||
|
await on_stream("planning")
|
||||||
|
await on_stream_end(resuming=True)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
await on_stream(" final")
|
||||||
|
await on_stream_end(resuming=False)
|
||||||
|
return "planning final"
|
||||||
|
|
||||||
|
agent.process_direct = fake_process_direct
|
||||||
|
agent._connect_mcp = AsyncMock()
|
||||||
|
agent.close_mcp = AsyncMock()
|
||||||
|
|
||||||
|
app = create_app(agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "use a tool"}], "stream": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
body = await resp.text()
|
||||||
|
data_lines = [
|
||||||
|
line[len("data: "):] for line in body.split("\n") if line.startswith("data: ")
|
||||||
|
]
|
||||||
|
assert data_lines[-1] == "[DONE]"
|
||||||
|
|
||||||
|
chunks = [json.loads(line) for line in data_lines[:-1]]
|
||||||
|
deltas = [c["choices"][0]["delta"].get("content", "") for c in chunks]
|
||||||
|
assert "planning" in deltas
|
||||||
|
assert " final" in deltas
|
||||||
|
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_uses_final_response_when_no_deltas(aiohttp_client) -> None:
|
||||||
|
"""stream=true should not return an empty stream when the agent returns content."""
|
||||||
|
agent = MagicMock()
|
||||||
|
|
||||||
|
async def fake_process_direct(*, on_stream=None, on_stream_end=None, **kwargs):
|
||||||
|
assert on_stream is not None
|
||||||
|
assert on_stream_end is not None
|
||||||
|
await on_stream_end(resuming=False)
|
||||||
|
return "plain final"
|
||||||
|
|
||||||
|
agent.process_direct = fake_process_direct
|
||||||
|
agent._connect_mcp = AsyncMock()
|
||||||
|
agent.close_mcp = AsyncMock()
|
||||||
|
|
||||||
|
app = create_app(agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
body = await resp.text()
|
||||||
|
data_lines = [
|
||||||
|
line[len("data: "):] for line in body.split("\n") if line.startswith("data: ")
|
||||||
|
]
|
||||||
|
chunks = [json.loads(line) for line in data_lines[:-1]]
|
||||||
|
deltas = [c["choices"][0]["delta"].get("content", "") for c in chunks]
|
||||||
|
|
||||||
|
assert "plain final" in deltas
|
||||||
|
assert data_lines[-1] == "[DONE]"
|
||||||
|
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_with_session_id(aiohttp_client) -> None:
|
async def test_stream_with_session_id(aiohttp_client) -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user