Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 102 additions & 97 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def __init__(
self._exit_stack = AsyncExitStack()
self._lifecycle_lock = asyncio.Lock()
self._lifecycle_request_lock = asyncio.Lock()
self._function_load_lock = asyncio.Lock()
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, asyncio.Future[None]]] | None = None
self._lifecycle_owner_task: asyncio.Task[None] | None = None
self.session = session
Expand Down Expand Up @@ -975,44 +976,45 @@ async def load_prompts(self) -> None:
"""
from mcp import types

# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}

params: types.PaginatedRequestParams | None = None
while True:
# Ensure connection is still valid before each page request
await self._ensure_connected()

prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr]

for prompt in prompt_list.prompts:
normalized_name = _normalize_mcp_name(prompt.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)

# Skip if already loaded
if local_name in existing_names:
continue

input_model = _get_input_model_from_mcp_prompt(prompt)
approval_mode = self._determine_approval_mode(local_name, normalized_name, prompt.name)
func: FunctionTool = FunctionTool(
func=partial(self.get_prompt, prompt.name),
name=local_name,
description=prompt.description or "",
approval_mode=approval_mode,
input_model=input_model,
additional_properties={
_MCP_REMOTE_NAME_KEY: prompt.name,
_MCP_NORMALIZED_NAME_KEY: normalized_name,
},
)
self._functions.append(func)
existing_names.add(local_name)
async with self._function_load_lock:
# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}

# Check if there are more pages
if not prompt_list or not prompt_list.nextCursor:
break
params = types.PaginatedRequestParams(cursor=prompt_list.nextCursor)
params: types.PaginatedRequestParams | None = None
while True:
# Ensure connection is still valid before each page request
await self._ensure_connected()

prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr]

for prompt in prompt_list.prompts:
normalized_name = _normalize_mcp_name(prompt.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)

# Skip if already loaded
if local_name in existing_names:
continue

input_model = _get_input_model_from_mcp_prompt(prompt)
approval_mode = self._determine_approval_mode(local_name, normalized_name, prompt.name)
func: FunctionTool = FunctionTool(
func=partial(self.get_prompt, prompt.name),
name=local_name,
description=prompt.description or "",
approval_mode=approval_mode,
input_model=input_model,
additional_properties={
_MCP_REMOTE_NAME_KEY: prompt.name,
_MCP_NORMALIZED_NAME_KEY: normalized_name,
},
)
self._functions.append(func)
existing_names.add(local_name)

# Check if there are more pages
if not prompt_list or not prompt_list.nextCursor:
break
params = types.PaginatedRequestParams(cursor=prompt_list.nextCursor)

async def load_tools(self) -> None:
"""Load tools from the MCP server.
Expand All @@ -1025,67 +1027,70 @@ async def load_tools(self) -> None:
"""
from mcp import types

# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
self._tool_call_meta_by_name.clear()

params: types.PaginatedRequestParams | None = None
while True:
# Ensure connection is still valid before each page request
await self._ensure_connected()

tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]

for tool in tool_list.tools:
if tool.meta is not None:
self._tool_call_meta_by_name[tool.name] = dict(tool.meta)

normalized_name = _normalize_mcp_name(tool.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)

# Skip if already loaded
if local_name in existing_names:
continue

approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name)
# Normalize inputSchema: ensure "properties" exists for object schemas.
# Some MCP servers (e.g. zero-argument tools) omit "properties",
# which causes OpenAI API to reject the schema with a 400 error.
# Guard against non-conforming MCP servers that send inputSchema=None
# despite the MCP spec typing it as dict[str, Any].
input_schema = dict(tool.inputSchema or {})
if input_schema.get("type") == "object" and "properties" not in input_schema:
input_schema["properties"] = {}

async def _call_tool_with_runtime_kwargs(
ctx: FunctionInvocationContext,
*,
_remote_tool_name: str = tool.name,
**kwargs: Any,
) -> str | list[Content]:
call_kwargs = dict(ctx.kwargs)
call_kwargs.update(kwargs)
return await self.call_tool(_remote_tool_name, **call_kwargs)

# Create FunctionTools out of each tool
func: FunctionTool = FunctionTool(
func=_call_tool_with_runtime_kwargs,
name=local_name,
description=tool.description or "",
approval_mode=approval_mode,
input_model=input_schema,
additional_properties={
_MCP_REMOTE_NAME_KEY: tool.name,
_MCP_NORMALIZED_NAME_KEY: normalized_name,
},
)
self._functions.append(func)
existing_names.add(local_name)
async with self._function_load_lock:
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}

params: types.PaginatedRequestParams | None = None
while True:
# Ensure connection is still valid before each page request
await self._ensure_connected()

tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]

for tool in tool_list.tools:
if tool.meta is not None:
tool_call_meta_by_name[tool.name] = dict(tool.meta)

normalized_name = _normalize_mcp_name(tool.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)

# Skip if already loaded
if local_name in existing_names:
continue

approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name)
# Normalize inputSchema: ensure "properties" exists for object schemas.
# Some MCP servers (e.g. zero-argument tools) omit "properties",
# which causes OpenAI API to reject the schema with a 400 error.
# Guard against non-conforming MCP servers that send inputSchema=None
# despite the MCP spec typing it as dict[str, Any].
input_schema = dict(tool.inputSchema or {})
if input_schema.get("type") == "object" and "properties" not in input_schema:
input_schema["properties"] = {}

async def _call_tool_with_runtime_kwargs(
ctx: FunctionInvocationContext,
*,
_remote_tool_name: str = tool.name,
**kwargs: Any,
) -> str | list[Content]:
call_kwargs = dict(ctx.kwargs)
call_kwargs.update(kwargs)
return await self.call_tool(_remote_tool_name, **call_kwargs)

# Create FunctionTools out of each tool
func: FunctionTool = FunctionTool(
func=_call_tool_with_runtime_kwargs,
name=local_name,
description=tool.description or "",
approval_mode=approval_mode,
input_model=input_schema,
additional_properties={
_MCP_REMOTE_NAME_KEY: tool.name,
_MCP_NORMALIZED_NAME_KEY: normalized_name,
},
)
self._functions.append(func)
existing_names.add(local_name)

# Check if there are more pages
if not tool_list or not tool_list.nextCursor:
break
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)

# Check if there are more pages
if not tool_list or not tool_list.nextCursor:
break
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
self._tool_call_meta_by_name = tool_call_meta_by_name

async def _close_on_owner(self) -> None:
# Cancel any pending reload tasks before tearing down the session.
Expand Down
89 changes: 89 additions & 0 deletions python/packages/core/tests/core/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3202,6 +3202,95 @@ async def mock_list_tools(params=None):
assert [f.name for f in tool._functions] == ["tool_1", "tool_2", "tool_3"]


async def test_load_tools_concurrent_reload_does_not_duplicate_tools():
from unittest.mock import AsyncMock, MagicMock

from agent_framework._mcp import MCPTool

tool = MCPTool(name="test_tool")
mock_session = AsyncMock()
mock_session.send_ping = AsyncMock(return_value=None)
tool.session = mock_session
tool.load_tools_flag = True

page = MagicMock()
page.tools = [
types.Tool(
name="tool_1",
description="First tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
_meta={"echo": "tool_1"},
),
]
page.nextCursor = None

async def mock_list_tools(params=None):
assert params is None
await asyncio.sleep(0)
return page

mock_session.list_tools = AsyncMock(side_effect=mock_list_tools)

await asyncio.wait_for(asyncio.gather(tool.load_tools(), tool.load_tools()), timeout=1)

assert mock_session.list_tools.call_count == 2
assert [f.name for f in tool._functions] == ["tool_1"]
assert tool._tool_call_meta_by_name == {"tool_1": {"echo": "tool_1"}}


async def test_load_tools_concurrent_paginated_reload_preserves_meta():
from unittest.mock import AsyncMock, MagicMock

from agent_framework._mcp import MCPTool

tool = MCPTool(name="test_tool")
mock_session = AsyncMock()
mock_session.send_ping = AsyncMock(return_value=None)
tool.session = mock_session
tool.load_tools_flag = True

page1 = MagicMock()
page1.tools = [
types.Tool(
name="tool_1",
description="First tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
_meta={"echo": "tool_1"},
)
]
page1.nextCursor = "cursor_page2"

page2 = MagicMock()
page2.tools = [
types.Tool(
name="tool_2",
description="Second tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
_meta={"echo": "tool_2"},
)
]
page2.nextCursor = None

async def mock_list_tools(params=None):
await asyncio.sleep(0)
if params is None:
return page1
if params.cursor == "cursor_page2":
return page2
raise ValueError("Unexpected cursor value")

mock_session.list_tools = AsyncMock(side_effect=mock_list_tools)

await asyncio.wait_for(asyncio.gather(tool.load_tools(), tool.load_tools()), timeout=1)

assert mock_session.list_tools.call_count == 4
assert [f.name for f in tool._functions] == ["tool_1", "tool_2"]
assert tool._tool_call_meta_by_name == {
"tool_1": {"echo": "tool_1"},
"tool_2": {"echo": "tool_2"},
}


async def test_load_prompts_pagination_with_duplicates():
"""Test that load_prompts prevents duplicates across paginated results."""
from unittest.mock import AsyncMock, MagicMock
Expand Down