Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 38 additions & 61 deletions Sources/MCP/Base/Transports/NetworkTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -511,9 +511,6 @@ import Logging
var messageWithNewline = message
messageWithNewline.append(UInt8(ascii: "\n"))

// Use a local actor-isolated variable to track continuation state
var sendContinuationResumed = false

try await withCheckedThrowingContinuation {
[weak self] (continuation: CheckedContinuation<Void, Swift.Error>) in
guard let self = self else {
Expand All @@ -526,54 +523,43 @@ import Logging
contentContext: .defaultMessage,
isComplete: true,
completion: .contentProcessed { [weak self] error in
guard let self = self else { return }

Task { @MainActor in
if !sendContinuationResumed {
sendContinuationResumed = true
if let error = error {
self.logger.error("Send error: \(error)")

// Check if we should attempt to reconnect on send failure
let isStopping = await self.isStopping // Await actor-isolated property
if !isStopping && self.reconnectionConfig.enabled {
let isConnected = await self.isConnected
if isConnected {
if error.isConnectionLost {
self.logger.warning(
"Connection appears broken, will attempt to reconnect..."
)

// Schedule connection restart
Task { [weak self] in // Operate on self's executor
guard let self = self else { return }

await self.setIsConnected(false)

try? await Task.sleep(for: .milliseconds(500))

let currentIsStopping = await self.isStopping
if !currentIsStopping {
// Cancel the connection, then attempt to reconnect fully.
self.connection.cancel()
try? await self.connect()
}
}
}
}
}

continuation.resume(
throwing: MCPError.internalError("Send error: \(error)"))
} else {
continuation.resume()
if let error = error {
continuation.resume(
throwing: MCPError.internalError("Send error: \(error)"))
// Handle reconnection on the actor's executor
if let self {
Task {
await self.handleSendError(error)
}
}
} else {
continuation.resume()
}
})
}
}

/// Handles reconnection logic after a send error.
///
/// This method is actor-isolated, so it safely accesses `isStopping`,
/// `isConnected`, and `reconnectionConfig` without data races.
private func handleSendError(_ error: NWError) async {
logger.error("Send error: \(error)")
guard !isStopping && reconnectionConfig.enabled && isConnected else { return }
guard error.isConnectionLost else { return }

logger.warning("Connection appears broken, will attempt to reconnect...")
setIsConnected(false)
connection.cancel()

guard !isStopping else { return }
do {
try await connect()
} catch {
logger.error("Reconnection failed: \(error)")
}
}

/// Receives data in an async sequence
///
/// This returns an AsyncThrowingStream that emits Data objects representing
Expand Down Expand Up @@ -747,8 +733,6 @@ import Logging
/// - Returns: The received data chunk
/// - Throws: Network errors or transport failures
private func receiveData() async throws -> Data {
var receiveContinuationResumed = false

return try await withCheckedThrowingContinuation {
[weak self] (continuation: CheckedContinuation<Data, Swift.Error>) in
guard let self = self else {
Expand All @@ -759,21 +743,14 @@ import Logging
let maxLength = bufferConfig.maxReceiveBufferSize ?? Int.max
connection.receive(minimumIncompleteLength: 1, maximumLength: maxLength) {
content, _, isComplete, error in
Task { @MainActor in
if !receiveContinuationResumed {
receiveContinuationResumed = true
if let error = error {
continuation.resume(throwing: MCPError.transportError(error))
} else if let content = content {
continuation.resume(returning: content)
} else if isComplete {
self.logger.trace("Connection completed by peer")
continuation.resume(throwing: MCPError.connectionClosed)
} else {
// EOF: Resume with empty data instead of throwing an error
continuation.resume(returning: Data())
}
}
if let error = error {
continuation.resume(throwing: MCPError.transportError(error))
} else if let content = content {
continuation.resume(returning: content)
} else if isComplete {
continuation.resume(throwing: MCPError.connectionClosed)
} else {
continuation.resume(returning: Data())
}
}
}
Expand Down
133 changes: 133 additions & 0 deletions Tests/MCPTests/NetworkTransportTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ import Testing
queueDataForReceiving(data)
}

/// Set a send-only error that does not change connection state.
/// Unlike `simulateFailure`, this keeps the connection in `.ready` state
/// so that `send()` proceeds to the NWConnection callback path.
func setSendError(_ error: NWError?) {
mockError = error
}

/// Clear any injected error without changing connection state.
func clearError() {
mockError = nil
}

/// Get all sent data
func getSentData() -> [Data] {
return sentData
Expand Down Expand Up @@ -661,5 +673,126 @@ import Testing
// Verify connection is cleaned up
#expect(weakConnection == nil, "Connection was not properly cleaned up")
}
@Test("Concurrent sends do not cause data races")
func testConcurrentSendsNoCrash() async throws {
let mockConnection = MockNetworkConnection()
let transport = NetworkTransport(
mockConnection,
heartbeatConfig: .disabled,
reconnectionConfig: .disabled
)

try await transport.connect()

// Fire many concurrent sends to surface any data race in continuation handling.
// Before the fix, mutable `var` flags captured by @Sendable closures could race.
try await withThrowingTaskGroup(of: Void.self) { group in
for i in 0..<100 {
group.addTask {
let msg = #"{"id":\#(i)}"#.data(using: .utf8)!
try await transport.send(msg)
}
}
try await group.waitForAll()
}

let sentData = mockConnection.getSentData()
#expect(sentData.count == 100)

await transport.disconnect()
}

@Test("Send error resumes continuation immediately")
func testSendErrorResumesContinuation() async throws {
let mockConnection = MockNetworkConnection()
let transport = NetworkTransport(
mockConnection,
heartbeatConfig: .disabled,
reconnectionConfig: .disabled
)

try await transport.connect()

// Inject a connection-lost error for the next send
mockConnection.setSendError(NWError.posix(POSIXErrorCode(rawValue: 57)!))

// send() must throw without hanging — continuation is resumed directly
// in the NWConnection callback, not deferred via Task { @MainActor in }
do {
try await transport.send(#"{"test":"error"}"#.data(using: .utf8)!)
Issue.record("Expected send to throw on error")
} catch {
#expect(error is MCPError)
}

await transport.disconnect()
}

@Test("Receive error resumes continuation without crash")
func testReceiveErrorResumesContinuation() async throws {
let mockConnection = MockNetworkConnection()
let transport = NetworkTransport(
mockConnection,
heartbeatConfig: .disabled,
reconnectionConfig: .disabled
)

try await transport.connect()

// Inject receive error after connection is established
mockConnection.simulateFailure(error: NWError.posix(POSIXErrorCode.ECONNRESET))

let stream = await transport.receive()
var receivedError = false

do {
for try await _ in stream {
break
}
} catch {
receivedError = true
}

#expect(receivedError, "Expected receive stream to surface the error")

await transport.disconnect()
}

@Test("Concurrent sends with intermittent errors")
func testConcurrentSendsWithErrors() async throws {
let mockConnection = MockNetworkConnection()
let transport = NetworkTransport(
mockConnection,
heartbeatConfig: .disabled,
reconnectionConfig: .disabled
)

try await transport.connect()

// Mix of successful and failing sends should not crash or deadlock.
// This tests that continuation resume is always called exactly once.
var successCount = 0
var errorCount = 0

for i in 0..<20 {
if i == 10 {
mockConnection.setSendError(NWError.posix(POSIXErrorCode(rawValue: 57)!))
} else if i == 11 {
mockConnection.clearError()
}

do {
try await transport.send(#"{"id":\#(i)}"#.data(using: .utf8)!)
successCount += 1
} catch {
errorCount += 1
}
}

#expect(successCount > 0)
#expect(errorCount > 0)

await transport.disconnect()
}
}
#endif
Loading