diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index bc83cc77c..f7afbe901 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -43,6 +43,7 @@ if TYPE_CHECKING: UNIFIED_SESSION_KEY = "unified:default" + class _LoopHook(AgentHook): """Core hook for the main loop.""" @@ -76,7 +77,7 @@ class _LoopHook(AgentHook): prev_clean = strip_think(self._stream_buf) self._stream_buf += delta new_clean = strip_think(self._stream_buf) - incremental = new_clean[len(prev_clean):] + incremental = new_clean[len(prev_clean) :] if incremental and self._on_stream: await self._on_stream(incremental) @@ -112,6 +113,7 @@ class _LoopHook(AgentHook): def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: return self._loop._strip_think(content) + class AgentLoop: """ The agent loop is the core processing engine. @@ -196,7 +198,7 @@ class AgentLoop: self._unified_session = unified_session self._running = False self._mcp_servers = mcp_servers or {} - self._mcp_stack: AsyncExitStack | None = None + self._mcp_stacks: dict[str, AsyncExitStack] = {} self._mcp_connected = False self._mcp_connecting = False self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks @@ -228,24 +230,34 @@ class AgentLoop: def _register_default_tools(self) -> None: """Register the default set of tools.""" - allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None + allowed_dir = ( + self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None + ) extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None - self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read)) + self.tools.register( + ReadFileTool( + workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read + ) + ) for cls in (WriteFileTool, EditFileTool, ListDirTool): self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) for cls in (GlobTool, GrepTool): self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) if self.exec_config.enable: - self.tools.register(ExecTool( - working_dir=str(self.workspace), - timeout=self.exec_config.timeout, - restrict_to_workspace=self.restrict_to_workspace, - sandbox=self.exec_config.sandbox, - path_append=self.exec_config.path_append, - allowed_env_keys=self.exec_config.allowed_env_keys, - )) + self.tools.register( + ExecTool( + working_dir=str(self.workspace), + timeout=self.exec_config.timeout, + restrict_to_workspace=self.restrict_to_workspace, + sandbox=self.exec_config.sandbox, + path_append=self.exec_config.path_append, + allowed_env_keys=self.exec_config.allowed_env_keys, + ) + ) if self.web_config.enable: - self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)) + self.tools.register( + WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy) + ) self.tools.register(WebFetchTool(proxy=self.web_config.proxy)) self.tools.register(MessageTool(send_callback=self.bus.publish_outbound)) self.tools.register(SpawnTool(manager=self.subagents)) @@ -260,19 +272,16 @@ class AgentLoop: return self._mcp_connecting = True from nanobot.agent.tools.mcp import connect_mcp_servers + try: - self._mcp_stack = AsyncExitStack() - await self._mcp_stack.__aenter__() - await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack) + self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools) self._mcp_connected = True + except asyncio.CancelledError: + logger.warning("MCP connection cancelled (will retry next message)") + self._mcp_stacks.clear() except BaseException as e: logger.error("Failed to connect MCP servers (will retry next message): {}", e) - if self._mcp_stack: - try: - await self._mcp_stack.aclose() - except Exception: - pass - self._mcp_stack = None + self._mcp_stacks.clear() finally: self._mcp_connecting = False @@ -289,6 +298,7 @@ class AgentLoop: if not text: return None from nanobot.utils.helpers import strip_think + return strip_think(text) or None @staticmethod @@ -327,9 +337,7 @@ class AgentLoop: message_id=message_id, ) hook: AgentHook = ( - CompositeHook([loop_hook] + self._extra_hooks) - if self._extra_hooks - else loop_hook + CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook ) async def _checkpoint(payload: dict[str, Any]) -> None: @@ -337,23 +345,25 @@ class AgentLoop: return self._set_runtime_checkpoint(session, payload) - result = await self.runner.run(AgentRunSpec( - initial_messages=initial_messages, - tools=self.tools, - model=self.model, - max_iterations=self.max_iterations, - max_tool_result_chars=self.max_tool_result_chars, - hook=hook, - error_message="Sorry, I encountered an error calling the AI model.", - concurrent_tools=True, - workspace=self.workspace, - session_key=session.key if session else None, - context_window_tokens=self.context_window_tokens, - context_block_limit=self.context_block_limit, - provider_retry_mode=self.provider_retry_mode, - progress_callback=on_progress, - checkpoint_callback=_checkpoint, - )) + result = await self.runner.run( + AgentRunSpec( + initial_messages=initial_messages, + tools=self.tools, + model=self.model, + max_iterations=self.max_iterations, + max_tool_result_chars=self.max_tool_result_chars, + hook=hook, + error_message="Sorry, I encountered an error calling the AI model.", + concurrent_tools=True, + workspace=self.workspace, + session_key=session.key if session else None, + context_window_tokens=self.context_window_tokens, + context_block_limit=self.context_block_limit, + provider_retry_mode=self.provider_retry_mode, + progress_callback=on_progress, + checkpoint_callback=_checkpoint, + ) + ) self._last_usage = result.usage if result.stop_reason == "max_iterations": logger.warning("Max iterations ({}) reached", self.max_iterations) @@ -391,10 +401,19 @@ class AgentLoop: continue # Compute the effective session key before dispatching # This ensures /stop command can find tasks correctly when unified session is enabled - effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key + effective_key = ( + UNIFIED_SESSION_KEY + if self._unified_session and not msg.session_key_override + else msg.session_key + ) task = asyncio.create_task(self._dispatch(msg)) self._active_tasks.setdefault(effective_key, []).append(task) - task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) + task.add_done_callback( + lambda t, k=effective_key: self._active_tasks.get(k, []) + and self._active_tasks[k].remove(t) + if t in self._active_tasks.get(k, []) + else None + ) async def _dispatch(self, msg: InboundMessage) -> None: """Process a message: per-session serial, cross-session concurrent.""" @@ -417,11 +436,14 @@ class AgentLoop: meta = dict(msg.metadata or {}) meta["_stream_delta"] = True meta["_stream_id"] = _current_stream_id() - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content=delta, - metadata=meta, - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content=delta, + metadata=meta, + ) + ) async def on_stream_end(*, resuming: bool = False) -> None: nonlocal stream_segment @@ -429,44 +451,56 @@ class AgentLoop: meta["_stream_end"] = True meta["_resuming"] = resuming meta["_stream_id"] = _current_stream_id() - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="", - metadata=meta, - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="", + metadata=meta, + ) + ) stream_segment += 1 response = await self._process_message( - msg, on_stream=on_stream, on_stream_end=on_stream_end, + msg, + on_stream=on_stream, + on_stream_end=on_stream_end, ) if response is not None: await self.bus.publish_outbound(response) elif msg.channel == "cli": - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="", metadata=msg.metadata or {}, - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="", + metadata=msg.metadata or {}, + ) + ) except asyncio.CancelledError: logger.info("Task cancelled for session {}", msg.session_key) raise except Exception: logger.exception("Error processing message for session {}", msg.session_key) - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="Sorry, I encountered an error.", - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="Sorry, I encountered an error.", + ) + ) async def close_mcp(self) -> None: """Drain pending background archives, then close MCP connections.""" if self._background_tasks: await asyncio.gather(*self._background_tasks, return_exceptions=True) self._background_tasks.clear() - if self._mcp_stack: + for name, stack in self._mcp_stacks.items(): try: - await self._mcp_stack.aclose() + await stack.aclose() except (RuntimeError, BaseExceptionGroup): - pass # MCP SDK cancel scope cleanup is noisy but harmless - self._mcp_stack = None + logger.debug("MCP server '{}' cleanup error (can be ignored)", name) + self._mcp_stacks.clear() def _schedule_background(self, coro) -> None: """Schedule a coroutine as a tracked background task (drained on shutdown).""" @@ -490,8 +524,9 @@ class AgentLoop: """Process a single inbound message and return the response.""" # System messages: parse origin from chat_id ("channel:chat_id") if msg.channel == "system": - channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id - else ("cli", msg.chat_id)) + channel, chat_id = ( + msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id) + ) logger.info("Processing system message from {}", msg.sender_id) key = f"{channel}:{chat_id}" session = self.sessions.get_or_create(key) @@ -503,19 +538,27 @@ class AgentLoop: current_role = "assistant" if msg.sender_id == "subagent" else "user" messages = self.context.build_messages( history=history, - current_message=msg.content, channel=channel, chat_id=chat_id, + current_message=msg.content, + channel=channel, + chat_id=chat_id, current_role=current_role, ) final_content, _, all_msgs, _ = await self._run_agent_loop( - messages, session=session, channel=channel, chat_id=chat_id, + messages, + session=session, + channel=channel, + chat_id=chat_id, message_id=msg.metadata.get("message_id"), ) self._save_turn(session, all_msgs, 1 + len(history)) self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) - return OutboundMessage(channel=channel, chat_id=chat_id, - content=final_content or "Background task completed.") + return OutboundMessage( + channel=channel, + chat_id=chat_id, + content=final_content or "Background task completed.", + ) preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview) @@ -543,16 +586,22 @@ class AgentLoop: history=history, current_message=msg.content, media=msg.media if msg.media else None, - channel=msg.channel, chat_id=msg.chat_id, + channel=msg.channel, + chat_id=msg.chat_id, ) async def _bus_progress(content: str, *, tool_hint: bool = False) -> None: meta = dict(msg.metadata or {}) meta["_progress"] = True meta["_tool_hint"] = tool_hint - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta, - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content=content, + metadata=meta, + ) + ) final_content, _, all_msgs, stop_reason = await self._run_agent_loop( initial_messages, @@ -560,7 +609,8 @@ class AgentLoop: on_stream=on_stream, on_stream_end=on_stream_end, session=session, - channel=msg.channel, chat_id=msg.chat_id, + channel=msg.channel, + chat_id=msg.chat_id, message_id=msg.metadata.get("message_id"), ) @@ -582,7 +632,9 @@ class AgentLoop: if on_stream is not None and stop_reason != "error": meta["_streamed"] = True return OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, content=final_content, + channel=msg.channel, + chat_id=msg.chat_id, + content=final_content, metadata=meta, ) @@ -608,10 +660,9 @@ class AgentLoop: ): continue - if ( - block.get("type") == "image_url" - and block.get("image_url", {}).get("url", "").startswith("data:image/") - ): + if block.get("type") == "image_url" and block.get("image_url", {}).get( + "url", "" + ).startswith("data:image/"): path = (block.get("_meta") or {}).get("path", "") filtered.append({"type": "text", "text": image_placeholder_text(path)}) continue @@ -630,6 +681,7 @@ class AgentLoop: def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None: """Save new-turn messages into session, truncating large tool results.""" from datetime import datetime + for m in messages[skip:]: entry = dict(m) role, content = entry.get("role"), entry.get("content") @@ -644,7 +696,9 @@ class AgentLoop: continue entry["content"] = filtered elif role == "user": - if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG): + if isinstance(content, str) and content.startswith( + ContextBuilder._RUNTIME_CONTEXT_TAG + ): # Strip the runtime-context prefix, keep only the user text. parts = content.split("\n\n", 1) if len(parts) > 1 and parts[1].strip(): @@ -708,13 +762,15 @@ class AgentLoop: continue tool_id = tool_call.get("id") name = ((tool_call.get("function") or {}).get("name")) or "tool" - restored_messages.append({ - "role": "tool", - "tool_call_id": tool_id, - "name": name, - "content": "Error: Task interrupted before this tool finished.", - "timestamp": datetime.now().isoformat(), - }) + restored_messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "name": name, + "content": "Error: Task interrupted before this tool finished.", + "timestamp": datetime.now().isoformat(), + } + ) overlap = 0 max_overlap = min(len(session.messages), len(restored_messages)) @@ -746,6 +802,9 @@ class AgentLoop: await self._connect_mcp() msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content) return await self._process_message( - msg, session_key=session_key, on_progress=on_progress, - on_stream=on_stream, on_stream_end=on_stream_end, + msg, + session_key=session_key, + on_progress=on_progress, + on_stream=on_stream, + on_stream_end=on_stream_end, ) diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index 86df9d744..1b5a71322 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]: if "properties" in normalized and isinstance(normalized["properties"], dict): normalized["properties"] = { - name: _normalize_schema_for_openai(prop) - if isinstance(prop, dict) - else prop + name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop for name, prop in normalized["properties"].items() } @@ -138,9 +136,7 @@ class MCPToolWrapper(Tool): class MCPResourceWrapper(Tool): """Wraps an MCP resource URI as a read-only nanobot Tool.""" - def __init__( - self, session, server_name: str, resource_def, resource_timeout: int = 30 - ): + def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30): self._session = session self._uri = resource_def.uri self._name = f"mcp_{server_name}_resource_{resource_def.name}" @@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool): class MCPPromptWrapper(Tool): """Wraps an MCP prompt as a read-only nanobot Tool.""" - def __init__( - self, session, server_name: str, prompt_def, prompt_timeout: int = 30 - ): + def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30): self._session = session self._prompt_name = prompt_def.name self._name = f"mcp_{server_name}_prompt_{prompt_def.name}" @@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool): timeout=self._prompt_timeout, ) except asyncio.TimeoutError: - logger.warning( - "MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout - ) + logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout) return f"(MCP prompt call timed out after {self._prompt_timeout}s)" except asyncio.CancelledError: task = asyncio.current_task() @@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool): except McpError as exc: logger.error( "MCP prompt '{}' failed: code={} message={}", - self._name, exc.error.code, exc.error.message, + self._name, + exc.error.code, + exc.error.message, ) return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])" except Exception as exc: logger.exception( "MCP prompt '{}' failed: {}: {}", - self._name, type(exc).__name__, exc, + self._name, + type(exc).__name__, + exc, ) return f"(MCP prompt call failed: {type(exc).__name__})" @@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool): async def connect_mcp_servers( - mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack -) -> None: - """Connect to configured MCP servers and register their tools, resources, and prompts.""" + mcp_servers: dict, registry: ToolRegistry +) -> dict[str, AsyncExitStack]: + """Connect to configured MCP servers and register their tools, resources, prompts. + + Returns a dict mapping server name -> its dedicated AsyncExitStack. + Each server gets its own stack and runs in its own task to prevent + cancel scope conflicts when multiple MCP servers are configured. + """ from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamable_http_client - for name, cfg in mcp_servers.items(): + async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]: + server_stack = AsyncExitStack() + await server_stack.__aenter__() + try: transport_type = cfg.type if not transport_type: if cfg.command: transport_type = "stdio" elif cfg.url: - # Convention: URLs ending with /sse use SSE transport; others use streamableHttp transport_type = ( "sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp" ) else: logger.warning("MCP server '{}': no command or url configured, skipping", name) - continue + await server_stack.aclose() + return name, None if transport_type == "stdio": params = StdioServerParameters( command=cfg.command, args=cfg.args, env=cfg.env or None ) - read, write = await stack.enter_async_context(stdio_client(params)) + read, write = await server_stack.enter_async_context(stdio_client(params)) elif transport_type == "sse": + def httpx_client_factory( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, @@ -353,27 +358,26 @@ async def connect_mcp_servers( auth=auth, ) - read, write = await stack.enter_async_context( + read, write = await server_stack.enter_async_context( sse_client(cfg.url, httpx_client_factory=httpx_client_factory) ) elif transport_type == "streamableHttp": - # Always provide an explicit httpx client so MCP HTTP transport does not - # inherit httpx's default 5s timeout and preempt the higher-level tool timeout. - http_client = await stack.enter_async_context( + http_client = await server_stack.enter_async_context( httpx.AsyncClient( headers=cfg.headers or None, follow_redirects=True, timeout=None, ) ) - read, write, _ = await stack.enter_async_context( + read, write, _ = await server_stack.enter_async_context( streamable_http_client(cfg.url, http_client=http_client) ) else: logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type) - continue + await server_stack.aclose() + return name, None - session = await stack.enter_async_context(ClientSession(read, write)) + session = await server_stack.enter_async_context(ClientSession(read, write)) await session.initialize() tools = await session.list_tools() @@ -418,7 +422,6 @@ async def connect_mcp_servers( ", ".join(available_wrapped_names) or "(none)", ) - # --- Register resources --- try: resources_result = await session.list_resources() for resource in resources_result.resources: @@ -433,7 +436,6 @@ async def connect_mcp_servers( except Exception as e: logger.debug("MCP server '{}': resources not supported or failed: {}", name, e) - # --- Register prompts --- try: prompts_result = await session.list_prompts() for prompt in prompts_result.prompts: @@ -442,14 +444,38 @@ async def connect_mcp_servers( ) registry.register(wrapper) registered_count += 1 - logger.debug( - "MCP: registered prompt '{}' from server '{}'", wrapper.name, name - ) + logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name) except Exception as e: logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e) logger.info( "MCP server '{}': connected, {} capabilities registered", name, registered_count ) + return name, server_stack + except Exception as e: logger.error("MCP server '{}': failed to connect: {}", name, e) + try: + await server_stack.aclose() + except Exception: + pass + return name, None + + server_stacks: dict[str, AsyncExitStack] = {} + + tasks: list[asyncio.Task] = [] + for name, cfg in mcp_servers.items(): + task = asyncio.create_task(connect_single_server(name, cfg)) + tasks.append(task) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, result in enumerate(results): + name = list(mcp_servers.keys())[i] + if isinstance(result, BaseException): + if not isinstance(result, asyncio.CancelledError): + logger.error("MCP server '{}' connection task failed: {}", name, result) + elif result is not None and result[1] is not None: + server_stacks[result[0]] = result[1] + + return server_stacks diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 67382d9ea..adeb78e75 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names( ) -> None: fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake", enabled_tools=["demo"])}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake", enabled_tools=["demo"])}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == ["mcp_test_demo"] @@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all( ) -> None: fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake")}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake")}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"] @@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names( ) -> None: fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == ["mcp_test_demo"] @@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none( ) -> None: fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake", enabled_tools=[])}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake", enabled_tools=[])}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == [] @@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries( monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning) - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == [] @@ -389,9 +369,7 @@ def _make_resource_def( return SimpleNamespace(name=name, uri=uri, description=description) -def _make_resource_wrapper( - session: object, *, timeout: float = 0.1 -) -> MCPResourceWrapper: +def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper: return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout) @@ -434,9 +412,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None: await asyncio.sleep(1) return SimpleNamespace(contents=[]) - wrapper = _make_resource_wrapper( - SimpleNamespace(read_resource=read_resource), timeout=0.01 - ) + wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01) result = await wrapper.execute() assert result == "(MCP resource read timed out after 0.01s)" @@ -464,20 +440,14 @@ def _make_prompt_def( return SimpleNamespace(name=name, description=description, arguments=arguments) -def _make_prompt_wrapper( - session: object, *, timeout: float = 0.1 -) -> MCPPromptWrapper: - return MCPPromptWrapper( - session, "srv", _make_prompt_def(), prompt_timeout=timeout - ) +def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper: + return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout) def test_prompt_wrapper_properties() -> None: arg1 = SimpleNamespace(name="topic", required=True) arg2 = SimpleNamespace(name="style", required=False) - wrapper = MCPPromptWrapper( - None, "myserver", _make_prompt_def(arguments=[arg1, arg2]) - ) + wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2])) assert wrapper.name == "mcp_myserver_prompt_myprompt" assert "[MCP Prompt]" in wrapper.description assert "A test prompt" in wrapper.description @@ -528,9 +498,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None: await asyncio.sleep(1) return SimpleNamespace(messages=[]) - wrapper = _make_prompt_wrapper( - SimpleNamespace(get_prompt=get_prompt), timeout=0.01 - ) + wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01) result = await wrapper.execute() assert result == "(MCP prompt call timed out after 0.01s)" @@ -616,15 +584,11 @@ async def test_connect_registers_resources_and_prompts( prompt_names=["prompt_c"], ) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake")}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake")}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert "mcp_test_tool_a" in registry.tool_names