fix(mcp): support multiple MCP servers by connecting each in isolated task

Each MCP server now connects in its own asyncio.Task to isolate anyio
cancel scopes and prevent 'exit cancel scope in different task' errors
when multiple servers (especially mixed transport types) are configured.

Changes:
- connect_mcp_servers() returns dict[str, AsyncExitStack] instead of None
- Each server runs in separate task via asyncio.gather()
- AgentLoop uses _mcp_stacks dict to track per-server stacks
- Tests updated to handle new API
This commit is contained in:
worenidewen 2026-04-10 23:51:50 +08:00
parent 9bccfa63d2
commit a167959027
3 changed files with 247 additions and 198 deletions

View File

@ -43,6 +43,7 @@ if TYPE_CHECKING:
UNIFIED_SESSION_KEY = "unified:default"
class _LoopHook(AgentHook):
"""Core hook for the main loop."""
@ -76,7 +77,7 @@ class _LoopHook(AgentHook):
prev_clean = strip_think(self._stream_buf)
self._stream_buf += delta
new_clean = strip_think(self._stream_buf)
incremental = new_clean[len(prev_clean):]
incremental = new_clean[len(prev_clean) :]
if incremental and self._on_stream:
await self._on_stream(incremental)
@ -112,6 +113,7 @@ class _LoopHook(AgentHook):
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return self._loop._strip_think(content)
class AgentLoop:
"""
The agent loop is the core processing engine.
@ -196,7 +198,7 @@ class AgentLoop:
self._unified_session = unified_session
self._running = False
self._mcp_servers = mcp_servers or {}
self._mcp_stack: AsyncExitStack | None = None
self._mcp_stacks: dict[str, AsyncExitStack] = {}
self._mcp_connected = False
self._mcp_connecting = False
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
@ -228,24 +230,34 @@ class AgentLoop:
def _register_default_tools(self) -> None:
"""Register the default set of tools."""
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
allowed_dir = (
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
)
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
self.tools.register(
ReadFileTool(
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
)
)
for cls in (WriteFileTool, EditFileTool, ListDirTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
for cls in (GlobTool, GrepTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
if self.exec_config.enable:
self.tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
sandbox=self.exec_config.sandbox,
path_append=self.exec_config.path_append,
allowed_env_keys=self.exec_config.allowed_env_keys,
))
self.tools.register(
ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
sandbox=self.exec_config.sandbox,
path_append=self.exec_config.path_append,
allowed_env_keys=self.exec_config.allowed_env_keys,
)
)
if self.web_config.enable:
self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy))
self.tools.register(
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
)
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
self.tools.register(SpawnTool(manager=self.subagents))
@ -260,19 +272,16 @@ class AgentLoop:
return
self._mcp_connecting = True
from nanobot.agent.tools.mcp import connect_mcp_servers
try:
self._mcp_stack = AsyncExitStack()
await self._mcp_stack.__aenter__()
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools)
self._mcp_connected = True
except asyncio.CancelledError:
logger.warning("MCP connection cancelled (will retry next message)")
self._mcp_stacks.clear()
except BaseException as e:
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
if self._mcp_stack:
try:
await self._mcp_stack.aclose()
except Exception:
pass
self._mcp_stack = None
self._mcp_stacks.clear()
finally:
self._mcp_connecting = False
@ -289,6 +298,7 @@ class AgentLoop:
if not text:
return None
from nanobot.utils.helpers import strip_think
return strip_think(text) or None
@staticmethod
@ -327,9 +337,7 @@ class AgentLoop:
message_id=message_id,
)
hook: AgentHook = (
CompositeHook([loop_hook] + self._extra_hooks)
if self._extra_hooks
else loop_hook
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
)
async def _checkpoint(payload: dict[str, Any]) -> None:
@ -337,23 +345,25 @@ class AgentLoop:
return
self._set_runtime_checkpoint(session, payload)
result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=self.tools,
model=self.model,
max_iterations=self.max_iterations,
max_tool_result_chars=self.max_tool_result_chars,
hook=hook,
error_message="Sorry, I encountered an error calling the AI model.",
concurrent_tools=True,
workspace=self.workspace,
session_key=session.key if session else None,
context_window_tokens=self.context_window_tokens,
context_block_limit=self.context_block_limit,
provider_retry_mode=self.provider_retry_mode,
progress_callback=on_progress,
checkpoint_callback=_checkpoint,
))
result = await self.runner.run(
AgentRunSpec(
initial_messages=initial_messages,
tools=self.tools,
model=self.model,
max_iterations=self.max_iterations,
max_tool_result_chars=self.max_tool_result_chars,
hook=hook,
error_message="Sorry, I encountered an error calling the AI model.",
concurrent_tools=True,
workspace=self.workspace,
session_key=session.key if session else None,
context_window_tokens=self.context_window_tokens,
context_block_limit=self.context_block_limit,
provider_retry_mode=self.provider_retry_mode,
progress_callback=on_progress,
checkpoint_callback=_checkpoint,
)
)
self._last_usage = result.usage
if result.stop_reason == "max_iterations":
logger.warning("Max iterations ({}) reached", self.max_iterations)
@ -391,10 +401,19 @@ class AgentLoop:
continue
# Compute the effective session key before dispatching
# This ensures /stop command can find tasks correctly when unified session is enabled
effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key
effective_key = (
UNIFIED_SESSION_KEY
if self._unified_session and not msg.session_key_override
else msg.session_key
)
task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(effective_key, []).append(task)
task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
task.add_done_callback(
lambda t, k=effective_key: self._active_tasks.get(k, [])
and self._active_tasks[k].remove(t)
if t in self._active_tasks.get(k, [])
else None
)
async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message: per-session serial, cross-session concurrent."""
@ -417,11 +436,14 @@ class AgentLoop:
meta = dict(msg.metadata or {})
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content=delta,
metadata=meta,
))
await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=delta,
metadata=meta,
)
)
async def on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
@ -429,44 +451,56 @@ class AgentLoop:
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="",
metadata=meta,
))
await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="",
metadata=meta,
)
)
stream_segment += 1
response = await self._process_message(
msg, on_stream=on_stream, on_stream_end=on_stream_end,
msg,
on_stream=on_stream,
on_stream_end=on_stream_end,
)
if response is not None:
await self.bus.publish_outbound(response)
elif msg.channel == "cli":
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="", metadata=msg.metadata or {},
))
await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="",
metadata=msg.metadata or {},
)
)
except asyncio.CancelledError:
logger.info("Task cancelled for session {}", msg.session_key)
raise
except Exception:
logger.exception("Error processing message for session {}", msg.session_key)
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Sorry, I encountered an error.",
))
await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="Sorry, I encountered an error.",
)
)
async def close_mcp(self) -> None:
"""Drain pending background archives, then close MCP connections."""
if self._background_tasks:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
if self._mcp_stack:
for name, stack in self._mcp_stacks.items():
try:
await self._mcp_stack.aclose()
await stack.aclose()
except (RuntimeError, BaseExceptionGroup):
pass # MCP SDK cancel scope cleanup is noisy but harmless
self._mcp_stack = None
logger.debug("MCP server '{}' cleanup error (can be ignored)", name)
self._mcp_stacks.clear()
def _schedule_background(self, coro) -> None:
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
@ -490,8 +524,9 @@ class AgentLoop:
"""Process a single inbound message and return the response."""
# System messages: parse origin from chat_id ("channel:chat_id")
if msg.channel == "system":
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
else ("cli", msg.chat_id))
channel, chat_id = (
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
)
logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key)
@ -503,19 +538,27 @@ class AgentLoop:
current_role = "assistant" if msg.sender_id == "subagent" else "user"
messages = self.context.build_messages(
history=history,
current_message=msg.content, channel=channel, chat_id=chat_id,
current_message=msg.content,
channel=channel,
chat_id=chat_id,
current_role=current_role,
)
final_content, _, all_msgs, _ = await self._run_agent_loop(
messages, session=session, channel=channel, chat_id=chat_id,
messages,
session=session,
channel=channel,
chat_id=chat_id,
message_id=msg.metadata.get("message_id"),
)
self._save_turn(session, all_msgs, 1 + len(history))
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
return OutboundMessage(channel=channel, chat_id=chat_id,
content=final_content or "Background task completed.")
return OutboundMessage(
channel=channel,
chat_id=chat_id,
content=final_content or "Background task completed.",
)
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
@ -543,16 +586,22 @@ class AgentLoop:
history=history,
current_message=msg.content,
media=msg.media if msg.media else None,
channel=msg.channel, chat_id=msg.chat_id,
channel=msg.channel,
chat_id=msg.chat_id,
)
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
meta = dict(msg.metadata or {})
meta["_progress"] = True
meta["_tool_hint"] = tool_hint
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
))
await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=content,
metadata=meta,
)
)
final_content, _, all_msgs, stop_reason = await self._run_agent_loop(
initial_messages,
@ -560,7 +609,8 @@ class AgentLoop:
on_stream=on_stream,
on_stream_end=on_stream_end,
session=session,
channel=msg.channel, chat_id=msg.chat_id,
channel=msg.channel,
chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"),
)
@ -582,7 +632,9 @@ class AgentLoop:
if on_stream is not None and stop_reason != "error":
meta["_streamed"] = True
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
channel=msg.channel,
chat_id=msg.chat_id,
content=final_content,
metadata=meta,
)
@ -608,10 +660,9 @@ class AgentLoop:
):
continue
if (
block.get("type") == "image_url"
and block.get("image_url", {}).get("url", "").startswith("data:image/")
):
if block.get("type") == "image_url" and block.get("image_url", {}).get(
"url", ""
).startswith("data:image/"):
path = (block.get("_meta") or {}).get("path", "")
filtered.append({"type": "text", "text": image_placeholder_text(path)})
continue
@ -630,6 +681,7 @@ class AgentLoop:
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
"""Save new-turn messages into session, truncating large tool results."""
from datetime import datetime
for m in messages[skip:]:
entry = dict(m)
role, content = entry.get("role"), entry.get("content")
@ -644,7 +696,9 @@ class AgentLoop:
continue
entry["content"] = filtered
elif role == "user":
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
if isinstance(content, str) and content.startswith(
ContextBuilder._RUNTIME_CONTEXT_TAG
):
# Strip the runtime-context prefix, keep only the user text.
parts = content.split("\n\n", 1)
if len(parts) > 1 and parts[1].strip():
@ -708,13 +762,15 @@ class AgentLoop:
continue
tool_id = tool_call.get("id")
name = ((tool_call.get("function") or {}).get("name")) or "tool"
restored_messages.append({
"role": "tool",
"tool_call_id": tool_id,
"name": name,
"content": "Error: Task interrupted before this tool finished.",
"timestamp": datetime.now().isoformat(),
})
restored_messages.append(
{
"role": "tool",
"tool_call_id": tool_id,
"name": name,
"content": "Error: Task interrupted before this tool finished.",
"timestamp": datetime.now().isoformat(),
}
)
overlap = 0
max_overlap = min(len(session.messages), len(restored_messages))
@ -746,6 +802,9 @@ class AgentLoop:
await self._connect_mcp()
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
return await self._process_message(
msg, session_key=session_key, on_progress=on_progress,
on_stream=on_stream, on_stream_end=on_stream_end,
msg,
session_key=session_key,
on_progress=on_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
)

View File

@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
if "properties" in normalized and isinstance(normalized["properties"], dict):
normalized["properties"] = {
name: _normalize_schema_for_openai(prop)
if isinstance(prop, dict)
else prop
name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop
for name, prop in normalized["properties"].items()
}
@ -138,9 +136,7 @@ class MCPToolWrapper(Tool):
class MCPResourceWrapper(Tool):
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
def __init__(
self, session, server_name: str, resource_def, resource_timeout: int = 30
):
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
self._session = session
self._uri = resource_def.uri
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool):
class MCPPromptWrapper(Tool):
"""Wraps an MCP prompt as a read-only nanobot Tool."""
def __init__(
self, session, server_name: str, prompt_def, prompt_timeout: int = 30
):
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
self._session = session
self._prompt_name = prompt_def.name
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool):
timeout=self._prompt_timeout,
)
except asyncio.TimeoutError:
logger.warning(
"MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout
)
logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout)
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
except asyncio.CancelledError:
task = asyncio.current_task()
@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool):
except McpError as exc:
logger.error(
"MCP prompt '{}' failed: code={} message={}",
self._name, exc.error.code, exc.error.message,
self._name,
exc.error.code,
exc.error.message,
)
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
except Exception as exc:
logger.exception(
"MCP prompt '{}' failed: {}: {}",
self._name, type(exc).__name__, exc,
self._name,
type(exc).__name__,
exc,
)
return f"(MCP prompt call failed: {type(exc).__name__})"
@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool):
async def connect_mcp_servers(
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
) -> None:
"""Connect to configured MCP servers and register their tools, resources, and prompts."""
mcp_servers: dict, registry: ToolRegistry
) -> dict[str, AsyncExitStack]:
"""Connect to configured MCP servers and register their tools, resources, prompts.
Returns a dict mapping server name -> its dedicated AsyncExitStack.
Each server gets its own stack and runs in its own task to prevent
cancel scope conflicts when multiple MCP servers are configured.
"""
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client
for name, cfg in mcp_servers.items():
async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]:
server_stack = AsyncExitStack()
await server_stack.__aenter__()
try:
transport_type = cfg.type
if not transport_type:
if cfg.command:
transport_type = "stdio"
elif cfg.url:
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
transport_type = (
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
)
else:
logger.warning("MCP server '{}': no command or url configured, skipping", name)
continue
await server_stack.aclose()
return name, None
if transport_type == "stdio":
params = StdioServerParameters(
command=cfg.command, args=cfg.args, env=cfg.env or None
)
read, write = await stack.enter_async_context(stdio_client(params))
read, write = await server_stack.enter_async_context(stdio_client(params))
elif transport_type == "sse":
def httpx_client_factory(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
@ -353,27 +358,26 @@ async def connect_mcp_servers(
auth=auth,
)
read, write = await stack.enter_async_context(
read, write = await server_stack.enter_async_context(
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
)
elif transport_type == "streamableHttp":
# Always provide an explicit httpx client so MCP HTTP transport does not
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
http_client = await stack.enter_async_context(
http_client = await server_stack.enter_async_context(
httpx.AsyncClient(
headers=cfg.headers or None,
follow_redirects=True,
timeout=None,
)
)
read, write, _ = await stack.enter_async_context(
read, write, _ = await server_stack.enter_async_context(
streamable_http_client(cfg.url, http_client=http_client)
)
else:
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
continue
await server_stack.aclose()
return name, None
session = await stack.enter_async_context(ClientSession(read, write))
session = await server_stack.enter_async_context(ClientSession(read, write))
await session.initialize()
tools = await session.list_tools()
@ -418,7 +422,6 @@ async def connect_mcp_servers(
", ".join(available_wrapped_names) or "(none)",
)
# --- Register resources ---
try:
resources_result = await session.list_resources()
for resource in resources_result.resources:
@ -433,7 +436,6 @@ async def connect_mcp_servers(
except Exception as e:
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
# --- Register prompts ---
try:
prompts_result = await session.list_prompts()
for prompt in prompts_result.prompts:
@ -442,14 +444,38 @@ async def connect_mcp_servers(
)
registry.register(wrapper)
registered_count += 1
logger.debug(
"MCP: registered prompt '{}' from server '{}'", wrapper.name, name
)
logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name)
except Exception as e:
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
logger.info(
"MCP server '{}': connected, {} capabilities registered", name, registered_count
)
return name, server_stack
except Exception as e:
logger.error("MCP server '{}': failed to connect: {}", name, e)
try:
await server_stack.aclose()
except Exception:
pass
return name, None
server_stacks: dict[str, AsyncExitStack] = {}
tasks: list[asyncio.Task] = []
for name, cfg in mcp_servers.items():
task = asyncio.create_task(connect_single_server(name, cfg))
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
name = list(mcp_servers.keys())[i]
if isinstance(result, BaseException):
if not isinstance(result, asyncio.CancelledError):
logger.error("MCP server '{}' connection task failed: {}", name, result)
elif result is not None and result[1] is not None:
server_stacks[result[0]] = result[1]
return server_stacks

View File

@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
registry,
stack,
)
finally:
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert registry.tool_names == ["mcp_test_demo"]
@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake")},
registry,
stack,
)
finally:
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake")},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
registry,
stack,
)
finally:
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert registry.tool_names == ["mcp_test_demo"]
@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
registry,
stack,
)
finally:
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert registry.tool_names == []
@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
registry,
stack,
)
finally:
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert registry.tool_names == []
@ -389,9 +369,7 @@ def _make_resource_def(
return SimpleNamespace(name=name, uri=uri, description=description)
def _make_resource_wrapper(
session: object, *, timeout: float = 0.1
) -> MCPResourceWrapper:
def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper:
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
@ -434,9 +412,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None:
await asyncio.sleep(1)
return SimpleNamespace(contents=[])
wrapper = _make_resource_wrapper(
SimpleNamespace(read_resource=read_resource), timeout=0.01
)
wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01)
result = await wrapper.execute()
assert result == "(MCP resource read timed out after 0.01s)"
@ -464,20 +440,14 @@ def _make_prompt_def(
return SimpleNamespace(name=name, description=description, arguments=arguments)
def _make_prompt_wrapper(
session: object, *, timeout: float = 0.1
) -> MCPPromptWrapper:
return MCPPromptWrapper(
session, "srv", _make_prompt_def(), prompt_timeout=timeout
)
def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper:
return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout)
def test_prompt_wrapper_properties() -> None:
arg1 = SimpleNamespace(name="topic", required=True)
arg2 = SimpleNamespace(name="style", required=False)
wrapper = MCPPromptWrapper(
None, "myserver", _make_prompt_def(arguments=[arg1, arg2])
)
wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2]))
assert wrapper.name == "mcp_myserver_prompt_myprompt"
assert "[MCP Prompt]" in wrapper.description
assert "A test prompt" in wrapper.description
@ -528,9 +498,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None:
await asyncio.sleep(1)
return SimpleNamespace(messages=[])
wrapper = _make_prompt_wrapper(
SimpleNamespace(get_prompt=get_prompt), timeout=0.01
)
wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01)
result = await wrapper.execute()
assert result == "(MCP prompt call timed out after 0.01s)"
@ -616,15 +584,11 @@ async def test_connect_registers_resources_and_prompts(
prompt_names=["prompt_c"],
)
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake")},
registry,
stack,
)
finally:
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake")},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert "mcp_test_tool_a" in registry.tool_names