mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-14 23:19:55 +00:00
Merge PR #3019: fix(mcp): support multiple MCP servers
fix(mcp): support multiple MCP servers
This commit is contained in:
commit
e7bbbe98f4
@ -44,6 +44,7 @@ if TYPE_CHECKING:
|
||||
|
||||
UNIFIED_SESSION_KEY = "unified:default"
|
||||
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
"""Core hook for the main loop."""
|
||||
|
||||
@ -77,7 +78,7 @@ class _LoopHook(AgentHook):
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
incremental = new_clean[len(prev_clean) :]
|
||||
if incremental and self._on_stream:
|
||||
await self._on_stream(incremental)
|
||||
|
||||
@ -113,6 +114,7 @@ class _LoopHook(AgentHook):
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return self._loop._strip_think(content)
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
@ -198,7 +200,7 @@ class AgentLoop:
|
||||
self._unified_session = unified_session
|
||||
self._running = False
|
||||
self._mcp_servers = mcp_servers or {}
|
||||
self._mcp_stack: AsyncExitStack | None = None
|
||||
self._mcp_stacks: dict[str, AsyncExitStack] = {}
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
@ -235,24 +237,34 @@ class AgentLoop:
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||
allowed_dir = (
|
||||
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||
)
|
||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
|
||||
self.tools.register(
|
||||
ReadFileTool(
|
||||
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
|
||||
)
|
||||
)
|
||||
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
for cls in (GlobTool, GrepTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
if self.exec_config.enable:
|
||||
self.tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
sandbox=self.exec_config.sandbox,
|
||||
path_append=self.exec_config.path_append,
|
||||
allowed_env_keys=self.exec_config.allowed_env_keys,
|
||||
))
|
||||
self.tools.register(
|
||||
ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
sandbox=self.exec_config.sandbox,
|
||||
path_append=self.exec_config.path_append,
|
||||
allowed_env_keys=self.exec_config.allowed_env_keys,
|
||||
)
|
||||
)
|
||||
if self.web_config.enable:
|
||||
self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy))
|
||||
self.tools.register(
|
||||
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
|
||||
)
|
||||
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||
self.tools.register(SpawnTool(manager=self.subagents))
|
||||
@ -267,19 +279,19 @@ class AgentLoop:
|
||||
return
|
||||
self._mcp_connecting = True
|
||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||
|
||||
try:
|
||||
self._mcp_stack = AsyncExitStack()
|
||||
await self._mcp_stack.__aenter__()
|
||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||
self._mcp_connected = True
|
||||
self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools)
|
||||
if self._mcp_stacks:
|
||||
self._mcp_connected = True
|
||||
else:
|
||||
logger.warning("No MCP servers connected successfully (will retry next message)")
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("MCP connection cancelled (will retry next message)")
|
||||
self._mcp_stacks.clear()
|
||||
except BaseException as e:
|
||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._mcp_stack = None
|
||||
self._mcp_stacks.clear()
|
||||
finally:
|
||||
self._mcp_connecting = False
|
||||
|
||||
@ -296,6 +308,7 @@ class AgentLoop:
|
||||
if not text:
|
||||
return None
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
return strip_think(text) or None
|
||||
|
||||
@staticmethod
|
||||
@ -334,9 +347,7 @@ class AgentLoop:
|
||||
message_id=message_id,
|
||||
)
|
||||
hook: AgentHook = (
|
||||
CompositeHook([loop_hook] + self._extra_hooks)
|
||||
if self._extra_hooks
|
||||
else loop_hook
|
||||
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
|
||||
)
|
||||
|
||||
async def _checkpoint(payload: dict[str, Any]) -> None:
|
||||
@ -344,23 +355,25 @@ class AgentLoop:
|
||||
return
|
||||
self._set_runtime_checkpoint(session, payload)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
hook=hook,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
workspace=self.workspace,
|
||||
session_key=session.key if session else None,
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
context_block_limit=self.context_block_limit,
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
progress_callback=on_progress,
|
||||
checkpoint_callback=_checkpoint,
|
||||
))
|
||||
result = await self.runner.run(
|
||||
AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
hook=hook,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
workspace=self.workspace,
|
||||
session_key=session.key if session else None,
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
context_block_limit=self.context_block_limit,
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
progress_callback=on_progress,
|
||||
checkpoint_callback=_checkpoint,
|
||||
)
|
||||
)
|
||||
self._last_usage = result.usage
|
||||
if result.stop_reason == "max_iterations":
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
@ -399,10 +412,19 @@ class AgentLoop:
|
||||
continue
|
||||
# Compute the effective session key before dispatching
|
||||
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||
effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key
|
||||
effective_key = (
|
||||
UNIFIED_SESSION_KEY
|
||||
if self._unified_session and not msg.session_key_override
|
||||
else msg.session_key
|
||||
)
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
self._active_tasks.setdefault(effective_key, []).append(task)
|
||||
task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||
task.add_done_callback(
|
||||
lambda t, k=effective_key: self._active_tasks.get(k, [])
|
||||
and self._active_tasks[k].remove(t)
|
||||
if t in self._active_tasks.get(k, [])
|
||||
else None
|
||||
)
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message: per-session serial, cross-session concurrent."""
|
||||
@ -425,11 +447,14 @@ class AgentLoop:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata=meta,
|
||||
))
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata=meta,
|
||||
)
|
||||
)
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
@ -437,44 +462,56 @@ class AgentLoop:
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=meta,
|
||||
))
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=meta,
|
||||
)
|
||||
)
|
||||
stream_segment += 1
|
||||
|
||||
response = await self._process_message(
|
||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
msg,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="", metadata=msg.metadata or {},
|
||||
))
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=msg.metadata or {},
|
||||
)
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", msg.session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", msg.session_key)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="Sorry, I encountered an error.",
|
||||
)
|
||||
)
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Drain pending background archives, then close MCP connections."""
|
||||
if self._background_tasks:
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
if self._mcp_stack:
|
||||
for name, stack in self._mcp_stacks.items():
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
await stack.aclose()
|
||||
except (RuntimeError, BaseExceptionGroup):
|
||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
||||
self._mcp_stack = None
|
||||
logger.debug("MCP server '{}' cleanup error (can be ignored)", name)
|
||||
self._mcp_stacks.clear()
|
||||
|
||||
def _schedule_background(self, coro) -> None:
|
||||
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
||||
@ -498,8 +535,9 @@ class AgentLoop:
|
||||
"""Process a single inbound message and return the response."""
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
if msg.channel == "system":
|
||||
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
||||
else ("cli", msg.chat_id))
|
||||
channel, chat_id = (
|
||||
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
|
||||
)
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
@ -520,15 +558,21 @@ class AgentLoop:
|
||||
current_role=current_role,
|
||||
)
|
||||
final_content, _, all_msgs, _ = await self._run_agent_loop(
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
messages,
|
||||
session=session,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
return OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=final_content or "Background task completed.",
|
||||
)
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
@ -560,16 +604,22 @@ class AgentLoop:
|
||||
current_message=msg.content,
|
||||
session_summary=pending,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_progress"] = True
|
||||
meta["_tool_hint"] = tool_hint
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||
))
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=content,
|
||||
metadata=meta,
|
||||
)
|
||||
)
|
||||
|
||||
final_content, _, all_msgs, stop_reason = await self._run_agent_loop(
|
||||
initial_messages,
|
||||
@ -577,7 +627,8 @@ class AgentLoop:
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
session=session,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
|
||||
@ -599,7 +650,9 @@ class AgentLoop:
|
||||
if on_stream is not None and stop_reason != "error":
|
||||
meta["_streamed"] = True
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=final_content,
|
||||
metadata=meta,
|
||||
)
|
||||
|
||||
@ -625,10 +678,9 @@ class AgentLoop:
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
block.get("type") == "image_url"
|
||||
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
||||
):
|
||||
if block.get("type") == "image_url" and block.get("image_url", {}).get(
|
||||
"url", ""
|
||||
).startswith("data:image/"):
|
||||
path = (block.get("_meta") or {}).get("path", "")
|
||||
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||
continue
|
||||
@ -647,6 +699,7 @@ class AgentLoop:
|
||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||
"""Save new-turn messages into session, truncating large tool results."""
|
||||
from datetime import datetime
|
||||
|
||||
for m in messages[skip:]:
|
||||
entry = dict(m)
|
||||
role, content = entry.get("role"), entry.get("content")
|
||||
@ -736,13 +789,15 @@ class AgentLoop:
|
||||
continue
|
||||
tool_id = tool_call.get("id")
|
||||
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||
restored_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"name": name,
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
restored_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"name": name,
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
overlap = 0
|
||||
max_overlap = min(len(session.messages), len(restored_messages))
|
||||
@ -774,6 +829,9 @@ class AgentLoop:
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||
return await self._process_message(
|
||||
msg, session_key=session_key, on_progress=on_progress,
|
||||
on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
msg,
|
||||
session_key=session_key,
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
|
||||
|
||||
if "properties" in normalized and isinstance(normalized["properties"], dict):
|
||||
normalized["properties"] = {
|
||||
name: _normalize_schema_for_openai(prop)
|
||||
if isinstance(prop, dict)
|
||||
else prop
|
||||
name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop
|
||||
for name, prop in normalized["properties"].items()
|
||||
}
|
||||
|
||||
@ -138,9 +136,7 @@ class MCPToolWrapper(Tool):
|
||||
class MCPResourceWrapper(Tool):
|
||||
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
||||
|
||||
def __init__(
|
||||
self, session, server_name: str, resource_def, resource_timeout: int = 30
|
||||
):
|
||||
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
|
||||
self._session = session
|
||||
self._uri = resource_def.uri
|
||||
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
|
||||
@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool):
|
||||
class MCPPromptWrapper(Tool):
|
||||
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
||||
|
||||
def __init__(
|
||||
self, session, server_name: str, prompt_def, prompt_timeout: int = 30
|
||||
):
|
||||
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
|
||||
self._session = session
|
||||
self._prompt_name = prompt_def.name
|
||||
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
|
||||
@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool):
|
||||
timeout=self._prompt_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout
|
||||
)
|
||||
logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout)
|
||||
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
|
||||
except asyncio.CancelledError:
|
||||
task = asyncio.current_task()
|
||||
@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool):
|
||||
except McpError as exc:
|
||||
logger.error(
|
||||
"MCP prompt '{}' failed: code={} message={}",
|
||||
self._name, exc.error.code, exc.error.message,
|
||||
self._name,
|
||||
exc.error.code,
|
||||
exc.error.message,
|
||||
)
|
||||
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"MCP prompt '{}' failed: {}: {}",
|
||||
self._name, type(exc).__name__, exc,
|
||||
self._name,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
return f"(MCP prompt call failed: {type(exc).__name__})"
|
||||
|
||||
@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool):
|
||||
|
||||
|
||||
async def connect_mcp_servers(
|
||||
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
|
||||
) -> None:
|
||||
"""Connect to configured MCP servers and register their tools, resources, and prompts."""
|
||||
mcp_servers: dict, registry: ToolRegistry
|
||||
) -> dict[str, AsyncExitStack]:
|
||||
"""Connect to configured MCP servers and register their tools, resources, prompts.
|
||||
|
||||
Returns a dict mapping server name -> its dedicated AsyncExitStack.
|
||||
Each server gets its own stack and runs in its own task to prevent
|
||||
cancel scope conflicts when multiple MCP servers are configured.
|
||||
"""
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
|
||||
for name, cfg in mcp_servers.items():
|
||||
async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]:
|
||||
server_stack = AsyncExitStack()
|
||||
await server_stack.__aenter__()
|
||||
|
||||
try:
|
||||
transport_type = cfg.type
|
||||
if not transport_type:
|
||||
if cfg.command:
|
||||
transport_type = "stdio"
|
||||
elif cfg.url:
|
||||
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
|
||||
transport_type = (
|
||||
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
||||
)
|
||||
else:
|
||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||
continue
|
||||
await server_stack.aclose()
|
||||
return name, None
|
||||
|
||||
if transport_type == "stdio":
|
||||
params = StdioServerParameters(
|
||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||
)
|
||||
read, write = await stack.enter_async_context(stdio_client(params))
|
||||
read, write = await server_stack.enter_async_context(stdio_client(params))
|
||||
elif transport_type == "sse":
|
||||
|
||||
def httpx_client_factory(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
@ -353,27 +358,26 @@ async def connect_mcp_servers(
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
read, write = await stack.enter_async_context(
|
||||
read, write = await server_stack.enter_async_context(
|
||||
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
||||
)
|
||||
elif transport_type == "streamableHttp":
|
||||
# Always provide an explicit httpx client so MCP HTTP transport does not
|
||||
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
||||
http_client = await stack.enter_async_context(
|
||||
http_client = await server_stack.enter_async_context(
|
||||
httpx.AsyncClient(
|
||||
headers=cfg.headers or None,
|
||||
follow_redirects=True,
|
||||
timeout=None,
|
||||
)
|
||||
)
|
||||
read, write, _ = await stack.enter_async_context(
|
||||
read, write, _ = await server_stack.enter_async_context(
|
||||
streamable_http_client(cfg.url, http_client=http_client)
|
||||
)
|
||||
else:
|
||||
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
||||
continue
|
||||
await server_stack.aclose()
|
||||
return name, None
|
||||
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
session = await server_stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
|
||||
tools = await session.list_tools()
|
||||
@ -418,7 +422,6 @@ async def connect_mcp_servers(
|
||||
", ".join(available_wrapped_names) or "(none)",
|
||||
)
|
||||
|
||||
# --- Register resources ---
|
||||
try:
|
||||
resources_result = await session.list_resources()
|
||||
for resource in resources_result.resources:
|
||||
@ -433,7 +436,6 @@ async def connect_mcp_servers(
|
||||
except Exception as e:
|
||||
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
|
||||
|
||||
# --- Register prompts ---
|
||||
try:
|
||||
prompts_result = await session.list_prompts()
|
||||
for prompt in prompts_result.prompts:
|
||||
@ -442,14 +444,38 @@ async def connect_mcp_servers(
|
||||
)
|
||||
registry.register(wrapper)
|
||||
registered_count += 1
|
||||
logger.debug(
|
||||
"MCP: registered prompt '{}' from server '{}'", wrapper.name, name
|
||||
)
|
||||
logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name)
|
||||
except Exception as e:
|
||||
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
|
||||
|
||||
logger.info(
|
||||
"MCP server '{}': connected, {} capabilities registered", name, registered_count
|
||||
)
|
||||
return name, server_stack
|
||||
|
||||
except Exception as e:
|
||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||
try:
|
||||
await server_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
return name, None
|
||||
|
||||
server_stacks: dict[str, AsyncExitStack] = {}
|
||||
|
||||
tasks: list[asyncio.Task] = []
|
||||
for name, cfg in mcp_servers.items():
|
||||
task = asyncio.create_task(connect_single_server(name, cfg))
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
name = list(mcp_servers.keys())[i]
|
||||
if isinstance(result, BaseException):
|
||||
if not isinstance(result, asyncio.CancelledError):
|
||||
logger.error("MCP server '{}' connection task failed: {}", name, result)
|
||||
elif result is not None and result[1] is not None:
|
||||
server_stacks[result[0]] = result[1]
|
||||
|
||||
return server_stacks
|
||||
|
||||
44
tests/agent/test_mcp_connection.py
Normal file
44
tests/agent/test_mcp_connection.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Tests for MCP connection lifecycle in AgentLoop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
|
||||
def _make_loop(tmp_path, *, mcp_servers: dict | None = None) -> AgentLoop:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation.max_tokens = 4096
|
||||
return AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
mcp_servers=mcp_servers or {"test": object()},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_retries_when_no_servers_connect(tmp_path, monkeypatch: pytest.MonkeyPatch):
|
||||
loop = _make_loop(tmp_path)
|
||||
attempts = 0
|
||||
|
||||
async def _fake_connect(_servers, _registry):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
||||
|
||||
await loop._connect_mcp()
|
||||
await loop._connect_mcp()
|
||||
|
||||
assert attempts == 2
|
||||
assert loop._mcp_connected is False
|
||||
assert loop._mcp_stacks == {}
|
||||
@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo"]
|
||||
@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
||||
@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo"]
|
||||
@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == []
|
||||
@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
||||
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == []
|
||||
@ -376,6 +356,46 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
||||
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_one_failure_does_not_block_others(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
sessions = {"good": _make_fake_session(["demo"])}
|
||||
|
||||
class _SelectiveClientSession:
|
||||
def __init__(self, read: object, _write: object) -> None:
|
||||
self._session = sessions[read]
|
||||
|
||||
async def __aenter__(self) -> object:
|
||||
return self._session
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
@asynccontextmanager
|
||||
async def _selective_stdio_client(params: object):
|
||||
if params.command == "bad":
|
||||
raise RuntimeError("boom")
|
||||
yield params.command, object()
|
||||
|
||||
monkeypatch.setattr(sys.modules["mcp"], "ClientSession", _SelectiveClientSession)
|
||||
monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _selective_stdio_client)
|
||||
|
||||
registry = ToolRegistry()
|
||||
stacks = await connect_mcp_servers(
|
||||
{
|
||||
"good": MCPServerConfig(command="good"),
|
||||
"bad": MCPServerConfig(command="bad"),
|
||||
},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_good_demo"]
|
||||
assert set(stacks) == {"good"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCPResourceWrapper tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -389,9 +409,7 @@ def _make_resource_def(
|
||||
return SimpleNamespace(name=name, uri=uri, description=description)
|
||||
|
||||
|
||||
def _make_resource_wrapper(
|
||||
session: object, *, timeout: float = 0.1
|
||||
) -> MCPResourceWrapper:
|
||||
def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper:
|
||||
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
|
||||
|
||||
|
||||
@ -434,9 +452,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None:
|
||||
await asyncio.sleep(1)
|
||||
return SimpleNamespace(contents=[])
|
||||
|
||||
wrapper = _make_resource_wrapper(
|
||||
SimpleNamespace(read_resource=read_resource), timeout=0.01
|
||||
)
|
||||
wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01)
|
||||
result = await wrapper.execute()
|
||||
assert result == "(MCP resource read timed out after 0.01s)"
|
||||
|
||||
@ -464,20 +480,14 @@ def _make_prompt_def(
|
||||
return SimpleNamespace(name=name, description=description, arguments=arguments)
|
||||
|
||||
|
||||
def _make_prompt_wrapper(
|
||||
session: object, *, timeout: float = 0.1
|
||||
) -> MCPPromptWrapper:
|
||||
return MCPPromptWrapper(
|
||||
session, "srv", _make_prompt_def(), prompt_timeout=timeout
|
||||
)
|
||||
def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper:
|
||||
return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout)
|
||||
|
||||
|
||||
def test_prompt_wrapper_properties() -> None:
|
||||
arg1 = SimpleNamespace(name="topic", required=True)
|
||||
arg2 = SimpleNamespace(name="style", required=False)
|
||||
wrapper = MCPPromptWrapper(
|
||||
None, "myserver", _make_prompt_def(arguments=[arg1, arg2])
|
||||
)
|
||||
wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2]))
|
||||
assert wrapper.name == "mcp_myserver_prompt_myprompt"
|
||||
assert "[MCP Prompt]" in wrapper.description
|
||||
assert "A test prompt" in wrapper.description
|
||||
@ -528,9 +538,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None:
|
||||
await asyncio.sleep(1)
|
||||
return SimpleNamespace(messages=[])
|
||||
|
||||
wrapper = _make_prompt_wrapper(
|
||||
SimpleNamespace(get_prompt=get_prompt), timeout=0.01
|
||||
)
|
||||
wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01)
|
||||
result = await wrapper.execute()
|
||||
assert result == "(MCP prompt call timed out after 0.01s)"
|
||||
|
||||
@ -616,15 +624,11 @@ async def test_connect_registers_resources_and_prompts(
|
||||
prompt_names=["prompt_c"],
|
||||
)
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert "mcp_test_tool_a" in registry.tool_names
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user