diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 05a27349f..56a662add 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: UNIFIED_SESSION_KEY = "unified:default" + class _LoopHook(AgentHook): """Core hook for the main loop.""" @@ -77,7 +78,7 @@ class _LoopHook(AgentHook): prev_clean = strip_think(self._stream_buf) self._stream_buf += delta new_clean = strip_think(self._stream_buf) - incremental = new_clean[len(prev_clean):] + incremental = new_clean[len(prev_clean) :] if incremental and self._on_stream: await self._on_stream(incremental) @@ -113,6 +114,7 @@ class _LoopHook(AgentHook): def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: return self._loop._strip_think(content) + class AgentLoop: """ The agent loop is the core processing engine. @@ -198,7 +200,7 @@ class AgentLoop: self._unified_session = unified_session self._running = False self._mcp_servers = mcp_servers or {} - self._mcp_stack: AsyncExitStack | None = None + self._mcp_stacks: dict[str, AsyncExitStack] = {} self._mcp_connected = False self._mcp_connecting = False self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks @@ -235,24 +237,34 @@ class AgentLoop: def _register_default_tools(self) -> None: """Register the default set of tools.""" - allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None + allowed_dir = ( + self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None + ) extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None - self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read)) + self.tools.register( + ReadFileTool( + workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read + ) + ) for cls in (WriteFileTool, EditFileTool, ListDirTool): self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) for cls in (GlobTool, GrepTool): self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) if self.exec_config.enable: - self.tools.register(ExecTool( - working_dir=str(self.workspace), - timeout=self.exec_config.timeout, - restrict_to_workspace=self.restrict_to_workspace, - sandbox=self.exec_config.sandbox, - path_append=self.exec_config.path_append, - allowed_env_keys=self.exec_config.allowed_env_keys, - )) + self.tools.register( + ExecTool( + working_dir=str(self.workspace), + timeout=self.exec_config.timeout, + restrict_to_workspace=self.restrict_to_workspace, + sandbox=self.exec_config.sandbox, + path_append=self.exec_config.path_append, + allowed_env_keys=self.exec_config.allowed_env_keys, + ) + ) if self.web_config.enable: - self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)) + self.tools.register( + WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy) + ) self.tools.register(WebFetchTool(proxy=self.web_config.proxy)) self.tools.register(MessageTool(send_callback=self.bus.publish_outbound)) self.tools.register(SpawnTool(manager=self.subagents)) @@ -267,19 +279,19 @@ class AgentLoop: return self._mcp_connecting = True from nanobot.agent.tools.mcp import connect_mcp_servers + try: - self._mcp_stack = AsyncExitStack() - await self._mcp_stack.__aenter__() - await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack) - self._mcp_connected = True + self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools) + if self._mcp_stacks: + self._mcp_connected = True + else: + logger.warning("No MCP servers connected successfully (will retry next message)") + except asyncio.CancelledError: + logger.warning("MCP connection cancelled (will retry next message)") + self._mcp_stacks.clear() except BaseException as e: logger.error("Failed to connect MCP servers (will retry next message): {}", e) - if self._mcp_stack: - try: - await self._mcp_stack.aclose() - except Exception: - pass - self._mcp_stack = None + self._mcp_stacks.clear() finally: self._mcp_connecting = False @@ -296,6 +308,7 @@ class AgentLoop: if not text: return None from nanobot.utils.helpers import strip_think + return strip_think(text) or None @staticmethod @@ -334,9 +347,7 @@ class AgentLoop: message_id=message_id, ) hook: AgentHook = ( - CompositeHook([loop_hook] + self._extra_hooks) - if self._extra_hooks - else loop_hook + CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook ) async def _checkpoint(payload: dict[str, Any]) -> None: @@ -344,23 +355,25 @@ class AgentLoop: return self._set_runtime_checkpoint(session, payload) - result = await self.runner.run(AgentRunSpec( - initial_messages=initial_messages, - tools=self.tools, - model=self.model, - max_iterations=self.max_iterations, - max_tool_result_chars=self.max_tool_result_chars, - hook=hook, - error_message="Sorry, I encountered an error calling the AI model.", - concurrent_tools=True, - workspace=self.workspace, - session_key=session.key if session else None, - context_window_tokens=self.context_window_tokens, - context_block_limit=self.context_block_limit, - provider_retry_mode=self.provider_retry_mode, - progress_callback=on_progress, - checkpoint_callback=_checkpoint, - )) + result = await self.runner.run( + AgentRunSpec( + initial_messages=initial_messages, + tools=self.tools, + model=self.model, + max_iterations=self.max_iterations, + max_tool_result_chars=self.max_tool_result_chars, + hook=hook, + error_message="Sorry, I encountered an error calling the AI model.", + concurrent_tools=True, + workspace=self.workspace, + session_key=session.key if session else None, + context_window_tokens=self.context_window_tokens, + context_block_limit=self.context_block_limit, + provider_retry_mode=self.provider_retry_mode, + progress_callback=on_progress, + checkpoint_callback=_checkpoint, + ) + ) self._last_usage = result.usage if result.stop_reason == "max_iterations": logger.warning("Max iterations ({}) reached", self.max_iterations) @@ -399,10 +412,19 @@ class AgentLoop: continue # Compute the effective session key before dispatching # This ensures /stop command can find tasks correctly when unified session is enabled - effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key + effective_key = ( + UNIFIED_SESSION_KEY + if self._unified_session and not msg.session_key_override + else msg.session_key + ) task = asyncio.create_task(self._dispatch(msg)) self._active_tasks.setdefault(effective_key, []).append(task) - task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) + task.add_done_callback( + lambda t, k=effective_key: self._active_tasks.get(k, []) + and self._active_tasks[k].remove(t) + if t in self._active_tasks.get(k, []) + else None + ) async def _dispatch(self, msg: InboundMessage) -> None: """Process a message: per-session serial, cross-session concurrent.""" @@ -425,11 +447,14 @@ class AgentLoop: meta = dict(msg.metadata or {}) meta["_stream_delta"] = True meta["_stream_id"] = _current_stream_id() - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content=delta, - metadata=meta, - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content=delta, + metadata=meta, + ) + ) async def on_stream_end(*, resuming: bool = False) -> None: nonlocal stream_segment @@ -437,44 +462,56 @@ class AgentLoop: meta["_stream_end"] = True meta["_resuming"] = resuming meta["_stream_id"] = _current_stream_id() - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="", - metadata=meta, - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="", + metadata=meta, + ) + ) stream_segment += 1 response = await self._process_message( - msg, on_stream=on_stream, on_stream_end=on_stream_end, + msg, + on_stream=on_stream, + on_stream_end=on_stream_end, ) if response is not None: await self.bus.publish_outbound(response) elif msg.channel == "cli": - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="", metadata=msg.metadata or {}, - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="", + metadata=msg.metadata or {}, + ) + ) except asyncio.CancelledError: logger.info("Task cancelled for session {}", msg.session_key) raise except Exception: logger.exception("Error processing message for session {}", msg.session_key) - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="Sorry, I encountered an error.", - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="Sorry, I encountered an error.", + ) + ) async def close_mcp(self) -> None: """Drain pending background archives, then close MCP connections.""" if self._background_tasks: await asyncio.gather(*self._background_tasks, return_exceptions=True) self._background_tasks.clear() - if self._mcp_stack: + for name, stack in self._mcp_stacks.items(): try: - await self._mcp_stack.aclose() + await stack.aclose() except (RuntimeError, BaseExceptionGroup): - pass # MCP SDK cancel scope cleanup is noisy but harmless - self._mcp_stack = None + logger.debug("MCP server '{}' cleanup error (can be ignored)", name) + self._mcp_stacks.clear() def _schedule_background(self, coro) -> None: """Schedule a coroutine as a tracked background task (drained on shutdown).""" @@ -498,8 +535,9 @@ class AgentLoop: """Process a single inbound message and return the response.""" # System messages: parse origin from chat_id ("channel:chat_id") if msg.channel == "system": - channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id - else ("cli", msg.chat_id)) + channel, chat_id = ( + msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id) + ) logger.info("Processing system message from {}", msg.sender_id) key = f"{channel}:{chat_id}" session = self.sessions.get_or_create(key) @@ -520,15 +558,21 @@ class AgentLoop: current_role=current_role, ) final_content, _, all_msgs, _ = await self._run_agent_loop( - messages, session=session, channel=channel, chat_id=chat_id, + messages, + session=session, + channel=channel, + chat_id=chat_id, message_id=msg.metadata.get("message_id"), ) self._save_turn(session, all_msgs, 1 + len(history)) self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) - return OutboundMessage(channel=channel, chat_id=chat_id, - content=final_content or "Background task completed.") + return OutboundMessage( + channel=channel, + chat_id=chat_id, + content=final_content or "Background task completed.", + ) preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview) @@ -560,16 +604,22 @@ class AgentLoop: current_message=msg.content, session_summary=pending, media=msg.media if msg.media else None, - channel=msg.channel, chat_id=msg.chat_id, + channel=msg.channel, + chat_id=msg.chat_id, ) async def _bus_progress(content: str, *, tool_hint: bool = False) -> None: meta = dict(msg.metadata or {}) meta["_progress"] = True meta["_tool_hint"] = tool_hint - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta, - )) + await self.bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content=content, + metadata=meta, + ) + ) final_content, _, all_msgs, stop_reason = await self._run_agent_loop( initial_messages, @@ -577,7 +627,8 @@ class AgentLoop: on_stream=on_stream, on_stream_end=on_stream_end, session=session, - channel=msg.channel, chat_id=msg.chat_id, + channel=msg.channel, + chat_id=msg.chat_id, message_id=msg.metadata.get("message_id"), ) @@ -599,7 +650,9 @@ class AgentLoop: if on_stream is not None and stop_reason != "error": meta["_streamed"] = True return OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, content=final_content, + channel=msg.channel, + chat_id=msg.chat_id, + content=final_content, metadata=meta, ) @@ -625,10 +678,9 @@ class AgentLoop: ): continue - if ( - block.get("type") == "image_url" - and block.get("image_url", {}).get("url", "").startswith("data:image/") - ): + if block.get("type") == "image_url" and block.get("image_url", {}).get( + "url", "" + ).startswith("data:image/"): path = (block.get("_meta") or {}).get("path", "") filtered.append({"type": "text", "text": image_placeholder_text(path)}) continue @@ -647,6 +699,7 @@ class AgentLoop: def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None: """Save new-turn messages into session, truncating large tool results.""" from datetime import datetime + for m in messages[skip:]: entry = dict(m) role, content = entry.get("role"), entry.get("content") @@ -736,13 +789,15 @@ class AgentLoop: continue tool_id = tool_call.get("id") name = ((tool_call.get("function") or {}).get("name")) or "tool" - restored_messages.append({ - "role": "tool", - "tool_call_id": tool_id, - "name": name, - "content": "Error: Task interrupted before this tool finished.", - "timestamp": datetime.now().isoformat(), - }) + restored_messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "name": name, + "content": "Error: Task interrupted before this tool finished.", + "timestamp": datetime.now().isoformat(), + } + ) overlap = 0 max_overlap = min(len(session.messages), len(restored_messages)) @@ -774,6 +829,9 @@ class AgentLoop: await self._connect_mcp() msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content) return await self._process_message( - msg, session_key=session_key, on_progress=on_progress, - on_stream=on_stream, on_stream_end=on_stream_end, + msg, + session_key=session_key, + on_progress=on_progress, + on_stream=on_stream, + on_stream_end=on_stream_end, ) diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index 86df9d744..1b5a71322 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]: if "properties" in normalized and isinstance(normalized["properties"], dict): normalized["properties"] = { - name: _normalize_schema_for_openai(prop) - if isinstance(prop, dict) - else prop + name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop for name, prop in normalized["properties"].items() } @@ -138,9 +136,7 @@ class MCPToolWrapper(Tool): class MCPResourceWrapper(Tool): """Wraps an MCP resource URI as a read-only nanobot Tool.""" - def __init__( - self, session, server_name: str, resource_def, resource_timeout: int = 30 - ): + def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30): self._session = session self._uri = resource_def.uri self._name = f"mcp_{server_name}_resource_{resource_def.name}" @@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool): class MCPPromptWrapper(Tool): """Wraps an MCP prompt as a read-only nanobot Tool.""" - def __init__( - self, session, server_name: str, prompt_def, prompt_timeout: int = 30 - ): + def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30): self._session = session self._prompt_name = prompt_def.name self._name = f"mcp_{server_name}_prompt_{prompt_def.name}" @@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool): timeout=self._prompt_timeout, ) except asyncio.TimeoutError: - logger.warning( - "MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout - ) + logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout) return f"(MCP prompt call timed out after {self._prompt_timeout}s)" except asyncio.CancelledError: task = asyncio.current_task() @@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool): except McpError as exc: logger.error( "MCP prompt '{}' failed: code={} message={}", - self._name, exc.error.code, exc.error.message, + self._name, + exc.error.code, + exc.error.message, ) return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])" except Exception as exc: logger.exception( "MCP prompt '{}' failed: {}: {}", - self._name, type(exc).__name__, exc, + self._name, + type(exc).__name__, + exc, ) return f"(MCP prompt call failed: {type(exc).__name__})" @@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool): async def connect_mcp_servers( - mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack -) -> None: - """Connect to configured MCP servers and register their tools, resources, and prompts.""" + mcp_servers: dict, registry: ToolRegistry +) -> dict[str, AsyncExitStack]: + """Connect to configured MCP servers and register their tools, resources, prompts. + + Returns a dict mapping server name -> its dedicated AsyncExitStack. + Each server gets its own stack and runs in its own task to prevent + cancel scope conflicts when multiple MCP servers are configured. + """ from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamable_http_client - for name, cfg in mcp_servers.items(): + async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]: + server_stack = AsyncExitStack() + await server_stack.__aenter__() + try: transport_type = cfg.type if not transport_type: if cfg.command: transport_type = "stdio" elif cfg.url: - # Convention: URLs ending with /sse use SSE transport; others use streamableHttp transport_type = ( "sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp" ) else: logger.warning("MCP server '{}': no command or url configured, skipping", name) - continue + await server_stack.aclose() + return name, None if transport_type == "stdio": params = StdioServerParameters( command=cfg.command, args=cfg.args, env=cfg.env or None ) - read, write = await stack.enter_async_context(stdio_client(params)) + read, write = await server_stack.enter_async_context(stdio_client(params)) elif transport_type == "sse": + def httpx_client_factory( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, @@ -353,27 +358,26 @@ async def connect_mcp_servers( auth=auth, ) - read, write = await stack.enter_async_context( + read, write = await server_stack.enter_async_context( sse_client(cfg.url, httpx_client_factory=httpx_client_factory) ) elif transport_type == "streamableHttp": - # Always provide an explicit httpx client so MCP HTTP transport does not - # inherit httpx's default 5s timeout and preempt the higher-level tool timeout. - http_client = await stack.enter_async_context( + http_client = await server_stack.enter_async_context( httpx.AsyncClient( headers=cfg.headers or None, follow_redirects=True, timeout=None, ) ) - read, write, _ = await stack.enter_async_context( + read, write, _ = await server_stack.enter_async_context( streamable_http_client(cfg.url, http_client=http_client) ) else: logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type) - continue + await server_stack.aclose() + return name, None - session = await stack.enter_async_context(ClientSession(read, write)) + session = await server_stack.enter_async_context(ClientSession(read, write)) await session.initialize() tools = await session.list_tools() @@ -418,7 +422,6 @@ async def connect_mcp_servers( ", ".join(available_wrapped_names) or "(none)", ) - # --- Register resources --- try: resources_result = await session.list_resources() for resource in resources_result.resources: @@ -433,7 +436,6 @@ async def connect_mcp_servers( except Exception as e: logger.debug("MCP server '{}': resources not supported or failed: {}", name, e) - # --- Register prompts --- try: prompts_result = await session.list_prompts() for prompt in prompts_result.prompts: @@ -442,14 +444,38 @@ async def connect_mcp_servers( ) registry.register(wrapper) registered_count += 1 - logger.debug( - "MCP: registered prompt '{}' from server '{}'", wrapper.name, name - ) + logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name) except Exception as e: logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e) logger.info( "MCP server '{}': connected, {} capabilities registered", name, registered_count ) + return name, server_stack + except Exception as e: logger.error("MCP server '{}': failed to connect: {}", name, e) + try: + await server_stack.aclose() + except Exception: + pass + return name, None + + server_stacks: dict[str, AsyncExitStack] = {} + + tasks: list[asyncio.Task] = [] + for name, cfg in mcp_servers.items(): + task = asyncio.create_task(connect_single_server(name, cfg)) + tasks.append(task) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, result in enumerate(results): + name = list(mcp_servers.keys())[i] + if isinstance(result, BaseException): + if not isinstance(result, asyncio.CancelledError): + logger.error("MCP server '{}' connection task failed: {}", name, result) + elif result is not None and result[1] is not None: + server_stacks[result[0]] = result[1] + + return server_stacks diff --git a/tests/agent/test_mcp_connection.py b/tests/agent/test_mcp_connection.py new file mode 100644 index 000000000..e7d0a7854 --- /dev/null +++ b/tests/agent/test_mcp_connection.py @@ -0,0 +1,44 @@ +"""Tests for MCP connection lifecycle in AgentLoop.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from nanobot.agent.loop import AgentLoop +from nanobot.bus.queue import MessageBus + + +def _make_loop(tmp_path, *, mcp_servers: dict | None = None) -> AgentLoop: + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.generation.max_tokens = 4096 + return AgentLoop( + bus=bus, + provider=provider, + workspace=tmp_path, + model="test-model", + mcp_servers=mcp_servers or {"test": object()}, + ) + + +@pytest.mark.asyncio +async def test_connect_mcp_retries_when_no_servers_connect(tmp_path, monkeypatch: pytest.MonkeyPatch): + loop = _make_loop(tmp_path) + attempts = 0 + + async def _fake_connect(_servers, _registry): + nonlocal attempts + attempts += 1 + return {} + + monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect) + + await loop._connect_mcp() + await loop._connect_mcp() + + assert attempts == 2 + assert loop._mcp_connected is False + assert loop._mcp_stacks == {} diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 67382d9ea..da90c4d0d 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names( ) -> None: fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake", enabled_tools=["demo"])}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake", enabled_tools=["demo"])}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == ["mcp_test_demo"] @@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all( ) -> None: fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake")}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake")}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"] @@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names( ) -> None: fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == ["mcp_test_demo"] @@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none( ) -> None: fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake", enabled_tools=[])}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake", enabled_tools=[])}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == [] @@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries( monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning) - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert registry.tool_names == [] @@ -376,6 +356,46 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries( assert "Available wrapped names: mcp_test_demo" in warnings[-1] +@pytest.mark.asyncio +async def test_connect_mcp_servers_one_failure_does_not_block_others( + monkeypatch: pytest.MonkeyPatch, +) -> None: + sessions = {"good": _make_fake_session(["demo"])} + + class _SelectiveClientSession: + def __init__(self, read: object, _write: object) -> None: + self._session = sessions[read] + + async def __aenter__(self) -> object: + return self._session + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False + + @asynccontextmanager + async def _selective_stdio_client(params: object): + if params.command == "bad": + raise RuntimeError("boom") + yield params.command, object() + + monkeypatch.setattr(sys.modules["mcp"], "ClientSession", _SelectiveClientSession) + monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _selective_stdio_client) + + registry = ToolRegistry() + stacks = await connect_mcp_servers( + { + "good": MCPServerConfig(command="good"), + "bad": MCPServerConfig(command="bad"), + }, + registry, + ) + for stack in stacks.values(): + await stack.aclose() + + assert registry.tool_names == ["mcp_good_demo"] + assert set(stacks) == {"good"} + + # --------------------------------------------------------------------------- # MCPResourceWrapper tests # --------------------------------------------------------------------------- @@ -389,9 +409,7 @@ def _make_resource_def( return SimpleNamespace(name=name, uri=uri, description=description) -def _make_resource_wrapper( - session: object, *, timeout: float = 0.1 -) -> MCPResourceWrapper: +def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper: return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout) @@ -434,9 +452,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None: await asyncio.sleep(1) return SimpleNamespace(contents=[]) - wrapper = _make_resource_wrapper( - SimpleNamespace(read_resource=read_resource), timeout=0.01 - ) + wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01) result = await wrapper.execute() assert result == "(MCP resource read timed out after 0.01s)" @@ -464,20 +480,14 @@ def _make_prompt_def( return SimpleNamespace(name=name, description=description, arguments=arguments) -def _make_prompt_wrapper( - session: object, *, timeout: float = 0.1 -) -> MCPPromptWrapper: - return MCPPromptWrapper( - session, "srv", _make_prompt_def(), prompt_timeout=timeout - ) +def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper: + return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout) def test_prompt_wrapper_properties() -> None: arg1 = SimpleNamespace(name="topic", required=True) arg2 = SimpleNamespace(name="style", required=False) - wrapper = MCPPromptWrapper( - None, "myserver", _make_prompt_def(arguments=[arg1, arg2]) - ) + wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2])) assert wrapper.name == "mcp_myserver_prompt_myprompt" assert "[MCP Prompt]" in wrapper.description assert "A test prompt" in wrapper.description @@ -528,9 +538,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None: await asyncio.sleep(1) return SimpleNamespace(messages=[]) - wrapper = _make_prompt_wrapper( - SimpleNamespace(get_prompt=get_prompt), timeout=0.01 - ) + wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01) result = await wrapper.execute() assert result == "(MCP prompt call timed out after 0.01s)" @@ -616,15 +624,11 @@ async def test_connect_registers_resources_and_prompts( prompt_names=["prompt_c"], ) registry = ToolRegistry() - stack = AsyncExitStack() - await stack.__aenter__() - try: - await connect_mcp_servers( - {"test": MCPServerConfig(command="fake")}, - registry, - stack, - ) - finally: + stacks = await connect_mcp_servers( + {"test": MCPServerConfig(command="fake")}, + registry, + ) + for stack in stacks.values(): await stack.aclose() assert "mcp_test_tool_a" in registry.tool_names