fix(agent): address session and streaming concurrency bugs

This commit is contained in:
hamb1y 2026-05-28 16:39:48 +05:30 committed by Xubin Ren
parent 1a4ae8994d
commit 0df60416ba
10 changed files with 250 additions and 39 deletions

View File

@ -812,16 +812,16 @@ class AgentLoop:
logger.warning("Error consuming inbound message: {}, continuing...", e)
continue
raw = msg.content.strip()
effective_key = self._effective_session_key(msg)
if await agent_context.handle_runtime_control(self, msg, self.tools):
continue
raw = msg.content.strip()
if self.commands.is_priority(raw):
await self._dispatch_command_inline(
msg, msg.session_key, raw,
msg, effective_key, raw,
self.commands.dispatch_priority,
)
continue
effective_key = self._effective_session_key(msg)
# If this session already has an active pending queue (i.e. a task
# is processing this session), route the message there for mid-turn
# injection instead of creating a competing task.
@ -872,13 +872,13 @@ class AgentLoop:
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
gate = self._concurrency_gate or nullcontext()
# Register a pending queue so follow-up messages for this session are
# routed here (mid-turn injection) instead of spawning a new task.
pending = asyncio.Queue(maxsize=20)
self._pending_queues[session_key] = pending
pending: asyncio.Queue | None = None
try:
async with lock, gate:
# Only the task that owns the session lock may publish the
# active mid-turn injection queue for this session.
pending = asyncio.Queue(maxsize=20)
self._pending_queues[session_key] = pending
try:
on_stream = on_stream_end = None
if msg.metadata.get("_wants_stream"):
@ -962,28 +962,39 @@ class AgentLoop:
channel=msg.channel, chat_id=msg.chat_id,
content="Sorry, I encountered an error.",
))
finally:
# Drain any messages still in the pending queue and re-publish
# them to the bus so they are processed as fresh inbound messages
# rather than silently lost. Only remove our own queue; a
# later task waiting on the lock must not be able to steal
# cleanup ownership.
queue = None
if self._pending_queues.get(session_key) is pending:
queue = self._pending_queues.pop(session_key, None)
else:
queue = pending
if queue is not None:
leftover = 0
while True:
try:
item = queue.get_nowait()
except asyncio.QueueEmpty:
break
await self.bus.publish_inbound(item)
leftover += 1
if leftover:
logger.info(
"Re-published {} leftover message(s) to bus for session {}",
leftover, session_key,
)
await self._webui_turns.publish_run_status(msg, "idle")
self._pending_turn_latency_ms.pop(session_key, None)
self._webui_turns.discard(session_key)
finally:
# Drain any messages still in the pending queue and re-publish
# them to the bus so they are processed as fresh inbound messages
# rather than silently lost.
queue = self._pending_queues.pop(session_key, None)
if queue is not None:
leftover = 0
while True:
try:
item = queue.get_nowait()
except asyncio.QueueEmpty:
break
await self.bus.publish_inbound(item)
leftover += 1
if leftover:
logger.info(
"Re-published {} leftover message(s) to bus for session {}",
leftover, session_key,
)
await self._webui_turns.publish_run_status(msg, "idle")
self._pending_turn_latency_ms.pop(session_key, None)
self._webui_turns.discard(session_key)
if pending is None:
await self._webui_turns.publish_run_status(msg, "idle")
self._pending_turn_latency_ms.pop(session_key, None)
self._webui_turns.discard(session_key)
async def close_mcp(self) -> None:
"""Drain pending background archives, then close MCP connections."""

View File

@ -1273,7 +1273,13 @@ class AgentRunner:
return messages
system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages)
remaining_budget = max(128, budget - system_tokens)
fixed_tokens, _ = estimate_prompt_tokens_chain(
self.provider,
spec.model,
system_messages,
spec.tools.get_definitions(),
)
remaining_budget = max(0, budget - max(system_tokens, fixed_tokens))
kept: list[dict[str, Any]] = []
kept_tokens = 0
for message in reversed(non_system):

View File

@ -16,6 +16,7 @@ There is **no** sub-agent orchestrator and **no** special WebSocket ``agent_ui``
from __future__ import annotations
from contextvars import ContextVar
from datetime import datetime
from typing import TYPE_CHECKING, Any
@ -45,15 +46,19 @@ class _GoalToolsMixin(ContextAware):
def __init__(self, sessions: SessionManager, bus: Any | None = None) -> None:
self._sessions = sessions
self._bus = bus
self._request_ctx: RequestContext | None = None
self._request_ctx: ContextVar[RequestContext | None] = ContextVar(
f"{self.__class__.__name__}_request_ctx",
default=None,
)
def set_context(self, ctx: RequestContext) -> None:
self._request_ctx = ctx
self._request_ctx.set(ctx)
def _session(self):
if self._request_ctx is None:
request_ctx = self._request_ctx.get()
if request_ctx is None:
return None
key = self._request_ctx.session_key
key = request_ctx.session_key
if not key:
return None
return self._sessions.get_or_create(key)
@ -61,7 +66,7 @@ class _GoalToolsMixin(ContextAware):
async def _publish_goal_state_ws(self, metadata: dict[str, Any]) -> None:
"""Fan-out authoritative goal snapshot for this WebSocket chat only."""
bus = self._bus
rc = self._request_ctx
rc = self._request_ctx.get()
if bus is None or rc is None or rc.channel != "websocket":
return
cid = (rc.chat_id or "").strip()
@ -224,4 +229,3 @@ class CompleteGoalTool(Tool, _GoalToolsMixin):
if tail:
return f"Goal marked complete ({ended}). Recap:\n{tail}"
return f"Goal marked complete ({ended})."

View File

@ -123,7 +123,7 @@ async def cmd_stop(ctx: CommandContext) -> OutboundMessage:
"""Cancel all active tasks and subagents for the session."""
loop = ctx.loop
msg = ctx.msg
total = await loop._cancel_active_tasks(msg.session_key)
total = await loop._cancel_active_tasks(ctx.key)
content = f"Stopped {total} task(s)." if total else "No active task to stop."
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content,

View File

@ -557,11 +557,20 @@ class LLMProvider(ABC):
if reasoning_effort is self._SENTINEL:
reasoning_effort = self.generation.reasoning_effort
has_streamed_content = False
async def _tracking_delta(text: str) -> None:
nonlocal has_streamed_content
if text:
has_streamed_content = True
if on_content_delta:
await on_content_delta(text)
kw: dict[str, Any] = dict(
messages=messages, tools=tools, model=model,
max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
on_content_delta=on_content_delta,
on_content_delta=_tracking_delta if on_content_delta is not None else None,
on_thinking_delta=on_thinking_delta,
on_tool_call_delta=on_tool_call_delta,
)
@ -571,6 +580,7 @@ class LLMProvider(ABC):
messages,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
should_retry_guard=lambda: not has_streamed_content,
)
async def chat_with_retry(
@ -717,6 +727,7 @@ class LLMProvider(ABC):
*,
retry_mode: str,
on_retry_wait: Callable[[str], Awaitable[None]] | None,
should_retry_guard: Callable[[], bool] | None = None,
) -> LLMResponse:
attempt = 0
delays = list(self._CHAT_RETRY_DELAYS)
@ -730,6 +741,11 @@ class LLMProvider(ABC):
if response.finish_reason != "error":
return response
last_response = response
if should_retry_guard is not None and not should_retry_guard():
logger.warning(
"LLM stream failed after content was emitted; skipping retry"
)
return response
error_key = ((response.content or "").strip().lower() or None)
if error_key and error_key == last_error_key:
identical_error_count += 1

View File

@ -105,6 +105,60 @@ def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch
assert trimmed[0]["role"] == "system"
non_system = [m for m in trimmed if m["role"] != "system"]
assert non_system[0]["role"] == "user", f"Expected user after system, got {non_system[0]['role']}"
def test_snip_history_reserves_budget_for_tool_definitions(monkeypatch):
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
tools = MagicMock()
tools.get_definitions.return_value = [{"type": "function", "function": {"name": "large_tool"}}]
runner = AgentRunner(provider)
messages = [
{"role": "system", "content": "system"},
{"role": "user", "content": "old user"},
{"role": "assistant", "content": "old assistant"},
{"role": "user", "content": "recent one"},
{"role": "assistant", "content": "recent answer"},
{"role": "user", "content": "recent two"},
]
spec = AgentRunSpec(
initial_messages=messages,
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
context_window_tokens=2000,
context_block_limit=500,
)
def _estimate(_provider, _model, estimate_messages, estimate_tools):
if estimate_messages == messages:
return 1000, None
assert estimate_messages == [{"role": "system", "content": "system"}]
assert estimate_tools == tools.get_definitions.return_value
return 350, None
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", _estimate)
token_sizes = {
"system": 50,
"old user": 200,
"old assistant": 200,
"recent one": 200,
"recent answer": 200,
"recent two": 200,
}
monkeypatch.setattr(
"nanobot.agent.runner.estimate_message_tokens",
lambda msg: token_sizes.get(str(msg.get("content")), 40),
)
trimmed = runner._snip_history(spec, messages)
contents = [message.get("content") for message in trimmed]
assert contents == ["system", "recent two"]
async def test_backfill_missing_tool_results_inserts_error():
"""Orphaned tool_use (no matching tool_result) should get a synthetic error."""
from nanobot.agent.runner import AgentRunner, _BACKFILL_CONTENT

View File

@ -554,6 +554,31 @@ async def test_pending_queue_cleanup_on_dispatch(tmp_path):
assert msg.session_key not in loop._pending_queues
@pytest.mark.asyncio
async def test_waiting_dispatch_does_not_replace_active_pending_queue(tmp_path):
"""A queued dispatch must not steal the active task's injection queue."""
from nanobot.bus.events import InboundMessage
loop = _make_loop(tmp_path)
session_key = "cli:c"
lock = loop._session_locks.setdefault(session_key, asyncio.Lock())
await lock.acquire()
active_pending = asyncio.Queue(maxsize=1)
loop._pending_queues[session_key] = active_pending
waiting = asyncio.create_task(
loop._dispatch(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="queued"))
)
await asyncio.sleep(0.05)
assert loop._pending_queues[session_key] is active_pending
waiting.cancel()
with pytest.raises(asyncio.CancelledError):
await waiting
lock.release()
@pytest.mark.asyncio
async def test_followup_routed_to_pending_queue(tmp_path):
"""Unified-session follow-ups should route into the active pending queue."""

View File

@ -474,6 +474,32 @@ class TestStopCommandWithUnifiedSession:
assert task.cancelled() or task.done()
assert "Stopped 1 task" in result.content
@pytest.mark.asyncio
async def test_stop_command_uses_effective_key_without_session_override(self, tmp_path: Path):
"""Priority /stop must cancel the unified session even before dispatch rewrites the message."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.command.builtin import cmd_stop
loop = _make_loop(tmp_path, unified_session=True)
async def long_running():
await asyncio.sleep(10)
task = asyncio.create_task(long_running())
loop._active_tasks[UNIFIED_SESSION_KEY] = [task]
msg = InboundMessage(
channel="telegram",
chat_id="123456",
sender_id="user1",
content="/stop",
)
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
result = await cmd_stop(ctx)
assert task.cancelled() or task.done()
assert "Stopped 1 task" in result.content
@pytest.mark.asyncio
async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path):
"""In unified mode, /stop from one channel cancels tasks from another channel."""
@ -504,4 +530,4 @@ class TestStopCommandWithUnifiedSession:
result = await cmd_stop(ctx)
# Both tasks should be cancelled
assert "Stopped 2 task" in result.content
assert "Stopped 2 task" in result.content

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -72,6 +73,33 @@ async def test_complete_goal_closes_active_goal(tmp_path):
assert blob["recap"] == "Done."
@pytest.mark.asyncio
async def test_goal_tools_keep_request_context_per_task(tmp_path):
sm = SessionManager(tmp_path)
lt = LongTaskTool(sessions=sm)
cg = CompleteGoalTool(sessions=sm)
ctx_a = RequestContext(channel="websocket", chat_id="a", session_key="websocket:a")
ctx_b = RequestContext(channel="websocket", chat_id="b", session_key="websocket:b")
lt.set_context(ctx_a)
task_a = asyncio.create_task(lt.execute(goal="Goal A"))
lt.set_context(ctx_b)
task_b = asyncio.create_task(lt.execute(goal="Goal B"))
await asyncio.gather(task_a, task_b)
assert sm.get_or_create("websocket:a").metadata[GOAL_STATE_KEY]["objective"] == "Goal A"
assert sm.get_or_create("websocket:b").metadata[GOAL_STATE_KEY]["objective"] == "Goal B"
cg.set_context(ctx_a)
done_a = asyncio.create_task(cg.execute(recap="Done A"))
cg.set_context(ctx_b)
done_b = asyncio.create_task(cg.execute(recap="Done B"))
await asyncio.gather(done_a, done_b)
assert sm.get_or_create("websocket:a").metadata[GOAL_STATE_KEY]["recap"] == "Done A"
assert sm.get_or_create("websocket:b").metadata[GOAL_STATE_KEY]["recap"] == "Done B"
@pytest.mark.asyncio
async def test_long_task_publishes_goal_state_ws_after_save(tmp_path):
bus = MagicMock()

View File

@ -21,6 +21,17 @@ class ScriptedProvider(LLMProvider):
raise response
return response
async def chat_stream(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
self.last_kwargs = kwargs
response = self._responses.pop(0)
if isinstance(response, BaseException):
raise response
delta = getattr(response, "_test_stream_delta", None)
if delta and kwargs.get("on_content_delta"):
await kwargs["on_content_delta"](delta)
return response
def get_default_model(self) -> str:
return "test-model"
@ -122,6 +133,36 @@ async def test_chat_with_retry_preserves_cancelled_error() -> None:
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
@pytest.mark.asyncio
async def test_chat_stream_with_retry_does_not_retry_after_emitting_content(monkeypatch) -> None:
first = LLMResponse(content="stream stalled", finish_reason="error")
first._test_stream_delta = "partial" # type: ignore[attr-defined]
provider = ScriptedProvider([
first,
LLMResponse(content="ok"),
])
deltas: list[str] = []
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
async def _on_delta(delta: str) -> None:
deltas.append(delta)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_stream_with_retry(
messages=[{"role": "user", "content": "hello"}],
on_content_delta=_on_delta,
)
assert response.content == "stream stalled"
assert provider.calls == 1
assert deltas == ["partial"]
assert delays == []
@pytest.mark.asyncio
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
"""When callers omit generation params, provider.generation defaults are used."""