Skip to content

Commit ca917fa

Browse files
committed
Implement MCP client opentelemetry tracing
finish client spans Inject tracer_provider into servers
1 parent c0ae32f commit ca917fa

File tree

8 files changed

+806
-6
lines changed

8 files changed

+806
-6
lines changed

src/mcp/client/client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from dataclasses import KW_ONLY, dataclass, field
77
from typing import Any
88

9+
from opentelemetry.trace import TracerProvider
10+
911
from mcp.client._memory import InMemoryTransport
1012
from mcp.client._transport import Transport
1113
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
@@ -95,6 +97,8 @@ async def main():
9597
elicitation_callback: ElicitationFnT | None = None
9698
"""Callback for handling elicitation requests."""
9799

100+
tracer_provider: TracerProvider | None = None
101+
98102
_session: ClientSession | None = field(init=False, default=None)
99103
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
100104
_transport: Transport = field(init=False)
@@ -126,6 +130,7 @@ async def __aenter__(self) -> Client:
126130
message_handler=self.message_handler,
127131
client_info=self.client_info,
128132
elicitation_callback=self.elicitation_callback,
133+
tracer_provider=self.tracer_provider,
129134
)
130135
)
131136

src/mcp/client/session.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import anyio.lowlevel
77
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
8+
from opentelemetry.trace import TracerProvider
89
from pydantic import TypeAdapter
910

1011
from mcp import types
@@ -121,8 +122,11 @@ def __init__(
121122
*,
122123
sampling_capabilities: types.SamplingCapability | None = None,
123124
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
125+
tracer_provider: TracerProvider | None = None,
124126
) -> None:
125-
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds)
127+
super().__init__(
128+
read_stream, write_stream, read_timeout_seconds=read_timeout_seconds, tracer_provider=tracer_provider
129+
)
126130
self._client_info = client_info or DEFAULT_CLIENT_INFO
127131
self._sampling_callback = sampling_callback or _default_sampling_callback
128132
self._sampling_capabilities = sampling_capabilities

src/mcp/server/lowlevel/server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ async def main():
4646

4747
import anyio
4848
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
49+
from opentelemetry.trace import TracerProvider
4950
from starlette.applications import Starlette
5051
from starlette.middleware import Middleware
5152
from starlette.middleware.authentication import AuthenticationMiddleware
@@ -182,6 +183,7 @@ def __init__(
182183
Awaitable[None],
183184
]
184185
| None = None,
186+
tracer_provider: TracerProvider | None = None,
185187
):
186188
self.name = name
187189
self.version = version
@@ -197,6 +199,7 @@ def __init__(
197199
] = {}
198200
self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None
199201
self._session_manager: StreamableHTTPSessionManager | None = None
202+
self._tracer_provider = tracer_provider
200203
logger.debug("Initializing server %r", name)
201204

202205
# Populate internal handler dicts from on_* kwargs
@@ -378,6 +381,7 @@ async def run(
378381
write_stream,
379382
initialization_options,
380383
stateless=stateless,
384+
tracer_provider=self._tracer_provider,
381385
)
382386
)
383387

src/mcp/server/mcpserver/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import anyio
1414
import pydantic_core
15+
from opentelemetry.trace import TracerProvider
1516
from pydantic.networks import AnyUrl
1617
from pydantic_settings import BaseSettings, SettingsConfigDict
1718
from starlette.applications import Starlette
@@ -144,6 +145,7 @@ def __init__(
144145
warn_on_duplicate_prompts: bool = True,
145146
lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None,
146147
auth: AuthSettings | None = None,
148+
tracer_provider: TracerProvider | None = None,
147149
):
148150
self.settings = Settings(
149151
debug=debug,
@@ -176,6 +178,7 @@ def __init__(
176178
# TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an MCPServer and Server.
177179
# We need to create a Lifespan type that is a generic on the server type, like Starlette does.
178180
lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore
181+
tracer_provider=tracer_provider,
179182
)
180183
# Validate auth configuration
181184
if self.settings.auth is not None:

src/mcp/server/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
3434
import anyio
3535
import anyio.lowlevel
3636
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
37+
from opentelemetry.trace import TracerProvider
3738
from pydantic import AnyUrl, TypeAdapter
3839

3940
from mcp import types
@@ -83,8 +84,9 @@ def __init__(
8384
write_stream: MemoryObjectSendStream[SessionMessage],
8485
init_options: InitializationOptions,
8586
stateless: bool = False,
87+
tracer_provider: TracerProvider | None = None,
8688
) -> None:
87-
super().__init__(read_stream, write_stream)
89+
super().__init__(read_stream, write_stream, tracer_provider=tracer_provider)
8890
self._stateless = stateless
8991
self._initialization_state = (
9092
InitializationState.Initialized if stateless else InitializationState.NotInitialized

src/mcp/shared/_otel_utils.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import contextlib
2+
from collections.abc import Iterator
3+
from typing import Any
4+
5+
from opentelemetry.trace import Span, SpanKind, StatusCode, Tracer
6+
7+
from mcp import types
8+
from mcp.shared.exceptions import MCPError
9+
10+
# OTel Semantic Conventions for MCP and GenAI
11+
# See: https://github.com/open-telemetry/semantic-conventions/blob/v1.40.0/docs/gen-ai/mcp.md
12+
MCP_METHOD_NAME = "mcp.method.name"
13+
MCP_RESOURCE_URI = "mcp.resource.uri"
14+
JSONRPC_REQUEST_ID = "jsonrpc.request.id"
15+
16+
GEN_AI_TOOL_NAME = "gen_ai.tool.name"
17+
GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
18+
GEN_AI_PROMPT_NAME = "gen_ai.prompt.name"
19+
20+
ERROR_TYPE = "error.type"
21+
RPC_RESPONSE_STATUS_CODE = "rpc.response.status_code"
22+
23+
24+
def _get_span_name(
25+
request: types.ClientRequest | types.ServerRequest | types.ClientNotification | types.ServerNotification,
26+
) -> str:
27+
"""Computes the span name based on the request type and parameters."""
28+
target = None
29+
match request:
30+
case types.CallToolRequest():
31+
target = request.params.name
32+
case types.GetPromptRequest():
33+
target = request.params.name
34+
case _:
35+
pass
36+
37+
if target:
38+
return f"{request.method} {target}"
39+
return request.method
40+
41+
42+
def _get_common_attributes(
43+
request: types.ClientRequest | types.ServerRequest | types.ClientNotification | types.ServerNotification,
44+
*,
45+
json_rpc_request_id: int | str | None = None,
46+
) -> dict[str, Any]:
47+
"""Computes common attributes for both client and server spans."""
48+
attributes = {MCP_METHOD_NAME: request.method}
49+
50+
if json_rpc_request_id is not None:
51+
attributes[JSONRPC_REQUEST_ID] = str(json_rpc_request_id)
52+
53+
match request:
54+
case types.CallToolRequest():
55+
attributes[GEN_AI_TOOL_NAME] = request.params.name
56+
attributes[GEN_AI_OPERATION_NAME] = "execute_tool"
57+
case types.GetPromptRequest():
58+
attributes[GEN_AI_PROMPT_NAME] = request.params.name
59+
case (
60+
types.ReadResourceRequest()
61+
| types.SubscribeRequest()
62+
| types.UnsubscribeRequest()
63+
| types.ResourceUpdatedNotification()
64+
):
65+
attributes[MCP_RESOURCE_URI] = request.params.uri
66+
case _:
67+
pass
68+
return attributes
69+
70+
71+
_ERROR_NAMES = {
72+
types.INVALID_PARAMS: "invalid_params",
73+
types.METHOD_NOT_FOUND: "method_not_found",
74+
types.CONNECTION_CLOSED: "connection_closed",
75+
types.REQUEST_TIMEOUT: "timeout",
76+
types.PARSE_ERROR: "parse_error",
77+
types.INTERNAL_ERROR: "internal_error",
78+
types.INVALID_REQUEST: "invalid_request",
79+
types.URL_ELICITATION_REQUIRED: "url_elicitation_required",
80+
}
81+
82+
83+
def _record_error_data(span: Span, e: types.ErrorData, record_status: bool = True) -> None:
84+
"""Record an MCP protocol error on the span set status
85+
86+
https://github.com/open-telemetry/semantic-conventions/blob/v1.40.0/docs/general/recording-errors.md
87+
"""
88+
if not span.is_recording():
89+
return
90+
91+
span.set_attribute(ERROR_TYPE, _ERROR_NAMES.get(e.code, str(e.code)))
92+
span.set_attribute(RPC_RESPONSE_STATUS_CODE, str(e.code))
93+
span.set_status(status=StatusCode.ERROR, description=e.message)
94+
95+
96+
@contextlib.contextmanager
97+
def mcp_client_span(
98+
tracer: Tracer,
99+
request: types.ClientRequest | types.ServerRequest | types.ClientNotification | types.ServerNotification,
100+
*,
101+
json_rpc_request_id: int | str | None = None,
102+
) -> Iterator[Span]:
103+
"""Starts an MCP client span as current span
104+
105+
https://github.com/open-telemetry/semantic-conventions/blob/v1.40.0/docs/gen-ai/mcp.md#client
106+
"""
107+
span_name = _get_span_name(request)
108+
attributes = _get_common_attributes(request, json_rpc_request_id=json_rpc_request_id)
109+
110+
reraise_exc = None
111+
with tracer.start_as_current_span(
112+
span_name,
113+
kind=SpanKind.CLIENT,
114+
attributes=attributes,
115+
set_status_on_exception=False,
116+
) as span:
117+
try:
118+
yield span
119+
except MCPError as mcp_error:
120+
_record_error_data(span, mcp_error.error)
121+
span.record_exception(mcp_error)
122+
# re-raise outside of with block to avoid overwriting span status
123+
reraise_exc = mcp_error
124+
125+
if reraise_exc:
126+
raise reraise_exc

src/mcp/shared/session.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1212
from opentelemetry import context as otel_context
1313
from opentelemetry.propagate import extract, inject
14+
from opentelemetry.trace import TracerProvider, get_tracer
1415
from pydantic import BaseModel, TypeAdapter
1516
from typing_extensions import Self
1617

18+
from mcp.shared._otel_utils import mcp_client_span
1719
from mcp.shared.exceptions import MCPError
1820
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
1921
from mcp.shared.response_router import ResponseRouter
@@ -190,6 +192,8 @@ def __init__(
190192
write_stream: MemoryObjectSendStream[SessionMessage],
191193
# If none, reading will never time out
192194
read_timeout_seconds: float | None = None,
195+
*,
196+
tracer_provider: TracerProvider | None = None,
193197
) -> None:
194198
self._read_stream = read_stream
195199
self._write_stream = write_stream
@@ -200,6 +204,7 @@ def __init__(
200204
self._progress_callbacks = {}
201205
self._response_routers = []
202206
self._exit_stack = AsyncExitStack()
207+
self._tracer = get_tracer("mcp", tracer_provider=tracer_provider)
203208

204209
def add_response_router(self, router: ResponseRouter) -> None:
205210
"""Register a response router to handle responses for non-standard requests.
@@ -268,10 +273,10 @@ async def send_request(
268273
# Store the callback for this request
269274
self._progress_callbacks[request_id] = progress_callback
270275

271-
# Propagate opentelemetry trace context
272-
self._inject_otel_context(request_data)
276+
async def make_request():
277+
# Propagate opentelemetry trace context
278+
self._inject_otel_context(request_data)
273279

274-
try:
275280
jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data)
276281
await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata))
277282

@@ -291,6 +296,9 @@ async def send_request(
291296
else:
292297
return result_type.model_validate(response_or_error.result, by_name=False)
293298

299+
try:
300+
with mcp_client_span(self._tracer, request, json_rpc_request_id=request_id):
301+
return await make_request()
294302
finally:
295303
self._response_streams.pop(request_id, None)
296304
self._progress_callbacks.pop(request_id, None)
@@ -315,7 +323,9 @@ async def send_notification(
315323
message=jsonrpc_notification,
316324
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
317325
)
318-
await self._write_stream.send(session_message)
326+
327+
with mcp_client_span(self._tracer, notification):
328+
await self._write_stream.send(session_message)
319329

320330
def _inject_otel_context(self, request: dict[str, Any]) -> None:
321331
"""Propagate OpenTelemetry context in `_meta`.

0 commit comments

Comments
 (0)