fix(mcp): forward prompt arg descriptions & standardise error format

- Propagate `description` from MCP prompt arguments into the JSON
  Schema so LLMs can better understand prompt parameters.
- Align generic-exception error message with tool/resource wrappers
  (drop redundant `{exc}` detail).
- Extend test fixture to mock `mcp.shared.exceptions.McpError`.
- Add tests for argument description forwarding and McpError handling.

Made-with: Cursor
This commit is contained in:
Xubin Ren 2026-04-07 16:25:20 +00:00 committed by Xubin Ren
parent 7cc527cf65
commit 8871a57b4c
2 changed files with 39 additions and 2 deletions

View File

@ -228,7 +228,10 @@ class MCPPromptWrapper(Tool):
properties: dict[str, Any] = {}
required: list[str] = []
for arg in prompt_def.arguments or []:
properties[arg.name] = {"type": "string"}
prop: dict[str, Any] = {"type": "string"}
if getattr(arg, "description", None):
prop["description"] = arg.description
properties[arg.name] = prop
if arg.required:
required.append(arg.name)
self._parameters: dict[str, Any] = {
@ -284,7 +287,7 @@ class MCPPromptWrapper(Tool):
"MCP prompt '{}' failed: {}: {}",
self._name, type(exc).__name__, exc,
)
return f"(MCP prompt call failed: {type(exc).__name__}: {exc})"
return f"(MCP prompt call failed: {type(exc).__name__})"
parts: list[str] = []
for message in result.messages:

View File

@ -93,6 +93,18 @@ def _fake_mcp_module(
monkeypatch.setitem(sys.modules, "mcp.client.sse", sse_mod)
monkeypatch.setitem(sys.modules, "mcp.client.streamable_http", streamable_http_mod)
shared_mod = ModuleType("mcp.shared")
exc_mod = ModuleType("mcp.shared.exceptions")
class _FakeMcpError(Exception):
def __init__(self, code: int = -1, message: str = "error"):
self.error = SimpleNamespace(code=code, message=message)
super().__init__(message)
exc_mod.McpError = _FakeMcpError
monkeypatch.setitem(sys.modules, "mcp.shared", shared_mod)
monkeypatch.setitem(sys.modules, "mcp.shared.exceptions", exc_mod)
def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
tool_def = SimpleNamespace(
@ -481,6 +493,15 @@ def test_prompt_wrapper_no_arguments() -> None:
assert wrapper.parameters == {"type": "object", "properties": {}, "required": []}
def test_prompt_wrapper_preserves_argument_descriptions() -> None:
arg = SimpleNamespace(name="topic", required=True, description="The subject to discuss")
wrapper = MCPPromptWrapper(None, "srv", _make_prompt_def(arguments=[arg]))
assert wrapper.parameters["properties"]["topic"] == {
"type": "string",
"description": "The subject to discuss",
}
@pytest.mark.asyncio
async def test_prompt_wrapper_execute_returns_text() -> None:
async def get_prompt(name: str, arguments: dict | None = None) -> object:
@ -514,6 +535,19 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None:
assert result == "(MCP prompt call timed out after 0.01s)"
@pytest.mark.asyncio
async def test_prompt_wrapper_execute_handles_mcp_error() -> None:
from mcp.shared.exceptions import McpError
async def get_prompt(name: str, arguments: dict | None = None) -> object:
raise McpError(code=42, message="invalid argument")
wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt))
result = await wrapper.execute()
assert "invalid argument" in result
assert "code 42" in result
@pytest.mark.asyncio
async def test_prompt_wrapper_execute_handles_error() -> None:
async def get_prompt(name: str, arguments: dict | None = None) -> object: