diff --git a/Sources/MCP/Base/Transports/NetworkTransport.swift b/Sources/MCP/Base/Transports/NetworkTransport.swift index 62b623c8..ea55c66a 100644 --- a/Sources/MCP/Base/Transports/NetworkTransport.swift +++ b/Sources/MCP/Base/Transports/NetworkTransport.swift @@ -511,9 +511,6 @@ import Logging var messageWithNewline = message messageWithNewline.append(UInt8(ascii: "\n")) - // Use a local actor-isolated variable to track continuation state - var sendContinuationResumed = false - try await withCheckedThrowingContinuation { [weak self] (continuation: CheckedContinuation) in guard let self = self else { @@ -526,54 +523,43 @@ import Logging contentContext: .defaultMessage, isComplete: true, completion: .contentProcessed { [weak self] error in - guard let self = self else { return } - - Task { @MainActor in - if !sendContinuationResumed { - sendContinuationResumed = true - if let error = error { - self.logger.error("Send error: \(error)") - - // Check if we should attempt to reconnect on send failure - let isStopping = await self.isStopping // Await actor-isolated property - if !isStopping && self.reconnectionConfig.enabled { - let isConnected = await self.isConnected - if isConnected { - if error.isConnectionLost { - self.logger.warning( - "Connection appears broken, will attempt to reconnect..." - ) - - // Schedule connection restart - Task { [weak self] in // Operate on self's executor - guard let self = self else { return } - - await self.setIsConnected(false) - - try? await Task.sleep(for: .milliseconds(500)) - - let currentIsStopping = await self.isStopping - if !currentIsStopping { - // Cancel the connection, then attempt to reconnect fully. - self.connection.cancel() - try? await self.connect() - } - } - } - } - } - - continuation.resume( - throwing: MCPError.internalError("Send error: \(error)")) - } else { - continuation.resume() + if let error = error { + continuation.resume( + throwing: MCPError.internalError("Send error: \(error)")) + // Handle reconnection on the actor's executor + if let self { + Task { + await self.handleSendError(error) } } + } else { + continuation.resume() } }) } } + /// Handles reconnection logic after a send error. + /// + /// This method is actor-isolated, so it safely accesses `isStopping`, + /// `isConnected`, and `reconnectionConfig` without data races. + private func handleSendError(_ error: NWError) async { + logger.error("Send error: \(error)") + guard !isStopping && reconnectionConfig.enabled && isConnected else { return } + guard error.isConnectionLost else { return } + + logger.warning("Connection appears broken, will attempt to reconnect...") + setIsConnected(false) + connection.cancel() + + guard !isStopping else { return } + do { + try await connect() + } catch { + logger.error("Reconnection failed: \(error)") + } + } + /// Receives data in an async sequence /// /// This returns an AsyncThrowingStream that emits Data objects representing @@ -747,8 +733,6 @@ import Logging /// - Returns: The received data chunk /// - Throws: Network errors or transport failures private func receiveData() async throws -> Data { - var receiveContinuationResumed = false - return try await withCheckedThrowingContinuation { [weak self] (continuation: CheckedContinuation) in guard let self = self else { @@ -759,21 +743,14 @@ import Logging let maxLength = bufferConfig.maxReceiveBufferSize ?? Int.max connection.receive(minimumIncompleteLength: 1, maximumLength: maxLength) { content, _, isComplete, error in - Task { @MainActor in - if !receiveContinuationResumed { - receiveContinuationResumed = true - if let error = error { - continuation.resume(throwing: MCPError.transportError(error)) - } else if let content = content { - continuation.resume(returning: content) - } else if isComplete { - self.logger.trace("Connection completed by peer") - continuation.resume(throwing: MCPError.connectionClosed) - } else { - // EOF: Resume with empty data instead of throwing an error - continuation.resume(returning: Data()) - } - } + if let error = error { + continuation.resume(throwing: MCPError.transportError(error)) + } else if let content = content { + continuation.resume(returning: content) + } else if isComplete { + continuation.resume(throwing: MCPError.connectionClosed) + } else { + continuation.resume(returning: Data()) } } } diff --git a/Tests/MCPTests/NetworkTransportTests.swift b/Tests/MCPTests/NetworkTransportTests.swift index 9c02c90c..93515971 100644 --- a/Tests/MCPTests/NetworkTransportTests.swift +++ b/Tests/MCPTests/NetworkTransportTests.swift @@ -160,6 +160,18 @@ import Testing queueDataForReceiving(data) } + /// Set a send-only error that does not change connection state. + /// Unlike `simulateFailure`, this keeps the connection in `.ready` state + /// so that `send()` proceeds to the NWConnection callback path. + func setSendError(_ error: NWError?) { + mockError = error + } + + /// Clear any injected error without changing connection state. + func clearError() { + mockError = nil + } + /// Get all sent data func getSentData() -> [Data] { return sentData @@ -661,5 +673,126 @@ import Testing // Verify connection is cleaned up #expect(weakConnection == nil, "Connection was not properly cleaned up") } + @Test("Concurrent sends do not cause data races") + func testConcurrentSendsNoCrash() async throws { + let mockConnection = MockNetworkConnection() + let transport = NetworkTransport( + mockConnection, + heartbeatConfig: .disabled, + reconnectionConfig: .disabled + ) + + try await transport.connect() + + // Fire many concurrent sends to surface any data race in continuation handling. + // Before the fix, mutable `var` flags captured by @Sendable closures could race. + try await withThrowingTaskGroup(of: Void.self) { group in + for i in 0..<100 { + group.addTask { + let msg = #"{"id":\#(i)}"#.data(using: .utf8)! + try await transport.send(msg) + } + } + try await group.waitForAll() + } + + let sentData = mockConnection.getSentData() + #expect(sentData.count == 100) + + await transport.disconnect() + } + + @Test("Send error resumes continuation immediately") + func testSendErrorResumesContinuation() async throws { + let mockConnection = MockNetworkConnection() + let transport = NetworkTransport( + mockConnection, + heartbeatConfig: .disabled, + reconnectionConfig: .disabled + ) + + try await transport.connect() + + // Inject a connection-lost error for the next send + mockConnection.setSendError(NWError.posix(POSIXErrorCode(rawValue: 57)!)) + + // send() must throw without hanging — continuation is resumed directly + // in the NWConnection callback, not deferred via Task { @MainActor in } + do { + try await transport.send(#"{"test":"error"}"#.data(using: .utf8)!) + Issue.record("Expected send to throw on error") + } catch { + #expect(error is MCPError) + } + + await transport.disconnect() + } + + @Test("Receive error resumes continuation without crash") + func testReceiveErrorResumesContinuation() async throws { + let mockConnection = MockNetworkConnection() + let transport = NetworkTransport( + mockConnection, + heartbeatConfig: .disabled, + reconnectionConfig: .disabled + ) + + try await transport.connect() + + // Inject receive error after connection is established + mockConnection.simulateFailure(error: NWError.posix(POSIXErrorCode.ECONNRESET)) + + let stream = await transport.receive() + var receivedError = false + + do { + for try await _ in stream { + break + } + } catch { + receivedError = true + } + + #expect(receivedError, "Expected receive stream to surface the error") + + await transport.disconnect() + } + + @Test("Concurrent sends with intermittent errors") + func testConcurrentSendsWithErrors() async throws { + let mockConnection = MockNetworkConnection() + let transport = NetworkTransport( + mockConnection, + heartbeatConfig: .disabled, + reconnectionConfig: .disabled + ) + + try await transport.connect() + + // Mix of successful and failing sends should not crash or deadlock. + // This tests that continuation resume is always called exactly once. + var successCount = 0 + var errorCount = 0 + + for i in 0..<20 { + if i == 10 { + mockConnection.setSendError(NWError.posix(POSIXErrorCode(rawValue: 57)!)) + } else if i == 11 { + mockConnection.clearError() + } + + do { + try await transport.send(#"{"id":\#(i)}"#.data(using: .utf8)!) + successCount += 1 + } catch { + errorCount += 1 + } + } + + #expect(successCount > 0) + #expect(errorCount > 0) + + await transport.disconnect() + } } #endif