fix(mcp): support multiple MCP servers by connecting each in isolated task

Each MCP server now connects in its own asyncio.Task to isolate anyio
cancel scopes and prevent 'exit cancel scope in different task' errors
when multiple servers (especially mixed transport types) are configured.

Changes:
- connect_mcp_servers() returns dict[str, AsyncExitStack] instead of None
- Each server runs in separate task via asyncio.gather()
- AgentLoop uses _mcp_stacks dict to track per-server stacks
- Tests updated to handle new API
This commit is contained in:
worenidewen 2026-04-10 23:51:50 +08:00
parent 9bccfa63d2
commit a167959027
3 changed files with 247 additions and 198 deletions

View File

@ -43,6 +43,7 @@ if TYPE_CHECKING:
UNIFIED_SESSION_KEY = "unified:default" UNIFIED_SESSION_KEY = "unified:default"
class _LoopHook(AgentHook): class _LoopHook(AgentHook):
"""Core hook for the main loop.""" """Core hook for the main loop."""
@ -76,7 +77,7 @@ class _LoopHook(AgentHook):
prev_clean = strip_think(self._stream_buf) prev_clean = strip_think(self._stream_buf)
self._stream_buf += delta self._stream_buf += delta
new_clean = strip_think(self._stream_buf) new_clean = strip_think(self._stream_buf)
incremental = new_clean[len(prev_clean):] incremental = new_clean[len(prev_clean) :]
if incremental and self._on_stream: if incremental and self._on_stream:
await self._on_stream(incremental) await self._on_stream(incremental)
@ -112,6 +113,7 @@ class _LoopHook(AgentHook):
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return self._loop._strip_think(content) return self._loop._strip_think(content)
class AgentLoop: class AgentLoop:
""" """
The agent loop is the core processing engine. The agent loop is the core processing engine.
@ -196,7 +198,7 @@ class AgentLoop:
self._unified_session = unified_session self._unified_session = unified_session
self._running = False self._running = False
self._mcp_servers = mcp_servers or {} self._mcp_servers = mcp_servers or {}
self._mcp_stack: AsyncExitStack | None = None self._mcp_stacks: dict[str, AsyncExitStack] = {}
self._mcp_connected = False self._mcp_connected = False
self._mcp_connecting = False self._mcp_connecting = False
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
@ -228,24 +230,34 @@ class AgentLoop:
def _register_default_tools(self) -> None: def _register_default_tools(self) -> None:
"""Register the default set of tools.""" """Register the default set of tools."""
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None allowed_dir = (
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
)
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read)) self.tools.register(
ReadFileTool(
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
)
)
for cls in (WriteFileTool, EditFileTool, ListDirTool): for cls in (WriteFileTool, EditFileTool, ListDirTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
for cls in (GlobTool, GrepTool): for cls in (GlobTool, GrepTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
if self.exec_config.enable: if self.exec_config.enable:
self.tools.register(ExecTool( self.tools.register(
working_dir=str(self.workspace), ExecTool(
timeout=self.exec_config.timeout, working_dir=str(self.workspace),
restrict_to_workspace=self.restrict_to_workspace, timeout=self.exec_config.timeout,
sandbox=self.exec_config.sandbox, restrict_to_workspace=self.restrict_to_workspace,
path_append=self.exec_config.path_append, sandbox=self.exec_config.sandbox,
allowed_env_keys=self.exec_config.allowed_env_keys, path_append=self.exec_config.path_append,
)) allowed_env_keys=self.exec_config.allowed_env_keys,
)
)
if self.web_config.enable: if self.web_config.enable:
self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)) self.tools.register(
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
)
self.tools.register(WebFetchTool(proxy=self.web_config.proxy)) self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound)) self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
self.tools.register(SpawnTool(manager=self.subagents)) self.tools.register(SpawnTool(manager=self.subagents))
@ -260,19 +272,16 @@ class AgentLoop:
return return
self._mcp_connecting = True self._mcp_connecting = True
from nanobot.agent.tools.mcp import connect_mcp_servers from nanobot.agent.tools.mcp import connect_mcp_servers
try: try:
self._mcp_stack = AsyncExitStack() self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools)
await self._mcp_stack.__aenter__()
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
self._mcp_connected = True self._mcp_connected = True
except asyncio.CancelledError:
logger.warning("MCP connection cancelled (will retry next message)")
self._mcp_stacks.clear()
except BaseException as e: except BaseException as e:
logger.error("Failed to connect MCP servers (will retry next message): {}", e) logger.error("Failed to connect MCP servers (will retry next message): {}", e)
if self._mcp_stack: self._mcp_stacks.clear()
try:
await self._mcp_stack.aclose()
except Exception:
pass
self._mcp_stack = None
finally: finally:
self._mcp_connecting = False self._mcp_connecting = False
@ -289,6 +298,7 @@ class AgentLoop:
if not text: if not text:
return None return None
from nanobot.utils.helpers import strip_think from nanobot.utils.helpers import strip_think
return strip_think(text) or None return strip_think(text) or None
@staticmethod @staticmethod
@ -327,9 +337,7 @@ class AgentLoop:
message_id=message_id, message_id=message_id,
) )
hook: AgentHook = ( hook: AgentHook = (
CompositeHook([loop_hook] + self._extra_hooks) CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
if self._extra_hooks
else loop_hook
) )
async def _checkpoint(payload: dict[str, Any]) -> None: async def _checkpoint(payload: dict[str, Any]) -> None:
@ -337,23 +345,25 @@ class AgentLoop:
return return
self._set_runtime_checkpoint(session, payload) self._set_runtime_checkpoint(session, payload)
result = await self.runner.run(AgentRunSpec( result = await self.runner.run(
initial_messages=initial_messages, AgentRunSpec(
tools=self.tools, initial_messages=initial_messages,
model=self.model, tools=self.tools,
max_iterations=self.max_iterations, model=self.model,
max_tool_result_chars=self.max_tool_result_chars, max_iterations=self.max_iterations,
hook=hook, max_tool_result_chars=self.max_tool_result_chars,
error_message="Sorry, I encountered an error calling the AI model.", hook=hook,
concurrent_tools=True, error_message="Sorry, I encountered an error calling the AI model.",
workspace=self.workspace, concurrent_tools=True,
session_key=session.key if session else None, workspace=self.workspace,
context_window_tokens=self.context_window_tokens, session_key=session.key if session else None,
context_block_limit=self.context_block_limit, context_window_tokens=self.context_window_tokens,
provider_retry_mode=self.provider_retry_mode, context_block_limit=self.context_block_limit,
progress_callback=on_progress, provider_retry_mode=self.provider_retry_mode,
checkpoint_callback=_checkpoint, progress_callback=on_progress,
)) checkpoint_callback=_checkpoint,
)
)
self._last_usage = result.usage self._last_usage = result.usage
if result.stop_reason == "max_iterations": if result.stop_reason == "max_iterations":
logger.warning("Max iterations ({}) reached", self.max_iterations) logger.warning("Max iterations ({}) reached", self.max_iterations)
@ -391,10 +401,19 @@ class AgentLoop:
continue continue
# Compute the effective session key before dispatching # Compute the effective session key before dispatching
# This ensures /stop command can find tasks correctly when unified session is enabled # This ensures /stop command can find tasks correctly when unified session is enabled
effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key effective_key = (
UNIFIED_SESSION_KEY
if self._unified_session and not msg.session_key_override
else msg.session_key
)
task = asyncio.create_task(self._dispatch(msg)) task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(effective_key, []).append(task) self._active_tasks.setdefault(effective_key, []).append(task)
task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) task.add_done_callback(
lambda t, k=effective_key: self._active_tasks.get(k, [])
and self._active_tasks[k].remove(t)
if t in self._active_tasks.get(k, [])
else None
)
async def _dispatch(self, msg: InboundMessage) -> None: async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message: per-session serial, cross-session concurrent.""" """Process a message: per-session serial, cross-session concurrent."""
@ -417,11 +436,14 @@ class AgentLoop:
meta = dict(msg.metadata or {}) meta = dict(msg.metadata or {})
meta["_stream_delta"] = True meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id() meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage( await self.bus.publish_outbound(
channel=msg.channel, chat_id=msg.chat_id, OutboundMessage(
content=delta, channel=msg.channel,
metadata=meta, chat_id=msg.chat_id,
)) content=delta,
metadata=meta,
)
)
async def on_stream_end(*, resuming: bool = False) -> None: async def on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment nonlocal stream_segment
@ -429,44 +451,56 @@ class AgentLoop:
meta["_stream_end"] = True meta["_stream_end"] = True
meta["_resuming"] = resuming meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id() meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage( await self.bus.publish_outbound(
channel=msg.channel, chat_id=msg.chat_id, OutboundMessage(
content="", channel=msg.channel,
metadata=meta, chat_id=msg.chat_id,
)) content="",
metadata=meta,
)
)
stream_segment += 1 stream_segment += 1
response = await self._process_message( response = await self._process_message(
msg, on_stream=on_stream, on_stream_end=on_stream_end, msg,
on_stream=on_stream,
on_stream_end=on_stream_end,
) )
if response is not None: if response is not None:
await self.bus.publish_outbound(response) await self.bus.publish_outbound(response)
elif msg.channel == "cli": elif msg.channel == "cli":
await self.bus.publish_outbound(OutboundMessage( await self.bus.publish_outbound(
channel=msg.channel, chat_id=msg.chat_id, OutboundMessage(
content="", metadata=msg.metadata or {}, channel=msg.channel,
)) chat_id=msg.chat_id,
content="",
metadata=msg.metadata or {},
)
)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Task cancelled for session {}", msg.session_key) logger.info("Task cancelled for session {}", msg.session_key)
raise raise
except Exception: except Exception:
logger.exception("Error processing message for session {}", msg.session_key) logger.exception("Error processing message for session {}", msg.session_key)
await self.bus.publish_outbound(OutboundMessage( await self.bus.publish_outbound(
channel=msg.channel, chat_id=msg.chat_id, OutboundMessage(
content="Sorry, I encountered an error.", channel=msg.channel,
)) chat_id=msg.chat_id,
content="Sorry, I encountered an error.",
)
)
async def close_mcp(self) -> None: async def close_mcp(self) -> None:
"""Drain pending background archives, then close MCP connections.""" """Drain pending background archives, then close MCP connections."""
if self._background_tasks: if self._background_tasks:
await asyncio.gather(*self._background_tasks, return_exceptions=True) await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear() self._background_tasks.clear()
if self._mcp_stack: for name, stack in self._mcp_stacks.items():
try: try:
await self._mcp_stack.aclose() await stack.aclose()
except (RuntimeError, BaseExceptionGroup): except (RuntimeError, BaseExceptionGroup):
pass # MCP SDK cancel scope cleanup is noisy but harmless logger.debug("MCP server '{}' cleanup error (can be ignored)", name)
self._mcp_stack = None self._mcp_stacks.clear()
def _schedule_background(self, coro) -> None: def _schedule_background(self, coro) -> None:
"""Schedule a coroutine as a tracked background task (drained on shutdown).""" """Schedule a coroutine as a tracked background task (drained on shutdown)."""
@ -490,8 +524,9 @@ class AgentLoop:
"""Process a single inbound message and return the response.""" """Process a single inbound message and return the response."""
# System messages: parse origin from chat_id ("channel:chat_id") # System messages: parse origin from chat_id ("channel:chat_id")
if msg.channel == "system": if msg.channel == "system":
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id channel, chat_id = (
else ("cli", msg.chat_id)) msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
)
logger.info("Processing system message from {}", msg.sender_id) logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}" key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
@ -503,19 +538,27 @@ class AgentLoop:
current_role = "assistant" if msg.sender_id == "subagent" else "user" current_role = "assistant" if msg.sender_id == "subagent" else "user"
messages = self.context.build_messages( messages = self.context.build_messages(
history=history, history=history,
current_message=msg.content, channel=channel, chat_id=chat_id, current_message=msg.content,
channel=channel,
chat_id=chat_id,
current_role=current_role, current_role=current_role,
) )
final_content, _, all_msgs, _ = await self._run_agent_loop( final_content, _, all_msgs, _ = await self._run_agent_loop(
messages, session=session, channel=channel, chat_id=chat_id, messages,
session=session,
channel=channel,
chat_id=chat_id,
message_id=msg.metadata.get("message_id"), message_id=msg.metadata.get("message_id"),
) )
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self._clear_runtime_checkpoint(session) self._clear_runtime_checkpoint(session)
self.sessions.save(session) self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
return OutboundMessage(channel=channel, chat_id=chat_id, return OutboundMessage(
content=final_content or "Background task completed.") channel=channel,
chat_id=chat_id,
content=final_content or "Background task completed.",
)
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview) logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
@ -543,16 +586,22 @@ class AgentLoop:
history=history, history=history,
current_message=msg.content, current_message=msg.content,
media=msg.media if msg.media else None, media=msg.media if msg.media else None,
channel=msg.channel, chat_id=msg.chat_id, channel=msg.channel,
chat_id=msg.chat_id,
) )
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None: async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
meta = dict(msg.metadata or {}) meta = dict(msg.metadata or {})
meta["_progress"] = True meta["_progress"] = True
meta["_tool_hint"] = tool_hint meta["_tool_hint"] = tool_hint
await self.bus.publish_outbound(OutboundMessage( await self.bus.publish_outbound(
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta, OutboundMessage(
)) channel=msg.channel,
chat_id=msg.chat_id,
content=content,
metadata=meta,
)
)
final_content, _, all_msgs, stop_reason = await self._run_agent_loop( final_content, _, all_msgs, stop_reason = await self._run_agent_loop(
initial_messages, initial_messages,
@ -560,7 +609,8 @@ class AgentLoop:
on_stream=on_stream, on_stream=on_stream,
on_stream_end=on_stream_end, on_stream_end=on_stream_end,
session=session, session=session,
channel=msg.channel, chat_id=msg.chat_id, channel=msg.channel,
chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"), message_id=msg.metadata.get("message_id"),
) )
@ -582,7 +632,9 @@ class AgentLoop:
if on_stream is not None and stop_reason != "error": if on_stream is not None and stop_reason != "error":
meta["_streamed"] = True meta["_streamed"] = True
return OutboundMessage( return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=final_content, channel=msg.channel,
chat_id=msg.chat_id,
content=final_content,
metadata=meta, metadata=meta,
) )
@ -608,10 +660,9 @@ class AgentLoop:
): ):
continue continue
if ( if block.get("type") == "image_url" and block.get("image_url", {}).get(
block.get("type") == "image_url" "url", ""
and block.get("image_url", {}).get("url", "").startswith("data:image/") ).startswith("data:image/"):
):
path = (block.get("_meta") or {}).get("path", "") path = (block.get("_meta") or {}).get("path", "")
filtered.append({"type": "text", "text": image_placeholder_text(path)}) filtered.append({"type": "text", "text": image_placeholder_text(path)})
continue continue
@ -630,6 +681,7 @@ class AgentLoop:
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None: def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
"""Save new-turn messages into session, truncating large tool results.""" """Save new-turn messages into session, truncating large tool results."""
from datetime import datetime from datetime import datetime
for m in messages[skip:]: for m in messages[skip:]:
entry = dict(m) entry = dict(m)
role, content = entry.get("role"), entry.get("content") role, content = entry.get("role"), entry.get("content")
@ -644,7 +696,9 @@ class AgentLoop:
continue continue
entry["content"] = filtered entry["content"] = filtered
elif role == "user": elif role == "user":
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG): if isinstance(content, str) and content.startswith(
ContextBuilder._RUNTIME_CONTEXT_TAG
):
# Strip the runtime-context prefix, keep only the user text. # Strip the runtime-context prefix, keep only the user text.
parts = content.split("\n\n", 1) parts = content.split("\n\n", 1)
if len(parts) > 1 and parts[1].strip(): if len(parts) > 1 and parts[1].strip():
@ -708,13 +762,15 @@ class AgentLoop:
continue continue
tool_id = tool_call.get("id") tool_id = tool_call.get("id")
name = ((tool_call.get("function") or {}).get("name")) or "tool" name = ((tool_call.get("function") or {}).get("name")) or "tool"
restored_messages.append({ restored_messages.append(
"role": "tool", {
"tool_call_id": tool_id, "role": "tool",
"name": name, "tool_call_id": tool_id,
"content": "Error: Task interrupted before this tool finished.", "name": name,
"timestamp": datetime.now().isoformat(), "content": "Error: Task interrupted before this tool finished.",
}) "timestamp": datetime.now().isoformat(),
}
)
overlap = 0 overlap = 0
max_overlap = min(len(session.messages), len(restored_messages)) max_overlap = min(len(session.messages), len(restored_messages))
@ -746,6 +802,9 @@ class AgentLoop:
await self._connect_mcp() await self._connect_mcp()
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content) msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
return await self._process_message( return await self._process_message(
msg, session_key=session_key, on_progress=on_progress, msg,
on_stream=on_stream, on_stream_end=on_stream_end, session_key=session_key,
on_progress=on_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
) )

View File

@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
if "properties" in normalized and isinstance(normalized["properties"], dict): if "properties" in normalized and isinstance(normalized["properties"], dict):
normalized["properties"] = { normalized["properties"] = {
name: _normalize_schema_for_openai(prop) name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop
if isinstance(prop, dict)
else prop
for name, prop in normalized["properties"].items() for name, prop in normalized["properties"].items()
} }
@ -138,9 +136,7 @@ class MCPToolWrapper(Tool):
class MCPResourceWrapper(Tool): class MCPResourceWrapper(Tool):
"""Wraps an MCP resource URI as a read-only nanobot Tool.""" """Wraps an MCP resource URI as a read-only nanobot Tool."""
def __init__( def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
self, session, server_name: str, resource_def, resource_timeout: int = 30
):
self._session = session self._session = session
self._uri = resource_def.uri self._uri = resource_def.uri
self._name = f"mcp_{server_name}_resource_{resource_def.name}" self._name = f"mcp_{server_name}_resource_{resource_def.name}"
@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool):
class MCPPromptWrapper(Tool): class MCPPromptWrapper(Tool):
"""Wraps an MCP prompt as a read-only nanobot Tool.""" """Wraps an MCP prompt as a read-only nanobot Tool."""
def __init__( def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
self, session, server_name: str, prompt_def, prompt_timeout: int = 30
):
self._session = session self._session = session
self._prompt_name = prompt_def.name self._prompt_name = prompt_def.name
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}" self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool):
timeout=self._prompt_timeout, timeout=self._prompt_timeout,
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning( logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout)
"MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout
)
return f"(MCP prompt call timed out after {self._prompt_timeout}s)" return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
except asyncio.CancelledError: except asyncio.CancelledError:
task = asyncio.current_task() task = asyncio.current_task()
@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool):
except McpError as exc: except McpError as exc:
logger.error( logger.error(
"MCP prompt '{}' failed: code={} message={}", "MCP prompt '{}' failed: code={} message={}",
self._name, exc.error.code, exc.error.message, self._name,
exc.error.code,
exc.error.message,
) )
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])" return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
except Exception as exc: except Exception as exc:
logger.exception( logger.exception(
"MCP prompt '{}' failed: {}: {}", "MCP prompt '{}' failed: {}: {}",
self._name, type(exc).__name__, exc, self._name,
type(exc).__name__,
exc,
) )
return f"(MCP prompt call failed: {type(exc).__name__})" return f"(MCP prompt call failed: {type(exc).__name__})"
@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool):
async def connect_mcp_servers( async def connect_mcp_servers(
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack mcp_servers: dict, registry: ToolRegistry
) -> None: ) -> dict[str, AsyncExitStack]:
"""Connect to configured MCP servers and register their tools, resources, and prompts.""" """Connect to configured MCP servers and register their tools, resources, prompts.
Returns a dict mapping server name -> its dedicated AsyncExitStack.
Each server gets its own stack and runs in its own task to prevent
cancel scope conflicts when multiple MCP servers are configured.
"""
from mcp import ClientSession, StdioServerParameters from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client from mcp.client.streamable_http import streamable_http_client
for name, cfg in mcp_servers.items(): async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]:
server_stack = AsyncExitStack()
await server_stack.__aenter__()
try: try:
transport_type = cfg.type transport_type = cfg.type
if not transport_type: if not transport_type:
if cfg.command: if cfg.command:
transport_type = "stdio" transport_type = "stdio"
elif cfg.url: elif cfg.url:
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
transport_type = ( transport_type = (
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp" "sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
) )
else: else:
logger.warning("MCP server '{}': no command or url configured, skipping", name) logger.warning("MCP server '{}': no command or url configured, skipping", name)
continue await server_stack.aclose()
return name, None
if transport_type == "stdio": if transport_type == "stdio":
params = StdioServerParameters( params = StdioServerParameters(
command=cfg.command, args=cfg.args, env=cfg.env or None command=cfg.command, args=cfg.args, env=cfg.env or None
) )
read, write = await stack.enter_async_context(stdio_client(params)) read, write = await server_stack.enter_async_context(stdio_client(params))
elif transport_type == "sse": elif transport_type == "sse":
def httpx_client_factory( def httpx_client_factory(
headers: dict[str, str] | None = None, headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None, timeout: httpx.Timeout | None = None,
@ -353,27 +358,26 @@ async def connect_mcp_servers(
auth=auth, auth=auth,
) )
read, write = await stack.enter_async_context( read, write = await server_stack.enter_async_context(
sse_client(cfg.url, httpx_client_factory=httpx_client_factory) sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
) )
elif transport_type == "streamableHttp": elif transport_type == "streamableHttp":
# Always provide an explicit httpx client so MCP HTTP transport does not http_client = await server_stack.enter_async_context(
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
http_client = await stack.enter_async_context(
httpx.AsyncClient( httpx.AsyncClient(
headers=cfg.headers or None, headers=cfg.headers or None,
follow_redirects=True, follow_redirects=True,
timeout=None, timeout=None,
) )
) )
read, write, _ = await stack.enter_async_context( read, write, _ = await server_stack.enter_async_context(
streamable_http_client(cfg.url, http_client=http_client) streamable_http_client(cfg.url, http_client=http_client)
) )
else: else:
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type) logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
continue await server_stack.aclose()
return name, None
session = await stack.enter_async_context(ClientSession(read, write)) session = await server_stack.enter_async_context(ClientSession(read, write))
await session.initialize() await session.initialize()
tools = await session.list_tools() tools = await session.list_tools()
@ -418,7 +422,6 @@ async def connect_mcp_servers(
", ".join(available_wrapped_names) or "(none)", ", ".join(available_wrapped_names) or "(none)",
) )
# --- Register resources ---
try: try:
resources_result = await session.list_resources() resources_result = await session.list_resources()
for resource in resources_result.resources: for resource in resources_result.resources:
@ -433,7 +436,6 @@ async def connect_mcp_servers(
except Exception as e: except Exception as e:
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e) logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
# --- Register prompts ---
try: try:
prompts_result = await session.list_prompts() prompts_result = await session.list_prompts()
for prompt in prompts_result.prompts: for prompt in prompts_result.prompts:
@ -442,14 +444,38 @@ async def connect_mcp_servers(
) )
registry.register(wrapper) registry.register(wrapper)
registered_count += 1 registered_count += 1
logger.debug( logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name)
"MCP: registered prompt '{}' from server '{}'", wrapper.name, name
)
except Exception as e: except Exception as e:
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e) logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
logger.info( logger.info(
"MCP server '{}': connected, {} capabilities registered", name, registered_count "MCP server '{}': connected, {} capabilities registered", name, registered_count
) )
return name, server_stack
except Exception as e: except Exception as e:
logger.error("MCP server '{}': failed to connect: {}", name, e) logger.error("MCP server '{}': failed to connect: {}", name, e)
try:
await server_stack.aclose()
except Exception:
pass
return name, None
server_stacks: dict[str, AsyncExitStack] = {}
tasks: list[asyncio.Task] = []
for name, cfg in mcp_servers.items():
task = asyncio.create_task(connect_single_server(name, cfg))
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
name = list(mcp_servers.keys())[i]
if isinstance(result, BaseException):
if not isinstance(result, asyncio.CancelledError):
logger.error("MCP server '{}' connection task failed: {}", name, result)
elif result is not None and result[1] is not None:
server_stacks[result[0]] = result[1]
return server_stacks

View File

@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
) -> None: ) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry() registry = ToolRegistry()
stack = AsyncExitStack() stacks = await connect_mcp_servers(
await stack.__aenter__() {"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
try: registry,
await connect_mcp_servers( )
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])}, for stack in stacks.values():
registry,
stack,
)
finally:
await stack.aclose() await stack.aclose()
assert registry.tool_names == ["mcp_test_demo"] assert registry.tool_names == ["mcp_test_demo"]
@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
) -> None: ) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry() registry = ToolRegistry()
stack = AsyncExitStack() stacks = await connect_mcp_servers(
await stack.__aenter__() {"test": MCPServerConfig(command="fake")},
try: registry,
await connect_mcp_servers( )
{"test": MCPServerConfig(command="fake")}, for stack in stacks.values():
registry,
stack,
)
finally:
await stack.aclose() await stack.aclose()
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"] assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
) -> None: ) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry() registry = ToolRegistry()
stack = AsyncExitStack() stacks = await connect_mcp_servers(
await stack.__aenter__() {"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
try: registry,
await connect_mcp_servers( )
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])}, for stack in stacks.values():
registry,
stack,
)
finally:
await stack.aclose() await stack.aclose()
assert registry.tool_names == ["mcp_test_demo"] assert registry.tool_names == ["mcp_test_demo"]
@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
) -> None: ) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"]) fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry() registry = ToolRegistry()
stack = AsyncExitStack() stacks = await connect_mcp_servers(
await stack.__aenter__() {"test": MCPServerConfig(command="fake", enabled_tools=[])},
try: registry,
await connect_mcp_servers( )
{"test": MCPServerConfig(command="fake", enabled_tools=[])}, for stack in stacks.values():
registry,
stack,
)
finally:
await stack.aclose() await stack.aclose()
assert registry.tool_names == [] assert registry.tool_names == []
@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning) monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
stack = AsyncExitStack() stacks = await connect_mcp_servers(
await stack.__aenter__() {"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
try: registry,
await connect_mcp_servers( )
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])}, for stack in stacks.values():
registry,
stack,
)
finally:
await stack.aclose() await stack.aclose()
assert registry.tool_names == [] assert registry.tool_names == []
@ -389,9 +369,7 @@ def _make_resource_def(
return SimpleNamespace(name=name, uri=uri, description=description) return SimpleNamespace(name=name, uri=uri, description=description)
def _make_resource_wrapper( def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper:
session: object, *, timeout: float = 0.1
) -> MCPResourceWrapper:
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout) return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
@ -434,9 +412,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None:
await asyncio.sleep(1) await asyncio.sleep(1)
return SimpleNamespace(contents=[]) return SimpleNamespace(contents=[])
wrapper = _make_resource_wrapper( wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01)
SimpleNamespace(read_resource=read_resource), timeout=0.01
)
result = await wrapper.execute() result = await wrapper.execute()
assert result == "(MCP resource read timed out after 0.01s)" assert result == "(MCP resource read timed out after 0.01s)"
@ -464,20 +440,14 @@ def _make_prompt_def(
return SimpleNamespace(name=name, description=description, arguments=arguments) return SimpleNamespace(name=name, description=description, arguments=arguments)
def _make_prompt_wrapper( def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper:
session: object, *, timeout: float = 0.1 return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout)
) -> MCPPromptWrapper:
return MCPPromptWrapper(
session, "srv", _make_prompt_def(), prompt_timeout=timeout
)
def test_prompt_wrapper_properties() -> None: def test_prompt_wrapper_properties() -> None:
arg1 = SimpleNamespace(name="topic", required=True) arg1 = SimpleNamespace(name="topic", required=True)
arg2 = SimpleNamespace(name="style", required=False) arg2 = SimpleNamespace(name="style", required=False)
wrapper = MCPPromptWrapper( wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2]))
None, "myserver", _make_prompt_def(arguments=[arg1, arg2])
)
assert wrapper.name == "mcp_myserver_prompt_myprompt" assert wrapper.name == "mcp_myserver_prompt_myprompt"
assert "[MCP Prompt]" in wrapper.description assert "[MCP Prompt]" in wrapper.description
assert "A test prompt" in wrapper.description assert "A test prompt" in wrapper.description
@ -528,9 +498,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None:
await asyncio.sleep(1) await asyncio.sleep(1)
return SimpleNamespace(messages=[]) return SimpleNamespace(messages=[])
wrapper = _make_prompt_wrapper( wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01)
SimpleNamespace(get_prompt=get_prompt), timeout=0.01
)
result = await wrapper.execute() result = await wrapper.execute()
assert result == "(MCP prompt call timed out after 0.01s)" assert result == "(MCP prompt call timed out after 0.01s)"
@ -616,15 +584,11 @@ async def test_connect_registers_resources_and_prompts(
prompt_names=["prompt_c"], prompt_names=["prompt_c"],
) )
registry = ToolRegistry() registry = ToolRegistry()
stack = AsyncExitStack() stacks = await connect_mcp_servers(
await stack.__aenter__() {"test": MCPServerConfig(command="fake")},
try: registry,
await connect_mcp_servers( )
{"test": MCPServerConfig(command="fake")}, for stack in stacks.values():
registry,
stack,
)
finally:
await stack.aclose() await stack.aclose()
assert "mcp_test_tool_a" in registry.tool_names assert "mcp_test_tool_a" in registry.tool_names