mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 14:23:58 +00:00
fix(agent): address session and streaming concurrency bugs
This commit is contained in:
parent
1a4ae8994d
commit
0df60416ba
@ -812,16 +812,16 @@ class AgentLoop:
|
||||
logger.warning("Error consuming inbound message: {}, continuing...", e)
|
||||
continue
|
||||
|
||||
raw = msg.content.strip()
|
||||
effective_key = self._effective_session_key(msg)
|
||||
if await agent_context.handle_runtime_control(self, msg, self.tools):
|
||||
continue
|
||||
raw = msg.content.strip()
|
||||
if self.commands.is_priority(raw):
|
||||
await self._dispatch_command_inline(
|
||||
msg, msg.session_key, raw,
|
||||
msg, effective_key, raw,
|
||||
self.commands.dispatch_priority,
|
||||
)
|
||||
continue
|
||||
effective_key = self._effective_session_key(msg)
|
||||
# If this session already has an active pending queue (i.e. a task
|
||||
# is processing this session), route the message there for mid-turn
|
||||
# injection instead of creating a competing task.
|
||||
@ -872,13 +872,13 @@ class AgentLoop:
|
||||
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
|
||||
gate = self._concurrency_gate or nullcontext()
|
||||
|
||||
# Register a pending queue so follow-up messages for this session are
|
||||
# routed here (mid-turn injection) instead of spawning a new task.
|
||||
pending = asyncio.Queue(maxsize=20)
|
||||
self._pending_queues[session_key] = pending
|
||||
|
||||
pending: asyncio.Queue | None = None
|
||||
try:
|
||||
async with lock, gate:
|
||||
# Only the task that owns the session lock may publish the
|
||||
# active mid-turn injection queue for this session.
|
||||
pending = asyncio.Queue(maxsize=20)
|
||||
self._pending_queues[session_key] = pending
|
||||
try:
|
||||
on_stream = on_stream_end = None
|
||||
if msg.metadata.get("_wants_stream"):
|
||||
@ -962,28 +962,39 @@ class AgentLoop:
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
finally:
|
||||
# Drain any messages still in the pending queue and re-publish
|
||||
# them to the bus so they are processed as fresh inbound messages
|
||||
# rather than silently lost. Only remove our own queue; a
|
||||
# later task waiting on the lock must not be able to steal
|
||||
# cleanup ownership.
|
||||
queue = None
|
||||
if self._pending_queues.get(session_key) is pending:
|
||||
queue = self._pending_queues.pop(session_key, None)
|
||||
else:
|
||||
queue = pending
|
||||
if queue is not None:
|
||||
leftover = 0
|
||||
while True:
|
||||
try:
|
||||
item = queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
await self.bus.publish_inbound(item)
|
||||
leftover += 1
|
||||
if leftover:
|
||||
logger.info(
|
||||
"Re-published {} leftover message(s) to bus for session {}",
|
||||
leftover, session_key,
|
||||
)
|
||||
await self._webui_turns.publish_run_status(msg, "idle")
|
||||
self._pending_turn_latency_ms.pop(session_key, None)
|
||||
self._webui_turns.discard(session_key)
|
||||
finally:
|
||||
# Drain any messages still in the pending queue and re-publish
|
||||
# them to the bus so they are processed as fresh inbound messages
|
||||
# rather than silently lost.
|
||||
queue = self._pending_queues.pop(session_key, None)
|
||||
if queue is not None:
|
||||
leftover = 0
|
||||
while True:
|
||||
try:
|
||||
item = queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
await self.bus.publish_inbound(item)
|
||||
leftover += 1
|
||||
if leftover:
|
||||
logger.info(
|
||||
"Re-published {} leftover message(s) to bus for session {}",
|
||||
leftover, session_key,
|
||||
)
|
||||
await self._webui_turns.publish_run_status(msg, "idle")
|
||||
self._pending_turn_latency_ms.pop(session_key, None)
|
||||
self._webui_turns.discard(session_key)
|
||||
if pending is None:
|
||||
await self._webui_turns.publish_run_status(msg, "idle")
|
||||
self._pending_turn_latency_ms.pop(session_key, None)
|
||||
self._webui_turns.discard(session_key)
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Drain pending background archives, then close MCP connections."""
|
||||
|
||||
@ -1273,7 +1273,13 @@ class AgentRunner:
|
||||
return messages
|
||||
|
||||
system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages)
|
||||
remaining_budget = max(128, budget - system_tokens)
|
||||
fixed_tokens, _ = estimate_prompt_tokens_chain(
|
||||
self.provider,
|
||||
spec.model,
|
||||
system_messages,
|
||||
spec.tools.get_definitions(),
|
||||
)
|
||||
remaining_budget = max(0, budget - max(system_tokens, fixed_tokens))
|
||||
kept: list[dict[str, Any]] = []
|
||||
kept_tokens = 0
|
||||
for message in reversed(non_system):
|
||||
|
||||
@ -16,6 +16,7 @@ There is **no** sub-agent orchestrator and **no** special WebSocket ``agent_ui``
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@ -45,15 +46,19 @@ class _GoalToolsMixin(ContextAware):
|
||||
def __init__(self, sessions: SessionManager, bus: Any | None = None) -> None:
|
||||
self._sessions = sessions
|
||||
self._bus = bus
|
||||
self._request_ctx: RequestContext | None = None
|
||||
self._request_ctx: ContextVar[RequestContext | None] = ContextVar(
|
||||
f"{self.__class__.__name__}_request_ctx",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def set_context(self, ctx: RequestContext) -> None:
|
||||
self._request_ctx = ctx
|
||||
self._request_ctx.set(ctx)
|
||||
|
||||
def _session(self):
|
||||
if self._request_ctx is None:
|
||||
request_ctx = self._request_ctx.get()
|
||||
if request_ctx is None:
|
||||
return None
|
||||
key = self._request_ctx.session_key
|
||||
key = request_ctx.session_key
|
||||
if not key:
|
||||
return None
|
||||
return self._sessions.get_or_create(key)
|
||||
@ -61,7 +66,7 @@ class _GoalToolsMixin(ContextAware):
|
||||
async def _publish_goal_state_ws(self, metadata: dict[str, Any]) -> None:
|
||||
"""Fan-out authoritative goal snapshot for this WebSocket chat only."""
|
||||
bus = self._bus
|
||||
rc = self._request_ctx
|
||||
rc = self._request_ctx.get()
|
||||
if bus is None or rc is None or rc.channel != "websocket":
|
||||
return
|
||||
cid = (rc.chat_id or "").strip()
|
||||
@ -224,4 +229,3 @@ class CompleteGoalTool(Tool, _GoalToolsMixin):
|
||||
if tail:
|
||||
return f"Goal marked complete ({ended}). Recap:\n{tail}"
|
||||
return f"Goal marked complete ({ended})."
|
||||
|
||||
|
||||
@ -123,7 +123,7 @@ async def cmd_stop(ctx: CommandContext) -> OutboundMessage:
|
||||
"""Cancel all active tasks and subagents for the session."""
|
||||
loop = ctx.loop
|
||||
msg = ctx.msg
|
||||
total = await loop._cancel_active_tasks(msg.session_key)
|
||||
total = await loop._cancel_active_tasks(ctx.key)
|
||||
content = f"Stopped {total} task(s)." if total else "No active task to stop."
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||
|
||||
@ -557,11 +557,20 @@ class LLMProvider(ABC):
|
||||
if reasoning_effort is self._SENTINEL:
|
||||
reasoning_effort = self.generation.reasoning_effort
|
||||
|
||||
has_streamed_content = False
|
||||
|
||||
async def _tracking_delta(text: str) -> None:
|
||||
nonlocal has_streamed_content
|
||||
if text:
|
||||
has_streamed_content = True
|
||||
if on_content_delta:
|
||||
await on_content_delta(text)
|
||||
|
||||
kw: dict[str, Any] = dict(
|
||||
messages=messages, tools=tools, model=model,
|
||||
max_tokens=max_tokens, temperature=temperature,
|
||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||
on_content_delta=on_content_delta,
|
||||
on_content_delta=_tracking_delta if on_content_delta is not None else None,
|
||||
on_thinking_delta=on_thinking_delta,
|
||||
on_tool_call_delta=on_tool_call_delta,
|
||||
)
|
||||
@ -571,6 +580,7 @@ class LLMProvider(ABC):
|
||||
messages,
|
||||
retry_mode=retry_mode,
|
||||
on_retry_wait=on_retry_wait,
|
||||
should_retry_guard=lambda: not has_streamed_content,
|
||||
)
|
||||
|
||||
async def chat_with_retry(
|
||||
@ -717,6 +727,7 @@ class LLMProvider(ABC):
|
||||
*,
|
||||
retry_mode: str,
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None,
|
||||
should_retry_guard: Callable[[], bool] | None = None,
|
||||
) -> LLMResponse:
|
||||
attempt = 0
|
||||
delays = list(self._CHAT_RETRY_DELAYS)
|
||||
@ -730,6 +741,11 @@ class LLMProvider(ABC):
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
last_response = response
|
||||
if should_retry_guard is not None and not should_retry_guard():
|
||||
logger.warning(
|
||||
"LLM stream failed after content was emitted; skipping retry"
|
||||
)
|
||||
return response
|
||||
error_key = ((response.content or "").strip().lower() or None)
|
||||
if error_key and error_key == last_error_key:
|
||||
identical_error_count += 1
|
||||
|
||||
@ -105,6 +105,60 @@ def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch
|
||||
assert trimmed[0]["role"] == "system"
|
||||
non_system = [m for m in trimmed if m["role"] != "system"]
|
||||
assert non_system[0]["role"] == "user", f"Expected user after system, got {non_system[0]['role']}"
|
||||
|
||||
|
||||
def test_snip_history_reserves_budget_for_tool_definitions(monkeypatch):
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = [{"type": "function", "function": {"name": "large_tool"}}]
|
||||
runner = AgentRunner(provider)
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{"role": "assistant", "content": "old assistant"},
|
||||
{"role": "user", "content": "recent one"},
|
||||
{"role": "assistant", "content": "recent answer"},
|
||||
{"role": "user", "content": "recent two"},
|
||||
]
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
context_window_tokens=2000,
|
||||
context_block_limit=500,
|
||||
)
|
||||
|
||||
def _estimate(_provider, _model, estimate_messages, estimate_tools):
|
||||
if estimate_messages == messages:
|
||||
return 1000, None
|
||||
assert estimate_messages == [{"role": "system", "content": "system"}]
|
||||
assert estimate_tools == tools.get_definitions.return_value
|
||||
return 350, None
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", _estimate)
|
||||
token_sizes = {
|
||||
"system": 50,
|
||||
"old user": 200,
|
||||
"old assistant": 200,
|
||||
"recent one": 200,
|
||||
"recent answer": 200,
|
||||
"recent two": 200,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runner.estimate_message_tokens",
|
||||
lambda msg: token_sizes.get(str(msg.get("content")), 40),
|
||||
)
|
||||
|
||||
trimmed = runner._snip_history(spec, messages)
|
||||
|
||||
contents = [message.get("content") for message in trimmed]
|
||||
assert contents == ["system", "recent two"]
|
||||
|
||||
|
||||
async def test_backfill_missing_tool_results_inserts_error():
|
||||
"""Orphaned tool_use (no matching tool_result) should get a synthetic error."""
|
||||
from nanobot.agent.runner import AgentRunner, _BACKFILL_CONTENT
|
||||
|
||||
@ -554,6 +554,31 @@ async def test_pending_queue_cleanup_on_dispatch(tmp_path):
|
||||
assert msg.session_key not in loop._pending_queues
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_waiting_dispatch_does_not_replace_active_pending_queue(tmp_path):
|
||||
"""A queued dispatch must not steal the active task's injection queue."""
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = _make_loop(tmp_path)
|
||||
session_key = "cli:c"
|
||||
lock = loop._session_locks.setdefault(session_key, asyncio.Lock())
|
||||
await lock.acquire()
|
||||
active_pending = asyncio.Queue(maxsize=1)
|
||||
loop._pending_queues[session_key] = active_pending
|
||||
|
||||
waiting = asyncio.create_task(
|
||||
loop._dispatch(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="queued"))
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert loop._pending_queues[session_key] is active_pending
|
||||
|
||||
waiting.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await waiting
|
||||
lock.release()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_followup_routed_to_pending_queue(tmp_path):
|
||||
"""Unified-session follow-ups should route into the active pending queue."""
|
||||
|
||||
@ -474,6 +474,32 @@ class TestStopCommandWithUnifiedSession:
|
||||
assert task.cancelled() or task.done()
|
||||
assert "Stopped 1 task" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_command_uses_effective_key_without_session_override(self, tmp_path: Path):
|
||||
"""Priority /stop must cancel the unified session even before dispatch rewrites the message."""
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
from nanobot.command.builtin import cmd_stop
|
||||
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
async def long_running():
|
||||
await asyncio.sleep(10)
|
||||
|
||||
task = asyncio.create_task(long_running())
|
||||
loop._active_tasks[UNIFIED_SESSION_KEY] = [task]
|
||||
msg = InboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123456",
|
||||
sender_id="user1",
|
||||
content="/stop",
|
||||
)
|
||||
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
|
||||
|
||||
result = await cmd_stop(ctx)
|
||||
|
||||
assert task.cancelled() or task.done()
|
||||
assert "Stopped 1 task" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path):
|
||||
"""In unified mode, /stop from one channel cancels tasks from another channel."""
|
||||
@ -504,4 +530,4 @@ class TestStopCommandWithUnifiedSession:
|
||||
result = await cmd_stop(ctx)
|
||||
|
||||
# Both tasks should be cancelled
|
||||
assert "Stopped 2 task" in result.content
|
||||
assert "Stopped 2 task" in result.content
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
@ -72,6 +73,33 @@ async def test_complete_goal_closes_active_goal(tmp_path):
|
||||
assert blob["recap"] == "Done."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_goal_tools_keep_request_context_per_task(tmp_path):
|
||||
sm = SessionManager(tmp_path)
|
||||
lt = LongTaskTool(sessions=sm)
|
||||
cg = CompleteGoalTool(sessions=sm)
|
||||
ctx_a = RequestContext(channel="websocket", chat_id="a", session_key="websocket:a")
|
||||
ctx_b = RequestContext(channel="websocket", chat_id="b", session_key="websocket:b")
|
||||
|
||||
lt.set_context(ctx_a)
|
||||
task_a = asyncio.create_task(lt.execute(goal="Goal A"))
|
||||
lt.set_context(ctx_b)
|
||||
task_b = asyncio.create_task(lt.execute(goal="Goal B"))
|
||||
await asyncio.gather(task_a, task_b)
|
||||
|
||||
assert sm.get_or_create("websocket:a").metadata[GOAL_STATE_KEY]["objective"] == "Goal A"
|
||||
assert sm.get_or_create("websocket:b").metadata[GOAL_STATE_KEY]["objective"] == "Goal B"
|
||||
|
||||
cg.set_context(ctx_a)
|
||||
done_a = asyncio.create_task(cg.execute(recap="Done A"))
|
||||
cg.set_context(ctx_b)
|
||||
done_b = asyncio.create_task(cg.execute(recap="Done B"))
|
||||
await asyncio.gather(done_a, done_b)
|
||||
|
||||
assert sm.get_or_create("websocket:a").metadata[GOAL_STATE_KEY]["recap"] == "Done A"
|
||||
assert sm.get_or_create("websocket:b").metadata[GOAL_STATE_KEY]["recap"] == "Done B"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_task_publishes_goal_state_ws_after_save(tmp_path):
|
||||
bus = MagicMock()
|
||||
|
||||
@ -21,6 +21,17 @@ class ScriptedProvider(LLMProvider):
|
||||
raise response
|
||||
return response
|
||||
|
||||
async def chat_stream(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
self.last_kwargs = kwargs
|
||||
response = self._responses.pop(0)
|
||||
if isinstance(response, BaseException):
|
||||
raise response
|
||||
delta = getattr(response, "_test_stream_delta", None)
|
||||
if delta and kwargs.get("on_content_delta"):
|
||||
await kwargs["on_content_delta"](delta)
|
||||
return response
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
@ -122,6 +133,36 @@ async def test_chat_with_retry_preserves_cancelled_error() -> None:
|
||||
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_with_retry_does_not_retry_after_emitting_content(monkeypatch) -> None:
|
||||
first = LLMResponse(content="stream stalled", finish_reason="error")
|
||||
first._test_stream_delta = "partial" # type: ignore[attr-defined]
|
||||
provider = ScriptedProvider([
|
||||
first,
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
deltas: list[str] = []
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
async def _on_delta(delta: str) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_stream_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
on_content_delta=_on_delta,
|
||||
)
|
||||
|
||||
assert response.content == "stream stalled"
|
||||
assert provider.calls == 1
|
||||
assert deltas == ["partial"]
|
||||
assert delays == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
||||
"""When callers omit generation params, provider.generation defaults are used."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user