mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-14 14:54:06 +00:00
fix: cover MCP reconnect edge cases
maintainer edit: handle prompt sessions that report Connection closed outside McpError, and match reconnect registration prefixes with the same sanitization used by MCP wrapper names.
This commit is contained in:
parent
e9145b7acd
commit
d0eba7cd9d
@ -514,6 +514,13 @@ class MCPPromptWrapper(_MCPWrapperBase):
|
|||||||
)
|
)
|
||||||
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
if await self._refresh_session_after_termination(
|
||||||
|
exc,
|
||||||
|
refreshed_session,
|
||||||
|
"prompt",
|
||||||
|
):
|
||||||
|
refreshed_session = True
|
||||||
|
continue
|
||||||
if _is_transient(exc):
|
if _is_transient(exc):
|
||||||
if not retried_transient:
|
if not retried_transient:
|
||||||
retried_transient = True
|
retried_transient = True
|
||||||
@ -1066,10 +1073,7 @@ def _server_signature(cfg: Any) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
def _tool_prefix(server_name: str) -> str:
|
def _tool_prefix(server_name: str) -> str:
|
||||||
safe_name = "".join(ch if ch.isalnum() or ch in {"_", "-"} else "_" for ch in server_name)
|
return _sanitize_name(f"mcp_{server_name}_")
|
||||||
while "__" in safe_name:
|
|
||||||
safe_name = safe_name.replace("__", "_")
|
|
||||||
return f"mcp_{safe_name}_"
|
|
||||||
|
|
||||||
|
|
||||||
def _unregister_server_tools(state: Any, registry: ToolRegistry, server_name: str) -> int:
|
def _unregister_server_tools(state: Any, registry: ToolRegistry, server_name: str) -> int:
|
||||||
|
|||||||
@ -288,6 +288,55 @@ async def test_mcp_tool_reconnects_after_session_terminated(
|
|||||||
assert loop.tools.get("mcp_remote_quote") is not old_tool
|
assert loop.tools.get("mcp_remote_quote") is not old_tool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_reconnect_handler_uses_sanitized_server_prefix(
|
||||||
|
tmp_path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
):
|
||||||
|
loop = _make_loop(tmp_path, mcp_servers={"remote_": object()})
|
||||||
|
connect_count = 0
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self, index: int) -> None:
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
async def call_tool(self, _name: str, arguments: dict[str, Any]) -> Any:
|
||||||
|
assert arguments == {}
|
||||||
|
if self.index == 1:
|
||||||
|
raise McpError(ErrorData(code=-32000, message="Session terminated"))
|
||||||
|
return SimpleNamespace(
|
||||||
|
content=[mcp_types.TextContent(type="text", text="recovered")]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fake_connect(servers, registry):
|
||||||
|
nonlocal connect_count
|
||||||
|
stacks = {}
|
||||||
|
for name in servers:
|
||||||
|
connect_count += 1
|
||||||
|
tool_def = SimpleNamespace(
|
||||||
|
name="quote",
|
||||||
|
description="quote tool",
|
||||||
|
inputSchema={"type": "object", "properties": {}},
|
||||||
|
)
|
||||||
|
registry.register(MCPToolWrapper(_FakeSession(connect_count), name, tool_def))
|
||||||
|
stack = AsyncExitStack()
|
||||||
|
await stack.__aenter__()
|
||||||
|
stacks[name] = stack
|
||||||
|
return stacks
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
||||||
|
|
||||||
|
await loop._connect_mcp()
|
||||||
|
old_tool = loop.tools.get("mcp_remote_quote")
|
||||||
|
assert isinstance(old_tool, MCPToolWrapper)
|
||||||
|
|
||||||
|
output = await old_tool.execute()
|
||||||
|
|
||||||
|
assert output == "recovered"
|
||||||
|
assert connect_count == 2
|
||||||
|
assert loop.tools.get("mcp_remote_quote") is not old_tool
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_concurrent_mcp_reconnect_reuses_fresh_session(
|
async def test_concurrent_mcp_reconnect_reuses_fresh_session(
|
||||||
tmp_path,
|
tmp_path,
|
||||||
|
|||||||
@ -464,3 +464,29 @@ async def test_prompt_reconnects_on_session_terminated():
|
|||||||
assert output == "fresh prompt"
|
assert output == "fresh prompt"
|
||||||
assert old_session.get_prompt.call_count == 1
|
assert old_session.get_prompt.call_count == 1
|
||||||
assert new_session.get_prompt.call_count == 1
|
assert new_session.get_prompt.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_reconnects_on_connection_closed_exception():
|
||||||
|
"""Prompt should reconnect when the SDK reports a closed session as a generic exception."""
|
||||||
|
old_session = AsyncMock()
|
||||||
|
old_session.get_prompt = AsyncMock(side_effect=RuntimeError("Connection closed"))
|
||||||
|
new_session = AsyncMock()
|
||||||
|
new_session.get_prompt = AsyncMock(return_value=_make_prompt_result("fresh prompt"))
|
||||||
|
|
||||||
|
wrapper = MCPPromptWrapper(old_session, "test_server", _make_prompt_def())
|
||||||
|
replacement = MCPPromptWrapper(new_session, "test_server", _make_prompt_def())
|
||||||
|
|
||||||
|
async def reconnect(server_name: str, tool_name: str, stale_tool):
|
||||||
|
assert server_name == "test_server"
|
||||||
|
assert tool_name == "mcp_test_server_prompt_test_prompt"
|
||||||
|
assert stale_tool is wrapper
|
||||||
|
return replacement
|
||||||
|
|
||||||
|
wrapper.set_reconnect_handler(reconnect)
|
||||||
|
|
||||||
|
output = await wrapper.execute()
|
||||||
|
|
||||||
|
assert output == "fresh prompt"
|
||||||
|
assert old_session.get_prompt.call_count == 1
|
||||||
|
assert new_session.get_prompt.call_count == 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user