diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index 60950dfa5..fa9ee97aa 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -80,6 +80,8 @@ public override async Task SendMessageAsync( messageId = messageWithId.Id.ToString(); } + LogTransportSendingMessageSensitive(message); + using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint); StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null); var response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false); diff --git a/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs index c896bd433..d582abe31 100644 --- a/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs @@ -105,6 +105,8 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation var json = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); + LogTransportSendingMessageSensitive(Name, json); + using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); try { diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index 017512589..c33d35322 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -75,6 +75,8 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes $"Call {nameof(McpClient)}.{nameof(McpClient.ResumeSessionAsync)} to resume existing sessions."); } + LogTransportSendingMessageSensitive(message); + using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token); cancellationToken = sendCts.Token; diff --git a/src/ModelContextProtocol.Core/Protocol/TransportBase.cs b/src/ModelContextProtocol.Core/Protocol/TransportBase.cs index 97897b53f..e3e8e8c8b 100644 --- a/src/ModelContextProtocol.Core/Protocol/TransportBase.cs +++ b/src/ModelContextProtocol.Core/Protocol/TransportBase.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using System.Diagnostics; +using System.Text.Json; using System.Threading.Channels; namespace ModelContextProtocol.Protocol; @@ -166,6 +167,21 @@ protected void SetDisconnected(Exception? error = null) [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} transport send failed for message ID '{MessageId}'.")] private protected partial void LogTransportSendFailed(string endpointName, string messageId, Exception exception); + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} transport sending message. Message: '{Message}'.")] + private protected partial void LogTransportSendingMessageSensitive(string endpointName, string message); + + /// + /// Logs a sending message at Trace level if trace logging is enabled. + /// + /// The JSON-RPC message to log. + private protected void LogTransportSendingMessageSensitive(JsonRpcMessage message) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogTransportSendingMessageSensitive(Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); + } + } + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} transport reading messages.")] private protected partial void LogTransportEnteringReadMessagesLoop(string endpointName); diff --git a/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs index 7747d7f18..1ab106e26 100644 --- a/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs @@ -74,7 +74,9 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation try { - await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), cancellationToken).ConfigureAwait(false); + var json = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); + LogTransportSendingMessageSensitive(Name, json); + await _outputStream.WriteAsync(Encoding.UTF8.GetBytes(json), cancellationToken).ConfigureAwait(false); await _outputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false); await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false); } diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index cbe44da15..3f5756620 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -1,4 +1,5 @@ -using ModelContextProtocol.Protocol; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; using System.IO.Pipelines; @@ -21,6 +22,14 @@ public StdioServerTransportTests(ITestOutputHelper testOutputHelper) InitializationTimeout = TimeSpan.FromSeconds(10), ServerInstructions = "Test Instructions" }; + + // Override the LoggerFactory to use Trace level for testing Trace-level logging + LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => + { + builder.AddProvider(XunitLoggerProvider); + builder.AddProvider(MockLoggerProvider); + builder.SetMinimumLevel(LogLevel.Trace); + }); } [Fact(Skip="https://github.com/modelcontextprotocol/csharp-sdk/issues/143")] @@ -193,4 +202,59 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() Assert.True(magnifyingGlassFound, "Magnifying glass emoji not found in result"); Assert.True(rocketFound, "Rocket emoji not found in result"); } + + [Fact] + public async Task SendMessageAsync_Should_Log_At_Trace_Level() + { + // Arrange + using var output = new MemoryStream(); + + await using var transport = new StreamServerTransport( + new Pipe().Reader.AsStream(), + output, + loggerFactory: LoggerFactory); + + // Act + var message = new JsonRpcRequest { Method = "test", Id = new RequestId(44) }; + await transport.SendMessageAsync(message, TestContext.Current.CancellationToken); + + // Assert + var traceLogMessages = MockLoggerProvider.LogMessages + .Where(x => x.LogLevel == LogLevel.Trace && x.Message.Contains("transport sending message")) + .ToList(); + + Assert.NotEmpty(traceLogMessages); + Assert.Contains(traceLogMessages, x => x.Message.Contains("\"method\":\"test\"") && x.Message.Contains("\"id\":44")); + } + + [Fact] + public async Task ReadMessagesAsync_Should_Log_Received_At_Trace_Level() + { + // Arrange + var message = new JsonRpcRequest { Method = "test", Id = new RequestId(99) }; + var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions); + + Pipe pipe = new(); + using var input = pipe.Reader.AsStream(); + + await using var transport = new StreamServerTransport( + input, + Stream.Null, + loggerFactory: LoggerFactory); + + // Act + await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes($"{json}\n"), TestContext.Current.CancellationToken); + + // Wait for the message to be processed + var canRead = await transport.MessageReader.WaitToReadAsync(TestContext.Current.CancellationToken); + Assert.True(canRead, "Nothing to read here from transport message reader"); + + // Assert + var traceLogMessages = MockLoggerProvider.LogMessages + .Where(x => x.LogLevel == LogLevel.Trace && x.Message.Contains("transport received message")) + .ToList(); + + Assert.NotEmpty(traceLogMessages); + Assert.Contains(traceLogMessages, x => x.Message.Contains("\"method\":\"test\"") && x.Message.Contains("\"id\":99")); + } }