Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 100 additions & 5 deletions python/semantic_kernel/connectors/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
import sys
from abc import abstractmethod
from collections.abc import Callable, Sequence
from collections.abc import Awaitable, Callable, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack, _AsyncGeneratorContextManager
from datetime import timedelta
from functools import partial
Expand Down Expand Up @@ -59,6 +59,8 @@

# region: Helpers

SamplingConsentCallback = Callable[[str, types.CreateMessageRequestParams], Awaitable[bool]]

LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = {
"debug": logging.DEBUG,
"info": logging.INFO,
Expand Down Expand Up @@ -243,8 +245,23 @@ def __init__(
session: ClientSession | None = None,
kernel: Kernel | None = None,
request_timeout: int | None = None,
sampling_consent_callback: SamplingConsentCallback | None = None,
) -> None:
"""Initialize the MCP Plugin Base."""
"""Initialize the MCP Plugin Base.

Args:
name: The name of the plugin.
description: The description of the plugin.
load_tools: Whether to load tools from the MCP server.
load_prompts: Whether to load prompts from the MCP server.
session: The session to use for the MCP connection.
kernel: The kernel instance with one or more Chat Completion clients.
request_timeout: The default timeout used for all requests.
sampling_consent_callback: Optional callback for approving MCP sampling requests.
Receives the plugin name and MCP sampling request params. Return
False to deny the request. When omitted, sampling requests are
auto-approved and a warning is logged.
"""
self.name = name
self.description = description
self.load_tools_flag = load_tools
Expand All @@ -253,6 +270,9 @@ def __init__(
self.session = session
self.kernel = kernel or None
self.request_timeout = request_timeout
self.sampling_consent_callback = sampling_consent_callback
self._sampling_auto_approved_warning_logged = False
self._mcp_reserved_attribute_names: set[str] | None = None
self._current_task: asyncio.Task | None = None
self._stop_event: asyncio.Event | None = None

Expand All @@ -273,6 +293,14 @@ async def connect(self) -> None:
try:
self._current_task = asyncio.create_task(self._inner_connect(ready_event))
await ready_event.wait()
# If the background task finished before (or exactly when) ready_event was
# set, it means it raised an exception on an error path. Re-raise it here
# so callers always receive the error rather than silently succeeding with a
# broken connection state.
if self._current_task.done():
exc = self._current_task.exception()
if exc is not None:
raise exc
Comment on lines +296 to +303
Comment on lines +300 to +303
except KernelPluginInvalidConfigurationError:
ready_event.clear()
raise
Expand Down Expand Up @@ -316,13 +344,15 @@ async def _inner_connect(self, ready_event: asyncio.Event) -> None:
)
except Exception as ex:
await self._exit_stack.aclose()
ready_event.set()
Comment on lines 346 to +347
raise KernelPluginInvalidConfigurationError(
"Failed to create a session. Please check your configuration."
) from ex
try:
await session.initialize()
except Exception as ex:
await self._exit_stack.aclose()
ready_event.set()
raise KernelPluginInvalidConfigurationError(
"Failed to initialize session. Please check your configuration."
) from ex
Expand Down Expand Up @@ -361,9 +391,24 @@ async def sampling_callback(

This function is called when the MCP server needs to get a message completed.

This is a simple version of this function, it can be overridden to allow more complex sampling.
It get's added to the session at initialization time, so overriding it is the best way to do this.
If a sampling consent callback is configured, it is called before forwarding the request to the configured
chat completion service. Returning False denies the request. If no callback is configured, requests are
auto-approved and a warning is logged.
"""
if self.sampling_consent_callback is None:
if not self._sampling_auto_approved_warning_logged:
logger.warning(
"MCP sampling request for plugin '%s' was auto-approved because no sampling consent callback "
"was configured.",
self.name,
)
self._sampling_auto_approved_warning_logged = True
elif not await self._is_sampling_approved(params):
return types.ErrorData(
code=types.INTERNAL_ERROR,
message="Sampling denied by policy.",
)

if not self.kernel or not self.kernel.services:
return types.ErrorData(
code=types.INTERNAL_ERROR,
Expand Down Expand Up @@ -431,6 +476,15 @@ async def sampling_callback(
model=service.ai_model_id,
)

async def _is_sampling_approved(self, params: types.CreateMessageRequestParams) -> bool:
if self.sampling_consent_callback is None:
return True
try:
return await self.sampling_consent_callback(self.name, params)
except Exception:
logger.exception("MCP sampling consent callback failed for plugin '%s'.", self.name)
return False

async def logging_callback(self, params: types.LoggingMessageNotificationParams) -> None:
"""Callback function for logging.

Expand Down Expand Up @@ -464,6 +518,19 @@ async def message_handler(
case "notifications/prompts/list_changed":
await self.load_prompts()

def _has_mcp_function_name_conflict(self, item_type: str, remote_name: str, local_name: str) -> bool:
if self._mcp_reserved_attribute_names is None:
self._mcp_reserved_attribute_names = set(dir(self))
if local_name not in self._mcp_reserved_attribute_names:
return False
logger.warning(
"Skipping MCP %s '%s' because normalized name '%s' conflicts with an existing plugin attribute.",
item_type,
remote_name,
local_name,
)
return True

async def load_prompts(self):
"""Load prompts from the MCP server."""
try:
Expand All @@ -472,6 +539,8 @@ async def load_prompts(self):
prompt_list = None
for prompt in prompt_list.prompts if prompt_list else []:
local_name = _normalize_mcp_name(prompt.name)
if self._has_mcp_function_name_conflict("prompt", prompt.name, local_name):
continue
func = kernel_function(name=local_name, description=prompt.description)(
partial(self.get_prompt, prompt.name)
)
Expand All @@ -484,9 +553,11 @@ async def load_tools(self):
tool_list = await self.session.list_tools()
except Exception:
tool_list = None
# Create methods with the kernel_function decorator for each tool
# Create methods with the kernel_function decorator for each tool
for tool in tool_list.tools if tool_list else []:
local_name = _normalize_mcp_name(tool.name)
if self._has_mcp_function_name_conflict("tool", tool.name, local_name):
continue
func = kernel_function(name=local_name, description=tool.description)(partial(self.call_tool, tool.name))
func.__kernel_function_parameters__ = _get_parameter_dicts_from_mcp_tool(tool)
setattr(self, local_name, func)
Expand Down Expand Up @@ -558,6 +629,7 @@ def __init__(
env: dict[str, str] | None = None,
encoding: str | None = None,
kernel: Kernel | None = None,
sampling_consent_callback: SamplingConsentCallback | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP stdio plugin.
Expand All @@ -579,6 +651,10 @@ def __init__(
env: The environment variables to set for the command.
encoding: The encoding to use for the command output.
kernel: The kernel instance with one or more Chat Completion clients.
sampling_consent_callback: Optional callback for approving MCP sampling requests.
Receives the plugin name and MCP sampling request params. Return
False to deny the request. When omitted, sampling requests are
auto-approved and a warning is logged.
kwargs: Any extra arguments to pass to the stdio client.

"""
Expand All @@ -590,6 +666,7 @@ def __init__(
load_tools=load_tools,
load_prompts=load_prompts,
request_timeout=request_timeout,
sampling_consent_callback=sampling_consent_callback,
)
self.command = command
self.args = args or []
Expand Down Expand Up @@ -628,6 +705,7 @@ def __init__(
timeout: float | None = None,
sse_read_timeout: float | None = None,
kernel: Kernel | None = None,
sampling_consent_callback: SamplingConsentCallback | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP sse plugin.
Expand All @@ -650,6 +728,10 @@ def __init__(
timeout: The timeout for the request.
sse_read_timeout: The timeout for reading from the SSE stream.
kernel: The kernel instance with one or more Chat Completion clients.
sampling_consent_callback: Optional callback for approving MCP sampling requests.
Receives the plugin name and MCP sampling request params. Return
False to deny the request. When omitted, sampling requests are
auto-approved and a warning is logged.
kwargs: Any extra arguments to pass to the sse client.

"""
Expand All @@ -661,6 +743,7 @@ def __init__(
load_tools=load_tools,
load_prompts=load_prompts,
request_timeout=request_timeout,
sampling_consent_callback=sampling_consent_callback,
)
self.url = url
self.headers = headers or {}
Expand Down Expand Up @@ -702,6 +785,7 @@ def __init__(
sse_read_timeout: float | None = None,
terminate_on_close: bool | None = None,
kernel: Kernel | None = None,
sampling_consent_callback: SamplingConsentCallback | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP streamable http plugin.
Expand All @@ -725,6 +809,10 @@ def __init__(
sse_read_timeout: The timeout for reading from the SSE stream.
terminate_on_close: Close the transport when the MCP client is terminated.
kernel: The kernel instance with one or more Chat Completion clients.
sampling_consent_callback: Optional callback for approving MCP sampling requests.
Receives the plugin name and MCP sampling request params. Return
False to deny the request. When omitted, sampling requests are
auto-approved and a warning is logged.
kwargs: Any extra arguments to pass to the sse client.
"""
super().__init__(
Expand All @@ -735,6 +823,7 @@ def __init__(
load_tools=load_tools,
load_prompts=load_prompts,
request_timeout=request_timeout,
sampling_consent_callback=sampling_consent_callback,
)
self.url = url
self.headers = headers or {}
Expand Down Expand Up @@ -775,6 +864,7 @@ def __init__(
session: ClientSession | None = None,
description: str | None = None,
kernel: Kernel | None = None,
sampling_consent_callback: SamplingConsentCallback | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP websocket plugin.
Expand All @@ -794,6 +884,10 @@ def __init__(
session: The session to use for the MCP connection.
description: The description of the plugin.
kernel: The kernel instance with one or more Chat Completion clients.
sampling_consent_callback: Optional callback for approving MCP sampling requests.
Receives the plugin name and MCP sampling request params. Return
False to deny the request. When omitted, sampling requests are
auto-approved and a warning is logged.
kwargs: Any extra arguments to pass to the websocket client.

"""
Expand All @@ -805,6 +899,7 @@ def __init__(
load_tools=load_tools,
load_prompts=load_prompts,
request_timeout=request_timeout,
sampling_consent_callback=sampling_consent_callback,
)
self.url = url
self._client_kwargs = kwargs
Expand Down
48 changes: 48 additions & 0 deletions python/tests/unit/connectors/mcp/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,51 @@ async def test_mcp_normalization_function(mock_session, list_tool_calls_with_sla
assert _normalize_mcp_name("weird\\name with spaces") == "weird-name-with-spaces"
assert _normalize_mcp_name("simple_name") == "simple_name"
assert _normalize_mcp_name("Name-With.Dots_And-Hyphens") == "Name-With.Dots_And-Hyphens"


@patch("semantic_kernel.connectors.mcp.stdio_client")
@patch("semantic_kernel.connectors.mcp.ClientSession")
async def test_mcp_plugin_failed_client_session_creation(mock_client_session, mock_stdio_client):
"""Test that connect() raises KernelPluginInvalidConfigurationError when ClientSession creation fails."""
mock_read = MagicMock()
mock_write = MagicMock()

mock_generator = MagicMock()
mock_generator.__aenter__.return_value = (mock_read, mock_write)
mock_generator.__aexit__.return_value = None
mock_stdio_client.return_value = mock_generator

mock_client_session.return_value.__aenter__.side_effect = Exception("ClientSession creation failed")

with pytest.raises(KernelPluginInvalidConfigurationError):
async with MCPStdioPlugin(
name="test",
command="echo",
args=["Hello"],
):
pass


@patch("semantic_kernel.connectors.mcp.stdio_client")
@patch("semantic_kernel.connectors.mcp.ClientSession")
async def test_mcp_plugin_failed_session_initialize(mock_client_session, mock_stdio_client):
"""Test that connect() raises KernelPluginInvalidConfigurationError when session.initialize() fails."""
mock_read = MagicMock()
mock_write = MagicMock()

mock_generator = MagicMock()
mock_generator.__aenter__.return_value = (mock_read, mock_write)
mock_generator.__aexit__.return_value = None
mock_stdio_client.return_value = mock_generator

mock_session_inst = AsyncMock(spec=ClientSession)
mock_session_inst.initialize.side_effect = Exception("Session initialize failed")
mock_client_session.return_value.__aenter__.return_value = mock_session_inst

with pytest.raises(KernelPluginInvalidConfigurationError):
async with MCPStdioPlugin(
name="test",
command="echo",
args=["Hello"],
):
pass
Loading