From 38c4250af3e6883570e66da8cd3bc83c8d94868d Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Wed, 20 May 2026 02:58:55 +0800 Subject: [PATCH] Python: avoid duplicate MCP tools during reload --- python/packages/core/agent_framework/_mcp.py | 199 ++++++++++--------- python/packages/core/tests/core/test_mcp.py | 89 +++++++++ 2 files changed, 191 insertions(+), 97 deletions(-) diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 35ccb1d58a..b98075cf0b 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -255,6 +255,7 @@ def __init__( self._exit_stack = AsyncExitStack() self._lifecycle_lock = asyncio.Lock() self._lifecycle_request_lock = asyncio.Lock() + self._function_load_lock = asyncio.Lock() self._lifecycle_queue: asyncio.Queue[tuple[str, bool, asyncio.Future[None]]] | None = None self._lifecycle_owner_task: asyncio.Task[None] | None = None self.session = session @@ -975,44 +976,45 @@ async def load_prompts(self) -> None: """ from mcp import types - # Track existing function names to prevent duplicates - existing_names = {func.name for func in self._functions} - - params: types.PaginatedRequestParams | None = None - while True: - # Ensure connection is still valid before each page request - await self._ensure_connected() - - prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr] - - for prompt in prompt_list.prompts: - normalized_name = _normalize_mcp_name(prompt.name) - local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) - - # Skip if already loaded - if local_name in existing_names: - continue - - input_model = _get_input_model_from_mcp_prompt(prompt) - approval_mode = self._determine_approval_mode(local_name, normalized_name, prompt.name) - func: FunctionTool = FunctionTool( - func=partial(self.get_prompt, prompt.name), - name=local_name, - description=prompt.description or "", - approval_mode=approval_mode, - input_model=input_model, - additional_properties={ - _MCP_REMOTE_NAME_KEY: prompt.name, - _MCP_NORMALIZED_NAME_KEY: normalized_name, - }, - ) - self._functions.append(func) - existing_names.add(local_name) + async with self._function_load_lock: + # Track existing function names to prevent duplicates + existing_names = {func.name for func in self._functions} - # Check if there are more pages - if not prompt_list or not prompt_list.nextCursor: - break - params = types.PaginatedRequestParams(cursor=prompt_list.nextCursor) + params: types.PaginatedRequestParams | None = None + while True: + # Ensure connection is still valid before each page request + await self._ensure_connected() + + prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr] + + for prompt in prompt_list.prompts: + normalized_name = _normalize_mcp_name(prompt.name) + local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) + + # Skip if already loaded + if local_name in existing_names: + continue + + input_model = _get_input_model_from_mcp_prompt(prompt) + approval_mode = self._determine_approval_mode(local_name, normalized_name, prompt.name) + func: FunctionTool = FunctionTool( + func=partial(self.get_prompt, prompt.name), + name=local_name, + description=prompt.description or "", + approval_mode=approval_mode, + input_model=input_model, + additional_properties={ + _MCP_REMOTE_NAME_KEY: prompt.name, + _MCP_NORMALIZED_NAME_KEY: normalized_name, + }, + ) + self._functions.append(func) + existing_names.add(local_name) + + # Check if there are more pages + if not prompt_list or not prompt_list.nextCursor: + break + params = types.PaginatedRequestParams(cursor=prompt_list.nextCursor) async def load_tools(self) -> None: """Load tools from the MCP server. @@ -1025,67 +1027,70 @@ async def load_tools(self) -> None: """ from mcp import types - # Track existing function names to prevent duplicates - existing_names = {func.name for func in self._functions} - self._tool_call_meta_by_name.clear() - - params: types.PaginatedRequestParams | None = None - while True: - # Ensure connection is still valid before each page request - await self._ensure_connected() - - tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr] - - for tool in tool_list.tools: - if tool.meta is not None: - self._tool_call_meta_by_name[tool.name] = dict(tool.meta) - - normalized_name = _normalize_mcp_name(tool.name) - local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) - - # Skip if already loaded - if local_name in existing_names: - continue - - approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name) - # Normalize inputSchema: ensure "properties" exists for object schemas. - # Some MCP servers (e.g. zero-argument tools) omit "properties", - # which causes OpenAI API to reject the schema with a 400 error. - # Guard against non-conforming MCP servers that send inputSchema=None - # despite the MCP spec typing it as dict[str, Any]. - input_schema = dict(tool.inputSchema or {}) - if input_schema.get("type") == "object" and "properties" not in input_schema: - input_schema["properties"] = {} - - async def _call_tool_with_runtime_kwargs( - ctx: FunctionInvocationContext, - *, - _remote_tool_name: str = tool.name, - **kwargs: Any, - ) -> str | list[Content]: - call_kwargs = dict(ctx.kwargs) - call_kwargs.update(kwargs) - return await self.call_tool(_remote_tool_name, **call_kwargs) - - # Create FunctionTools out of each tool - func: FunctionTool = FunctionTool( - func=_call_tool_with_runtime_kwargs, - name=local_name, - description=tool.description or "", - approval_mode=approval_mode, - input_model=input_schema, - additional_properties={ - _MCP_REMOTE_NAME_KEY: tool.name, - _MCP_NORMALIZED_NAME_KEY: normalized_name, - }, - ) - self._functions.append(func) - existing_names.add(local_name) + async with self._function_load_lock: + tool_call_meta_by_name: dict[str, dict[str, Any]] = {} + # Track existing function names to prevent duplicates + existing_names = {func.name for func in self._functions} + + params: types.PaginatedRequestParams | None = None + while True: + # Ensure connection is still valid before each page request + await self._ensure_connected() + + tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr] + + for tool in tool_list.tools: + if tool.meta is not None: + tool_call_meta_by_name[tool.name] = dict(tool.meta) + + normalized_name = _normalize_mcp_name(tool.name) + local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) + + # Skip if already loaded + if local_name in existing_names: + continue + + approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name) + # Normalize inputSchema: ensure "properties" exists for object schemas. + # Some MCP servers (e.g. zero-argument tools) omit "properties", + # which causes OpenAI API to reject the schema with a 400 error. + # Guard against non-conforming MCP servers that send inputSchema=None + # despite the MCP spec typing it as dict[str, Any]. + input_schema = dict(tool.inputSchema or {}) + if input_schema.get("type") == "object" and "properties" not in input_schema: + input_schema["properties"] = {} + + async def _call_tool_with_runtime_kwargs( + ctx: FunctionInvocationContext, + *, + _remote_tool_name: str = tool.name, + **kwargs: Any, + ) -> str | list[Content]: + call_kwargs = dict(ctx.kwargs) + call_kwargs.update(kwargs) + return await self.call_tool(_remote_tool_name, **call_kwargs) + + # Create FunctionTools out of each tool + func: FunctionTool = FunctionTool( + func=_call_tool_with_runtime_kwargs, + name=local_name, + description=tool.description or "", + approval_mode=approval_mode, + input_model=input_schema, + additional_properties={ + _MCP_REMOTE_NAME_KEY: tool.name, + _MCP_NORMALIZED_NAME_KEY: normalized_name, + }, + ) + self._functions.append(func) + existing_names.add(local_name) + + # Check if there are more pages + if not tool_list or not tool_list.nextCursor: + break + params = types.PaginatedRequestParams(cursor=tool_list.nextCursor) - # Check if there are more pages - if not tool_list or not tool_list.nextCursor: - break - params = types.PaginatedRequestParams(cursor=tool_list.nextCursor) + self._tool_call_meta_by_name = tool_call_meta_by_name async def _close_on_owner(self) -> None: # Cancel any pending reload tasks before tearing down the session. diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index aea479ff86..d646f5abb6 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -3202,6 +3202,95 @@ async def mock_list_tools(params=None): assert [f.name for f in tool._functions] == ["tool_1", "tool_2", "tool_3"] +async def test_load_tools_concurrent_reload_does_not_duplicate_tools(): + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + mock_session = AsyncMock() + mock_session.send_ping = AsyncMock(return_value=None) + tool.session = mock_session + tool.load_tools_flag = True + + page = MagicMock() + page.tools = [ + types.Tool( + name="tool_1", + description="First tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + _meta={"echo": "tool_1"}, + ), + ] + page.nextCursor = None + + async def mock_list_tools(params=None): + assert params is None + await asyncio.sleep(0) + return page + + mock_session.list_tools = AsyncMock(side_effect=mock_list_tools) + + await asyncio.wait_for(asyncio.gather(tool.load_tools(), tool.load_tools()), timeout=1) + + assert mock_session.list_tools.call_count == 2 + assert [f.name for f in tool._functions] == ["tool_1"] + assert tool._tool_call_meta_by_name == {"tool_1": {"echo": "tool_1"}} + + +async def test_load_tools_concurrent_paginated_reload_preserves_meta(): + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + mock_session = AsyncMock() + mock_session.send_ping = AsyncMock(return_value=None) + tool.session = mock_session + tool.load_tools_flag = True + + page1 = MagicMock() + page1.tools = [ + types.Tool( + name="tool_1", + description="First tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + _meta={"echo": "tool_1"}, + ) + ] + page1.nextCursor = "cursor_page2" + + page2 = MagicMock() + page2.tools = [ + types.Tool( + name="tool_2", + description="Second tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + _meta={"echo": "tool_2"}, + ) + ] + page2.nextCursor = None + + async def mock_list_tools(params=None): + await asyncio.sleep(0) + if params is None: + return page1 + if params.cursor == "cursor_page2": + return page2 + raise ValueError("Unexpected cursor value") + + mock_session.list_tools = AsyncMock(side_effect=mock_list_tools) + + await asyncio.wait_for(asyncio.gather(tool.load_tools(), tool.load_tools()), timeout=1) + + assert mock_session.list_tools.call_count == 4 + assert [f.name for f in tool._functions] == ["tool_1", "tool_2"] + assert tool._tool_call_meta_by_name == { + "tool_1": {"echo": "tool_1"}, + "tool_2": {"echo": "tool_2"}, + } + + async def test_load_prompts_pagination_with_duplicates(): """Test that load_prompts prevents duplicates across paginated results.""" from unittest.mock import AsyncMock, MagicMock