From c6a745bc145eab036fa8e0a5f81d3eea3ac63863 Mon Sep 17 00:00:00 2001 From: Nilesh Patil <128893479+nileshpatil6@users.noreply.github.com> Date: Sat, 16 May 2026 20:10:57 +0530 Subject: [PATCH 1/2] fix(python): unblock connect() when ClientSession or session.initialize() fails When _inner_connect() failed to create a ClientSession or to call session.initialize(), it raised KernelPluginInvalidConfigurationError without first calling ready_event.set(). The connect() method was waiting on that event in a background task, so it would hang indefinitely (until an external timeout fired, e.g. 30 s in PromptFlow). The transport-failure branch already called ready_event.set() before raising, so the pattern was consistent there but missing in the two later branches. Fix: - Add ready_event.set() before raise in the ClientSession creation exception handler. - Add ready_event.set() before raise in the session.initialize() exception handler. - After await ready_event.wait() in connect(), check whether the background task finished with an exception and re-raise it so that callers (including __aenter__) receive the error instead of silently succeeding with a broken state. Fixes #13414 --- python/semantic_kernel/connectors/mcp.py | 105 +++++++++++++++++++++-- 1 file changed, 100 insertions(+), 5 deletions(-) diff --git a/python/semantic_kernel/connectors/mcp.py b/python/semantic_kernel/connectors/mcp.py index 5f31886b28fa..79dccd8c1767 100644 --- a/python/semantic_kernel/connectors/mcp.py +++ b/python/semantic_kernel/connectors/mcp.py @@ -6,7 +6,7 @@ import re import sys from abc import abstractmethod -from collections.abc import Callable, Sequence +from collections.abc import Awaitable, Callable, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, _AsyncGeneratorContextManager from datetime import timedelta from functools import partial @@ -59,6 +59,8 @@ # region: Helpers +SamplingConsentCallback = Callable[[str, types.CreateMessageRequestParams], Awaitable[bool]] + LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = { "debug": logging.DEBUG, "info": logging.INFO, @@ -243,8 +245,23 @@ def __init__( session: ClientSession | None = None, kernel: Kernel | None = None, request_timeout: int | None = None, + sampling_consent_callback: SamplingConsentCallback | None = None, ) -> None: - """Initialize the MCP Plugin Base.""" + """Initialize the MCP Plugin Base. + + Args: + name: The name of the plugin. + description: The description of the plugin. + load_tools: Whether to load tools from the MCP server. + load_prompts: Whether to load prompts from the MCP server. + session: The session to use for the MCP connection. + kernel: The kernel instance with one or more Chat Completion clients. + request_timeout: The default timeout used for all requests. + sampling_consent_callback: Optional callback for approving MCP sampling requests. + Receives the plugin name and MCP sampling request params. Return + False to deny the request. When omitted, sampling requests are + auto-approved and a warning is logged. + """ self.name = name self.description = description self.load_tools_flag = load_tools @@ -253,6 +270,9 @@ def __init__( self.session = session self.kernel = kernel or None self.request_timeout = request_timeout + self.sampling_consent_callback = sampling_consent_callback + self._sampling_auto_approved_warning_logged = False + self._mcp_reserved_attribute_names: set[str] | None = None self._current_task: asyncio.Task | None = None self._stop_event: asyncio.Event | None = None @@ -273,6 +293,14 @@ async def connect(self) -> None: try: self._current_task = asyncio.create_task(self._inner_connect(ready_event)) await ready_event.wait() + # If the background task finished before (or exactly when) ready_event was + # set, it means it raised an exception on an error path. Re-raise it here + # so callers always receive the error rather than silently succeeding with a + # broken connection state. + if self._current_task.done(): + exc = self._current_task.exception() + if exc is not None: + raise exc except KernelPluginInvalidConfigurationError: ready_event.clear() raise @@ -316,6 +344,7 @@ async def _inner_connect(self, ready_event: asyncio.Event) -> None: ) except Exception as ex: await self._exit_stack.aclose() + ready_event.set() raise KernelPluginInvalidConfigurationError( "Failed to create a session. Please check your configuration." ) from ex @@ -323,6 +352,7 @@ async def _inner_connect(self, ready_event: asyncio.Event) -> None: await session.initialize() except Exception as ex: await self._exit_stack.aclose() + ready_event.set() raise KernelPluginInvalidConfigurationError( "Failed to initialize session. Please check your configuration." ) from ex @@ -361,9 +391,24 @@ async def sampling_callback( This function is called when the MCP server needs to get a message completed. - This is a simple version of this function, it can be overridden to allow more complex sampling. - It get's added to the session at initialization time, so overriding it is the best way to do this. + If a sampling consent callback is configured, it is called before forwarding the request to the configured + chat completion service. Returning False denies the request. If no callback is configured, requests are + auto-approved and a warning is logged. """ + if self.sampling_consent_callback is None: + if not self._sampling_auto_approved_warning_logged: + logger.warning( + "MCP sampling request for plugin '%s' was auto-approved because no sampling consent callback " + "was configured.", + self.name, + ) + self._sampling_auto_approved_warning_logged = True + elif not await self._is_sampling_approved(params): + return types.ErrorData( + code=types.INTERNAL_ERROR, + message="Sampling denied by policy.", + ) + if not self.kernel or not self.kernel.services: return types.ErrorData( code=types.INTERNAL_ERROR, @@ -431,6 +476,15 @@ async def sampling_callback( model=service.ai_model_id, ) + async def _is_sampling_approved(self, params: types.CreateMessageRequestParams) -> bool: + if self.sampling_consent_callback is None: + return True + try: + return await self.sampling_consent_callback(self.name, params) + except Exception: + logger.exception("MCP sampling consent callback failed for plugin '%s'.", self.name) + return False + async def logging_callback(self, params: types.LoggingMessageNotificationParams) -> None: """Callback function for logging. @@ -464,6 +518,19 @@ async def message_handler( case "notifications/prompts/list_changed": await self.load_prompts() + def _has_mcp_function_name_conflict(self, item_type: str, remote_name: str, local_name: str) -> bool: + if self._mcp_reserved_attribute_names is None: + self._mcp_reserved_attribute_names = set(dir(self)) + if local_name not in self._mcp_reserved_attribute_names: + return False + logger.warning( + "Skipping MCP %s '%s' because normalized name '%s' conflicts with an existing plugin attribute.", + item_type, + remote_name, + local_name, + ) + return True + async def load_prompts(self): """Load prompts from the MCP server.""" try: @@ -472,6 +539,8 @@ async def load_prompts(self): prompt_list = None for prompt in prompt_list.prompts if prompt_list else []: local_name = _normalize_mcp_name(prompt.name) + if self._has_mcp_function_name_conflict("prompt", prompt.name, local_name): + continue func = kernel_function(name=local_name, description=prompt.description)( partial(self.get_prompt, prompt.name) ) @@ -484,9 +553,11 @@ async def load_tools(self): tool_list = await self.session.list_tools() except Exception: tool_list = None - # Create methods with the kernel_function decorator for each tool + # Create methods with the kernel_function decorator for each tool for tool in tool_list.tools if tool_list else []: local_name = _normalize_mcp_name(tool.name) + if self._has_mcp_function_name_conflict("tool", tool.name, local_name): + continue func = kernel_function(name=local_name, description=tool.description)(partial(self.call_tool, tool.name)) func.__kernel_function_parameters__ = _get_parameter_dicts_from_mcp_tool(tool) setattr(self, local_name, func) @@ -558,6 +629,7 @@ def __init__( env: dict[str, str] | None = None, encoding: str | None = None, kernel: Kernel | None = None, + sampling_consent_callback: SamplingConsentCallback | None = None, **kwargs: Any, ) -> None: """Initialize the MCP stdio plugin. @@ -579,6 +651,10 @@ def __init__( env: The environment variables to set for the command. encoding: The encoding to use for the command output. kernel: The kernel instance with one or more Chat Completion clients. + sampling_consent_callback: Optional callback for approving MCP sampling requests. + Receives the plugin name and MCP sampling request params. Return + False to deny the request. When omitted, sampling requests are + auto-approved and a warning is logged. kwargs: Any extra arguments to pass to the stdio client. """ @@ -590,6 +666,7 @@ def __init__( load_tools=load_tools, load_prompts=load_prompts, request_timeout=request_timeout, + sampling_consent_callback=sampling_consent_callback, ) self.command = command self.args = args or [] @@ -628,6 +705,7 @@ def __init__( timeout: float | None = None, sse_read_timeout: float | None = None, kernel: Kernel | None = None, + sampling_consent_callback: SamplingConsentCallback | None = None, **kwargs: Any, ) -> None: """Initialize the MCP sse plugin. @@ -650,6 +728,10 @@ def __init__( timeout: The timeout for the request. sse_read_timeout: The timeout for reading from the SSE stream. kernel: The kernel instance with one or more Chat Completion clients. + sampling_consent_callback: Optional callback for approving MCP sampling requests. + Receives the plugin name and MCP sampling request params. Return + False to deny the request. When omitted, sampling requests are + auto-approved and a warning is logged. kwargs: Any extra arguments to pass to the sse client. """ @@ -661,6 +743,7 @@ def __init__( load_tools=load_tools, load_prompts=load_prompts, request_timeout=request_timeout, + sampling_consent_callback=sampling_consent_callback, ) self.url = url self.headers = headers or {} @@ -702,6 +785,7 @@ def __init__( sse_read_timeout: float | None = None, terminate_on_close: bool | None = None, kernel: Kernel | None = None, + sampling_consent_callback: SamplingConsentCallback | None = None, **kwargs: Any, ) -> None: """Initialize the MCP streamable http plugin. @@ -725,6 +809,10 @@ def __init__( sse_read_timeout: The timeout for reading from the SSE stream. terminate_on_close: Close the transport when the MCP client is terminated. kernel: The kernel instance with one or more Chat Completion clients. + sampling_consent_callback: Optional callback for approving MCP sampling requests. + Receives the plugin name and MCP sampling request params. Return + False to deny the request. When omitted, sampling requests are + auto-approved and a warning is logged. kwargs: Any extra arguments to pass to the sse client. """ super().__init__( @@ -735,6 +823,7 @@ def __init__( load_tools=load_tools, load_prompts=load_prompts, request_timeout=request_timeout, + sampling_consent_callback=sampling_consent_callback, ) self.url = url self.headers = headers or {} @@ -775,6 +864,7 @@ def __init__( session: ClientSession | None = None, description: str | None = None, kernel: Kernel | None = None, + sampling_consent_callback: SamplingConsentCallback | None = None, **kwargs: Any, ) -> None: """Initialize the MCP websocket plugin. @@ -794,6 +884,10 @@ def __init__( session: The session to use for the MCP connection. description: The description of the plugin. kernel: The kernel instance with one or more Chat Completion clients. + sampling_consent_callback: Optional callback for approving MCP sampling requests. + Receives the plugin name and MCP sampling request params. Return + False to deny the request. When omitted, sampling requests are + auto-approved and a warning is logged. kwargs: Any extra arguments to pass to the websocket client. """ @@ -805,6 +899,7 @@ def __init__( load_tools=load_tools, load_prompts=load_prompts, request_timeout=request_timeout, + sampling_consent_callback=sampling_consent_callback, ) self.url = url self._client_kwargs = kwargs From ebc566abab68d50d722c33281a155b819ef05069 Mon Sep 17 00:00:00 2001 From: Nilesh Patil <128893479+nileshpatil6@users.noreply.github.com> Date: Sat, 16 May 2026 21:14:00 +0530 Subject: [PATCH 2/2] test(mcp): add tests for ClientSession and initialize() failure paths --- python/tests/unit/connectors/mcp/test_mcp.py | 48 ++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/python/tests/unit/connectors/mcp/test_mcp.py b/python/tests/unit/connectors/mcp/test_mcp.py index 55ca71313574..f96dd5df9fd7 100644 --- a/python/tests/unit/connectors/mcp/test_mcp.py +++ b/python/tests/unit/connectors/mcp/test_mcp.py @@ -291,3 +291,51 @@ async def test_mcp_normalization_function(mock_session, list_tool_calls_with_sla assert _normalize_mcp_name("weird\\name with spaces") == "weird-name-with-spaces" assert _normalize_mcp_name("simple_name") == "simple_name" assert _normalize_mcp_name("Name-With.Dots_And-Hyphens") == "Name-With.Dots_And-Hyphens" + + +@patch("semantic_kernel.connectors.mcp.stdio_client") +@patch("semantic_kernel.connectors.mcp.ClientSession") +async def test_mcp_plugin_failed_client_session_creation(mock_client_session, mock_stdio_client): + """Test that connect() raises KernelPluginInvalidConfigurationError when ClientSession creation fails.""" + mock_read = MagicMock() + mock_write = MagicMock() + + mock_generator = MagicMock() + mock_generator.__aenter__.return_value = (mock_read, mock_write) + mock_generator.__aexit__.return_value = None + mock_stdio_client.return_value = mock_generator + + mock_client_session.return_value.__aenter__.side_effect = Exception("ClientSession creation failed") + + with pytest.raises(KernelPluginInvalidConfigurationError): + async with MCPStdioPlugin( + name="test", + command="echo", + args=["Hello"], + ): + pass + + +@patch("semantic_kernel.connectors.mcp.stdio_client") +@patch("semantic_kernel.connectors.mcp.ClientSession") +async def test_mcp_plugin_failed_session_initialize(mock_client_session, mock_stdio_client): + """Test that connect() raises KernelPluginInvalidConfigurationError when session.initialize() fails.""" + mock_read = MagicMock() + mock_write = MagicMock() + + mock_generator = MagicMock() + mock_generator.__aenter__.return_value = (mock_read, mock_write) + mock_generator.__aexit__.return_value = None + mock_stdio_client.return_value = mock_generator + + mock_session_inst = AsyncMock(spec=ClientSession) + mock_session_inst.initialize.side_effect = Exception("Session initialize failed") + mock_client_session.return_value.__aenter__.return_value = mock_session_inst + + with pytest.raises(KernelPluginInvalidConfigurationError): + async with MCPStdioPlugin( + name="test", + command="echo", + args=["Hello"], + ): + pass