diff --git a/python/semantic_kernel/connectors/mcp.py b/python/semantic_kernel/connectors/mcp.py index 5f31886b28fa..6d7f8d2e182d 100644 --- a/python/semantic_kernel/connectors/mcp.py +++ b/python/semantic_kernel/connectors/mcp.py @@ -6,7 +6,7 @@ import re import sys from abc import abstractmethod -from collections.abc import Callable, Sequence +from collections.abc import Awaitable, Callable, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, _AsyncGeneratorContextManager from datetime import timedelta from functools import partial @@ -59,6 +59,8 @@ # region: Helpers +SamplingConsentCallback = Callable[[str, types.CreateMessageRequestParams], Awaitable[bool]] + LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = { "debug": logging.DEBUG, "info": logging.INFO, @@ -243,8 +245,23 @@ def __init__( session: ClientSession | None = None, kernel: Kernel | None = None, request_timeout: int | None = None, + sampling_consent_callback: SamplingConsentCallback | None = None, ) -> None: - """Initialize the MCP Plugin Base.""" + """Initialize the MCP Plugin Base. + + Args: + name: The name of the plugin. + description: The description of the plugin. + load_tools: Whether to load tools from the MCP server. + load_prompts: Whether to load prompts from the MCP server. + session: The session to use for the MCP connection. + kernel: The kernel instance with one or more Chat Completion clients. + request_timeout: The default timeout used for all requests. + sampling_consent_callback: Optional callback for approving MCP sampling requests. + Receives the plugin name and MCP sampling request params. Return + False to deny the request. When omitted, sampling requests are + auto-approved and a warning is logged. + """ self.name = name self.description = description self.load_tools_flag = load_tools @@ -253,6 +270,9 @@ def __init__( self.session = session self.kernel = kernel or None self.request_timeout = request_timeout + self.sampling_consent_callback = sampling_consent_callback + self._sampling_auto_approved_warning_logged = False + self._mcp_reserved_attribute_names: set[str] | None = None self._current_task: asyncio.Task | None = None self._stop_event: asyncio.Event | None = None @@ -361,9 +381,24 @@ async def sampling_callback( This function is called when the MCP server needs to get a message completed. - This is a simple version of this function, it can be overridden to allow more complex sampling. - It get's added to the session at initialization time, so overriding it is the best way to do this. + If a sampling consent callback is configured, it is called before forwarding the request to the configured + chat completion service. Returning False denies the request. If no callback is configured, requests are + auto-approved and a warning is logged. """ + if self.sampling_consent_callback is None: + if not self._sampling_auto_approved_warning_logged: + logger.warning( + "MCP sampling request for plugin '%s' was auto-approved because no sampling consent callback " + "was configured.", + self.name, + ) + self._sampling_auto_approved_warning_logged = True + elif not await self._is_sampling_approved(params): + return types.ErrorData( + code=types.INTERNAL_ERROR, + message="Sampling denied by policy.", + ) + if not self.kernel or not self.kernel.services: return types.ErrorData( code=types.INTERNAL_ERROR, @@ -431,6 +466,15 @@ async def sampling_callback( model=service.ai_model_id, ) + async def _is_sampling_approved(self, params: types.CreateMessageRequestParams) -> bool: + if self.sampling_consent_callback is None: + return True + try: + return await self.sampling_consent_callback(self.name, params) + except Exception: + logger.exception("MCP sampling consent callback failed for plugin '%s'.", self.name) + return False + async def logging_callback(self, params: types.LoggingMessageNotificationParams) -> None: """Callback function for logging. @@ -464,6 +508,19 @@ async def message_handler( case "notifications/prompts/list_changed": await self.load_prompts() + def _has_mcp_function_name_conflict(self, item_type: str, remote_name: str, local_name: str) -> bool: + if self._mcp_reserved_attribute_names is None: + self._mcp_reserved_attribute_names = set(dir(self)) + if local_name not in self._mcp_reserved_attribute_names: + return False + logger.warning( + "Skipping MCP %s '%s' because normalized name '%s' conflicts with an existing plugin attribute.", + item_type, + remote_name, + local_name, + ) + return True + async def load_prompts(self): """Load prompts from the MCP server.""" try: @@ -472,6 +529,8 @@ async def load_prompts(self): prompt_list = None for prompt in prompt_list.prompts if prompt_list else []: local_name = _normalize_mcp_name(prompt.name) + if self._has_mcp_function_name_conflict("prompt", prompt.name, local_name): + continue func = kernel_function(name=local_name, description=prompt.description)( partial(self.get_prompt, prompt.name) ) @@ -484,9 +543,11 @@ async def load_tools(self): tool_list = await self.session.list_tools() except Exception: tool_list = None - # Create methods with the kernel_function decorator for each tool + # Create methods with the kernel_function decorator for each tool for tool in tool_list.tools if tool_list else []: local_name = _normalize_mcp_name(tool.name) + if self._has_mcp_function_name_conflict("tool", tool.name, local_name): + continue func = kernel_function(name=local_name, description=tool.description)(partial(self.call_tool, tool.name)) func.__kernel_function_parameters__ = _get_parameter_dicts_from_mcp_tool(tool) setattr(self, local_name, func) @@ -558,6 +619,7 @@ def __init__( env: dict[str, str] | None = None, encoding: str | None = None, kernel: Kernel | None = None, + sampling_consent_callback: SamplingConsentCallback | None = None, **kwargs: Any, ) -> None: """Initialize the MCP stdio plugin. @@ -579,6 +641,10 @@ def __init__( env: The environment variables to set for the command. encoding: The encoding to use for the command output. kernel: The kernel instance with one or more Chat Completion clients. + sampling_consent_callback: Optional callback for approving MCP sampling requests. + Receives the plugin name and MCP sampling request params. Return + False to deny the request. When omitted, sampling requests are + auto-approved and a warning is logged. kwargs: Any extra arguments to pass to the stdio client. """ @@ -590,6 +656,7 @@ def __init__( load_tools=load_tools, load_prompts=load_prompts, request_timeout=request_timeout, + sampling_consent_callback=sampling_consent_callback, ) self.command = command self.args = args or [] @@ -628,6 +695,7 @@ def __init__( timeout: float | None = None, sse_read_timeout: float | None = None, kernel: Kernel | None = None, + sampling_consent_callback: SamplingConsentCallback | None = None, **kwargs: Any, ) -> None: """Initialize the MCP sse plugin. @@ -650,6 +718,10 @@ def __init__( timeout: The timeout for the request. sse_read_timeout: The timeout for reading from the SSE stream. kernel: The kernel instance with one or more Chat Completion clients. + sampling_consent_callback: Optional callback for approving MCP sampling requests. + Receives the plugin name and MCP sampling request params. Return + False to deny the request. When omitted, sampling requests are + auto-approved and a warning is logged. kwargs: Any extra arguments to pass to the sse client. """ @@ -661,6 +733,7 @@ def __init__( load_tools=load_tools, load_prompts=load_prompts, request_timeout=request_timeout, + sampling_consent_callback=sampling_consent_callback, ) self.url = url self.headers = headers or {} @@ -702,6 +775,7 @@ def __init__( sse_read_timeout: float | None = None, terminate_on_close: bool | None = None, kernel: Kernel | None = None, + sampling_consent_callback: SamplingConsentCallback | None = None, **kwargs: Any, ) -> None: """Initialize the MCP streamable http plugin. @@ -725,6 +799,10 @@ def __init__( sse_read_timeout: The timeout for reading from the SSE stream. terminate_on_close: Close the transport when the MCP client is terminated. kernel: The kernel instance with one or more Chat Completion clients. + sampling_consent_callback: Optional callback for approving MCP sampling requests. + Receives the plugin name and MCP sampling request params. Return + False to deny the request. When omitted, sampling requests are + auto-approved and a warning is logged. kwargs: Any extra arguments to pass to the sse client. """ super().__init__( @@ -735,6 +813,7 @@ def __init__( load_tools=load_tools, load_prompts=load_prompts, request_timeout=request_timeout, + sampling_consent_callback=sampling_consent_callback, ) self.url = url self.headers = headers or {} @@ -775,6 +854,7 @@ def __init__( session: ClientSession | None = None, description: str | None = None, kernel: Kernel | None = None, + sampling_consent_callback: SamplingConsentCallback | None = None, **kwargs: Any, ) -> None: """Initialize the MCP websocket plugin. @@ -794,6 +874,10 @@ def __init__( session: The session to use for the MCP connection. description: The description of the plugin. kernel: The kernel instance with one or more Chat Completion clients. + sampling_consent_callback: Optional callback for approving MCP sampling requests. + Receives the plugin name and MCP sampling request params. Return + False to deny the request. When omitted, sampling requests are + auto-approved and a warning is logged. kwargs: Any extra arguments to pass to the websocket client. """ @@ -805,6 +889,7 @@ def __init__( load_tools=load_tools, load_prompts=load_prompts, request_timeout=request_timeout, + sampling_consent_callback=sampling_consent_callback, ) self.url = url self._client_kwargs = kwargs diff --git a/python/tests/unit/connectors/mcp/test_mcp.py b/python/tests/unit/connectors/mcp/test_mcp.py index 55ca71313574..dc8ea38330d3 100644 --- a/python/tests/unit/connectors/mcp/test_mcp.py +++ b/python/tests/unit/connectors/mcp/test_mcp.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import logging import re from typing import TYPE_CHECKING from unittest.mock import AsyncMock, MagicMock, patch @@ -89,6 +90,120 @@ async def test_mcp_plugin_session_initialized(plugin_class, plugin_args): assert not mock_session.initialize.called +async def test_mcp_sampling_denied_by_consent_callback(): + sampling_consent_callback = AsyncMock(return_value=False) + plugin = MCPSsePlugin( + name="TestMCPPlugin", + url="http://localhost:8080/sse", + sampling_consent_callback=sampling_consent_callback, + ) + params = types.CreateMessageRequestParams( + messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hello"))], + systemPrompt="server instructions", + maxTokens=100, + ) + + result = await plugin.sampling_callback(MagicMock(), params) + + sampling_consent_callback.assert_awaited_once_with("TestMCPPlugin", params) + assert isinstance(result, types.ErrorData) + assert result.message == "Sampling denied by policy." + + +async def test_mcp_sampling_consent_callback_error_denies_request(caplog): + sampling_consent_callback = AsyncMock(side_effect=RuntimeError("policy failure")) + plugin = MCPSsePlugin( + name="TestMCPPlugin", + url="http://localhost:8080/sse", + sampling_consent_callback=sampling_consent_callback, + ) + params = types.CreateMessageRequestParams( + messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hello"))], + systemPrompt="server instructions", + maxTokens=100, + ) + + with caplog.at_level(logging.ERROR, logger="semantic_kernel.connectors.mcp"): + result = await plugin.sampling_callback(MagicMock(), params) + + sampling_consent_callback.assert_awaited_once_with("TestMCPPlugin", params) + assert isinstance(result, types.ErrorData) + assert result.message == "Sampling denied by policy." + assert "MCP sampling consent callback failed" in caplog.text + + +async def test_mcp_sampling_without_consent_callback_logs_auto_approve_warning(caplog): + plugin = MCPSsePlugin(name="TestMCPPlugin", url="http://localhost:8080/sse") + params = types.CreateMessageRequestParams( + messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hello"))], + systemPrompt="server instructions", + maxTokens=100, + ) + + with caplog.at_level(logging.WARNING, logger="semantic_kernel.connectors.mcp"): + result = await plugin.sampling_callback(MagicMock(), params) + + assert isinstance(result, types.ErrorData) + assert "auto-approved because no sampling consent callback was configured" in caplog.text + + +async def test_mcp_tool_and_prompt_names_do_not_shadow_plugin_attributes(): + kernel = MagicMock() + plugin = MCPSsePlugin(name="TestMCPPlugin", url="http://localhost:8080/sse", kernel=kernel) + session = AsyncMock(spec=ClientSession) + session.list_tools.return_value = ListToolsResult( + tools=[ + Tool(name="kernel", description="reserved", inputSchema={}), + Tool(name="safe_tool", description="safe", inputSchema={}), + ] + ) + session.list_prompts.return_value = types.ListPromptsResult( + prompts=[ + types.Prompt(name="session", description="reserved", arguments=[]), + types.Prompt(name="safe_prompt", description="safe", arguments=[]), + ] + ) + plugin.session = session + + await plugin.load_tools() + + assert plugin.kernel is kernel + assert hasattr(plugin, "safe_tool") + + await plugin.load_prompts() + + assert plugin.session is session + assert hasattr(plugin, "safe_prompt") + + +async def test_mcp_tool_and_prompt_names_can_reload_existing_mcp_functions(): + plugin = MCPSsePlugin(name="TestMCPPlugin", url="http://localhost:8080/sse") + session = AsyncMock(spec=ClientSession) + session.list_tools.side_effect = [ + ListToolsResult(tools=[Tool(name="safe_tool", description="first tool", inputSchema={})]), + ListToolsResult(tools=[Tool(name="safe_tool", description="second tool", inputSchema={})]), + ] + session.list_prompts.side_effect = [ + types.ListPromptsResult(prompts=[types.Prompt(name="safe_prompt", description="first prompt", arguments=[])]), + types.ListPromptsResult(prompts=[types.Prompt(name="safe_prompt", description="second prompt", arguments=[])]), + ] + plugin.session = session + + await plugin.load_tools() + first_tool = plugin.safe_tool + await plugin.load_tools() + + assert plugin.safe_tool is not first_tool + assert plugin.safe_tool.__kernel_function_description__ == "second tool" + + await plugin.load_prompts() + first_prompt = plugin.safe_prompt + await plugin.load_prompts() + + assert plugin.safe_prompt is not first_prompt + assert plugin.safe_prompt.__kernel_function_description__ == "second prompt" + + async def test_mcp_plugin_failed_get_session(): with ( patch("semantic_kernel.connectors.mcp.stdio_client") as mock_stdio_client,