From 0df60416ba70d3a70f37a694448cbd2214aedc94 Mon Sep 17 00:00:00 2001 From: hamb1y Date: Thu, 28 May 2026 16:39:48 +0530 Subject: [PATCH] fix(agent): address session and streaming concurrency bugs --- nanobot/agent/loop.py | 69 +++++++++++++++----------- nanobot/agent/runner.py | 8 ++- nanobot/agent/tools/long_task.py | 16 +++--- nanobot/command/builtin.py | 2 +- nanobot/providers/base.py | 18 ++++++- tests/agent/test_runner_governance.py | 54 ++++++++++++++++++++ tests/agent/test_runner_injections.py | 25 ++++++++++ tests/agent/test_unified_session.py | 28 ++++++++++- tests/agent/tools/test_long_task.py | 28 +++++++++++ tests/providers/test_provider_retry.py | 41 +++++++++++++++ 10 files changed, 250 insertions(+), 39 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 1ceb635c7..5a0985cbd 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -812,16 +812,16 @@ class AgentLoop: logger.warning("Error consuming inbound message: {}, continuing...", e) continue + raw = msg.content.strip() + effective_key = self._effective_session_key(msg) if await agent_context.handle_runtime_control(self, msg, self.tools): continue - raw = msg.content.strip() if self.commands.is_priority(raw): await self._dispatch_command_inline( - msg, msg.session_key, raw, + msg, effective_key, raw, self.commands.dispatch_priority, ) continue - effective_key = self._effective_session_key(msg) # If this session already has an active pending queue (i.e. a task # is processing this session), route the message there for mid-turn # injection instead of creating a competing task. @@ -872,13 +872,13 @@ class AgentLoop: lock = self._session_locks.setdefault(session_key, asyncio.Lock()) gate = self._concurrency_gate or nullcontext() - # Register a pending queue so follow-up messages for this session are - # routed here (mid-turn injection) instead of spawning a new task. - pending = asyncio.Queue(maxsize=20) - self._pending_queues[session_key] = pending - + pending: asyncio.Queue | None = None try: async with lock, gate: + # Only the task that owns the session lock may publish the + # active mid-turn injection queue for this session. + pending = asyncio.Queue(maxsize=20) + self._pending_queues[session_key] = pending try: on_stream = on_stream_end = None if msg.metadata.get("_wants_stream"): @@ -962,28 +962,39 @@ class AgentLoop: channel=msg.channel, chat_id=msg.chat_id, content="Sorry, I encountered an error.", )) + finally: + # Drain any messages still in the pending queue and re-publish + # them to the bus so they are processed as fresh inbound messages + # rather than silently lost. Only remove our own queue; a + # later task waiting on the lock must not be able to steal + # cleanup ownership. + queue = None + if self._pending_queues.get(session_key) is pending: + queue = self._pending_queues.pop(session_key, None) + else: + queue = pending + if queue is not None: + leftover = 0 + while True: + try: + item = queue.get_nowait() + except asyncio.QueueEmpty: + break + await self.bus.publish_inbound(item) + leftover += 1 + if leftover: + logger.info( + "Re-published {} leftover message(s) to bus for session {}", + leftover, session_key, + ) + await self._webui_turns.publish_run_status(msg, "idle") + self._pending_turn_latency_ms.pop(session_key, None) + self._webui_turns.discard(session_key) finally: - # Drain any messages still in the pending queue and re-publish - # them to the bus so they are processed as fresh inbound messages - # rather than silently lost. - queue = self._pending_queues.pop(session_key, None) - if queue is not None: - leftover = 0 - while True: - try: - item = queue.get_nowait() - except asyncio.QueueEmpty: - break - await self.bus.publish_inbound(item) - leftover += 1 - if leftover: - logger.info( - "Re-published {} leftover message(s) to bus for session {}", - leftover, session_key, - ) - await self._webui_turns.publish_run_status(msg, "idle") - self._pending_turn_latency_ms.pop(session_key, None) - self._webui_turns.discard(session_key) + if pending is None: + await self._webui_turns.publish_run_status(msg, "idle") + self._pending_turn_latency_ms.pop(session_key, None) + self._webui_turns.discard(session_key) async def close_mcp(self) -> None: """Drain pending background archives, then close MCP connections.""" diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 1991c7004..cb70116b6 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -1273,7 +1273,13 @@ class AgentRunner: return messages system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages) - remaining_budget = max(128, budget - system_tokens) + fixed_tokens, _ = estimate_prompt_tokens_chain( + self.provider, + spec.model, + system_messages, + spec.tools.get_definitions(), + ) + remaining_budget = max(0, budget - max(system_tokens, fixed_tokens)) kept: list[dict[str, Any]] = [] kept_tokens = 0 for message in reversed(non_system): diff --git a/nanobot/agent/tools/long_task.py b/nanobot/agent/tools/long_task.py index 0d1650cd1..fa13cbc5a 100644 --- a/nanobot/agent/tools/long_task.py +++ b/nanobot/agent/tools/long_task.py @@ -16,6 +16,7 @@ There is **no** sub-agent orchestrator and **no** special WebSocket ``agent_ui`` from __future__ import annotations +from contextvars import ContextVar from datetime import datetime from typing import TYPE_CHECKING, Any @@ -45,15 +46,19 @@ class _GoalToolsMixin(ContextAware): def __init__(self, sessions: SessionManager, bus: Any | None = None) -> None: self._sessions = sessions self._bus = bus - self._request_ctx: RequestContext | None = None + self._request_ctx: ContextVar[RequestContext | None] = ContextVar( + f"{self.__class__.__name__}_request_ctx", + default=None, + ) def set_context(self, ctx: RequestContext) -> None: - self._request_ctx = ctx + self._request_ctx.set(ctx) def _session(self): - if self._request_ctx is None: + request_ctx = self._request_ctx.get() + if request_ctx is None: return None - key = self._request_ctx.session_key + key = request_ctx.session_key if not key: return None return self._sessions.get_or_create(key) @@ -61,7 +66,7 @@ class _GoalToolsMixin(ContextAware): async def _publish_goal_state_ws(self, metadata: dict[str, Any]) -> None: """Fan-out authoritative goal snapshot for this WebSocket chat only.""" bus = self._bus - rc = self._request_ctx + rc = self._request_ctx.get() if bus is None or rc is None or rc.channel != "websocket": return cid = (rc.chat_id or "").strip() @@ -224,4 +229,3 @@ class CompleteGoalTool(Tool, _GoalToolsMixin): if tail: return f"Goal marked complete ({ended}). Recap:\n{tail}" return f"Goal marked complete ({ended})." - diff --git a/nanobot/command/builtin.py b/nanobot/command/builtin.py index 4646df38a..997b7ca16 100644 --- a/nanobot/command/builtin.py +++ b/nanobot/command/builtin.py @@ -123,7 +123,7 @@ async def cmd_stop(ctx: CommandContext) -> OutboundMessage: """Cancel all active tasks and subagents for the session.""" loop = ctx.loop msg = ctx.msg - total = await loop._cancel_active_tasks(msg.session_key) + total = await loop._cancel_active_tasks(ctx.key) content = f"Stopped {total} task(s)." if total else "No active task to stop." return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content=content, diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 87697650a..8bac5d4ba 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -557,11 +557,20 @@ class LLMProvider(ABC): if reasoning_effort is self._SENTINEL: reasoning_effort = self.generation.reasoning_effort + has_streamed_content = False + + async def _tracking_delta(text: str) -> None: + nonlocal has_streamed_content + if text: + has_streamed_content = True + if on_content_delta: + await on_content_delta(text) + kw: dict[str, Any] = dict( messages=messages, tools=tools, model=model, max_tokens=max_tokens, temperature=temperature, reasoning_effort=reasoning_effort, tool_choice=tool_choice, - on_content_delta=on_content_delta, + on_content_delta=_tracking_delta if on_content_delta is not None else None, on_thinking_delta=on_thinking_delta, on_tool_call_delta=on_tool_call_delta, ) @@ -571,6 +580,7 @@ class LLMProvider(ABC): messages, retry_mode=retry_mode, on_retry_wait=on_retry_wait, + should_retry_guard=lambda: not has_streamed_content, ) async def chat_with_retry( @@ -717,6 +727,7 @@ class LLMProvider(ABC): *, retry_mode: str, on_retry_wait: Callable[[str], Awaitable[None]] | None, + should_retry_guard: Callable[[], bool] | None = None, ) -> LLMResponse: attempt = 0 delays = list(self._CHAT_RETRY_DELAYS) @@ -730,6 +741,11 @@ class LLMProvider(ABC): if response.finish_reason != "error": return response last_response = response + if should_retry_guard is not None and not should_retry_guard(): + logger.warning( + "LLM stream failed after content was emitted; skipping retry" + ) + return response error_key = ((response.content or "").strip().lower() or None) if error_key and error_key == last_error_key: identical_error_count += 1 diff --git a/tests/agent/test_runner_governance.py b/tests/agent/test_runner_governance.py index 50e882ca6..901afc71e 100644 --- a/tests/agent/test_runner_governance.py +++ b/tests/agent/test_runner_governance.py @@ -105,6 +105,60 @@ def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch assert trimmed[0]["role"] == "system" non_system = [m for m in trimmed if m["role"] != "system"] assert non_system[0]["role"] == "user", f"Expected user after system, got {non_system[0]['role']}" + + +def test_snip_history_reserves_budget_for_tool_definitions(monkeypatch): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + tools = MagicMock() + tools.get_definitions.return_value = [{"type": "function", "function": {"name": "large_tool"}}] + runner = AgentRunner(provider) + messages = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + {"role": "assistant", "content": "old assistant"}, + {"role": "user", "content": "recent one"}, + {"role": "assistant", "content": "recent answer"}, + {"role": "user", "content": "recent two"}, + ] + spec = AgentRunSpec( + initial_messages=messages, + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + context_window_tokens=2000, + context_block_limit=500, + ) + + def _estimate(_provider, _model, estimate_messages, estimate_tools): + if estimate_messages == messages: + return 1000, None + assert estimate_messages == [{"role": "system", "content": "system"}] + assert estimate_tools == tools.get_definitions.return_value + return 350, None + + monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", _estimate) + token_sizes = { + "system": 50, + "old user": 200, + "old assistant": 200, + "recent one": 200, + "recent answer": 200, + "recent two": 200, + } + monkeypatch.setattr( + "nanobot.agent.runner.estimate_message_tokens", + lambda msg: token_sizes.get(str(msg.get("content")), 40), + ) + + trimmed = runner._snip_history(spec, messages) + + contents = [message.get("content") for message in trimmed] + assert contents == ["system", "recent two"] + + async def test_backfill_missing_tool_results_inserts_error(): """Orphaned tool_use (no matching tool_result) should get a synthetic error.""" from nanobot.agent.runner import AgentRunner, _BACKFILL_CONTENT diff --git a/tests/agent/test_runner_injections.py b/tests/agent/test_runner_injections.py index 1aa504e32..e00e9c86c 100644 --- a/tests/agent/test_runner_injections.py +++ b/tests/agent/test_runner_injections.py @@ -554,6 +554,31 @@ async def test_pending_queue_cleanup_on_dispatch(tmp_path): assert msg.session_key not in loop._pending_queues +@pytest.mark.asyncio +async def test_waiting_dispatch_does_not_replace_active_pending_queue(tmp_path): + """A queued dispatch must not steal the active task's injection queue.""" + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + session_key = "cli:c" + lock = loop._session_locks.setdefault(session_key, asyncio.Lock()) + await lock.acquire() + active_pending = asyncio.Queue(maxsize=1) + loop._pending_queues[session_key] = active_pending + + waiting = asyncio.create_task( + loop._dispatch(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="queued")) + ) + await asyncio.sleep(0.05) + + assert loop._pending_queues[session_key] is active_pending + + waiting.cancel() + with pytest.raises(asyncio.CancelledError): + await waiting + lock.release() + + @pytest.mark.asyncio async def test_followup_routed_to_pending_queue(tmp_path): """Unified-session follow-ups should route into the active pending queue.""" diff --git a/tests/agent/test_unified_session.py b/tests/agent/test_unified_session.py index f22290ba6..aa42c4e55 100644 --- a/tests/agent/test_unified_session.py +++ b/tests/agent/test_unified_session.py @@ -474,6 +474,32 @@ class TestStopCommandWithUnifiedSession: assert task.cancelled() or task.done() assert "Stopped 1 task" in result.content + @pytest.mark.asyncio + async def test_stop_command_uses_effective_key_without_session_override(self, tmp_path: Path): + """Priority /stop must cancel the unified session even before dispatch rewrites the message.""" + from nanobot.agent.loop import UNIFIED_SESSION_KEY + from nanobot.command.builtin import cmd_stop + + loop = _make_loop(tmp_path, unified_session=True) + + async def long_running(): + await asyncio.sleep(10) + + task = asyncio.create_task(long_running()) + loop._active_tasks[UNIFIED_SESSION_KEY] = [task] + msg = InboundMessage( + channel="telegram", + chat_id="123456", + sender_id="user1", + content="/stop", + ) + ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop) + + result = await cmd_stop(ctx) + + assert task.cancelled() or task.done() + assert "Stopped 1 task" in result.content + @pytest.mark.asyncio async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path): """In unified mode, /stop from one channel cancels tasks from another channel.""" @@ -504,4 +530,4 @@ class TestStopCommandWithUnifiedSession: result = await cmd_stop(ctx) # Both tasks should be cancelled - assert "Stopped 2 task" in result.content \ No newline at end of file + assert "Stopped 2 task" in result.content diff --git a/tests/agent/tools/test_long_task.py b/tests/agent/tools/test_long_task.py index 15c5f8db5..dd5a0d5d6 100644 --- a/tests/agent/tools/test_long_task.py +++ b/tests/agent/tools/test_long_task.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio from unittest.mock import AsyncMock, MagicMock import pytest @@ -72,6 +73,33 @@ async def test_complete_goal_closes_active_goal(tmp_path): assert blob["recap"] == "Done." +@pytest.mark.asyncio +async def test_goal_tools_keep_request_context_per_task(tmp_path): + sm = SessionManager(tmp_path) + lt = LongTaskTool(sessions=sm) + cg = CompleteGoalTool(sessions=sm) + ctx_a = RequestContext(channel="websocket", chat_id="a", session_key="websocket:a") + ctx_b = RequestContext(channel="websocket", chat_id="b", session_key="websocket:b") + + lt.set_context(ctx_a) + task_a = asyncio.create_task(lt.execute(goal="Goal A")) + lt.set_context(ctx_b) + task_b = asyncio.create_task(lt.execute(goal="Goal B")) + await asyncio.gather(task_a, task_b) + + assert sm.get_or_create("websocket:a").metadata[GOAL_STATE_KEY]["objective"] == "Goal A" + assert sm.get_or_create("websocket:b").metadata[GOAL_STATE_KEY]["objective"] == "Goal B" + + cg.set_context(ctx_a) + done_a = asyncio.create_task(cg.execute(recap="Done A")) + cg.set_context(ctx_b) + done_b = asyncio.create_task(cg.execute(recap="Done B")) + await asyncio.gather(done_a, done_b) + + assert sm.get_or_create("websocket:a").metadata[GOAL_STATE_KEY]["recap"] == "Done A" + assert sm.get_or_create("websocket:b").metadata[GOAL_STATE_KEY]["recap"] == "Done B" + + @pytest.mark.asyncio async def test_long_task_publishes_goal_state_ws_after_save(tmp_path): bus = MagicMock() diff --git a/tests/providers/test_provider_retry.py b/tests/providers/test_provider_retry.py index 4b72c163a..6fc2137df 100644 --- a/tests/providers/test_provider_retry.py +++ b/tests/providers/test_provider_retry.py @@ -21,6 +21,17 @@ class ScriptedProvider(LLMProvider): raise response return response + async def chat_stream(self, *args, **kwargs) -> LLMResponse: + self.calls += 1 + self.last_kwargs = kwargs + response = self._responses.pop(0) + if isinstance(response, BaseException): + raise response + delta = getattr(response, "_test_stream_delta", None) + if delta and kwargs.get("on_content_delta"): + await kwargs["on_content_delta"](delta) + return response + def get_default_model(self) -> str: return "test-model" @@ -122,6 +133,36 @@ async def test_chat_with_retry_preserves_cancelled_error() -> None: await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) +@pytest.mark.asyncio +async def test_chat_stream_with_retry_does_not_retry_after_emitting_content(monkeypatch) -> None: + first = LLMResponse(content="stream stalled", finish_reason="error") + first._test_stream_delta = "partial" # type: ignore[attr-defined] + provider = ScriptedProvider([ + first, + LLMResponse(content="ok"), + ]) + deltas: list[str] = [] + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + async def _on_delta(delta: str) -> None: + deltas.append(delta) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_stream_with_retry( + messages=[{"role": "user", "content": "hello"}], + on_content_delta=_on_delta, + ) + + assert response.content == "stream stalled" + assert provider.calls == 1 + assert deltas == ["partial"] + assert delays == [] + + @pytest.mark.asyncio async def test_chat_with_retry_uses_provider_generation_defaults() -> None: """When callers omit generation params, provider.generation defaults are used."""