Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 90 additions & 5 deletions python/semantic_kernel/connectors/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
import sys
from abc import abstractmethod
from collections.abc import Callable, Sequence
from collections.abc import Awaitable, Callable, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack, _AsyncGeneratorContextManager
from datetime import timedelta
from functools import partial
Expand Down Expand Up @@ -59,6 +59,8 @@

# region: Helpers

SamplingConsentCallback = Callable[[str, types.CreateMessageRequestParams], Awaitable[bool]]

LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = {
"debug": logging.DEBUG,
"info": logging.INFO,
Expand Down Expand Up @@ -243,8 +245,23 @@ def __init__(
session: ClientSession | None = None,
kernel: Kernel | None = None,
request_timeout: int | None = None,
sampling_consent_callback: SamplingConsentCallback | None = None,
) -> None:
"""Initialize the MCP Plugin Base."""
"""Initialize the MCP Plugin Base.

Args:
name: The name of the plugin.
description: The description of the plugin.
load_tools: Whether to load tools from the MCP server.
load_prompts: Whether to load prompts from the MCP server.
session: The session to use for the MCP connection.
kernel: The kernel instance with one or more Chat Completion clients.
request_timeout: The default timeout used for all requests.
sampling_consent_callback: Optional callback for approving MCP sampling requests.
Receives the plugin name and MCP sampling request params. Return
False to deny the request. When omitted, sampling requests are
auto-approved and a warning is logged.
"""
self.name = name
self.description = description
self.load_tools_flag = load_tools
Expand All @@ -253,6 +270,9 @@ def __init__(
self.session = session
self.kernel = kernel or None
self.request_timeout = request_timeout
self.sampling_consent_callback = sampling_consent_callback
self._sampling_auto_approved_warning_logged = False
self._mcp_reserved_attribute_names: set[str] | None = None
self._current_task: asyncio.Task | None = None
self._stop_event: asyncio.Event | None = None

Expand Down Expand Up @@ -361,9 +381,24 @@ async def sampling_callback(

This function is called when the MCP server needs to get a message completed.

This is a simple version of this function, it can be overridden to allow more complex sampling.
It get's added to the session at initialization time, so overriding it is the best way to do this.
If a sampling consent callback is configured, it is called before forwarding the request to the configured
chat completion service. Returning False denies the request. If no callback is configured, requests are
auto-approved and a warning is logged.
"""
if self.sampling_consent_callback is None:
if not self._sampling_auto_approved_warning_logged:
logger.warning(
"MCP sampling request for plugin '%s' was auto-approved because no sampling consent callback "
"was configured.",
self.name,
)
self._sampling_auto_approved_warning_logged = True
elif not await self._is_sampling_approved(params):
return types.ErrorData(
code=types.INTERNAL_ERROR,
message="Sampling denied by policy.",
)

if not self.kernel or not self.kernel.services:
return types.ErrorData(
code=types.INTERNAL_ERROR,
Expand Down Expand Up @@ -431,6 +466,15 @@ async def sampling_callback(
model=service.ai_model_id,
)

async def _is_sampling_approved(self, params: types.CreateMessageRequestParams) -> bool:
if self.sampling_consent_callback is None:
return True
try:
return await self.sampling_consent_callback(self.name, params)
except Exception:
logger.exception("MCP sampling consent callback failed for plugin '%s'.", self.name)
return False

async def logging_callback(self, params: types.LoggingMessageNotificationParams) -> None:
"""Callback function for logging.

Expand Down Expand Up @@ -464,6 +508,19 @@ async def message_handler(
case "notifications/prompts/list_changed":
await self.load_prompts()

def _has_mcp_function_name_conflict(self, item_type: str, remote_name: str, local_name: str) -> bool:
if self._mcp_reserved_attribute_names is None:
self._mcp_reserved_attribute_names = set(dir(self))
if local_name not in self._mcp_reserved_attribute_names:
return False
logger.warning(
"Skipping MCP %s '%s' because normalized name '%s' conflicts with an existing plugin attribute.",
item_type,
remote_name,
local_name,
)
return True
Comment thread
moonbox3 marked this conversation as resolved.

async def load_prompts(self):
"""Load prompts from the MCP server."""
try:
Expand All @@ -472,6 +529,8 @@ async def load_prompts(self):
prompt_list = None
for prompt in prompt_list.prompts if prompt_list else []:
local_name = _normalize_mcp_name(prompt.name)
if self._has_mcp_function_name_conflict("prompt", prompt.name, local_name):
continue
func = kernel_function(name=local_name, description=prompt.description)(
partial(self.get_prompt, prompt.name)
)
Expand All @@ -484,9 +543,11 @@ async def load_tools(self):
tool_list = await self.session.list_tools()
except Exception:
tool_list = None
# Create methods with the kernel_function decorator for each tool
# Create methods with the kernel_function decorator for each tool
for tool in tool_list.tools if tool_list else []:
local_name = _normalize_mcp_name(tool.name)
if self._has_mcp_function_name_conflict("tool", tool.name, local_name):
continue
func = kernel_function(name=local_name, description=tool.description)(partial(self.call_tool, tool.name))
func.__kernel_function_parameters__ = _get_parameter_dicts_from_mcp_tool(tool)
setattr(self, local_name, func)
Expand Down Expand Up @@ -558,6 +619,7 @@ def __init__(
env: dict[str, str] | None = None,
encoding: str | None = None,
kernel: Kernel | None = None,
sampling_consent_callback: SamplingConsentCallback | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP stdio plugin.
Expand All @@ -579,6 +641,10 @@ def __init__(
env: The environment variables to set for the command.
encoding: The encoding to use for the command output.
kernel: The kernel instance with one or more Chat Completion clients.
sampling_consent_callback: Optional callback for approving MCP sampling requests.
Receives the plugin name and MCP sampling request params. Return
False to deny the request. When omitted, sampling requests are
auto-approved and a warning is logged.
kwargs: Any extra arguments to pass to the stdio client.

"""
Expand All @@ -590,6 +656,7 @@ def __init__(
load_tools=load_tools,
load_prompts=load_prompts,
request_timeout=request_timeout,
sampling_consent_callback=sampling_consent_callback,
)
self.command = command
self.args = args or []
Expand Down Expand Up @@ -628,6 +695,7 @@ def __init__(
timeout: float | None = None,
sse_read_timeout: float | None = None,
kernel: Kernel | None = None,
sampling_consent_callback: SamplingConsentCallback | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP sse plugin.
Expand All @@ -650,6 +718,10 @@ def __init__(
timeout: The timeout for the request.
sse_read_timeout: The timeout for reading from the SSE stream.
kernel: The kernel instance with one or more Chat Completion clients.
sampling_consent_callback: Optional callback for approving MCP sampling requests.
Receives the plugin name and MCP sampling request params. Return
False to deny the request. When omitted, sampling requests are
auto-approved and a warning is logged.
kwargs: Any extra arguments to pass to the sse client.

"""
Expand All @@ -661,6 +733,7 @@ def __init__(
load_tools=load_tools,
load_prompts=load_prompts,
request_timeout=request_timeout,
sampling_consent_callback=sampling_consent_callback,
)
self.url = url
self.headers = headers or {}
Expand Down Expand Up @@ -702,6 +775,7 @@ def __init__(
sse_read_timeout: float | None = None,
terminate_on_close: bool | None = None,
kernel: Kernel | None = None,
sampling_consent_callback: SamplingConsentCallback | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP streamable http plugin.
Expand All @@ -725,6 +799,10 @@ def __init__(
sse_read_timeout: The timeout for reading from the SSE stream.
terminate_on_close: Close the transport when the MCP client is terminated.
kernel: The kernel instance with one or more Chat Completion clients.
sampling_consent_callback: Optional callback for approving MCP sampling requests.
Receives the plugin name and MCP sampling request params. Return
False to deny the request. When omitted, sampling requests are
auto-approved and a warning is logged.
kwargs: Any extra arguments to pass to the sse client.
"""
super().__init__(
Expand All @@ -735,6 +813,7 @@ def __init__(
load_tools=load_tools,
load_prompts=load_prompts,
request_timeout=request_timeout,
sampling_consent_callback=sampling_consent_callback,
)
self.url = url
self.headers = headers or {}
Expand Down Expand Up @@ -775,6 +854,7 @@ def __init__(
session: ClientSession | None = None,
description: str | None = None,
kernel: Kernel | None = None,
sampling_consent_callback: SamplingConsentCallback | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP websocket plugin.
Expand All @@ -794,6 +874,10 @@ def __init__(
session: The session to use for the MCP connection.
description: The description of the plugin.
kernel: The kernel instance with one or more Chat Completion clients.
sampling_consent_callback: Optional callback for approving MCP sampling requests.
Receives the plugin name and MCP sampling request params. Return
False to deny the request. When omitted, sampling requests are
auto-approved and a warning is logged.
kwargs: Any extra arguments to pass to the websocket client.

"""
Expand All @@ -805,6 +889,7 @@ def __init__(
load_tools=load_tools,
load_prompts=load_prompts,
request_timeout=request_timeout,
sampling_consent_callback=sampling_consent_callback,
)
self.url = url
self._client_kwargs = kwargs
Expand Down
115 changes: 115 additions & 0 deletions python/tests/unit/connectors/mcp/test_mcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

import logging
import re
from typing import TYPE_CHECKING
from unittest.mock import AsyncMock, MagicMock, patch
Expand Down Expand Up @@ -89,6 +90,120 @@ async def test_mcp_plugin_session_initialized(plugin_class, plugin_args):
assert not mock_session.initialize.called


async def test_mcp_sampling_denied_by_consent_callback():
sampling_consent_callback = AsyncMock(return_value=False)
plugin = MCPSsePlugin(
name="TestMCPPlugin",
url="http://localhost:8080/sse",
sampling_consent_callback=sampling_consent_callback,
)
params = types.CreateMessageRequestParams(
messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hello"))],
systemPrompt="server instructions",
maxTokens=100,
)

result = await plugin.sampling_callback(MagicMock(), params)

sampling_consent_callback.assert_awaited_once_with("TestMCPPlugin", params)
assert isinstance(result, types.ErrorData)
assert result.message == "Sampling denied by policy."


async def test_mcp_sampling_consent_callback_error_denies_request(caplog):
sampling_consent_callback = AsyncMock(side_effect=RuntimeError("policy failure"))
plugin = MCPSsePlugin(
name="TestMCPPlugin",
url="http://localhost:8080/sse",
sampling_consent_callback=sampling_consent_callback,
)
params = types.CreateMessageRequestParams(
messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hello"))],
systemPrompt="server instructions",
maxTokens=100,
)

with caplog.at_level(logging.ERROR, logger="semantic_kernel.connectors.mcp"):
result = await plugin.sampling_callback(MagicMock(), params)

sampling_consent_callback.assert_awaited_once_with("TestMCPPlugin", params)
assert isinstance(result, types.ErrorData)
assert result.message == "Sampling denied by policy."
assert "MCP sampling consent callback failed" in caplog.text


async def test_mcp_sampling_without_consent_callback_logs_auto_approve_warning(caplog):
plugin = MCPSsePlugin(name="TestMCPPlugin", url="http://localhost:8080/sse")
params = types.CreateMessageRequestParams(
messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hello"))],
systemPrompt="server instructions",
maxTokens=100,
)

with caplog.at_level(logging.WARNING, logger="semantic_kernel.connectors.mcp"):
result = await plugin.sampling_callback(MagicMock(), params)

assert isinstance(result, types.ErrorData)
assert "auto-approved because no sampling consent callback was configured" in caplog.text


async def test_mcp_tool_and_prompt_names_do_not_shadow_plugin_attributes():
kernel = MagicMock()
plugin = MCPSsePlugin(name="TestMCPPlugin", url="http://localhost:8080/sse", kernel=kernel)
session = AsyncMock(spec=ClientSession)
session.list_tools.return_value = ListToolsResult(
tools=[
Tool(name="kernel", description="reserved", inputSchema={}),
Tool(name="safe_tool", description="safe", inputSchema={}),
]
)
session.list_prompts.return_value = types.ListPromptsResult(
prompts=[
types.Prompt(name="session", description="reserved", arguments=[]),
types.Prompt(name="safe_prompt", description="safe", arguments=[]),
]
)
plugin.session = session

await plugin.load_tools()

assert plugin.kernel is kernel
assert hasattr(plugin, "safe_tool")

await plugin.load_prompts()

assert plugin.session is session
assert hasattr(plugin, "safe_prompt")


async def test_mcp_tool_and_prompt_names_can_reload_existing_mcp_functions():
plugin = MCPSsePlugin(name="TestMCPPlugin", url="http://localhost:8080/sse")
session = AsyncMock(spec=ClientSession)
session.list_tools.side_effect = [
ListToolsResult(tools=[Tool(name="safe_tool", description="first tool", inputSchema={})]),
ListToolsResult(tools=[Tool(name="safe_tool", description="second tool", inputSchema={})]),
]
session.list_prompts.side_effect = [
types.ListPromptsResult(prompts=[types.Prompt(name="safe_prompt", description="first prompt", arguments=[])]),
types.ListPromptsResult(prompts=[types.Prompt(name="safe_prompt", description="second prompt", arguments=[])]),
]
plugin.session = session

await plugin.load_tools()
first_tool = plugin.safe_tool
await plugin.load_tools()

assert plugin.safe_tool is not first_tool
assert plugin.safe_tool.__kernel_function_description__ == "second tool"

await plugin.load_prompts()
first_prompt = plugin.safe_prompt
await plugin.load_prompts()

assert plugin.safe_prompt is not first_prompt
assert plugin.safe_prompt.__kernel_function_description__ == "second prompt"


async def test_mcp_plugin_failed_get_session():
with (
patch("semantic_kernel.connectors.mcp.stdio_client") as mock_stdio_client,
Expand Down
Loading