mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-14 15:09:55 +00:00
fix(mcp): support multiple MCP servers by connecting each in isolated task
Each MCP server now connects in its own asyncio.Task to isolate anyio cancel scopes and prevent 'exit cancel scope in different task' errors when multiple servers (especially mixed transport types) are configured. Changes: - connect_mcp_servers() returns dict[str, AsyncExitStack] instead of None - Each server runs in separate task via asyncio.gather() - AgentLoop uses _mcp_stacks dict to track per-server stacks - Tests updated to handle new API
This commit is contained in:
parent
9bccfa63d2
commit
a167959027
@ -43,6 +43,7 @@ if TYPE_CHECKING:
|
||||
|
||||
UNIFIED_SESSION_KEY = "unified:default"
|
||||
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
"""Core hook for the main loop."""
|
||||
|
||||
@ -76,7 +77,7 @@ class _LoopHook(AgentHook):
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
incremental = new_clean[len(prev_clean) :]
|
||||
if incremental and self._on_stream:
|
||||
await self._on_stream(incremental)
|
||||
|
||||
@ -112,6 +113,7 @@ class _LoopHook(AgentHook):
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return self._loop._strip_think(content)
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
@ -196,7 +198,7 @@ class AgentLoop:
|
||||
self._unified_session = unified_session
|
||||
self._running = False
|
||||
self._mcp_servers = mcp_servers or {}
|
||||
self._mcp_stack: AsyncExitStack | None = None
|
||||
self._mcp_stacks: dict[str, AsyncExitStack] = {}
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
@ -228,24 +230,34 @@ class AgentLoop:
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||
allowed_dir = (
|
||||
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||
)
|
||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
|
||||
self.tools.register(
|
||||
ReadFileTool(
|
||||
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
|
||||
)
|
||||
)
|
||||
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
for cls in (GlobTool, GrepTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
if self.exec_config.enable:
|
||||
self.tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
sandbox=self.exec_config.sandbox,
|
||||
path_append=self.exec_config.path_append,
|
||||
allowed_env_keys=self.exec_config.allowed_env_keys,
|
||||
))
|
||||
self.tools.register(
|
||||
ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
sandbox=self.exec_config.sandbox,
|
||||
path_append=self.exec_config.path_append,
|
||||
allowed_env_keys=self.exec_config.allowed_env_keys,
|
||||
)
|
||||
)
|
||||
if self.web_config.enable:
|
||||
self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy))
|
||||
self.tools.register(
|
||||
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
|
||||
)
|
||||
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||
self.tools.register(SpawnTool(manager=self.subagents))
|
||||
@ -260,19 +272,16 @@ class AgentLoop:
|
||||
return
|
||||
self._mcp_connecting = True
|
||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||
|
||||
try:
|
||||
self._mcp_stack = AsyncExitStack()
|
||||
await self._mcp_stack.__aenter__()
|
||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||
self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools)
|
||||
self._mcp_connected = True
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("MCP connection cancelled (will retry next message)")
|
||||
self._mcp_stacks.clear()
|
||||
except BaseException as e:
|
||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._mcp_stack = None
|
||||
self._mcp_stacks.clear()
|
||||
finally:
|
||||
self._mcp_connecting = False
|
||||
|
||||
@ -289,6 +298,7 @@ class AgentLoop:
|
||||
if not text:
|
||||
return None
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
return strip_think(text) or None
|
||||
|
||||
@staticmethod
|
||||
@ -327,9 +337,7 @@ class AgentLoop:
|
||||
message_id=message_id,
|
||||
)
|
||||
hook: AgentHook = (
|
||||
CompositeHook([loop_hook] + self._extra_hooks)
|
||||
if self._extra_hooks
|
||||
else loop_hook
|
||||
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
|
||||
)
|
||||
|
||||
async def _checkpoint(payload: dict[str, Any]) -> None:
|
||||
@ -337,23 +345,25 @@ class AgentLoop:
|
||||
return
|
||||
self._set_runtime_checkpoint(session, payload)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
hook=hook,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
workspace=self.workspace,
|
||||
session_key=session.key if session else None,
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
context_block_limit=self.context_block_limit,
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
progress_callback=on_progress,
|
||||
checkpoint_callback=_checkpoint,
|
||||
))
|
||||
result = await self.runner.run(
|
||||
AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
hook=hook,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
workspace=self.workspace,
|
||||
session_key=session.key if session else None,
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
context_block_limit=self.context_block_limit,
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
progress_callback=on_progress,
|
||||
checkpoint_callback=_checkpoint,
|
||||
)
|
||||
)
|
||||
self._last_usage = result.usage
|
||||
if result.stop_reason == "max_iterations":
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
@ -391,10 +401,19 @@ class AgentLoop:
|
||||
continue
|
||||
# Compute the effective session key before dispatching
|
||||
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||
effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key
|
||||
effective_key = (
|
||||
UNIFIED_SESSION_KEY
|
||||
if self._unified_session and not msg.session_key_override
|
||||
else msg.session_key
|
||||
)
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
self._active_tasks.setdefault(effective_key, []).append(task)
|
||||
task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||
task.add_done_callback(
|
||||
lambda t, k=effective_key: self._active_tasks.get(k, [])
|
||||
and self._active_tasks[k].remove(t)
|
||||
if t in self._active_tasks.get(k, [])
|
||||
else None
|
||||
)
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message: per-session serial, cross-session concurrent."""
|
||||
@ -417,11 +436,14 @@ class AgentLoop:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata=meta,
|
||||
))
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata=meta,
|
||||
)
|
||||
)
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
@ -429,44 +451,56 @@ class AgentLoop:
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=meta,
|
||||
))
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=meta,
|
||||
)
|
||||
)
|
||||
stream_segment += 1
|
||||
|
||||
response = await self._process_message(
|
||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
msg,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="", metadata=msg.metadata or {},
|
||||
))
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=msg.metadata or {},
|
||||
)
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", msg.session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", msg.session_key)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="Sorry, I encountered an error.",
|
||||
)
|
||||
)
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Drain pending background archives, then close MCP connections."""
|
||||
if self._background_tasks:
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
if self._mcp_stack:
|
||||
for name, stack in self._mcp_stacks.items():
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
await stack.aclose()
|
||||
except (RuntimeError, BaseExceptionGroup):
|
||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
||||
self._mcp_stack = None
|
||||
logger.debug("MCP server '{}' cleanup error (can be ignored)", name)
|
||||
self._mcp_stacks.clear()
|
||||
|
||||
def _schedule_background(self, coro) -> None:
|
||||
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
||||
@ -490,8 +524,9 @@ class AgentLoop:
|
||||
"""Process a single inbound message and return the response."""
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
if msg.channel == "system":
|
||||
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
||||
else ("cli", msg.chat_id))
|
||||
channel, chat_id = (
|
||||
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
|
||||
)
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
@ -503,19 +538,27 @@ class AgentLoop:
|
||||
current_role = "assistant" if msg.sender_id == "subagent" else "user"
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||
current_message=msg.content,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
current_role=current_role,
|
||||
)
|
||||
final_content, _, all_msgs, _ = await self._run_agent_loop(
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
messages,
|
||||
session=session,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
return OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=final_content or "Background task completed.",
|
||||
)
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
@ -543,16 +586,22 @@ class AgentLoop:
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_progress"] = True
|
||||
meta["_tool_hint"] = tool_hint
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||
))
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=content,
|
||||
metadata=meta,
|
||||
)
|
||||
)
|
||||
|
||||
final_content, _, all_msgs, stop_reason = await self._run_agent_loop(
|
||||
initial_messages,
|
||||
@ -560,7 +609,8 @@ class AgentLoop:
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
session=session,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
|
||||
@ -582,7 +632,9 @@ class AgentLoop:
|
||||
if on_stream is not None and stop_reason != "error":
|
||||
meta["_streamed"] = True
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=final_content,
|
||||
metadata=meta,
|
||||
)
|
||||
|
||||
@ -608,10 +660,9 @@ class AgentLoop:
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
block.get("type") == "image_url"
|
||||
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
||||
):
|
||||
if block.get("type") == "image_url" and block.get("image_url", {}).get(
|
||||
"url", ""
|
||||
).startswith("data:image/"):
|
||||
path = (block.get("_meta") or {}).get("path", "")
|
||||
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||
continue
|
||||
@ -630,6 +681,7 @@ class AgentLoop:
|
||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||
"""Save new-turn messages into session, truncating large tool results."""
|
||||
from datetime import datetime
|
||||
|
||||
for m in messages[skip:]:
|
||||
entry = dict(m)
|
||||
role, content = entry.get("role"), entry.get("content")
|
||||
@ -644,7 +696,9 @@ class AgentLoop:
|
||||
continue
|
||||
entry["content"] = filtered
|
||||
elif role == "user":
|
||||
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
||||
if isinstance(content, str) and content.startswith(
|
||||
ContextBuilder._RUNTIME_CONTEXT_TAG
|
||||
):
|
||||
# Strip the runtime-context prefix, keep only the user text.
|
||||
parts = content.split("\n\n", 1)
|
||||
if len(parts) > 1 and parts[1].strip():
|
||||
@ -708,13 +762,15 @@ class AgentLoop:
|
||||
continue
|
||||
tool_id = tool_call.get("id")
|
||||
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||
restored_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"name": name,
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
restored_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"name": name,
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
overlap = 0
|
||||
max_overlap = min(len(session.messages), len(restored_messages))
|
||||
@ -746,6 +802,9 @@ class AgentLoop:
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||
return await self._process_message(
|
||||
msg, session_key=session_key, on_progress=on_progress,
|
||||
on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
msg,
|
||||
session_key=session_key,
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
|
||||
|
||||
if "properties" in normalized and isinstance(normalized["properties"], dict):
|
||||
normalized["properties"] = {
|
||||
name: _normalize_schema_for_openai(prop)
|
||||
if isinstance(prop, dict)
|
||||
else prop
|
||||
name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop
|
||||
for name, prop in normalized["properties"].items()
|
||||
}
|
||||
|
||||
@ -138,9 +136,7 @@ class MCPToolWrapper(Tool):
|
||||
class MCPResourceWrapper(Tool):
|
||||
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
||||
|
||||
def __init__(
|
||||
self, session, server_name: str, resource_def, resource_timeout: int = 30
|
||||
):
|
||||
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
|
||||
self._session = session
|
||||
self._uri = resource_def.uri
|
||||
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
|
||||
@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool):
|
||||
class MCPPromptWrapper(Tool):
|
||||
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
||||
|
||||
def __init__(
|
||||
self, session, server_name: str, prompt_def, prompt_timeout: int = 30
|
||||
):
|
||||
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
|
||||
self._session = session
|
||||
self._prompt_name = prompt_def.name
|
||||
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
|
||||
@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool):
|
||||
timeout=self._prompt_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout
|
||||
)
|
||||
logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout)
|
||||
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
|
||||
except asyncio.CancelledError:
|
||||
task = asyncio.current_task()
|
||||
@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool):
|
||||
except McpError as exc:
|
||||
logger.error(
|
||||
"MCP prompt '{}' failed: code={} message={}",
|
||||
self._name, exc.error.code, exc.error.message,
|
||||
self._name,
|
||||
exc.error.code,
|
||||
exc.error.message,
|
||||
)
|
||||
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"MCP prompt '{}' failed: {}: {}",
|
||||
self._name, type(exc).__name__, exc,
|
||||
self._name,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
return f"(MCP prompt call failed: {type(exc).__name__})"
|
||||
|
||||
@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool):
|
||||
|
||||
|
||||
async def connect_mcp_servers(
|
||||
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
|
||||
) -> None:
|
||||
"""Connect to configured MCP servers and register their tools, resources, and prompts."""
|
||||
mcp_servers: dict, registry: ToolRegistry
|
||||
) -> dict[str, AsyncExitStack]:
|
||||
"""Connect to configured MCP servers and register their tools, resources, prompts.
|
||||
|
||||
Returns a dict mapping server name -> its dedicated AsyncExitStack.
|
||||
Each server gets its own stack and runs in its own task to prevent
|
||||
cancel scope conflicts when multiple MCP servers are configured.
|
||||
"""
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
|
||||
for name, cfg in mcp_servers.items():
|
||||
async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]:
|
||||
server_stack = AsyncExitStack()
|
||||
await server_stack.__aenter__()
|
||||
|
||||
try:
|
||||
transport_type = cfg.type
|
||||
if not transport_type:
|
||||
if cfg.command:
|
||||
transport_type = "stdio"
|
||||
elif cfg.url:
|
||||
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
|
||||
transport_type = (
|
||||
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
||||
)
|
||||
else:
|
||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||
continue
|
||||
await server_stack.aclose()
|
||||
return name, None
|
||||
|
||||
if transport_type == "stdio":
|
||||
params = StdioServerParameters(
|
||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||
)
|
||||
read, write = await stack.enter_async_context(stdio_client(params))
|
||||
read, write = await server_stack.enter_async_context(stdio_client(params))
|
||||
elif transport_type == "sse":
|
||||
|
||||
def httpx_client_factory(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
@ -353,27 +358,26 @@ async def connect_mcp_servers(
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
read, write = await stack.enter_async_context(
|
||||
read, write = await server_stack.enter_async_context(
|
||||
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
||||
)
|
||||
elif transport_type == "streamableHttp":
|
||||
# Always provide an explicit httpx client so MCP HTTP transport does not
|
||||
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
||||
http_client = await stack.enter_async_context(
|
||||
http_client = await server_stack.enter_async_context(
|
||||
httpx.AsyncClient(
|
||||
headers=cfg.headers or None,
|
||||
follow_redirects=True,
|
||||
timeout=None,
|
||||
)
|
||||
)
|
||||
read, write, _ = await stack.enter_async_context(
|
||||
read, write, _ = await server_stack.enter_async_context(
|
||||
streamable_http_client(cfg.url, http_client=http_client)
|
||||
)
|
||||
else:
|
||||
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
||||
continue
|
||||
await server_stack.aclose()
|
||||
return name, None
|
||||
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
session = await server_stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
|
||||
tools = await session.list_tools()
|
||||
@ -418,7 +422,6 @@ async def connect_mcp_servers(
|
||||
", ".join(available_wrapped_names) or "(none)",
|
||||
)
|
||||
|
||||
# --- Register resources ---
|
||||
try:
|
||||
resources_result = await session.list_resources()
|
||||
for resource in resources_result.resources:
|
||||
@ -433,7 +436,6 @@ async def connect_mcp_servers(
|
||||
except Exception as e:
|
||||
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
|
||||
|
||||
# --- Register prompts ---
|
||||
try:
|
||||
prompts_result = await session.list_prompts()
|
||||
for prompt in prompts_result.prompts:
|
||||
@ -442,14 +444,38 @@ async def connect_mcp_servers(
|
||||
)
|
||||
registry.register(wrapper)
|
||||
registered_count += 1
|
||||
logger.debug(
|
||||
"MCP: registered prompt '{}' from server '{}'", wrapper.name, name
|
||||
)
|
||||
logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name)
|
||||
except Exception as e:
|
||||
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
|
||||
|
||||
logger.info(
|
||||
"MCP server '{}': connected, {} capabilities registered", name, registered_count
|
||||
)
|
||||
return name, server_stack
|
||||
|
||||
except Exception as e:
|
||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||
try:
|
||||
await server_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
return name, None
|
||||
|
||||
server_stacks: dict[str, AsyncExitStack] = {}
|
||||
|
||||
tasks: list[asyncio.Task] = []
|
||||
for name, cfg in mcp_servers.items():
|
||||
task = asyncio.create_task(connect_single_server(name, cfg))
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
name = list(mcp_servers.keys())[i]
|
||||
if isinstance(result, BaseException):
|
||||
if not isinstance(result, asyncio.CancelledError):
|
||||
logger.error("MCP server '{}' connection task failed: {}", name, result)
|
||||
elif result is not None and result[1] is not None:
|
||||
server_stacks[result[0]] = result[1]
|
||||
|
||||
return server_stacks
|
||||
|
||||
@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo"]
|
||||
@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
||||
@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo"]
|
||||
@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == []
|
||||
@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
||||
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == []
|
||||
@ -389,9 +369,7 @@ def _make_resource_def(
|
||||
return SimpleNamespace(name=name, uri=uri, description=description)
|
||||
|
||||
|
||||
def _make_resource_wrapper(
|
||||
session: object, *, timeout: float = 0.1
|
||||
) -> MCPResourceWrapper:
|
||||
def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper:
|
||||
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
|
||||
|
||||
|
||||
@ -434,9 +412,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None:
|
||||
await asyncio.sleep(1)
|
||||
return SimpleNamespace(contents=[])
|
||||
|
||||
wrapper = _make_resource_wrapper(
|
||||
SimpleNamespace(read_resource=read_resource), timeout=0.01
|
||||
)
|
||||
wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01)
|
||||
result = await wrapper.execute()
|
||||
assert result == "(MCP resource read timed out after 0.01s)"
|
||||
|
||||
@ -464,20 +440,14 @@ def _make_prompt_def(
|
||||
return SimpleNamespace(name=name, description=description, arguments=arguments)
|
||||
|
||||
|
||||
def _make_prompt_wrapper(
|
||||
session: object, *, timeout: float = 0.1
|
||||
) -> MCPPromptWrapper:
|
||||
return MCPPromptWrapper(
|
||||
session, "srv", _make_prompt_def(), prompt_timeout=timeout
|
||||
)
|
||||
def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper:
|
||||
return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout)
|
||||
|
||||
|
||||
def test_prompt_wrapper_properties() -> None:
|
||||
arg1 = SimpleNamespace(name="topic", required=True)
|
||||
arg2 = SimpleNamespace(name="style", required=False)
|
||||
wrapper = MCPPromptWrapper(
|
||||
None, "myserver", _make_prompt_def(arguments=[arg1, arg2])
|
||||
)
|
||||
wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2]))
|
||||
assert wrapper.name == "mcp_myserver_prompt_myprompt"
|
||||
assert "[MCP Prompt]" in wrapper.description
|
||||
assert "A test prompt" in wrapper.description
|
||||
@ -528,9 +498,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None:
|
||||
await asyncio.sleep(1)
|
||||
return SimpleNamespace(messages=[])
|
||||
|
||||
wrapper = _make_prompt_wrapper(
|
||||
SimpleNamespace(get_prompt=get_prompt), timeout=0.01
|
||||
)
|
||||
wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01)
|
||||
result = await wrapper.execute()
|
||||
assert result == "(MCP prompt call timed out after 0.01s)"
|
||||
|
||||
@ -616,15 +584,11 @@ async def test_connect_registers_resources_and_prompts(
|
||||
prompt_names=["prompt_c"],
|
||||
)
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert "mcp_test_tool_a" in registry.tool_names
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user