worenidewen a167959027 fix(mcp): support multiple MCP servers by connecting each in isolated task
Each MCP server now connects in its own asyncio.Task to isolate anyio
cancel scopes and prevent 'exit cancel scope in different task' errors
when multiple servers (especially mixed transport types) are configured.

Changes:
- connect_mcp_servers() returns dict[str, AsyncExitStack] instead of None
- Each server runs in separate task via asyncio.gather()
- AgentLoop uses _mcp_stacks dict to track per-server stacks
- Tests updated to handle new API
2026-04-10 23:51:50 +08:00

482 lines
18 KiB
Python

"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools."""
import asyncio
from contextlib import AsyncExitStack
from typing import Any
import httpx
from loguru import logger
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry
def _extract_nullable_branch(options: Any) -> tuple[dict[str, Any], bool] | None:
"""Return the single non-null branch for nullable unions."""
if not isinstance(options, list):
return None
non_null: list[dict[str, Any]] = []
saw_null = False
for option in options:
if not isinstance(option, dict):
return None
if option.get("type") == "null":
saw_null = True
continue
non_null.append(option)
if saw_null and len(non_null) == 1:
return non_null[0], True
return None
def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
"""Normalize only nullable JSON Schema patterns for tool definitions."""
if not isinstance(schema, dict):
return {"type": "object", "properties": {}}
normalized = dict(schema)
raw_type = normalized.get("type")
if isinstance(raw_type, list):
non_null = [item for item in raw_type if item != "null"]
if "null" in raw_type and len(non_null) == 1:
normalized["type"] = non_null[0]
normalized["nullable"] = True
for key in ("oneOf", "anyOf"):
nullable_branch = _extract_nullable_branch(normalized.get(key))
if nullable_branch is not None:
branch, _ = nullable_branch
merged = {k: v for k, v in normalized.items() if k != key}
merged.update(branch)
normalized = merged
normalized["nullable"] = True
break
if "properties" in normalized and isinstance(normalized["properties"], dict):
normalized["properties"] = {
name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop
for name, prop in normalized["properties"].items()
}
if "items" in normalized and isinstance(normalized["items"], dict):
normalized["items"] = _normalize_schema_for_openai(normalized["items"])
if normalized.get("type") != "object":
return normalized
normalized.setdefault("properties", {})
normalized.setdefault("required", [])
return normalized
class MCPToolWrapper(Tool):
"""Wraps a single MCP server tool as a nanobot Tool."""
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
self._session = session
self._original_name = tool_def.name
self._name = f"mcp_{server_name}_{tool_def.name}"
self._description = tool_def.description or tool_def.name
raw_schema = tool_def.inputSchema or {"type": "object", "properties": {}}
self._parameters = _normalize_schema_for_openai(raw_schema)
self._tool_timeout = tool_timeout
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
@property
def parameters(self) -> dict[str, Any]:
return self._parameters
async def execute(self, **kwargs: Any) -> str:
from mcp import types
try:
result = await asyncio.wait_for(
self._session.call_tool(self._original_name, arguments=kwargs),
timeout=self._tool_timeout,
)
except asyncio.TimeoutError:
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
return f"(MCP tool call timed out after {self._tool_timeout}s)"
except asyncio.CancelledError:
# MCP SDK's anyio cancel scopes can leak CancelledError on timeout/failure.
# Re-raise only if our task was externally cancelled (e.g. /stop).
task = asyncio.current_task()
if task is not None and task.cancelling() > 0:
raise
logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name)
return "(MCP tool call was cancelled)"
except Exception as exc:
logger.exception(
"MCP tool '{}' failed: {}: {}",
self._name,
type(exc).__name__,
exc,
)
return f"(MCP tool call failed: {type(exc).__name__})"
parts = []
for block in result.content:
if isinstance(block, types.TextContent):
parts.append(block.text)
else:
parts.append(str(block))
return "\n".join(parts) or "(no output)"
class MCPResourceWrapper(Tool):
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
self._session = session
self._uri = resource_def.uri
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
desc = resource_def.description or resource_def.name
self._description = f"[MCP Resource] {desc}\nURI: {self._uri}"
self._parameters: dict[str, Any] = {
"type": "object",
"properties": {},
"required": [],
}
self._resource_timeout = resource_timeout
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
@property
def parameters(self) -> dict[str, Any]:
return self._parameters
@property
def read_only(self) -> bool:
return True
async def execute(self, **kwargs: Any) -> str:
from mcp import types
try:
result = await asyncio.wait_for(
self._session.read_resource(self._uri),
timeout=self._resource_timeout,
)
except asyncio.TimeoutError:
logger.warning(
"MCP resource '{}' timed out after {}s", self._name, self._resource_timeout
)
return f"(MCP resource read timed out after {self._resource_timeout}s)"
except asyncio.CancelledError:
task = asyncio.current_task()
if task is not None and task.cancelling() > 0:
raise
logger.warning("MCP resource '{}' was cancelled by server/SDK", self._name)
return "(MCP resource read was cancelled)"
except Exception as exc:
logger.exception(
"MCP resource '{}' failed: {}: {}",
self._name,
type(exc).__name__,
exc,
)
return f"(MCP resource read failed: {type(exc).__name__})"
parts: list[str] = []
for block in result.contents:
if isinstance(block, types.TextResourceContents):
parts.append(block.text)
elif isinstance(block, types.BlobResourceContents):
parts.append(f"[Binary resource: {len(block.blob)} bytes]")
else:
parts.append(str(block))
return "\n".join(parts) or "(no output)"
class MCPPromptWrapper(Tool):
"""Wraps an MCP prompt as a read-only nanobot Tool."""
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
self._session = session
self._prompt_name = prompt_def.name
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
desc = prompt_def.description or prompt_def.name
self._description = (
f"[MCP Prompt] {desc}\n"
"Returns a filled prompt template that can be used as a workflow guide."
)
self._prompt_timeout = prompt_timeout
# Build parameters from prompt arguments
properties: dict[str, Any] = {}
required: list[str] = []
for arg in prompt_def.arguments or []:
prop: dict[str, Any] = {"type": "string"}
if getattr(arg, "description", None):
prop["description"] = arg.description
properties[arg.name] = prop
if arg.required:
required.append(arg.name)
self._parameters: dict[str, Any] = {
"type": "object",
"properties": properties,
"required": required,
}
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
@property
def parameters(self) -> dict[str, Any]:
return self._parameters
@property
def read_only(self) -> bool:
return True
async def execute(self, **kwargs: Any) -> str:
from mcp import types
from mcp.shared.exceptions import McpError
try:
result = await asyncio.wait_for(
self._session.get_prompt(self._prompt_name, arguments=kwargs),
timeout=self._prompt_timeout,
)
except asyncio.TimeoutError:
logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout)
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
except asyncio.CancelledError:
task = asyncio.current_task()
if task is not None and task.cancelling() > 0:
raise
logger.warning("MCP prompt '{}' was cancelled by server/SDK", self._name)
return "(MCP prompt call was cancelled)"
except McpError as exc:
logger.error(
"MCP prompt '{}' failed: code={} message={}",
self._name,
exc.error.code,
exc.error.message,
)
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
except Exception as exc:
logger.exception(
"MCP prompt '{}' failed: {}: {}",
self._name,
type(exc).__name__,
exc,
)
return f"(MCP prompt call failed: {type(exc).__name__})"
parts: list[str] = []
for message in result.messages:
content = message.content
# content is a single ContentBlock (not a list) in MCP SDK >= 1.x
if isinstance(content, types.TextContent):
parts.append(content.text)
elif isinstance(content, list):
for block in content:
if isinstance(block, types.TextContent):
parts.append(block.text)
else:
parts.append(str(block))
else:
parts.append(str(content))
return "\n".join(parts) or "(no output)"
async def connect_mcp_servers(
mcp_servers: dict, registry: ToolRegistry
) -> dict[str, AsyncExitStack]:
"""Connect to configured MCP servers and register their tools, resources, prompts.
Returns a dict mapping server name -> its dedicated AsyncExitStack.
Each server gets its own stack and runs in its own task to prevent
cancel scope conflicts when multiple MCP servers are configured.
"""
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client
async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]:
server_stack = AsyncExitStack()
await server_stack.__aenter__()
try:
transport_type = cfg.type
if not transport_type:
if cfg.command:
transport_type = "stdio"
elif cfg.url:
transport_type = (
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
)
else:
logger.warning("MCP server '{}': no command or url configured, skipping", name)
await server_stack.aclose()
return name, None
if transport_type == "stdio":
params = StdioServerParameters(
command=cfg.command, args=cfg.args, env=cfg.env or None
)
read, write = await server_stack.enter_async_context(stdio_client(params))
elif transport_type == "sse":
def httpx_client_factory(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
merged_headers = {
"Accept": "application/json, text/event-stream",
**(cfg.headers or {}),
**(headers or {}),
}
return httpx.AsyncClient(
headers=merged_headers or None,
follow_redirects=True,
timeout=timeout,
auth=auth,
)
read, write = await server_stack.enter_async_context(
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
)
elif transport_type == "streamableHttp":
http_client = await server_stack.enter_async_context(
httpx.AsyncClient(
headers=cfg.headers or None,
follow_redirects=True,
timeout=None,
)
)
read, write, _ = await server_stack.enter_async_context(
streamable_http_client(cfg.url, http_client=http_client)
)
else:
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
await server_stack.aclose()
return name, None
session = await server_stack.enter_async_context(ClientSession(read, write))
await session.initialize()
tools = await session.list_tools()
enabled_tools = set(cfg.enabled_tools)
allow_all_tools = "*" in enabled_tools
registered_count = 0
matched_enabled_tools: set[str] = set()
available_raw_names = [tool_def.name for tool_def in tools.tools]
available_wrapped_names = [f"mcp_{name}_{tool_def.name}" for tool_def in tools.tools]
for tool_def in tools.tools:
wrapped_name = f"mcp_{name}_{tool_def.name}"
if (
not allow_all_tools
and tool_def.name not in enabled_tools
and wrapped_name not in enabled_tools
):
logger.debug(
"MCP: skipping tool '{}' from server '{}' (not in enabledTools)",
wrapped_name,
name,
)
continue
wrapper = MCPToolWrapper(session, name, tool_def, tool_timeout=cfg.tool_timeout)
registry.register(wrapper)
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
registered_count += 1
if enabled_tools:
if tool_def.name in enabled_tools:
matched_enabled_tools.add(tool_def.name)
if wrapped_name in enabled_tools:
matched_enabled_tools.add(wrapped_name)
if enabled_tools and not allow_all_tools:
unmatched_enabled_tools = sorted(enabled_tools - matched_enabled_tools)
if unmatched_enabled_tools:
logger.warning(
"MCP server '{}': enabledTools entries not found: {}. Available raw names: {}. "
"Available wrapped names: {}",
name,
", ".join(unmatched_enabled_tools),
", ".join(available_raw_names) or "(none)",
", ".join(available_wrapped_names) or "(none)",
)
try:
resources_result = await session.list_resources()
for resource in resources_result.resources:
wrapper = MCPResourceWrapper(
session, name, resource, resource_timeout=cfg.tool_timeout
)
registry.register(wrapper)
registered_count += 1
logger.debug(
"MCP: registered resource '{}' from server '{}'", wrapper.name, name
)
except Exception as e:
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
try:
prompts_result = await session.list_prompts()
for prompt in prompts_result.prompts:
wrapper = MCPPromptWrapper(
session, name, prompt, prompt_timeout=cfg.tool_timeout
)
registry.register(wrapper)
registered_count += 1
logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name)
except Exception as e:
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
logger.info(
"MCP server '{}': connected, {} capabilities registered", name, registered_count
)
return name, server_stack
except Exception as e:
logger.error("MCP server '{}': failed to connect: {}", name, e)
try:
await server_stack.aclose()
except Exception:
pass
return name, None
server_stacks: dict[str, AsyncExitStack] = {}
tasks: list[asyncio.Task] = []
for name, cfg in mcp_servers.items():
task = asyncio.create_task(connect_single_server(name, cfg))
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
name = list(mcp_servers.keys())[i]
if isinstance(result, BaseException):
if not isinstance(result, asyncio.CancelledError):
logger.error("MCP server '{}' connection task failed: {}", name, result)
elif result is not None and result[1] is not None:
server_stacks[result[0]] = result[1]
return server_stacks