mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-27 13:25:52 +00:00
fix(mcp): support multiple MCP servers by connecting each in isolated task
Each MCP server now connects in its own asyncio.Task to isolate anyio cancel scopes and prevent 'exit cancel scope in different task' errors when multiple servers (especially mixed transport types) are configured. Changes: - connect_mcp_servers() returns dict[str, AsyncExitStack] instead of None - Each server runs in separate task via asyncio.gather() - AgentLoop uses _mcp_stacks dict to track per-server stacks - Tests updated to handle new API
This commit is contained in:
parent
9bccfa63d2
commit
a167959027
@ -43,6 +43,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
UNIFIED_SESSION_KEY = "unified:default"
|
UNIFIED_SESSION_KEY = "unified:default"
|
||||||
|
|
||||||
|
|
||||||
class _LoopHook(AgentHook):
|
class _LoopHook(AgentHook):
|
||||||
"""Core hook for the main loop."""
|
"""Core hook for the main loop."""
|
||||||
|
|
||||||
@ -76,7 +77,7 @@ class _LoopHook(AgentHook):
|
|||||||
prev_clean = strip_think(self._stream_buf)
|
prev_clean = strip_think(self._stream_buf)
|
||||||
self._stream_buf += delta
|
self._stream_buf += delta
|
||||||
new_clean = strip_think(self._stream_buf)
|
new_clean = strip_think(self._stream_buf)
|
||||||
incremental = new_clean[len(prev_clean):]
|
incremental = new_clean[len(prev_clean) :]
|
||||||
if incremental and self._on_stream:
|
if incremental and self._on_stream:
|
||||||
await self._on_stream(incremental)
|
await self._on_stream(incremental)
|
||||||
|
|
||||||
@ -112,6 +113,7 @@ class _LoopHook(AgentHook):
|
|||||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||||
return self._loop._strip_think(content)
|
return self._loop._strip_think(content)
|
||||||
|
|
||||||
|
|
||||||
class AgentLoop:
|
class AgentLoop:
|
||||||
"""
|
"""
|
||||||
The agent loop is the core processing engine.
|
The agent loop is the core processing engine.
|
||||||
@ -196,7 +198,7 @@ class AgentLoop:
|
|||||||
self._unified_session = unified_session
|
self._unified_session = unified_session
|
||||||
self._running = False
|
self._running = False
|
||||||
self._mcp_servers = mcp_servers or {}
|
self._mcp_servers = mcp_servers or {}
|
||||||
self._mcp_stack: AsyncExitStack | None = None
|
self._mcp_stacks: dict[str, AsyncExitStack] = {}
|
||||||
self._mcp_connected = False
|
self._mcp_connected = False
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
@ -228,24 +230,34 @@ class AgentLoop:
|
|||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
"""Register the default set of tools."""
|
"""Register the default set of tools."""
|
||||||
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
allowed_dir = (
|
||||||
|
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||||
|
)
|
||||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||||
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
|
self.tools.register(
|
||||||
|
ReadFileTool(
|
||||||
|
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
|
||||||
|
)
|
||||||
|
)
|
||||||
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
||||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
for cls in (GlobTool, GrepTool):
|
for cls in (GlobTool, GrepTool):
|
||||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
if self.exec_config.enable:
|
if self.exec_config.enable:
|
||||||
self.tools.register(ExecTool(
|
self.tools.register(
|
||||||
working_dir=str(self.workspace),
|
ExecTool(
|
||||||
timeout=self.exec_config.timeout,
|
working_dir=str(self.workspace),
|
||||||
restrict_to_workspace=self.restrict_to_workspace,
|
timeout=self.exec_config.timeout,
|
||||||
sandbox=self.exec_config.sandbox,
|
restrict_to_workspace=self.restrict_to_workspace,
|
||||||
path_append=self.exec_config.path_append,
|
sandbox=self.exec_config.sandbox,
|
||||||
allowed_env_keys=self.exec_config.allowed_env_keys,
|
path_append=self.exec_config.path_append,
|
||||||
))
|
allowed_env_keys=self.exec_config.allowed_env_keys,
|
||||||
|
)
|
||||||
|
)
|
||||||
if self.web_config.enable:
|
if self.web_config.enable:
|
||||||
self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy))
|
self.tools.register(
|
||||||
|
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
|
||||||
|
)
|
||||||
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
||||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||||
self.tools.register(SpawnTool(manager=self.subagents))
|
self.tools.register(SpawnTool(manager=self.subagents))
|
||||||
@ -260,19 +272,16 @@ class AgentLoop:
|
|||||||
return
|
return
|
||||||
self._mcp_connecting = True
|
self._mcp_connecting = True
|
||||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._mcp_stack = AsyncExitStack()
|
self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools)
|
||||||
await self._mcp_stack.__aenter__()
|
|
||||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
|
||||||
self._mcp_connected = True
|
self._mcp_connected = True
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.warning("MCP connection cancelled (will retry next message)")
|
||||||
|
self._mcp_stacks.clear()
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||||
if self._mcp_stack:
|
self._mcp_stacks.clear()
|
||||||
try:
|
|
||||||
await self._mcp_stack.aclose()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._mcp_stack = None
|
|
||||||
finally:
|
finally:
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
|
|
||||||
@ -289,6 +298,7 @@ class AgentLoop:
|
|||||||
if not text:
|
if not text:
|
||||||
return None
|
return None
|
||||||
from nanobot.utils.helpers import strip_think
|
from nanobot.utils.helpers import strip_think
|
||||||
|
|
||||||
return strip_think(text) or None
|
return strip_think(text) or None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -327,9 +337,7 @@ class AgentLoop:
|
|||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
)
|
)
|
||||||
hook: AgentHook = (
|
hook: AgentHook = (
|
||||||
CompositeHook([loop_hook] + self._extra_hooks)
|
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
|
||||||
if self._extra_hooks
|
|
||||||
else loop_hook
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _checkpoint(payload: dict[str, Any]) -> None:
|
async def _checkpoint(payload: dict[str, Any]) -> None:
|
||||||
@ -337,23 +345,25 @@ class AgentLoop:
|
|||||||
return
|
return
|
||||||
self._set_runtime_checkpoint(session, payload)
|
self._set_runtime_checkpoint(session, payload)
|
||||||
|
|
||||||
result = await self.runner.run(AgentRunSpec(
|
result = await self.runner.run(
|
||||||
initial_messages=initial_messages,
|
AgentRunSpec(
|
||||||
tools=self.tools,
|
initial_messages=initial_messages,
|
||||||
model=self.model,
|
tools=self.tools,
|
||||||
max_iterations=self.max_iterations,
|
model=self.model,
|
||||||
max_tool_result_chars=self.max_tool_result_chars,
|
max_iterations=self.max_iterations,
|
||||||
hook=hook,
|
max_tool_result_chars=self.max_tool_result_chars,
|
||||||
error_message="Sorry, I encountered an error calling the AI model.",
|
hook=hook,
|
||||||
concurrent_tools=True,
|
error_message="Sorry, I encountered an error calling the AI model.",
|
||||||
workspace=self.workspace,
|
concurrent_tools=True,
|
||||||
session_key=session.key if session else None,
|
workspace=self.workspace,
|
||||||
context_window_tokens=self.context_window_tokens,
|
session_key=session.key if session else None,
|
||||||
context_block_limit=self.context_block_limit,
|
context_window_tokens=self.context_window_tokens,
|
||||||
provider_retry_mode=self.provider_retry_mode,
|
context_block_limit=self.context_block_limit,
|
||||||
progress_callback=on_progress,
|
provider_retry_mode=self.provider_retry_mode,
|
||||||
checkpoint_callback=_checkpoint,
|
progress_callback=on_progress,
|
||||||
))
|
checkpoint_callback=_checkpoint,
|
||||||
|
)
|
||||||
|
)
|
||||||
self._last_usage = result.usage
|
self._last_usage = result.usage
|
||||||
if result.stop_reason == "max_iterations":
|
if result.stop_reason == "max_iterations":
|
||||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||||
@ -391,10 +401,19 @@ class AgentLoop:
|
|||||||
continue
|
continue
|
||||||
# Compute the effective session key before dispatching
|
# Compute the effective session key before dispatching
|
||||||
# This ensures /stop command can find tasks correctly when unified session is enabled
|
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||||
effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key
|
effective_key = (
|
||||||
|
UNIFIED_SESSION_KEY
|
||||||
|
if self._unified_session and not msg.session_key_override
|
||||||
|
else msg.session_key
|
||||||
|
)
|
||||||
task = asyncio.create_task(self._dispatch(msg))
|
task = asyncio.create_task(self._dispatch(msg))
|
||||||
self._active_tasks.setdefault(effective_key, []).append(task)
|
self._active_tasks.setdefault(effective_key, []).append(task)
|
||||||
task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
task.add_done_callback(
|
||||||
|
lambda t, k=effective_key: self._active_tasks.get(k, [])
|
||||||
|
and self._active_tasks[k].remove(t)
|
||||||
|
if t in self._active_tasks.get(k, [])
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||||
"""Process a message: per-session serial, cross-session concurrent."""
|
"""Process a message: per-session serial, cross-session concurrent."""
|
||||||
@ -417,11 +436,14 @@ class AgentLoop:
|
|||||||
meta = dict(msg.metadata or {})
|
meta = dict(msg.metadata or {})
|
||||||
meta["_stream_delta"] = True
|
meta["_stream_delta"] = True
|
||||||
meta["_stream_id"] = _current_stream_id()
|
meta["_stream_id"] = _current_stream_id()
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
OutboundMessage(
|
||||||
content=delta,
|
channel=msg.channel,
|
||||||
metadata=meta,
|
chat_id=msg.chat_id,
|
||||||
))
|
content=delta,
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||||
nonlocal stream_segment
|
nonlocal stream_segment
|
||||||
@ -429,44 +451,56 @@ class AgentLoop:
|
|||||||
meta["_stream_end"] = True
|
meta["_stream_end"] = True
|
||||||
meta["_resuming"] = resuming
|
meta["_resuming"] = resuming
|
||||||
meta["_stream_id"] = _current_stream_id()
|
meta["_stream_id"] = _current_stream_id()
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
OutboundMessage(
|
||||||
content="",
|
channel=msg.channel,
|
||||||
metadata=meta,
|
chat_id=msg.chat_id,
|
||||||
))
|
content="",
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
)
|
||||||
stream_segment += 1
|
stream_segment += 1
|
||||||
|
|
||||||
response = await self._process_message(
|
response = await self._process_message(
|
||||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
msg,
|
||||||
|
on_stream=on_stream,
|
||||||
|
on_stream_end=on_stream_end,
|
||||||
)
|
)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
await self.bus.publish_outbound(response)
|
await self.bus.publish_outbound(response)
|
||||||
elif msg.channel == "cli":
|
elif msg.channel == "cli":
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
OutboundMessage(
|
||||||
content="", metadata=msg.metadata or {},
|
channel=msg.channel,
|
||||||
))
|
chat_id=msg.chat_id,
|
||||||
|
content="",
|
||||||
|
metadata=msg.metadata or {},
|
||||||
|
)
|
||||||
|
)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Task cancelled for session {}", msg.session_key)
|
logger.info("Task cancelled for session {}", msg.session_key)
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error processing message for session {}", msg.session_key)
|
logger.exception("Error processing message for session {}", msg.session_key)
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
OutboundMessage(
|
||||||
content="Sorry, I encountered an error.",
|
channel=msg.channel,
|
||||||
))
|
chat_id=msg.chat_id,
|
||||||
|
content="Sorry, I encountered an error.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
"""Drain pending background archives, then close MCP connections."""
|
"""Drain pending background archives, then close MCP connections."""
|
||||||
if self._background_tasks:
|
if self._background_tasks:
|
||||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||||
self._background_tasks.clear()
|
self._background_tasks.clear()
|
||||||
if self._mcp_stack:
|
for name, stack in self._mcp_stacks.items():
|
||||||
try:
|
try:
|
||||||
await self._mcp_stack.aclose()
|
await stack.aclose()
|
||||||
except (RuntimeError, BaseExceptionGroup):
|
except (RuntimeError, BaseExceptionGroup):
|
||||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
logger.debug("MCP server '{}' cleanup error (can be ignored)", name)
|
||||||
self._mcp_stack = None
|
self._mcp_stacks.clear()
|
||||||
|
|
||||||
def _schedule_background(self, coro) -> None:
|
def _schedule_background(self, coro) -> None:
|
||||||
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
||||||
@ -490,8 +524,9 @@ class AgentLoop:
|
|||||||
"""Process a single inbound message and return the response."""
|
"""Process a single inbound message and return the response."""
|
||||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||||
if msg.channel == "system":
|
if msg.channel == "system":
|
||||||
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
channel, chat_id = (
|
||||||
else ("cli", msg.chat_id))
|
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
|
||||||
|
)
|
||||||
logger.info("Processing system message from {}", msg.sender_id)
|
logger.info("Processing system message from {}", msg.sender_id)
|
||||||
key = f"{channel}:{chat_id}"
|
key = f"{channel}:{chat_id}"
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
@ -503,19 +538,27 @@ class AgentLoop:
|
|||||||
current_role = "assistant" if msg.sender_id == "subagent" else "user"
|
current_role = "assistant" if msg.sender_id == "subagent" else "user"
|
||||||
messages = self.context.build_messages(
|
messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
current_message=msg.content,
|
||||||
|
channel=channel,
|
||||||
|
chat_id=chat_id,
|
||||||
current_role=current_role,
|
current_role=current_role,
|
||||||
)
|
)
|
||||||
final_content, _, all_msgs, _ = await self._run_agent_loop(
|
final_content, _, all_msgs, _ = await self._run_agent_loop(
|
||||||
messages, session=session, channel=channel, chat_id=chat_id,
|
messages,
|
||||||
|
session=session,
|
||||||
|
channel=channel,
|
||||||
|
chat_id=chat_id,
|
||||||
message_id=msg.metadata.get("message_id"),
|
message_id=msg.metadata.get("message_id"),
|
||||||
)
|
)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self._clear_runtime_checkpoint(session)
|
self._clear_runtime_checkpoint(session)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
return OutboundMessage(
|
||||||
content=final_content or "Background task completed.")
|
channel=channel,
|
||||||
|
chat_id=chat_id,
|
||||||
|
content=final_content or "Background task completed.",
|
||||||
|
)
|
||||||
|
|
||||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
@ -543,16 +586,22 @@ class AgentLoop:
|
|||||||
history=history,
|
history=history,
|
||||||
current_message=msg.content,
|
current_message=msg.content,
|
||||||
media=msg.media if msg.media else None,
|
media=msg.media if msg.media else None,
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||||
meta = dict(msg.metadata or {})
|
meta = dict(msg.metadata or {})
|
||||||
meta["_progress"] = True
|
meta["_progress"] = True
|
||||||
meta["_tool_hint"] = tool_hint
|
meta["_tool_hint"] = tool_hint
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
OutboundMessage(
|
||||||
))
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
content=content,
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
final_content, _, all_msgs, stop_reason = await self._run_agent_loop(
|
final_content, _, all_msgs, stop_reason = await self._run_agent_loop(
|
||||||
initial_messages,
|
initial_messages,
|
||||||
@ -560,7 +609,8 @@ class AgentLoop:
|
|||||||
on_stream=on_stream,
|
on_stream=on_stream,
|
||||||
on_stream_end=on_stream_end,
|
on_stream_end=on_stream_end,
|
||||||
session=session,
|
session=session,
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
message_id=msg.metadata.get("message_id"),
|
message_id=msg.metadata.get("message_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -582,7 +632,9 @@ class AgentLoop:
|
|||||||
if on_stream is not None and stop_reason != "error":
|
if on_stream is not None and stop_reason != "error":
|
||||||
meta["_streamed"] = True
|
meta["_streamed"] = True
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
content=final_content,
|
||||||
metadata=meta,
|
metadata=meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -608,10 +660,9 @@ class AgentLoop:
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
if block.get("type") == "image_url" and block.get("image_url", {}).get(
|
||||||
block.get("type") == "image_url"
|
"url", ""
|
||||||
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
).startswith("data:image/"):
|
||||||
):
|
|
||||||
path = (block.get("_meta") or {}).get("path", "")
|
path = (block.get("_meta") or {}).get("path", "")
|
||||||
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||||
continue
|
continue
|
||||||
@ -630,6 +681,7 @@ class AgentLoop:
|
|||||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||||
"""Save new-turn messages into session, truncating large tool results."""
|
"""Save new-turn messages into session, truncating large tool results."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
for m in messages[skip:]:
|
for m in messages[skip:]:
|
||||||
entry = dict(m)
|
entry = dict(m)
|
||||||
role, content = entry.get("role"), entry.get("content")
|
role, content = entry.get("role"), entry.get("content")
|
||||||
@ -644,7 +696,9 @@ class AgentLoop:
|
|||||||
continue
|
continue
|
||||||
entry["content"] = filtered
|
entry["content"] = filtered
|
||||||
elif role == "user":
|
elif role == "user":
|
||||||
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
if isinstance(content, str) and content.startswith(
|
||||||
|
ContextBuilder._RUNTIME_CONTEXT_TAG
|
||||||
|
):
|
||||||
# Strip the runtime-context prefix, keep only the user text.
|
# Strip the runtime-context prefix, keep only the user text.
|
||||||
parts = content.split("\n\n", 1)
|
parts = content.split("\n\n", 1)
|
||||||
if len(parts) > 1 and parts[1].strip():
|
if len(parts) > 1 and parts[1].strip():
|
||||||
@ -708,13 +762,15 @@ class AgentLoop:
|
|||||||
continue
|
continue
|
||||||
tool_id = tool_call.get("id")
|
tool_id = tool_call.get("id")
|
||||||
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||||
restored_messages.append({
|
restored_messages.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tool_id,
|
"role": "tool",
|
||||||
"name": name,
|
"tool_call_id": tool_id,
|
||||||
"content": "Error: Task interrupted before this tool finished.",
|
"name": name,
|
||||||
"timestamp": datetime.now().isoformat(),
|
"content": "Error: Task interrupted before this tool finished.",
|
||||||
})
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
overlap = 0
|
overlap = 0
|
||||||
max_overlap = min(len(session.messages), len(restored_messages))
|
max_overlap = min(len(session.messages), len(restored_messages))
|
||||||
@ -746,6 +802,9 @@ class AgentLoop:
|
|||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||||
return await self._process_message(
|
return await self._process_message(
|
||||||
msg, session_key=session_key, on_progress=on_progress,
|
msg,
|
||||||
on_stream=on_stream, on_stream_end=on_stream_end,
|
session_key=session_key,
|
||||||
|
on_progress=on_progress,
|
||||||
|
on_stream=on_stream,
|
||||||
|
on_stream_end=on_stream_end,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
|
|||||||
|
|
||||||
if "properties" in normalized and isinstance(normalized["properties"], dict):
|
if "properties" in normalized and isinstance(normalized["properties"], dict):
|
||||||
normalized["properties"] = {
|
normalized["properties"] = {
|
||||||
name: _normalize_schema_for_openai(prop)
|
name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop
|
||||||
if isinstance(prop, dict)
|
|
||||||
else prop
|
|
||||||
for name, prop in normalized["properties"].items()
|
for name, prop in normalized["properties"].items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,9 +136,7 @@ class MCPToolWrapper(Tool):
|
|||||||
class MCPResourceWrapper(Tool):
|
class MCPResourceWrapper(Tool):
|
||||||
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
|
||||||
self, session, server_name: str, resource_def, resource_timeout: int = 30
|
|
||||||
):
|
|
||||||
self._session = session
|
self._session = session
|
||||||
self._uri = resource_def.uri
|
self._uri = resource_def.uri
|
||||||
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
|
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
|
||||||
@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool):
|
|||||||
class MCPPromptWrapper(Tool):
|
class MCPPromptWrapper(Tool):
|
||||||
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
|
||||||
self, session, server_name: str, prompt_def, prompt_timeout: int = 30
|
|
||||||
):
|
|
||||||
self._session = session
|
self._session = session
|
||||||
self._prompt_name = prompt_def.name
|
self._prompt_name = prompt_def.name
|
||||||
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
|
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
|
||||||
@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool):
|
|||||||
timeout=self._prompt_timeout,
|
timeout=self._prompt_timeout,
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(
|
logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout)
|
||||||
"MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout
|
|
||||||
)
|
|
||||||
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
|
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
task = asyncio.current_task()
|
task = asyncio.current_task()
|
||||||
@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool):
|
|||||||
except McpError as exc:
|
except McpError as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
"MCP prompt '{}' failed: code={} message={}",
|
"MCP prompt '{}' failed: code={} message={}",
|
||||||
self._name, exc.error.code, exc.error.message,
|
self._name,
|
||||||
|
exc.error.code,
|
||||||
|
exc.error.message,
|
||||||
)
|
)
|
||||||
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"MCP prompt '{}' failed: {}: {}",
|
"MCP prompt '{}' failed: {}: {}",
|
||||||
self._name, type(exc).__name__, exc,
|
self._name,
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
)
|
)
|
||||||
return f"(MCP prompt call failed: {type(exc).__name__})"
|
return f"(MCP prompt call failed: {type(exc).__name__})"
|
||||||
|
|
||||||
@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool):
|
|||||||
|
|
||||||
|
|
||||||
async def connect_mcp_servers(
|
async def connect_mcp_servers(
|
||||||
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
|
mcp_servers: dict, registry: ToolRegistry
|
||||||
) -> None:
|
) -> dict[str, AsyncExitStack]:
|
||||||
"""Connect to configured MCP servers and register their tools, resources, and prompts."""
|
"""Connect to configured MCP servers and register their tools, resources, prompts.
|
||||||
|
|
||||||
|
Returns a dict mapping server name -> its dedicated AsyncExitStack.
|
||||||
|
Each server gets its own stack and runs in its own task to prevent
|
||||||
|
cancel scope conflicts when multiple MCP servers are configured.
|
||||||
|
"""
|
||||||
from mcp import ClientSession, StdioServerParameters
|
from mcp import ClientSession, StdioServerParameters
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.client.stdio import stdio_client
|
from mcp.client.stdio import stdio_client
|
||||||
from mcp.client.streamable_http import streamable_http_client
|
from mcp.client.streamable_http import streamable_http_client
|
||||||
|
|
||||||
for name, cfg in mcp_servers.items():
|
async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]:
|
||||||
|
server_stack = AsyncExitStack()
|
||||||
|
await server_stack.__aenter__()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
transport_type = cfg.type
|
transport_type = cfg.type
|
||||||
if not transport_type:
|
if not transport_type:
|
||||||
if cfg.command:
|
if cfg.command:
|
||||||
transport_type = "stdio"
|
transport_type = "stdio"
|
||||||
elif cfg.url:
|
elif cfg.url:
|
||||||
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
|
|
||||||
transport_type = (
|
transport_type = (
|
||||||
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||||
continue
|
await server_stack.aclose()
|
||||||
|
return name, None
|
||||||
|
|
||||||
if transport_type == "stdio":
|
if transport_type == "stdio":
|
||||||
params = StdioServerParameters(
|
params = StdioServerParameters(
|
||||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||||
)
|
)
|
||||||
read, write = await stack.enter_async_context(stdio_client(params))
|
read, write = await server_stack.enter_async_context(stdio_client(params))
|
||||||
elif transport_type == "sse":
|
elif transport_type == "sse":
|
||||||
|
|
||||||
def httpx_client_factory(
|
def httpx_client_factory(
|
||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
timeout: httpx.Timeout | None = None,
|
timeout: httpx.Timeout | None = None,
|
||||||
@ -353,27 +358,26 @@ async def connect_mcp_servers(
|
|||||||
auth=auth,
|
auth=auth,
|
||||||
)
|
)
|
||||||
|
|
||||||
read, write = await stack.enter_async_context(
|
read, write = await server_stack.enter_async_context(
|
||||||
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
||||||
)
|
)
|
||||||
elif transport_type == "streamableHttp":
|
elif transport_type == "streamableHttp":
|
||||||
# Always provide an explicit httpx client so MCP HTTP transport does not
|
http_client = await server_stack.enter_async_context(
|
||||||
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
|
||||||
http_client = await stack.enter_async_context(
|
|
||||||
httpx.AsyncClient(
|
httpx.AsyncClient(
|
||||||
headers=cfg.headers or None,
|
headers=cfg.headers or None,
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
read, write, _ = await stack.enter_async_context(
|
read, write, _ = await server_stack.enter_async_context(
|
||||||
streamable_http_client(cfg.url, http_client=http_client)
|
streamable_http_client(cfg.url, http_client=http_client)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
||||||
continue
|
await server_stack.aclose()
|
||||||
|
return name, None
|
||||||
|
|
||||||
session = await stack.enter_async_context(ClientSession(read, write))
|
session = await server_stack.enter_async_context(ClientSession(read, write))
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
tools = await session.list_tools()
|
tools = await session.list_tools()
|
||||||
@ -418,7 +422,6 @@ async def connect_mcp_servers(
|
|||||||
", ".join(available_wrapped_names) or "(none)",
|
", ".join(available_wrapped_names) or "(none)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Register resources ---
|
|
||||||
try:
|
try:
|
||||||
resources_result = await session.list_resources()
|
resources_result = await session.list_resources()
|
||||||
for resource in resources_result.resources:
|
for resource in resources_result.resources:
|
||||||
@ -433,7 +436,6 @@ async def connect_mcp_servers(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
|
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
|
||||||
|
|
||||||
# --- Register prompts ---
|
|
||||||
try:
|
try:
|
||||||
prompts_result = await session.list_prompts()
|
prompts_result = await session.list_prompts()
|
||||||
for prompt in prompts_result.prompts:
|
for prompt in prompts_result.prompts:
|
||||||
@ -442,14 +444,38 @@ async def connect_mcp_servers(
|
|||||||
)
|
)
|
||||||
registry.register(wrapper)
|
registry.register(wrapper)
|
||||||
registered_count += 1
|
registered_count += 1
|
||||||
logger.debug(
|
logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name)
|
||||||
"MCP: registered prompt '{}' from server '{}'", wrapper.name, name
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
|
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"MCP server '{}': connected, {} capabilities registered", name, registered_count
|
"MCP server '{}': connected, {} capabilities registered", name, registered_count
|
||||||
)
|
)
|
||||||
|
return name, server_stack
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||||
|
try:
|
||||||
|
await server_stack.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return name, None
|
||||||
|
|
||||||
|
server_stacks: dict[str, AsyncExitStack] = {}
|
||||||
|
|
||||||
|
tasks: list[asyncio.Task] = []
|
||||||
|
for name, cfg in mcp_servers.items():
|
||||||
|
task = asyncio.create_task(connect_single_server(name, cfg))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
name = list(mcp_servers.keys())[i]
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
if not isinstance(result, asyncio.CancelledError):
|
||||||
|
logger.error("MCP server '{}' connection task failed: {}", name, result)
|
||||||
|
elif result is not None and result[1] is not None:
|
||||||
|
server_stacks[result[0]] = result[1]
|
||||||
|
|
||||||
|
return server_stacks
|
||||||
|
|||||||
@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
|
|||||||
) -> None:
|
) -> None:
|
||||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == ["mcp_test_demo"]
|
assert registry.tool_names == ["mcp_test_demo"]
|
||||||
@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
|
|||||||
) -> None:
|
) -> None:
|
||||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake")},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake")},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
||||||
@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
|
|||||||
) -> None:
|
) -> None:
|
||||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == ["mcp_test_demo"]
|
assert registry.tool_names == ["mcp_test_demo"]
|
||||||
@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
|
|||||||
) -> None:
|
) -> None:
|
||||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == []
|
assert registry.tool_names == []
|
||||||
@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
|||||||
|
|
||||||
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
||||||
|
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == []
|
assert registry.tool_names == []
|
||||||
@ -389,9 +369,7 @@ def _make_resource_def(
|
|||||||
return SimpleNamespace(name=name, uri=uri, description=description)
|
return SimpleNamespace(name=name, uri=uri, description=description)
|
||||||
|
|
||||||
|
|
||||||
def _make_resource_wrapper(
|
def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper:
|
||||||
session: object, *, timeout: float = 0.1
|
|
||||||
) -> MCPResourceWrapper:
|
|
||||||
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
|
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
@ -434,9 +412,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None:
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
return SimpleNamespace(contents=[])
|
return SimpleNamespace(contents=[])
|
||||||
|
|
||||||
wrapper = _make_resource_wrapper(
|
wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01)
|
||||||
SimpleNamespace(read_resource=read_resource), timeout=0.01
|
|
||||||
)
|
|
||||||
result = await wrapper.execute()
|
result = await wrapper.execute()
|
||||||
assert result == "(MCP resource read timed out after 0.01s)"
|
assert result == "(MCP resource read timed out after 0.01s)"
|
||||||
|
|
||||||
@ -464,20 +440,14 @@ def _make_prompt_def(
|
|||||||
return SimpleNamespace(name=name, description=description, arguments=arguments)
|
return SimpleNamespace(name=name, description=description, arguments=arguments)
|
||||||
|
|
||||||
|
|
||||||
def _make_prompt_wrapper(
|
def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper:
|
||||||
session: object, *, timeout: float = 0.1
|
return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout)
|
||||||
) -> MCPPromptWrapper:
|
|
||||||
return MCPPromptWrapper(
|
|
||||||
session, "srv", _make_prompt_def(), prompt_timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_wrapper_properties() -> None:
|
def test_prompt_wrapper_properties() -> None:
|
||||||
arg1 = SimpleNamespace(name="topic", required=True)
|
arg1 = SimpleNamespace(name="topic", required=True)
|
||||||
arg2 = SimpleNamespace(name="style", required=False)
|
arg2 = SimpleNamespace(name="style", required=False)
|
||||||
wrapper = MCPPromptWrapper(
|
wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2]))
|
||||||
None, "myserver", _make_prompt_def(arguments=[arg1, arg2])
|
|
||||||
)
|
|
||||||
assert wrapper.name == "mcp_myserver_prompt_myprompt"
|
assert wrapper.name == "mcp_myserver_prompt_myprompt"
|
||||||
assert "[MCP Prompt]" in wrapper.description
|
assert "[MCP Prompt]" in wrapper.description
|
||||||
assert "A test prompt" in wrapper.description
|
assert "A test prompt" in wrapper.description
|
||||||
@ -528,9 +498,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None:
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
return SimpleNamespace(messages=[])
|
return SimpleNamespace(messages=[])
|
||||||
|
|
||||||
wrapper = _make_prompt_wrapper(
|
wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01)
|
||||||
SimpleNamespace(get_prompt=get_prompt), timeout=0.01
|
|
||||||
)
|
|
||||||
result = await wrapper.execute()
|
result = await wrapper.execute()
|
||||||
assert result == "(MCP prompt call timed out after 0.01s)"
|
assert result == "(MCP prompt call timed out after 0.01s)"
|
||||||
|
|
||||||
@ -616,15 +584,11 @@ async def test_connect_registers_resources_and_prompts(
|
|||||||
prompt_names=["prompt_c"],
|
prompt_names=["prompt_c"],
|
||||||
)
|
)
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake")},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake")},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert "mcp_test_tool_a" in registry.tool_names
|
assert "mcp_test_tool_a" in registry.tool_names
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user