diff --git a/nanobot/agent/tools/long_task.py b/nanobot/agent/tools/long_task.py index fa13cbc5a..a58d14ee5 100644 --- a/nanobot/agent/tools/long_task.py +++ b/nanobot/agent/tools/long_task.py @@ -46,6 +46,9 @@ class _GoalToolsMixin(ContextAware): def __init__(self, sessions: SessionManager, bus: Any | None = None) -> None: self._sessions = sessions self._bus = bus + # Each subclass gets its own ContextVar so concurrent tasks across + # different tool types (LongTaskTool vs CompleteGoalTool) do not + # interfere with each other. self._request_ctx: ContextVar[RequestContext | None] = ContextVar( f"{self.__class__.__name__}_request_ctx", default=None, diff --git a/tests/agent/test_runner_injections.py b/tests/agent/test_runner_injections.py index e00e9c86c..95cfc4f8d 100644 --- a/tests/agent/test_runner_injections.py +++ b/tests/agent/test_runner_injections.py @@ -566,10 +566,21 @@ async def test_waiting_dispatch_does_not_replace_active_pending_queue(tmp_path): active_pending = asyncio.Queue(maxsize=1) loop._pending_queues[session_key] = active_pending - waiting = asyncio.create_task( - loop._dispatch(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="queued")) - ) - await asyncio.sleep(0.05) + waiting_at_lock = asyncio.Event() + original_acquire = asyncio.Lock.acquire + + async def _patched_acquire(self, *args, **kwargs): + if self is lock: + waiting_at_lock.set() + return await original_acquire(self, *args, **kwargs) + + with patch.object(asyncio.Lock, "acquire", _patched_acquire): + waiting = asyncio.create_task( + loop._dispatch( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="queued") + ) + ) + await asyncio.wait_for(waiting_at_lock.wait(), timeout=2.0) assert loop._pending_queues[session_key] is active_pending diff --git a/tests/agent/tools/test_long_task.py b/tests/agent/tools/test_long_task.py index dd5a0d5d6..ef573a473 100644 --- a/tests/agent/tools/test_long_task.py +++ b/tests/agent/tools/test_long_task.py @@ -100,6 +100,22 @@ async def test_goal_tools_keep_request_context_per_task(tmp_path): assert sm.get_or_create("websocket:b").metadata[GOAL_STATE_KEY]["recap"] == "Done B" +@pytest.mark.asyncio +async def test_goal_tools_context_isolated_across_tool_types(tmp_path): + """LongTaskTool and CompleteGoalTool must not share routing context.""" + sm = SessionManager(tmp_path) + lt = LongTaskTool(sessions=sm) + cg = CompleteGoalTool(sessions=sm) + ctx = RequestContext(channel="websocket", chat_id="a", session_key="websocket:a") + + lt.set_context(ctx) + assert cg._request_ctx.get() is None + + cg.set_context(ctx) + assert lt._request_ctx.get() is ctx + assert cg._request_ctx.get() is ctx + + @pytest.mark.asyncio async def test_long_task_publishes_goal_state_ws_after_save(tmp_path): bus = MagicMock()