diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 35ccb1d58a..a8d81e0f37 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -255,6 +255,7 @@ def __init__( self._exit_stack = AsyncExitStack() self._lifecycle_lock = asyncio.Lock() self._lifecycle_request_lock = asyncio.Lock() + self._load_tools_lock = asyncio.Lock() self._lifecycle_queue: asyncio.Queue[tuple[str, bool, asyncio.Future[None]]] | None = None self._lifecycle_owner_task: asyncio.Task[None] | None = None self.session = session @@ -1023,6 +1024,10 @@ async def load_tools(self) -> None: Raises: ToolExecutionException: If the MCP server is not connected. """ + async with self._load_tools_lock: + await self._load_tools_unlocked() + + async def _load_tools_unlocked(self) -> None: from mcp import types # Track existing function names to prevent duplicates diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 0fc5867d79..ce23a93452 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -3202,6 +3202,51 @@ async def mock_list_tools(params=None): assert [f.name for f in tool._functions] == ["tool_1", "tool_2", "tool_3"] +async def test_load_tools_serializes_concurrent_reloads(): + tool = MCPTool(name="test_tool") + + mock_session = AsyncMock() + tool.session = mock_session + + page = Mock() + page.tools = [ + types.Tool( + name="tool_1", + description="First tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + ] + page.nextCursor = None + + first_call_started = asyncio.Event() + release_first_call = asyncio.Event() + call_count = 0 + + async def mock_list_tools(params=None): + nonlocal call_count + assert params is None + call_count += 1 + if call_count == 1: + first_call_started.set() + await release_first_call.wait() + return page + + mock_session.list_tools = AsyncMock(side_effect=mock_list_tools) + + first_reload = asyncio.create_task(tool.load_tools()) + await first_call_started.wait() + + second_reload = asyncio.create_task(tool.load_tools()) + await asyncio.sleep(0) + assert not second_reload.done() + + release_first_call.set() + await asyncio.gather(first_reload, second_reload) + + assert mock_session.list_tools.call_count == 2 + assert [func.name for func in tool._functions] == ["tool_1"] + + async def test_load_prompts_pagination_with_duplicates(): """Test that load_prompts prevents duplicates across paginated results.""" from unittest.mock import AsyncMock, MagicMock