mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-30 14:56:01 +00:00
Merge PR #3019: fix(mcp): support multiple MCP servers
fix(mcp): support multiple MCP servers
This commit is contained in:
commit
e7bbbe98f4
@ -44,6 +44,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
UNIFIED_SESSION_KEY = "unified:default"
|
UNIFIED_SESSION_KEY = "unified:default"
|
||||||
|
|
||||||
|
|
||||||
class _LoopHook(AgentHook):
|
class _LoopHook(AgentHook):
|
||||||
"""Core hook for the main loop."""
|
"""Core hook for the main loop."""
|
||||||
|
|
||||||
@ -77,7 +78,7 @@ class _LoopHook(AgentHook):
|
|||||||
prev_clean = strip_think(self._stream_buf)
|
prev_clean = strip_think(self._stream_buf)
|
||||||
self._stream_buf += delta
|
self._stream_buf += delta
|
||||||
new_clean = strip_think(self._stream_buf)
|
new_clean = strip_think(self._stream_buf)
|
||||||
incremental = new_clean[len(prev_clean):]
|
incremental = new_clean[len(prev_clean) :]
|
||||||
if incremental and self._on_stream:
|
if incremental and self._on_stream:
|
||||||
await self._on_stream(incremental)
|
await self._on_stream(incremental)
|
||||||
|
|
||||||
@ -113,6 +114,7 @@ class _LoopHook(AgentHook):
|
|||||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||||
return self._loop._strip_think(content)
|
return self._loop._strip_think(content)
|
||||||
|
|
||||||
|
|
||||||
class AgentLoop:
|
class AgentLoop:
|
||||||
"""
|
"""
|
||||||
The agent loop is the core processing engine.
|
The agent loop is the core processing engine.
|
||||||
@ -198,7 +200,7 @@ class AgentLoop:
|
|||||||
self._unified_session = unified_session
|
self._unified_session = unified_session
|
||||||
self._running = False
|
self._running = False
|
||||||
self._mcp_servers = mcp_servers or {}
|
self._mcp_servers = mcp_servers or {}
|
||||||
self._mcp_stack: AsyncExitStack | None = None
|
self._mcp_stacks: dict[str, AsyncExitStack] = {}
|
||||||
self._mcp_connected = False
|
self._mcp_connected = False
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
@ -235,24 +237,34 @@ class AgentLoop:
|
|||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
"""Register the default set of tools."""
|
"""Register the default set of tools."""
|
||||||
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
allowed_dir = (
|
||||||
|
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||||
|
)
|
||||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||||
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
|
self.tools.register(
|
||||||
|
ReadFileTool(
|
||||||
|
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
|
||||||
|
)
|
||||||
|
)
|
||||||
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
||||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
for cls in (GlobTool, GrepTool):
|
for cls in (GlobTool, GrepTool):
|
||||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
if self.exec_config.enable:
|
if self.exec_config.enable:
|
||||||
self.tools.register(ExecTool(
|
self.tools.register(
|
||||||
working_dir=str(self.workspace),
|
ExecTool(
|
||||||
timeout=self.exec_config.timeout,
|
working_dir=str(self.workspace),
|
||||||
restrict_to_workspace=self.restrict_to_workspace,
|
timeout=self.exec_config.timeout,
|
||||||
sandbox=self.exec_config.sandbox,
|
restrict_to_workspace=self.restrict_to_workspace,
|
||||||
path_append=self.exec_config.path_append,
|
sandbox=self.exec_config.sandbox,
|
||||||
allowed_env_keys=self.exec_config.allowed_env_keys,
|
path_append=self.exec_config.path_append,
|
||||||
))
|
allowed_env_keys=self.exec_config.allowed_env_keys,
|
||||||
|
)
|
||||||
|
)
|
||||||
if self.web_config.enable:
|
if self.web_config.enable:
|
||||||
self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy))
|
self.tools.register(
|
||||||
|
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
|
||||||
|
)
|
||||||
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
||||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||||
self.tools.register(SpawnTool(manager=self.subagents))
|
self.tools.register(SpawnTool(manager=self.subagents))
|
||||||
@ -267,19 +279,19 @@ class AgentLoop:
|
|||||||
return
|
return
|
||||||
self._mcp_connecting = True
|
self._mcp_connecting = True
|
||||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._mcp_stack = AsyncExitStack()
|
self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools)
|
||||||
await self._mcp_stack.__aenter__()
|
if self._mcp_stacks:
|
||||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
self._mcp_connected = True
|
||||||
self._mcp_connected = True
|
else:
|
||||||
|
logger.warning("No MCP servers connected successfully (will retry next message)")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.warning("MCP connection cancelled (will retry next message)")
|
||||||
|
self._mcp_stacks.clear()
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||||
if self._mcp_stack:
|
self._mcp_stacks.clear()
|
||||||
try:
|
|
||||||
await self._mcp_stack.aclose()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._mcp_stack = None
|
|
||||||
finally:
|
finally:
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
|
|
||||||
@ -296,6 +308,7 @@ class AgentLoop:
|
|||||||
if not text:
|
if not text:
|
||||||
return None
|
return None
|
||||||
from nanobot.utils.helpers import strip_think
|
from nanobot.utils.helpers import strip_think
|
||||||
|
|
||||||
return strip_think(text) or None
|
return strip_think(text) or None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -334,9 +347,7 @@ class AgentLoop:
|
|||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
)
|
)
|
||||||
hook: AgentHook = (
|
hook: AgentHook = (
|
||||||
CompositeHook([loop_hook] + self._extra_hooks)
|
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
|
||||||
if self._extra_hooks
|
|
||||||
else loop_hook
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _checkpoint(payload: dict[str, Any]) -> None:
|
async def _checkpoint(payload: dict[str, Any]) -> None:
|
||||||
@ -344,23 +355,25 @@ class AgentLoop:
|
|||||||
return
|
return
|
||||||
self._set_runtime_checkpoint(session, payload)
|
self._set_runtime_checkpoint(session, payload)
|
||||||
|
|
||||||
result = await self.runner.run(AgentRunSpec(
|
result = await self.runner.run(
|
||||||
initial_messages=initial_messages,
|
AgentRunSpec(
|
||||||
tools=self.tools,
|
initial_messages=initial_messages,
|
||||||
model=self.model,
|
tools=self.tools,
|
||||||
max_iterations=self.max_iterations,
|
model=self.model,
|
||||||
max_tool_result_chars=self.max_tool_result_chars,
|
max_iterations=self.max_iterations,
|
||||||
hook=hook,
|
max_tool_result_chars=self.max_tool_result_chars,
|
||||||
error_message="Sorry, I encountered an error calling the AI model.",
|
hook=hook,
|
||||||
concurrent_tools=True,
|
error_message="Sorry, I encountered an error calling the AI model.",
|
||||||
workspace=self.workspace,
|
concurrent_tools=True,
|
||||||
session_key=session.key if session else None,
|
workspace=self.workspace,
|
||||||
context_window_tokens=self.context_window_tokens,
|
session_key=session.key if session else None,
|
||||||
context_block_limit=self.context_block_limit,
|
context_window_tokens=self.context_window_tokens,
|
||||||
provider_retry_mode=self.provider_retry_mode,
|
context_block_limit=self.context_block_limit,
|
||||||
progress_callback=on_progress,
|
provider_retry_mode=self.provider_retry_mode,
|
||||||
checkpoint_callback=_checkpoint,
|
progress_callback=on_progress,
|
||||||
))
|
checkpoint_callback=_checkpoint,
|
||||||
|
)
|
||||||
|
)
|
||||||
self._last_usage = result.usage
|
self._last_usage = result.usage
|
||||||
if result.stop_reason == "max_iterations":
|
if result.stop_reason == "max_iterations":
|
||||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||||
@ -399,10 +412,19 @@ class AgentLoop:
|
|||||||
continue
|
continue
|
||||||
# Compute the effective session key before dispatching
|
# Compute the effective session key before dispatching
|
||||||
# This ensures /stop command can find tasks correctly when unified session is enabled
|
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||||
effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key
|
effective_key = (
|
||||||
|
UNIFIED_SESSION_KEY
|
||||||
|
if self._unified_session and not msg.session_key_override
|
||||||
|
else msg.session_key
|
||||||
|
)
|
||||||
task = asyncio.create_task(self._dispatch(msg))
|
task = asyncio.create_task(self._dispatch(msg))
|
||||||
self._active_tasks.setdefault(effective_key, []).append(task)
|
self._active_tasks.setdefault(effective_key, []).append(task)
|
||||||
task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
task.add_done_callback(
|
||||||
|
lambda t, k=effective_key: self._active_tasks.get(k, [])
|
||||||
|
and self._active_tasks[k].remove(t)
|
||||||
|
if t in self._active_tasks.get(k, [])
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||||
"""Process a message: per-session serial, cross-session concurrent."""
|
"""Process a message: per-session serial, cross-session concurrent."""
|
||||||
@ -425,11 +447,14 @@ class AgentLoop:
|
|||||||
meta = dict(msg.metadata or {})
|
meta = dict(msg.metadata or {})
|
||||||
meta["_stream_delta"] = True
|
meta["_stream_delta"] = True
|
||||||
meta["_stream_id"] = _current_stream_id()
|
meta["_stream_id"] = _current_stream_id()
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
OutboundMessage(
|
||||||
content=delta,
|
channel=msg.channel,
|
||||||
metadata=meta,
|
chat_id=msg.chat_id,
|
||||||
))
|
content=delta,
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||||
nonlocal stream_segment
|
nonlocal stream_segment
|
||||||
@ -437,44 +462,56 @@ class AgentLoop:
|
|||||||
meta["_stream_end"] = True
|
meta["_stream_end"] = True
|
||||||
meta["_resuming"] = resuming
|
meta["_resuming"] = resuming
|
||||||
meta["_stream_id"] = _current_stream_id()
|
meta["_stream_id"] = _current_stream_id()
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
OutboundMessage(
|
||||||
content="",
|
channel=msg.channel,
|
||||||
metadata=meta,
|
chat_id=msg.chat_id,
|
||||||
))
|
content="",
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
)
|
||||||
stream_segment += 1
|
stream_segment += 1
|
||||||
|
|
||||||
response = await self._process_message(
|
response = await self._process_message(
|
||||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
msg,
|
||||||
|
on_stream=on_stream,
|
||||||
|
on_stream_end=on_stream_end,
|
||||||
)
|
)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
await self.bus.publish_outbound(response)
|
await self.bus.publish_outbound(response)
|
||||||
elif msg.channel == "cli":
|
elif msg.channel == "cli":
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
OutboundMessage(
|
||||||
content="", metadata=msg.metadata or {},
|
channel=msg.channel,
|
||||||
))
|
chat_id=msg.chat_id,
|
||||||
|
content="",
|
||||||
|
metadata=msg.metadata or {},
|
||||||
|
)
|
||||||
|
)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Task cancelled for session {}", msg.session_key)
|
logger.info("Task cancelled for session {}", msg.session_key)
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error processing message for session {}", msg.session_key)
|
logger.exception("Error processing message for session {}", msg.session_key)
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
OutboundMessage(
|
||||||
content="Sorry, I encountered an error.",
|
channel=msg.channel,
|
||||||
))
|
chat_id=msg.chat_id,
|
||||||
|
content="Sorry, I encountered an error.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
"""Drain pending background archives, then close MCP connections."""
|
"""Drain pending background archives, then close MCP connections."""
|
||||||
if self._background_tasks:
|
if self._background_tasks:
|
||||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||||
self._background_tasks.clear()
|
self._background_tasks.clear()
|
||||||
if self._mcp_stack:
|
for name, stack in self._mcp_stacks.items():
|
||||||
try:
|
try:
|
||||||
await self._mcp_stack.aclose()
|
await stack.aclose()
|
||||||
except (RuntimeError, BaseExceptionGroup):
|
except (RuntimeError, BaseExceptionGroup):
|
||||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
logger.debug("MCP server '{}' cleanup error (can be ignored)", name)
|
||||||
self._mcp_stack = None
|
self._mcp_stacks.clear()
|
||||||
|
|
||||||
def _schedule_background(self, coro) -> None:
|
def _schedule_background(self, coro) -> None:
|
||||||
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
||||||
@ -498,8 +535,9 @@ class AgentLoop:
|
|||||||
"""Process a single inbound message and return the response."""
|
"""Process a single inbound message and return the response."""
|
||||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||||
if msg.channel == "system":
|
if msg.channel == "system":
|
||||||
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
channel, chat_id = (
|
||||||
else ("cli", msg.chat_id))
|
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
|
||||||
|
)
|
||||||
logger.info("Processing system message from {}", msg.sender_id)
|
logger.info("Processing system message from {}", msg.sender_id)
|
||||||
key = f"{channel}:{chat_id}"
|
key = f"{channel}:{chat_id}"
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
@ -520,15 +558,21 @@ class AgentLoop:
|
|||||||
current_role=current_role,
|
current_role=current_role,
|
||||||
)
|
)
|
||||||
final_content, _, all_msgs, _ = await self._run_agent_loop(
|
final_content, _, all_msgs, _ = await self._run_agent_loop(
|
||||||
messages, session=session, channel=channel, chat_id=chat_id,
|
messages,
|
||||||
|
session=session,
|
||||||
|
channel=channel,
|
||||||
|
chat_id=chat_id,
|
||||||
message_id=msg.metadata.get("message_id"),
|
message_id=msg.metadata.get("message_id"),
|
||||||
)
|
)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self._clear_runtime_checkpoint(session)
|
self._clear_runtime_checkpoint(session)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
return OutboundMessage(
|
||||||
content=final_content or "Background task completed.")
|
channel=channel,
|
||||||
|
chat_id=chat_id,
|
||||||
|
content=final_content or "Background task completed.",
|
||||||
|
)
|
||||||
|
|
||||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
@ -560,16 +604,22 @@ class AgentLoop:
|
|||||||
current_message=msg.content,
|
current_message=msg.content,
|
||||||
session_summary=pending,
|
session_summary=pending,
|
||||||
media=msg.media if msg.media else None,
|
media=msg.media if msg.media else None,
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||||
meta = dict(msg.metadata or {})
|
meta = dict(msg.metadata or {})
|
||||||
meta["_progress"] = True
|
meta["_progress"] = True
|
||||||
meta["_tool_hint"] = tool_hint
|
meta["_tool_hint"] = tool_hint
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
OutboundMessage(
|
||||||
))
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
content=content,
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
final_content, _, all_msgs, stop_reason = await self._run_agent_loop(
|
final_content, _, all_msgs, stop_reason = await self._run_agent_loop(
|
||||||
initial_messages,
|
initial_messages,
|
||||||
@ -577,7 +627,8 @@ class AgentLoop:
|
|||||||
on_stream=on_stream,
|
on_stream=on_stream,
|
||||||
on_stream_end=on_stream_end,
|
on_stream_end=on_stream_end,
|
||||||
session=session,
|
session=session,
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
message_id=msg.metadata.get("message_id"),
|
message_id=msg.metadata.get("message_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -599,7 +650,9 @@ class AgentLoop:
|
|||||||
if on_stream is not None and stop_reason != "error":
|
if on_stream is not None and stop_reason != "error":
|
||||||
meta["_streamed"] = True
|
meta["_streamed"] = True
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
content=final_content,
|
||||||
metadata=meta,
|
metadata=meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -625,10 +678,9 @@ class AgentLoop:
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
if block.get("type") == "image_url" and block.get("image_url", {}).get(
|
||||||
block.get("type") == "image_url"
|
"url", ""
|
||||||
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
).startswith("data:image/"):
|
||||||
):
|
|
||||||
path = (block.get("_meta") or {}).get("path", "")
|
path = (block.get("_meta") or {}).get("path", "")
|
||||||
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||||
continue
|
continue
|
||||||
@ -647,6 +699,7 @@ class AgentLoop:
|
|||||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||||
"""Save new-turn messages into session, truncating large tool results."""
|
"""Save new-turn messages into session, truncating large tool results."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
for m in messages[skip:]:
|
for m in messages[skip:]:
|
||||||
entry = dict(m)
|
entry = dict(m)
|
||||||
role, content = entry.get("role"), entry.get("content")
|
role, content = entry.get("role"), entry.get("content")
|
||||||
@ -736,13 +789,15 @@ class AgentLoop:
|
|||||||
continue
|
continue
|
||||||
tool_id = tool_call.get("id")
|
tool_id = tool_call.get("id")
|
||||||
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||||
restored_messages.append({
|
restored_messages.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tool_id,
|
"role": "tool",
|
||||||
"name": name,
|
"tool_call_id": tool_id,
|
||||||
"content": "Error: Task interrupted before this tool finished.",
|
"name": name,
|
||||||
"timestamp": datetime.now().isoformat(),
|
"content": "Error: Task interrupted before this tool finished.",
|
||||||
})
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
overlap = 0
|
overlap = 0
|
||||||
max_overlap = min(len(session.messages), len(restored_messages))
|
max_overlap = min(len(session.messages), len(restored_messages))
|
||||||
@ -774,6 +829,9 @@ class AgentLoop:
|
|||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||||
return await self._process_message(
|
return await self._process_message(
|
||||||
msg, session_key=session_key, on_progress=on_progress,
|
msg,
|
||||||
on_stream=on_stream, on_stream_end=on_stream_end,
|
session_key=session_key,
|
||||||
|
on_progress=on_progress,
|
||||||
|
on_stream=on_stream,
|
||||||
|
on_stream_end=on_stream_end,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
|
|||||||
|
|
||||||
if "properties" in normalized and isinstance(normalized["properties"], dict):
|
if "properties" in normalized and isinstance(normalized["properties"], dict):
|
||||||
normalized["properties"] = {
|
normalized["properties"] = {
|
||||||
name: _normalize_schema_for_openai(prop)
|
name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop
|
||||||
if isinstance(prop, dict)
|
|
||||||
else prop
|
|
||||||
for name, prop in normalized["properties"].items()
|
for name, prop in normalized["properties"].items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,9 +136,7 @@ class MCPToolWrapper(Tool):
|
|||||||
class MCPResourceWrapper(Tool):
|
class MCPResourceWrapper(Tool):
|
||||||
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
|
||||||
self, session, server_name: str, resource_def, resource_timeout: int = 30
|
|
||||||
):
|
|
||||||
self._session = session
|
self._session = session
|
||||||
self._uri = resource_def.uri
|
self._uri = resource_def.uri
|
||||||
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
|
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
|
||||||
@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool):
|
|||||||
class MCPPromptWrapper(Tool):
|
class MCPPromptWrapper(Tool):
|
||||||
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
|
||||||
self, session, server_name: str, prompt_def, prompt_timeout: int = 30
|
|
||||||
):
|
|
||||||
self._session = session
|
self._session = session
|
||||||
self._prompt_name = prompt_def.name
|
self._prompt_name = prompt_def.name
|
||||||
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
|
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
|
||||||
@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool):
|
|||||||
timeout=self._prompt_timeout,
|
timeout=self._prompt_timeout,
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(
|
logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout)
|
||||||
"MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout
|
|
||||||
)
|
|
||||||
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
|
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
task = asyncio.current_task()
|
task = asyncio.current_task()
|
||||||
@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool):
|
|||||||
except McpError as exc:
|
except McpError as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
"MCP prompt '{}' failed: code={} message={}",
|
"MCP prompt '{}' failed: code={} message={}",
|
||||||
self._name, exc.error.code, exc.error.message,
|
self._name,
|
||||||
|
exc.error.code,
|
||||||
|
exc.error.message,
|
||||||
)
|
)
|
||||||
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"MCP prompt '{}' failed: {}: {}",
|
"MCP prompt '{}' failed: {}: {}",
|
||||||
self._name, type(exc).__name__, exc,
|
self._name,
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
)
|
)
|
||||||
return f"(MCP prompt call failed: {type(exc).__name__})"
|
return f"(MCP prompt call failed: {type(exc).__name__})"
|
||||||
|
|
||||||
@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool):
|
|||||||
|
|
||||||
|
|
||||||
async def connect_mcp_servers(
|
async def connect_mcp_servers(
|
||||||
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
|
mcp_servers: dict, registry: ToolRegistry
|
||||||
) -> None:
|
) -> dict[str, AsyncExitStack]:
|
||||||
"""Connect to configured MCP servers and register their tools, resources, and prompts."""
|
"""Connect to configured MCP servers and register their tools, resources, prompts.
|
||||||
|
|
||||||
|
Returns a dict mapping server name -> its dedicated AsyncExitStack.
|
||||||
|
Each server gets its own stack and runs in its own task to prevent
|
||||||
|
cancel scope conflicts when multiple MCP servers are configured.
|
||||||
|
"""
|
||||||
from mcp import ClientSession, StdioServerParameters
|
from mcp import ClientSession, StdioServerParameters
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.client.stdio import stdio_client
|
from mcp.client.stdio import stdio_client
|
||||||
from mcp.client.streamable_http import streamable_http_client
|
from mcp.client.streamable_http import streamable_http_client
|
||||||
|
|
||||||
for name, cfg in mcp_servers.items():
|
async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]:
|
||||||
|
server_stack = AsyncExitStack()
|
||||||
|
await server_stack.__aenter__()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
transport_type = cfg.type
|
transport_type = cfg.type
|
||||||
if not transport_type:
|
if not transport_type:
|
||||||
if cfg.command:
|
if cfg.command:
|
||||||
transport_type = "stdio"
|
transport_type = "stdio"
|
||||||
elif cfg.url:
|
elif cfg.url:
|
||||||
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
|
|
||||||
transport_type = (
|
transport_type = (
|
||||||
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||||
continue
|
await server_stack.aclose()
|
||||||
|
return name, None
|
||||||
|
|
||||||
if transport_type == "stdio":
|
if transport_type == "stdio":
|
||||||
params = StdioServerParameters(
|
params = StdioServerParameters(
|
||||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||||
)
|
)
|
||||||
read, write = await stack.enter_async_context(stdio_client(params))
|
read, write = await server_stack.enter_async_context(stdio_client(params))
|
||||||
elif transport_type == "sse":
|
elif transport_type == "sse":
|
||||||
|
|
||||||
def httpx_client_factory(
|
def httpx_client_factory(
|
||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
timeout: httpx.Timeout | None = None,
|
timeout: httpx.Timeout | None = None,
|
||||||
@ -353,27 +358,26 @@ async def connect_mcp_servers(
|
|||||||
auth=auth,
|
auth=auth,
|
||||||
)
|
)
|
||||||
|
|
||||||
read, write = await stack.enter_async_context(
|
read, write = await server_stack.enter_async_context(
|
||||||
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
||||||
)
|
)
|
||||||
elif transport_type == "streamableHttp":
|
elif transport_type == "streamableHttp":
|
||||||
# Always provide an explicit httpx client so MCP HTTP transport does not
|
http_client = await server_stack.enter_async_context(
|
||||||
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
|
||||||
http_client = await stack.enter_async_context(
|
|
||||||
httpx.AsyncClient(
|
httpx.AsyncClient(
|
||||||
headers=cfg.headers or None,
|
headers=cfg.headers or None,
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
read, write, _ = await stack.enter_async_context(
|
read, write, _ = await server_stack.enter_async_context(
|
||||||
streamable_http_client(cfg.url, http_client=http_client)
|
streamable_http_client(cfg.url, http_client=http_client)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
||||||
continue
|
await server_stack.aclose()
|
||||||
|
return name, None
|
||||||
|
|
||||||
session = await stack.enter_async_context(ClientSession(read, write))
|
session = await server_stack.enter_async_context(ClientSession(read, write))
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
tools = await session.list_tools()
|
tools = await session.list_tools()
|
||||||
@ -418,7 +422,6 @@ async def connect_mcp_servers(
|
|||||||
", ".join(available_wrapped_names) or "(none)",
|
", ".join(available_wrapped_names) or "(none)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Register resources ---
|
|
||||||
try:
|
try:
|
||||||
resources_result = await session.list_resources()
|
resources_result = await session.list_resources()
|
||||||
for resource in resources_result.resources:
|
for resource in resources_result.resources:
|
||||||
@ -433,7 +436,6 @@ async def connect_mcp_servers(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
|
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
|
||||||
|
|
||||||
# --- Register prompts ---
|
|
||||||
try:
|
try:
|
||||||
prompts_result = await session.list_prompts()
|
prompts_result = await session.list_prompts()
|
||||||
for prompt in prompts_result.prompts:
|
for prompt in prompts_result.prompts:
|
||||||
@ -442,14 +444,38 @@ async def connect_mcp_servers(
|
|||||||
)
|
)
|
||||||
registry.register(wrapper)
|
registry.register(wrapper)
|
||||||
registered_count += 1
|
registered_count += 1
|
||||||
logger.debug(
|
logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name)
|
||||||
"MCP: registered prompt '{}' from server '{}'", wrapper.name, name
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
|
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"MCP server '{}': connected, {} capabilities registered", name, registered_count
|
"MCP server '{}': connected, {} capabilities registered", name, registered_count
|
||||||
)
|
)
|
||||||
|
return name, server_stack
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||||
|
try:
|
||||||
|
await server_stack.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return name, None
|
||||||
|
|
||||||
|
server_stacks: dict[str, AsyncExitStack] = {}
|
||||||
|
|
||||||
|
tasks: list[asyncio.Task] = []
|
||||||
|
for name, cfg in mcp_servers.items():
|
||||||
|
task = asyncio.create_task(connect_single_server(name, cfg))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
name = list(mcp_servers.keys())[i]
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
if not isinstance(result, asyncio.CancelledError):
|
||||||
|
logger.error("MCP server '{}' connection task failed: {}", name, result)
|
||||||
|
elif result is not None and result[1] is not None:
|
||||||
|
server_stacks[result[0]] = result[1]
|
||||||
|
|
||||||
|
return server_stacks
|
||||||
|
|||||||
44
tests/agent/test_mcp_connection.py
Normal file
44
tests/agent/test_mcp_connection.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
"""Tests for MCP connection lifecycle in AgentLoop."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop(tmp_path, *, mcp_servers: dict | None = None) -> AgentLoop:
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.generation.max_tokens = 4096
|
||||||
|
return AgentLoop(
|
||||||
|
bus=bus,
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="test-model",
|
||||||
|
mcp_servers=mcp_servers or {"test": object()},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_mcp_retries_when_no_servers_connect(tmp_path, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
attempts = 0
|
||||||
|
|
||||||
|
async def _fake_connect(_servers, _registry):
|
||||||
|
nonlocal attempts
|
||||||
|
attempts += 1
|
||||||
|
return {}
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
||||||
|
|
||||||
|
await loop._connect_mcp()
|
||||||
|
await loop._connect_mcp()
|
||||||
|
|
||||||
|
assert attempts == 2
|
||||||
|
assert loop._mcp_connected is False
|
||||||
|
assert loop._mcp_stacks == {}
|
||||||
@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
|
|||||||
) -> None:
|
) -> None:
|
||||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == ["mcp_test_demo"]
|
assert registry.tool_names == ["mcp_test_demo"]
|
||||||
@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
|
|||||||
) -> None:
|
) -> None:
|
||||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake")},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake")},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
||||||
@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
|
|||||||
) -> None:
|
) -> None:
|
||||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == ["mcp_test_demo"]
|
assert registry.tool_names == ["mcp_test_demo"]
|
||||||
@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
|
|||||||
) -> None:
|
) -> None:
|
||||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == []
|
assert registry.tool_names == []
|
||||||
@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
|||||||
|
|
||||||
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
||||||
|
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == []
|
assert registry.tool_names == []
|
||||||
@ -376,6 +356,46 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
|||||||
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
|
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_mcp_servers_one_failure_does_not_block_others(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
sessions = {"good": _make_fake_session(["demo"])}
|
||||||
|
|
||||||
|
class _SelectiveClientSession:
|
||||||
|
def __init__(self, read: object, _write: object) -> None:
|
||||||
|
self._session = sessions[read]
|
||||||
|
|
||||||
|
async def __aenter__(self) -> object:
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _selective_stdio_client(params: object):
|
||||||
|
if params.command == "bad":
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
yield params.command, object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(sys.modules["mcp"], "ClientSession", _SelectiveClientSession)
|
||||||
|
monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _selective_stdio_client)
|
||||||
|
|
||||||
|
registry = ToolRegistry()
|
||||||
|
stacks = await connect_mcp_servers(
|
||||||
|
{
|
||||||
|
"good": MCPServerConfig(command="good"),
|
||||||
|
"bad": MCPServerConfig(command="bad"),
|
||||||
|
},
|
||||||
|
registry,
|
||||||
|
)
|
||||||
|
for stack in stacks.values():
|
||||||
|
await stack.aclose()
|
||||||
|
|
||||||
|
assert registry.tool_names == ["mcp_good_demo"]
|
||||||
|
assert set(stacks) == {"good"}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# MCPResourceWrapper tests
|
# MCPResourceWrapper tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -389,9 +409,7 @@ def _make_resource_def(
|
|||||||
return SimpleNamespace(name=name, uri=uri, description=description)
|
return SimpleNamespace(name=name, uri=uri, description=description)
|
||||||
|
|
||||||
|
|
||||||
def _make_resource_wrapper(
|
def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper:
|
||||||
session: object, *, timeout: float = 0.1
|
|
||||||
) -> MCPResourceWrapper:
|
|
||||||
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
|
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
@ -434,9 +452,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None:
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
return SimpleNamespace(contents=[])
|
return SimpleNamespace(contents=[])
|
||||||
|
|
||||||
wrapper = _make_resource_wrapper(
|
wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01)
|
||||||
SimpleNamespace(read_resource=read_resource), timeout=0.01
|
|
||||||
)
|
|
||||||
result = await wrapper.execute()
|
result = await wrapper.execute()
|
||||||
assert result == "(MCP resource read timed out after 0.01s)"
|
assert result == "(MCP resource read timed out after 0.01s)"
|
||||||
|
|
||||||
@ -464,20 +480,14 @@ def _make_prompt_def(
|
|||||||
return SimpleNamespace(name=name, description=description, arguments=arguments)
|
return SimpleNamespace(name=name, description=description, arguments=arguments)
|
||||||
|
|
||||||
|
|
||||||
def _make_prompt_wrapper(
|
def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper:
|
||||||
session: object, *, timeout: float = 0.1
|
return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout)
|
||||||
) -> MCPPromptWrapper:
|
|
||||||
return MCPPromptWrapper(
|
|
||||||
session, "srv", _make_prompt_def(), prompt_timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_wrapper_properties() -> None:
|
def test_prompt_wrapper_properties() -> None:
|
||||||
arg1 = SimpleNamespace(name="topic", required=True)
|
arg1 = SimpleNamespace(name="topic", required=True)
|
||||||
arg2 = SimpleNamespace(name="style", required=False)
|
arg2 = SimpleNamespace(name="style", required=False)
|
||||||
wrapper = MCPPromptWrapper(
|
wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2]))
|
||||||
None, "myserver", _make_prompt_def(arguments=[arg1, arg2])
|
|
||||||
)
|
|
||||||
assert wrapper.name == "mcp_myserver_prompt_myprompt"
|
assert wrapper.name == "mcp_myserver_prompt_myprompt"
|
||||||
assert "[MCP Prompt]" in wrapper.description
|
assert "[MCP Prompt]" in wrapper.description
|
||||||
assert "A test prompt" in wrapper.description
|
assert "A test prompt" in wrapper.description
|
||||||
@ -528,9 +538,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None:
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
return SimpleNamespace(messages=[])
|
return SimpleNamespace(messages=[])
|
||||||
|
|
||||||
wrapper = _make_prompt_wrapper(
|
wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01)
|
||||||
SimpleNamespace(get_prompt=get_prompt), timeout=0.01
|
|
||||||
)
|
|
||||||
result = await wrapper.execute()
|
result = await wrapper.execute()
|
||||||
assert result == "(MCP prompt call timed out after 0.01s)"
|
assert result == "(MCP prompt call timed out after 0.01s)"
|
||||||
|
|
||||||
@ -616,15 +624,11 @@ async def test_connect_registers_resources_and_prompts(
|
|||||||
prompt_names=["prompt_c"],
|
prompt_names=["prompt_c"],
|
||||||
)
|
)
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake")},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake")},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert "mcp_test_tool_a" in registry.tool_names
|
assert "mcp_test_tool_a" in registry.tool_names
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user