mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 14:23:58 +00:00
fix: cover MCP reconnect edge cases
maintainer edit: handle prompt sessions that report Connection closed outside McpError, and match reconnect registration prefixes with the same sanitization used by MCP wrapper names.
This commit is contained in:
parent
e9145b7acd
commit
d0eba7cd9d
@ -514,6 +514,13 @@ class MCPPromptWrapper(_MCPWrapperBase):
|
||||
)
|
||||
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
||||
except Exception as exc:
|
||||
if await self._refresh_session_after_termination(
|
||||
exc,
|
||||
refreshed_session,
|
||||
"prompt",
|
||||
):
|
||||
refreshed_session = True
|
||||
continue
|
||||
if _is_transient(exc):
|
||||
if not retried_transient:
|
||||
retried_transient = True
|
||||
@ -1066,10 +1073,7 @@ def _server_signature(cfg: Any) -> Any:
|
||||
|
||||
|
||||
def _tool_prefix(server_name: str) -> str:
|
||||
safe_name = "".join(ch if ch.isalnum() or ch in {"_", "-"} else "_" for ch in server_name)
|
||||
while "__" in safe_name:
|
||||
safe_name = safe_name.replace("__", "_")
|
||||
return f"mcp_{safe_name}_"
|
||||
return _sanitize_name(f"mcp_{server_name}_")
|
||||
|
||||
|
||||
def _unregister_server_tools(state: Any, registry: ToolRegistry, server_name: str) -> int:
|
||||
|
||||
@ -288,6 +288,55 @@ async def test_mcp_tool_reconnects_after_session_terminated(
|
||||
assert loop.tools.get("mcp_remote_quote") is not old_tool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_reconnect_handler_uses_sanitized_server_prefix(
|
||||
tmp_path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
loop = _make_loop(tmp_path, mcp_servers={"remote_": object()})
|
||||
connect_count = 0
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, index: int) -> None:
|
||||
self.index = index
|
||||
|
||||
async def call_tool(self, _name: str, arguments: dict[str, Any]) -> Any:
|
||||
assert arguments == {}
|
||||
if self.index == 1:
|
||||
raise McpError(ErrorData(code=-32000, message="Session terminated"))
|
||||
return SimpleNamespace(
|
||||
content=[mcp_types.TextContent(type="text", text="recovered")]
|
||||
)
|
||||
|
||||
async def _fake_connect(servers, registry):
|
||||
nonlocal connect_count
|
||||
stacks = {}
|
||||
for name in servers:
|
||||
connect_count += 1
|
||||
tool_def = SimpleNamespace(
|
||||
name="quote",
|
||||
description="quote tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
registry.register(MCPToolWrapper(_FakeSession(connect_count), name, tool_def))
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
stacks[name] = stack
|
||||
return stacks
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
||||
|
||||
await loop._connect_mcp()
|
||||
old_tool = loop.tools.get("mcp_remote_quote")
|
||||
assert isinstance(old_tool, MCPToolWrapper)
|
||||
|
||||
output = await old_tool.execute()
|
||||
|
||||
assert output == "recovered"
|
||||
assert connect_count == 2
|
||||
assert loop.tools.get("mcp_remote_quote") is not old_tool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_mcp_reconnect_reuses_fresh_session(
|
||||
tmp_path,
|
||||
|
||||
@ -464,3 +464,29 @@ async def test_prompt_reconnects_on_session_terminated():
|
||||
assert output == "fresh prompt"
|
||||
assert old_session.get_prompt.call_count == 1
|
||||
assert new_session.get_prompt.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_reconnects_on_connection_closed_exception():
|
||||
"""Prompt should reconnect when the SDK reports a closed session as a generic exception."""
|
||||
old_session = AsyncMock()
|
||||
old_session.get_prompt = AsyncMock(side_effect=RuntimeError("Connection closed"))
|
||||
new_session = AsyncMock()
|
||||
new_session.get_prompt = AsyncMock(return_value=_make_prompt_result("fresh prompt"))
|
||||
|
||||
wrapper = MCPPromptWrapper(old_session, "test_server", _make_prompt_def())
|
||||
replacement = MCPPromptWrapper(new_session, "test_server", _make_prompt_def())
|
||||
|
||||
async def reconnect(server_name: str, tool_name: str, stale_tool):
|
||||
assert server_name == "test_server"
|
||||
assert tool_name == "mcp_test_server_prompt_test_prompt"
|
||||
assert stale_tool is wrapper
|
||||
return replacement
|
||||
|
||||
wrapper.set_reconnect_handler(reconnect)
|
||||
|
||||
output = await wrapper.execute()
|
||||
|
||||
assert output == "fresh prompt"
|
||||
assert old_session.get_prompt.call_count == 1
|
||||
assert new_session.get_prompt.call_count == 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user