diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 35ccb1d58a..184353f833 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -18,6 +18,7 @@ from opentelemetry import propagate +from ._middleware import FunctionInvocationContext from ._tools import FunctionTool from ._types import ( ChatOptions, @@ -39,7 +40,6 @@ from mcp.shared.session import RequestResponder from ._clients import SupportsChatGetResponse - from ._middleware import FunctionInvocationContext logger = logging.getLogger(__name__) @@ -1292,6 +1292,121 @@ async def get_prompt(self, prompt_name: str, **kwargs: Any) -> str: raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex raise ToolExecutionException(f"Failed to get prompt '{prompt_name}' after retries.") + def as_progressive_tools( + self, + list_tool_name: str = "list_mcp_tools", + call_tool_name: str = "call_mcp", + ) -> list[FunctionTool]: + """Expose this MCP server in a progressive discovery mode. + + Instead of exposing every remote tool schema upfront, the model receives a small + stable surface: + - A discovery tool to list available tools and their schemas. + - A dispatch tool to call a specific tool by name. + + This is useful for large MCP servers where exposing all tool schemas upfront + would add significant token overhead. The SDK still owns connection lifecycle, + allowed_tools filtering, result parsing, exceptions, and OTel propagation. + + Args: + list_tool_name: Name for the discovery tool. Defaults to "list_mcp_tools". + call_tool_name: Name for the dispatch tool. Defaults to "call_mcp". + + Returns: + A list of exactly two FunctionTools to pass to an Agent. + """ + + async def _list_tools(server: str | None = None) -> str: + """List available tools on this MCP server. + + Args: + server: The name of the server to list tools for. Must match this server's name if provided. + """ + if server and server != self.name: + return json.dumps([]) + + tool_list = [] + for func in self.functions: + tool_list.append( + { + "name": func.name, + "description": func.description, + "parameters": func.parameters(), + } + ) + return json.dumps(tool_list, separators=(",", ":")) + + async def _call_tool( + server: str, + tool: str, + arguments: dict[str, Any] | None = None, + context: FunctionInvocationContext | None = None, + ) -> Any: + """Call a specific tool on this MCP server. + + Note: + Any approval_mode or middleware configured on the underlying target tool + are enforced at the call_mcp wrapper tool level, as call_mcp is the + actual FunctionTool that traverses the agent execution pipeline. + + Args: + server: The name of the server. Must match this server's name. + tool: The name of the tool to call. + arguments: The arguments to pass to the tool. + context: The framework function invocation context. + """ + if server != self.name: + raise ToolExecutionException(f"Unknown server '{server}'. This dispatcher is for server '{self.name}'.") + + target_func: FunctionTool | None = None + for func in self.functions: + props = func.additional_properties or {} + if ( + func.name == tool + or props.get(_MCP_NORMALIZED_NAME_KEY) == tool + or props.get(_MCP_REMOTE_NAME_KEY) == tool + ): + target_func = func + break + + if not target_func: + raise ToolExecutionException(f"Tool '{tool}' not found or not allowed on server '{self.name}'.") + + # Create a fresh context for the target tool so that FunctionTool.invoke's + # in-place mutations (context.function, context.arguments, context.kwargs) + # do not corrupt the wrapper call_mcp's context that middleware may still + # be observing after call_next() returns. + target_context = FunctionInvocationContext( + function=target_func, + arguments=arguments or {}, + session=context.session if context is not None else None, + kwargs=context.kwargs if context is not None else None, + ) + return await target_func.invoke(arguments=arguments or {}, context=target_context) + + list_tool = FunctionTool( + name=list_tool_name, + description=f"List available tools on the {self.name} MCP server.", + func=_list_tools, + approval_mode="never_require", + ) + + # When approval_mode is a dict (MCPSpecificApproval with per-tool allow/deny + # lists), the framework's _try_execute_function_calls cannot interpret it as + # a wrapper-level policy string and will silently bypass approval. Normalise + # any dict value to "always_require" so the dispatch wrapper is always gated + # conservatively. + wrapper_approval_mode = "always_require" if isinstance(self.approval_mode, dict) else self.approval_mode + + call_tool = FunctionTool( + name=call_tool_name, + description=f"Call a specific tool on the {self.name} MCP server.", + func=_call_tool, + approval_mode=wrapper_approval_mode, + ) + + return [list_tool, call_tool] + async def __aenter__(self) -> Self: """Enter the async context manager. diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 0fc5867d79..bd34f52c05 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -20,6 +20,7 @@ Content, FunctionInvocationContext, FunctionMiddleware, + FunctionTool, MCPStdioTool, MCPStreamableHTTPTool, MCPWebsocketTool, @@ -4618,4 +4619,378 @@ def provider(kwargs): assert call_args.kwargs.get("arguments", {}).get("name") == "Alice" +# region Progressive Tools + +def test_as_progressive_tools_returns_two_tools(): + tool = MCPTool(name="my-server") + prog_tools = tool.as_progressive_tools() + assert len(prog_tools) == 2 + assert prog_tools[0].name == "list_mcp_tools" + assert prog_tools[1].name == "call_mcp" + + +def test_as_progressive_tools_custom_names(): + tool = MCPTool(name="my-server") + prog_tools = tool.as_progressive_tools( + list_tool_name="list_custom", + call_tool_name="call_custom", + ) + assert len(prog_tools) == 2 + assert prog_tools[0].name == "list_custom" + assert prog_tools[1].name == "call_custom" + + +@pytest.mark.asyncio +async def test_progressive_list_mcp_tools_returns_all_tools(): + tool = MCPTool(name="my-server") + # Mock some functions on the tool + func1 = FunctionTool(name="tool_a", description="First tool", func=lambda: None) + func2 = FunctionTool(name="tool_b", description="Second tool", func=lambda: None) + tool._functions = [func1, func2] + + prog_tools = tool.as_progressive_tools() + list_tool = prog_tools[0] + + # list_mcp_tools returns JSON + res = await list_tool.invoke(arguments={}) + assert isinstance(res, list) + parsed = json.loads(res[0].text) + assert len(parsed) == 2 + assert parsed[0]["name"] == "tool_a" + assert parsed[0]["description"] == "First tool" + assert parsed[1]["name"] == "tool_b" + assert parsed[1]["description"] == "Second tool" + + +@pytest.mark.asyncio +async def test_progressive_list_mcp_tools_filter_by_server(): + tool = MCPTool(name="my-server") + func1 = FunctionTool(name="tool_a", description="First tool", func=lambda: None) + tool._functions = [func1] + + prog_tools = tool.as_progressive_tools() + list_tool = prog_tools[0] + + # Specifying the correct server returns the tools + res = await list_tool.invoke(arguments={"server": "my-server"}) + assert isinstance(res, list) + parsed = json.loads(res[0].text) + assert len(parsed) == 1 + assert parsed[0]["name"] == "tool_a" + + +@pytest.mark.asyncio +async def test_progressive_list_mcp_tools_wrong_server_returns_empty(): + tool = MCPTool(name="my-server") + func1 = FunctionTool(name="tool_a", description="First tool", func=lambda: None) + tool._functions = [func1] + + prog_tools = tool.as_progressive_tools() + list_tool = prog_tools[0] + + # Specifying a different server returns an empty list + res = await list_tool.invoke(arguments={"server": "other-server"}) + assert isinstance(res, list) + parsed = json.loads(res[0].text) + assert len(parsed) == 0 + + +@pytest.mark.asyncio +async def test_progressive_call_mcp_dispatches_allowed_tool(): + tool = MCPTool(name="my-server") + + called_with = {} + async def mock_func(x: int): + called_with["x"] = x + return "Success!" + + func1 = FunctionTool(name="tool_a", description="First tool", func=mock_func) + tool._functions = [func1] + + prog_tools = tool.as_progressive_tools() + call_tool = prog_tools[1] + + res = await call_tool.invoke(arguments={"server": "my-server", "tool": "tool_a", "arguments": {"x": 1}}) + assert isinstance(res, list) + assert res[0].text == "Success!" + assert called_with == {"x": 1} + + +@pytest.mark.asyncio +async def test_progressive_call_mcp_with_no_arguments(): + tool = MCPTool(name="my-server") + + called_with = {} + async def mock_func(): + called_with["called"] = True + return "Success!" + + func1 = FunctionTool(name="tool_a", description="First tool", func=mock_func) + tool._functions = [func1] + + prog_tools = tool.as_progressive_tools() + call_tool = prog_tools[1] + + # Should default arguments to {} + res = await call_tool.invoke(arguments={"server": "my-server", "tool": "tool_a"}) + assert isinstance(res, list) + assert res[0].text == "Success!" + assert called_with == {"called": True} + + +@pytest.mark.asyncio +async def test_progressive_call_mcp_rejects_disallowed_tool(): + tool = MCPTool(name="my-server") + tool._functions = [] # No allowed tools + + prog_tools = tool.as_progressive_tools() + call_tool = prog_tools[1] + + with pytest.raises(ToolExecutionException) as exc_info: + await call_tool.invoke(arguments={"server": "my-server", "tool": "tool_a"}) + + assert "not found or not allowed" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_progressive_call_mcp_rejects_wrong_server(): + tool = MCPTool(name="my-server") + prog_tools = tool.as_progressive_tools() + call_tool = prog_tools[1] + + with pytest.raises(ToolExecutionException) as exc_info: + await call_tool.invoke(arguments={"server": "wrong-server", "tool": "tool_a"}) + + assert "Unknown server" in str(exc_info.value) + assert "This dispatcher is for server 'my-server'" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_progressive_call_mcp_preserves_error_handling(): + tool = MCPTool(name="my-server") + + async def _failing_func(): + raise ToolExecutionException("Inner error") + + func1 = FunctionTool(name="tool_a", description="Failing tool", func=_failing_func) + tool._functions = [func1] + + prog_tools = tool.as_progressive_tools() + call_tool = prog_tools[1] + + with pytest.raises(ToolExecutionException) as exc_info: + await call_tool.invoke(arguments={"server": "my-server", "tool": "tool_a"}) + + assert "Inner error" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_progressive_call_mcp_respects_approval_mode(): + from agent_framework import Content, FunctionTool + from agent_framework._mcp import MCPTool + from agent_framework._tools import _try_execute_function_calls, normalize_function_invocation_configuration + + tool = MCPTool(name="my-server", approval_mode="always_require") + + async def _mock_func(): + return "Executed!" + + func1 = FunctionTool( + name="tool_a", + description="A tool that requires approval", + func=_mock_func, + ) + tool._functions = [func1] + + prog_tools = tool.as_progressive_tools() + call_tool = prog_tools[1] + + # Assert the wrapper inherited approval_mode automatically + assert call_tool.approval_mode == "always_require" + + fcc = Content.from_function_call( + call_id="test_call_id", + name=call_tool.name, + arguments={"server": "my-server", "tool": "tool_a"} + ) + config = normalize_function_invocation_configuration(None) + + # AND assert it actually triggers an approval request (behavioral check) + results, _ = await _try_execute_function_calls( + custom_args={}, + attempt_idx=0, + function_calls=[fcc], + tools=[call_tool], + config=config, + ) + + assert len(results) == 1 + assert results[0].type == "function_approval_request" + assert results[0].function_call is not None + assert results[0].function_call.name == "call_mcp" + arguments = results[0].function_call.arguments + assert isinstance(arguments, dict) + assert arguments.get("server") == "my-server" + assert arguments.get("tool") == "tool_a" + + +@pytest.mark.asyncio +async def test_progressive_call_mcp_executes_middleware(): + from agent_framework import Content, FunctionInvocationContext, FunctionMiddleware, FunctionTool + from agent_framework._mcp import MCPTool + from agent_framework._middleware import FunctionMiddlewarePipeline + from agent_framework._tools import _try_execute_function_calls, normalize_function_invocation_configuration + + tool = MCPTool(name="my-server") + + async def _mock_func(): + return "Executed!" + + func1 = FunctionTool(name="tool_a", description="A tool", func=_mock_func) + tool._functions = [func1] + + prog_tools = tool.as_progressive_tools() + call_tool = prog_tools[1] + + executed_functions = [] + + class TestMiddleware(FunctionMiddleware): + async def process(self, context: FunctionInvocationContext, call_next): + executed_functions.append(context.function.name) + await call_next() + + pipeline = FunctionMiddlewarePipeline(TestMiddleware()) + + fcc = Content.from_function_call( + call_id="test_call_id", + name=call_tool.name, + arguments={"server": "my-server", "tool": "tool_a"} + ) + + results, _ = await _try_execute_function_calls( + custom_args={}, + attempt_idx=0, + function_calls=[fcc], + tools=[call_tool], + config=normalize_function_invocation_configuration(None), + middleware_pipeline=pipeline, + ) + + assert len(results) == 1 + assert results[0].type == "function_result" + assert results[0].result == "Executed!" + + # Middleware runs for call_mcp wrapper. Nested tool invocation does not + # trigger redundant middleware execution. + assert executed_functions == ["call_mcp"] + + +@pytest.mark.asyncio +async def test_progressive_call_mcp_dict_approval_mode_normalised(): + """When MCPTool.approval_mode is a dict (MCPSpecificApproval), as_progressive_tools() + must normalise it to "always_require" for the call_mcp wrapper so that + _try_execute_function_calls can enforce it as a concrete policy string.""" + from agent_framework import Content, FunctionTool + from agent_framework._mcp import MCPTool + from agent_framework._tools import _try_execute_function_calls, normalize_function_invocation_configuration + + # Dict-style approval_mode (MCPSpecificApproval) + tool = MCPTool( + name="my-server", + approval_mode={"always_require_approval": ["tool_a"], "never_require_approval": None}, + ) + + async def _mock_func(): + return "Executed!" + + func1 = FunctionTool(name="tool_a", description="A tool", func=_mock_func) + tool._functions = [func1] + + prog_tools = tool.as_progressive_tools() + call_tool = prog_tools[1] + + # Attribute check: dict must be normalised to the conservative "always_require" string + assert call_tool.approval_mode == "always_require" + + # Behavioral check: executing call_mcp actually triggers an approval request + fcc = Content.from_function_call( + call_id="test_call_id", + name=call_tool.name, + arguments={"server": "my-server", "tool": "tool_a"} + ) + config = normalize_function_invocation_configuration(None) + + results, _ = await _try_execute_function_calls( + custom_args={}, + attempt_idx=0, + function_calls=[fcc], + tools=[call_tool], + config=config, + ) + + assert len(results) == 1 + assert results[0].type == "function_approval_request" + assert results[0].function_call is not None + assert results[0].function_call.name == "call_mcp" + + +@pytest.mark.asyncio +async def test_progressive_call_mcp_context_not_mutated(): + """The wrapper call_mcp must not forward its own FunctionInvocationContext into + target_func.invoke(). FunctionTool.invoke mutates context in-place, so sharing + the same object would corrupt any middleware that inspects context after call_next().""" + from agent_framework import Content, FunctionInvocationContext, FunctionMiddleware, FunctionTool + from agent_framework._mcp import MCPTool + from agent_framework._middleware import FunctionMiddlewarePipeline + from agent_framework._tools import _try_execute_function_calls, normalize_function_invocation_configuration + + tool = MCPTool(name="my-server") + + async def _mock_func(): + return "Executed!" + + func1 = FunctionTool(name="tool_a", description="A tool", func=_mock_func) + tool._functions = [func1] + + prog_tools = tool.as_progressive_tools() + call_tool = prog_tools[1] + + # Capture the context state that the middleware sees BEFORE and AFTER call_next() + context_function_before: list[str] = [] + context_function_after: list[str] = [] + + class ContextSnapshotMiddleware(FunctionMiddleware): + async def process(self, context: FunctionInvocationContext, call_next): + context_function_before.append(context.function.name) + await call_next() + # After the nested invoke, the wrapper's context.function must still + # point to call_mcp — not to the inner tool_a. + context_function_after.append(context.function.name) + + pipeline = FunctionMiddlewarePipeline(ContextSnapshotMiddleware()) + + fcc = Content.from_function_call( + call_id="test_call_id", + name=call_tool.name, + arguments={"server": "my-server", "tool": "tool_a"} + ) + + results, _ = await _try_execute_function_calls( + custom_args={}, + attempt_idx=0, + function_calls=[fcc], + tools=[call_tool], + config=normalize_function_invocation_configuration(None), + middleware_pipeline=pipeline, + ) + + assert results[0].type == "function_result" + assert results[0].result == "Executed!" + + # Before call_next: context belongs to call_mcp + assert context_function_before == ["call_mcp"] + # After call_next: context.function must STILL be call_mcp (not mutated to tool_a) + assert context_function_after == ["call_mcp"] + + # endregion