diff --git a/Sources/MCP/Base/Transports/NetworkTransport.swift b/Sources/MCP/Base/Transports/NetworkTransport.swift index 62b623c8..e18247a1 100644 --- a/Sources/MCP/Base/Transports/NetworkTransport.swift +++ b/Sources/MCP/Base/Transports/NetworkTransport.swift @@ -511,8 +511,10 @@ import Logging var messageWithNewline = message messageWithNewline.append(UInt8(ascii: "\n")) - // Use a local actor-isolated variable to track continuation state - var sendContinuationResumed = false + // Use a local variable to track continuation state + // This is safe because the completion handler should only be called once, + // and the continuation can only be resumed once regardless + nonisolated(unsafe) var sendContinuationResumed = false try await withCheckedThrowingContinuation { [weak self] (continuation: CheckedContinuation) in @@ -747,7 +749,10 @@ import Logging /// - Returns: The received data chunk /// - Throws: Network errors or transport failures private func receiveData() async throws -> Data { - var receiveContinuationResumed = false + // Use a local variable to track continuation state + // This is safe because the receive completion handler should only be called once, + // and the continuation can only be resumed once regardless + nonisolated(unsafe) var receiveContinuationResumed = false return try await withCheckedThrowingContinuation { [weak self] (continuation: CheckedContinuation) in diff --git a/Sources/MCP/Base/Transports/UnixSocketClientTransport.swift b/Sources/MCP/Base/Transports/UnixSocketClientTransport.swift new file mode 100644 index 00000000..74c2b79f --- /dev/null +++ b/Sources/MCP/Base/Transports/UnixSocketClientTransport.swift @@ -0,0 +1,305 @@ +import Foundation +import Logging + +#if canImport(System) + import System +#else + @preconcurrency import SystemPackage +#endif + +#if canImport(Darwin) + import Darwin.POSIX +#elseif canImport(Glibc) + import Glibc +#elseif canImport(Musl) + import Musl +#endif + +#if canImport(Darwin) || canImport(Glibc) || canImport(Musl) + /// Unix domain socket transport for MCP clients. + /// + /// Connects to an existing Unix domain socket and provides + /// communication for local connections. + /// + /// The transport uses newline-delimited messages and supports reconnection cycles. + /// + /// ## Usage + /// + /// ```swift + /// let transport = UnixSocketClientTransport( + /// path: "/tmp/mcp.sock" + /// ) + /// + /// // Start MCP server + /// try await server.start(transport: transport) + /// + /// ``` + /// + /// ## When to Use + /// + /// Use this transport when local only commincation is prefered. + /// + public actor UnixSocketClientTransport: Transport { + /// Maximum socket path length in bytes + /// + /// - SeeAlso: https://github.com/torvalds/linux/blob/master/include/uapi/linux/un.h#L7 + /// - SeeAlso: https://github.com/apple-oss-distributions/xnu/blob/main/bsd/sys/un.h#L79 + /// - SeeAlso: https://github.com/kraj/musl/blob/kraj/master/include/sys/un.h#L19 + /// + public static let socketPathMax: Int = + MemoryLayout.size(ofValue: sockaddr_un().sun_path) - 1 + + public nonisolated let logger: Logger + + // MARK: - State + private var terminated = false + private var started = false + + /// MARK: - Socket + private var socketDescriptor: FileDescriptor? + private let socketPath: String + + // MARK: - ASync + private var isConnected = false + private var messageStream: AsyncThrowingStream? + private var messageContinuation: + AsyncThrowingStream.Continuation? + + private var readLoopTask: Task? + + /// Creates a new Unix socket client transport + /// + /// - Parameters: + /// - path: File system path for the Unix socket + /// - logger: Optional logger instance + public init(path: String, logger: Logger? = nil) { + self.socketPath = path + self.logger = + logger + ?? Logger( + label: "mcp.transport.unix-socket.client", + factory: { _ in SwiftLogNoOpLogHandler() }) + + // TODO: verify closure + var continuation: AsyncThrowingStream.Continuation! + + self.messageStream = AsyncThrowingStream { continuation = $0 } + self.messageContinuation = continuation + + } + + // MARK: `Transport` comformance + + /// Connects to the Unix socket + /// + /// This method can be called multiple times to support reconnection cycles. + /// Each call recreates the message stream. + /// + /// - Throws: `MCPError.transportError` if connection fails + public func connect() async throws { + guard !isConnected else { return } + isConnected = true + + try validateSocketPath() + + let sockfd = socket(AF_UNIX, SOCK_STREAM, 0) + guard sockfd >= 0 else { + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + + let fd = FileDescriptor(rawValue: sockfd) + + var addr = sockaddr_un() + addr.sun_family = sa_family_t(AF_UNIX) + let pathBytes = socketPath.utf8CString + _ = withUnsafeMutablePointer(to: &addr.sun_path) { ptr in + pathBytes.withUnsafeBufferPointer { buffer in + memcpy( + ptr, buffer.baseAddress, + min( + buffer.count, + UnixSocketClientTransport.socketPathMax)) + } + } + + let connectResult = withUnsafePointer(to: &addr) { addrPtr in + addrPtr.withMemoryRebound(to: sockaddr.self, capacity: 1) { + sockaddrPtr in + #if canImport(Darwin) + Darwin.connect( + sockfd, sockaddrPtr, + socklen_t(MemoryLayout.size)) + #elseif canImport(Glibc) + Glibc.connect( + sockfd, sockaddrPtr, + socklen_t(MemoryLayout.size)) + #else + Musl.connect( + sockfd, sockaddrPtr, + socklen_t(MemoryLayout.size)) + #endif + } + } + + guard connectResult >= 0 else { + try fd.close() + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + + try setNonBlocking(fd) + self.socketDescriptor = fd + + // Create new stream for this connection (supports reconnection) + var continuation: AsyncThrowingStream.Continuation! + self.messageStream = AsyncThrowingStream { continuation = $0 } + self.messageContinuation = continuation + + // isConnected = true + readLoopTask = Task { await readLoop() } + logger.debug("Connected", metadata: ["path": "\(socketPath)"]) + } + + /// Disconnects from the Unix socket + public func disconnect() async { + guard isConnected else { return } + isConnected = false + + readLoopTask?.cancel() + await readLoopTask?.value + readLoopTask = nil + + if let socket = socketDescriptor { + try? socket.close() + socketDescriptor = nil + } + + messageContinuation?.finish() + messageContinuation = nil + messageStream = nil + + logger.debug("Disconnected") + } + + /// Sends data to the server + /// + /// - Parameter data: Data to send (newline will be appended automatically) + /// - Throws: `MCPError.transportError` if not connected or write fails + public func send(_ data: Data) async throws { + guard isConnected, let socket = socketDescriptor else { + throw MCPError.transportError(Errno(rawValue: ENOTCONN)) + } + + var messageWithNewline = data + messageWithNewline.append(UInt8(ascii: "\n")) + + var remaining = messageWithNewline + while !remaining.isEmpty { + do { + let written = try remaining.withUnsafeBytes { buffer in + try socket.write(UnsafeRawBufferPointer(buffer)) + } + if written > 0 { + remaining = remaining.dropFirst(written) + } + } catch let error + where MCPError.isResourceTemporarilyUnavailable(error) + { + try await Task.sleep(for: .milliseconds(10)) + continue + } catch { + throw MCPError.transportError(error) + } + } + } + + /// Receives data from the server + /// + /// Returns a stream of newline-delimited messages. + /// + /// - Returns: Async stream of received data + public func receive() -> AsyncThrowingStream { + guard let stream = messageStream else { + return AsyncThrowingStream { $0.finish() } + } + return stream + } + + // MARK: - Private Implementation + + private func readLoop() async { + let bufferSize = 4096 + var buffer = [UInt8](repeating: 0, count: bufferSize) + var pendingData = Data() + + guard let descriptor = socketDescriptor, + let continuation = messageContinuation + else { + return + } + + while isConnected && !Task.isCancelled { + do { + let bytesRead = try buffer.withUnsafeMutableBufferPointer { + pointer in + try descriptor.read( + into: UnsafeMutableRawBufferPointer(pointer) + ) + } + + if bytesRead == 0 { + logger.notice("Server closed connection") + break + } + + pendingData.append(Data(buffer[..= 0 else { + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + + let result = fcntl(fd.rawValue, F_SETFL, flags | O_NONBLOCK) + guard result >= 0 else { + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + } + } +#endif diff --git a/Sources/MCP/Base/Transports/UnixSocketServerTransport.swift b/Sources/MCP/Base/Transports/UnixSocketServerTransport.swift new file mode 100644 index 00000000..2f5c61ea --- /dev/null +++ b/Sources/MCP/Base/Transports/UnixSocketServerTransport.swift @@ -0,0 +1,485 @@ +import Foundation +import Logging +import NIOCore +import NIOPosix + +#if canImport(Darwin) + import Darwin.POSIX +#elseif canImport(Glibc) + import Glibc +#elseif canImport(Musl) + import Musl +#endif + +#if canImport(Darwin) || canImport(Glibc) || canImport(Musl) + /// Unix domain socket transport for MCP servers using SwiftNIO. + /// + /// Creates a Unix domain socket, binds to it, and accepts multiple client connections sequentially. + /// The transport uses newline-delimited messages and handles reconnections automatically. + /// + /// ## Usage + /// + /// ```swift + /// let transport = UnixSocketServerTransport( + /// path: "/tmp/mcp.sock", + /// cleanup: .removeExisting + /// ) + /// try await transport.connect() // Starts accepting clients + /// + /// // Use with MCP server + /// try await server.start(transport: transport) + /// ``` + /// + /// ## When to Use + /// + /// Use this transport when you need: + /// - Local-only communication (same machine) + /// - High-performance IPC + /// - File system permission-based security + /// - Multiple sequential client connections + /// + public actor UnixSocketServerTransport: Transport { + /// Maximum socket path length in bytes + /// + /// - SeeAlso: https://github.com/torvalds/linux/blob/master/include/uapi/linux/un.h#L7 + /// - SeeAlso: https://github.com/apple-oss-distributions/xnu/blob/main/bsd/sys/un.h#L79 + /// - SeeAlso: https://github.com/kraj/musl/blob/kraj/master/include/sys/un.h#L19 + /// + public static let socketPathMax: Int = + MemoryLayout.size(ofValue: sockaddr_un().sun_path) - 1 + + /// Strategy for handling existing socket files + public enum SocketCleanup: Sendable { + /// Fail if socket file exists + case failIfExists + /// Remove existing socket file before binding + case removeExisting + /// Try to reuse if socket is stale, otherwise fail + case reuseIfPossible + } + + public nonisolated let logger: Logger + + // MARK: - Configuration + private let socketPath: String + private let cleanup: SocketCleanup + + // MARK: - State + private var terminated = false + private var started = false + + // MARK: - NIO Components + private var eventLoopGroup: MultiThreadedEventLoopGroup? + private var serverChannel: Channel? + + // MARK: - Async Streams + private var messageStream: AsyncThrowingStream? + private var messageContinuation: + AsyncThrowingStream.Continuation? + + // MARK: - Current Client + private var currentChannel: Channel? + + // MARK: - Response Waiters + /// Maps request ID → continuation waiting for a response. + /// When the server calls `send()` with a response, the matching continuation is resumed. + private var responseWaiters: [String: CheckedContinuation] = [:] + + // MARK: - Init + + /// Creates a new Unix socket server transport + /// + /// - Parameters: + /// - path: File system path for the Unix socket + /// - cleanup: Strategy for handling existing socket files + /// - logger: Optional logger instance + public init(path: String, cleanup: SocketCleanup, logger: Logger? = nil) { + self.socketPath = path + self.cleanup = cleanup + self.logger = + logger + ?? Logger( + label: "mcp.transport.unix-socket.server", + factory: { _ in SwiftLogNoOpLogHandler() }) + } + + // MARK: - Transport Conformance + + /// Starts the server (creates socket, binds, listens, accepts clients) + /// + /// This method starts accepting clients continuously until disconnect() is called. + /// + /// - Throws: `MCPError.transportError` if socket creation fails + public func connect() async throws { + guard !started else { + // Idempotent: already started + return + } + guard !terminated else { + throw MCPError.connectionClosed + } + + try validateSocketPath() + try handleCleanup() + + // Create event loop group + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + self.eventLoopGroup = group + + // Create message stream + var continuation: AsyncThrowingStream.Continuation! + self.messageStream = AsyncThrowingStream { continuation = $0 } + self.messageContinuation = continuation + + // Create server bootstrap + let bootstrap = ServerBootstrap(group: group) + .serverChannelOption( + ChannelOptions.socketOption(.so_reuseaddr), value: 1 + ) + .childChannelInitializer { [weak self] channel in + guard let self = self else { + return channel.eventLoop.makeSucceededVoidFuture() + } + return channel.pipeline.addHandlers([ + ByteToMessageHandler(NewlineFrameDecoder()), + MessageToByteHandler(NewlineFrameEncoder()), + UnixSocketServerHandler(transport: self), + ]) + } + + do { + let channel = try await bootstrap.bind( + unixDomainSocketPath: socketPath + ).get() + self.serverChannel = channel + started = true + + logger.info("Server listening", metadata: ["path": "\(socketPath)"]) + } catch { + try? await group.shutdownGracefully() + self.eventLoopGroup = nil + throw MCPError.transportError(error) + } + } + + /// Stops the server and cleans up the socket file + public func disconnect() async { + await terminate() + } + + /// Sends data to the current client or routes to a waiting continuation. + /// + /// - Responses are matched by JSON-RPC ID to waiting continuations. + /// - If no waiter exists, the response is sent directly to the client. + /// - Notifications and requests are always sent directly to the client. + /// + /// - Parameter data: Data to send + /// - Throws: `MCPError.transportError` if not connected or write fails + public func send(_ data: Data) async throws { + guard !terminated else { + throw MCPError.connectionClosed + } + + // Classify the message for routing + if let kind = JSONRPCMessageKind(data: data) { + switch kind { + case .response(let id): + // Check if there's a waiter for this response + if let continuation = responseWaiters.removeValue( + forKey: id) + { + continuation.resume(returning: data) + return + } + // No waiter, fall through to send to socket + + case .notification, .request: + // Always send to socket + break + } + } + + // Send to client via socket + guard let channel = currentChannel else { + throw MCPError.transportError( + NSError( + domain: "mcp.unix-socket", code: Int(ENOTCONN), + userInfo: [ + NSLocalizedDescriptionKey: + "No client connected" + ])) + } + + var buffer = channel.allocator.buffer(capacity: data.count) + buffer.writeBytes(data) + + try await channel.writeAndFlush(buffer) + } + + /// Receives data from clients + /// + /// Returns a stream of newline-delimited messages. + /// + /// - Returns: Async stream of received data + public func receive() -> AsyncThrowingStream { + guard let stream = messageStream else { + return AsyncThrowingStream { $0.finish() } + } + return stream + } + + // MARK: - Internal Methods (called from handler) + + /// JSON-RPC notification sent when a new client connects. + /// MCP Server can handle this to reset initialization state. + public static let newConnectionNotification = Data( + #"{"jsonrpc":"2.0","method":"$/connection/didOpen"}"#.utf8) + + func handleNewClient(_ channel: Channel) { + self.currentChannel = channel + // Signal new connection to the MCP Server so it can reset state + messageContinuation?.yield(Self.newConnectionNotification) + logger.info("Client connected") + } + + func handleClientDisconnected() { + self.currentChannel = nil + // Note: We do NOT finish the message stream here. + // The stream lives for the lifetime of the server, not per client. + // This allows the MCP server layer to keep receiving from the same stream + // as clients connect and disconnect. + logger.info("Client disconnected") + } + + func handleIncomingData(_ data: Data) { + // Filter out empty messages + guard !data.isEmpty else { return } + messageContinuation?.yield(data) + } + + // MARK: - Private Implementation + + private func validateSocketPath() throws { + guard socketPath.utf8.count < UnixSocketServerTransport.socketPathMax else { + throw MCPError.internalError( + "Socket path too long: \(socketPath.utf8.count) bytes (max: \(UnixSocketServerTransport.socketPathMax))" + ) + } + } + + private func handleCleanup() throws { + switch cleanup { + case .failIfExists: + if access(socketPath, F_OK) == 0 { + throw MCPError.transportError( + NSError( + domain: "mcp.unix-socket", + code: Int(EADDRINUSE), + userInfo: [ + NSLocalizedDescriptionKey: + "Socket already exists: \(socketPath)" + ])) + } + case .removeExisting: + if access(socketPath, F_OK) == 0 { + unlink(socketPath) + } + case .reuseIfPossible: + if access(socketPath, F_OK) == 0 { + // Try to connect to see if socket is alive + let testResult = testSocketConnection() + if testResult { + throw MCPError.transportError( + NSError( + domain: "mcp.unix-socket", + code: Int(EADDRINUSE), + userInfo: [ + NSLocalizedDescriptionKey: + "Socket is in use: \(socketPath)" + ])) + } else { + // Stale socket, remove it + unlink(socketPath) + } + } + } + } + + private func testSocketConnection() -> Bool { + let testSock = socket(AF_UNIX, SOCK_STREAM, 0) + guard testSock >= 0 else { return false } + + defer { + #if canImport(Darwin) + Darwin.close(testSock) + #elseif canImport(Glibc) + Glibc.close(testSock) + #else + Musl.close(testSock) + #endif + } + + var addr = sockaddr_un() + addr.sun_family = sa_family_t(AF_UNIX) + let pathBytes = socketPath.utf8CString + _ = withUnsafeMutablePointer(to: &addr.sun_path) { ptr in + pathBytes.withUnsafeBufferPointer { buffer in + memcpy( + ptr, buffer.baseAddress, + min( + buffer.count, + UnixSocketServerTransport.socketPathMax)) + } + } + + let result = withUnsafePointer(to: &addr) { addrPtr in + addrPtr.withMemoryRebound(to: sockaddr.self, capacity: 1) { + sockaddrPtr in + #if canImport(Darwin) + Darwin.connect( + testSock, sockaddrPtr, + socklen_t(MemoryLayout.size)) + #elseif canImport(Glibc) + Glibc.connect( + testSock, sockaddrPtr, + socklen_t(MemoryLayout.size)) + #else + Musl.connect( + testSock, sockaddrPtr, + socklen_t(MemoryLayout.size)) + #endif + } + } + + return result >= 0 + } + + private func terminate() async { + guard !terminated else { return } + terminated = true + started = false + + // Cancel all waiting continuations + for (id, continuation) in responseWaiters { + continuation.resume(throwing: MCPError.connectionClosed) + logger.debug( + "Cancelled waiter for request", + metadata: ["requestID": "\(id)"]) + } + responseWaiters.removeAll() + + // Close server channel + if let channel = serverChannel { + try? await channel.close() + self.serverChannel = nil + } + + // Shutdown event loop group + if let group = eventLoopGroup { + try? await group.shutdownGracefully() + self.eventLoopGroup = nil + } + + // Clean up socket file + unlink(socketPath) + + messageContinuation?.finish() + messageContinuation = nil + messageStream = nil + currentChannel = nil + + logger.info("Server stopped", metadata: ["path": "\(socketPath)"]) + } + } + + // MARK: - NIO Channel Handlers + + /// Decodes newline-delimited frames + private final class NewlineFrameDecoder: ByteToMessageDecoder, @unchecked Sendable { + typealias InboundOut = ByteBuffer + + func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws + -> DecodingState + { + guard + let newlineIndex = buffer.readableBytesView.firstIndex( + of: UInt8(ascii: "\n")) + else { + return .needMoreData + } + + let length = newlineIndex - buffer.readerIndex + guard let frame = buffer.readSlice(length: length) else { + return .needMoreData + } + + // Skip the newline + buffer.moveReaderIndex(forwardBy: 1) + + context.fireChannelRead(self.wrapInboundOut(frame)) + return .continue + } + + func decodeLast( + context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool + ) throws -> DecodingState { + // Process any remaining data + if buffer.readableBytes > 0 { + let frame = buffer.readSlice(length: buffer.readableBytes)! + context.fireChannelRead(self.wrapInboundOut(frame)) + } + return .needMoreData + } + } + + /// Encodes frames with newline delimiter + private final class NewlineFrameEncoder: MessageToByteEncoder, @unchecked Sendable { + typealias OutboundIn = ByteBuffer + + func encode(data: ByteBuffer, out: inout ByteBuffer) throws { + out.writeImmutableBuffer(data) + out.writeInteger(UInt8(ascii: "\n")) + } + } + + /// Handles client connections and data + private final class UnixSocketServerHandler: ChannelInboundHandler, @unchecked Sendable { + typealias InboundIn = ByteBuffer + + private let transport: UnixSocketServerTransport + + init(transport: UnixSocketServerTransport) { + self.transport = transport + } + + func channelActive(context: ChannelHandlerContext) { + let channel = context.channel + Task.detached { + await self.transport.handleNewClient(channel) + } + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var buffer = self.unwrapInboundIn(data) + guard let bytes = buffer.readBytes(length: buffer.readableBytes) else { + return + } + + let data = Data(bytes) + Task.detached { + await self.transport.handleIncomingData(data) + } + } + + func channelInactive(context: ChannelHandlerContext) { + Task.detached { + await self.transport.handleClientDisconnected() + } + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + // Log error and close channel + context.close(promise: nil) + } + } +#endif diff --git a/Sources/MCP/Base/Transports/UnixSocketTransport.swift b/Sources/MCP/Base/Transports/UnixSocketTransport.swift new file mode 100644 index 00000000..940e3d9f --- /dev/null +++ b/Sources/MCP/Base/Transports/UnixSocketTransport.swift @@ -0,0 +1,554 @@ +import Foundation +import Logging + +#if canImport(System) + import System +#else + @preconcurrency import SystemPackage +#endif + +// Import for specific low-level operations not yet in Swift System +#if canImport(Darwin) + import Darwin.POSIX +#elseif canImport(Glibc) + import Glibc +#elseif canImport(Musl) + import Musl +#endif + +#if canImport(Darwin) || canImport(Glibc) || canImport(Musl) + /// An implementation of Unix domain socket transport for MCP. + /// + /// Unix domain sockets provide high-performance inter-process communication (IPC) + /// for processes on the same machine. They offer better performance than TCP/IP + /// for local communication and use file system permissions for security. + /// + /// This transport supports both client and server modes: + /// - **Client mode**: Connects to an existing Unix socket + /// - **Server mode**: Creates a Unix socket, binds to it, and accepts a single client connection + /// + /// The transport uses newline-delimited messages, matching the stdio transport protocol. + /// + /// - Important: This transport is available on Apple platforms and Linux distributions with glibc + /// (Ubuntu, Debian, Fedora, CentOS, RHEL). + /// + /// ## Example Usage (Client) + /// + /// ```swift + /// import MCP + /// + /// // Connect to a server socket + /// let transport = UnixSocketTransport( + /// path: "/tmp/mcp.sock", + /// mode: .client + /// ) + /// try await transport.connect() + /// ``` + /// + /// ## Example Usage (Server) + /// + /// ```swift + /// import MCP + /// + /// // Create a server socket + /// let transport = UnixSocketTransport( + /// path: "/tmp/mcp.sock", + /// mode: .server(cleanup: .removeExisting) + /// ) + /// try await transport.connect() + /// ``` + public actor UnixSocketTransport: Transport { + + #if canImport(Darwin) + /// Ref: https://github.com/apple-oss-distributions/xnu/blob/main/bsd/sys/un.h#L79 + public static let socketPathMax: Int = MemoryLayout.size(ofValue: sockaddr_un().sun_path) - 1 + #elseif canImport(Glibc) + /// Ref: https://github.com/torvalds/linux/blob/master/include/uapi/linux/un.h#L7 + public static let socketPathMax: Int = MemoryLayout.size(ofValue: sockaddr_un().sun_path) + #elseif canImport(Musl) + /// Ref: https://github.com/torvalds/linux/blob/master/include/uapi/linux/un.h#L7 + public static let socketPathMax: Int = MemoryLayout.size(ofValue: sockaddr_un().sun_path) + #endif + + /// Mode of operation (client or server) + public enum Mode: Sendable { + /// Client mode: connects to existing Unix socket + case client + + /// Server mode: creates and binds to Unix socket + /// - Parameter cleanup: Strategy for handling existing socket files + case server(cleanup: SocketCleanup) + } + + /// Strategy for handling existing socket files in server mode + public enum SocketCleanup: Sendable { + /// Fail if socket file exists + case failIfExists + + /// Remove existing socket file before binding + case removeExisting + + /// Try to reuse if socket is still alive, otherwise remove + case reuseIfPossible + } + + private let socketPath: String + private let mode: Mode + /// Socket descriptor for listening (server) or connection (client) + private var socketDescriptor: FileDescriptor? + /// Client connection descriptor (server mode only) + private var clientDescriptor: FileDescriptor? + + /// Logger instance for transport-related events + public nonisolated let logger: Logger + + private var isConnected = false + private let messageStream: AsyncThrowingStream + private let messageContinuation: AsyncThrowingStream.Continuation + + /// Creates a new Unix socket transport + /// + /// - Parameters: + /// - path: File system path for the Unix socket + /// - mode: Operation mode (client or server) + /// - logger: Optional logger instance for transport events + public init(path: String, mode: Mode, logger: Logger? = nil) { + self.socketPath = path + self.mode = mode + self.logger = + logger + ?? Logger( + label: "mcp.transport.unix-socket", + factory: { _ in SwiftLogNoOpLogHandler() }) + + // Create message stream + var continuation: AsyncThrowingStream.Continuation! + self.messageStream = AsyncThrowingStream { continuation = $0 } + self.messageContinuation = continuation + } + + /// Establishes connection with the transport + /// + /// For client mode, this connects to the existing Unix socket. + /// For server mode, this creates the socket, binds to it, and waits for a client connection. + /// + /// - Throws: Error if the connection cannot be established + public func connect() async throws { + guard !isConnected else { return } + + switch mode { + case .client: + try await connectClient() + case .server(let cleanup): + try await connectServer(cleanup: cleanup) + } + + isConnected = true + logger.debug("Transport connected successfully", metadata: ["path": "\(socketPath)"]) + + // Start reading loop in background + Task { + await readLoop() + } + } + + /// Validates that the socket path length is within the sockaddr_un limit + /// + /// - Throws: `MCPError.internalError` if the path exceeds platform specific `socketPathMax` bytes + private func validateSocketPath() throws { + guard socketPath.utf8.count < UnixSocketTransport.socketPathMax else { + throw MCPError.internalError( + "Socket path too long: \(socketPath.utf8.count) bytes") + } + } + + /// Connects to an existing Unix socket (client mode) + private func connectClient() async throws { + try validateSocketPath() + + // Create socket + #if canImport(Darwin) || canImport(Glibc) || canImport(Musl) + let sockfd = socket(AF_UNIX, SOCK_STREAM, 0) + guard sockfd >= 0 else { + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + + let fd = FileDescriptor(rawValue: sockfd) + + // Build socket address + var addr = sockaddr_un() + addr.sun_family = sa_family_t(AF_UNIX) + + let pathBytes = socketPath.utf8CString + _ = withUnsafeMutablePointer(to: &addr.sun_path) { ptr in + pathBytes.withUnsafeBufferPointer { buffer in + memcpy(ptr, buffer.baseAddress, min(buffer.count, UnixSocketTransport.socketPathMax)) + } + } + + // Connect to socket + let connectResult = withUnsafePointer(to: &addr) { addrPtr in + addrPtr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in + #if canImport(Darwin) + Darwin.connect(sockfd, sockaddrPtr, socklen_t(MemoryLayout.size)) + #elseif canImport(Glibc) + Glibc.connect(sockfd, sockaddrPtr, socklen_t(MemoryLayout.size)) + #else + Musl.connect(sockfd, sockaddrPtr, socklen_t(MemoryLayout.size)) + #endif + } + } + + guard connectResult >= 0 else { + try fd.close() + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + + // Set non-blocking mode + try setNonBlocking(fileDescriptor: fd) + + self.socketDescriptor = fd + logger.debug("Client connected to Unix socket", metadata: ["path": "\(socketPath)"]) + #else + throw MCPError.internalError("Unix sockets not supported on this platform") + #endif + } + + /// Creates a Unix socket and accepts a client connection (server mode) + private func connectServer(cleanup: SocketCleanup) async throws { + try validateSocketPath() + + #if canImport(Darwin) || canImport(Glibc) || canImport(Musl) + // Handle cleanup strategy + switch cleanup { + case .failIfExists: + // Check if file exists + if access(socketPath, F_OK) == 0 { + throw MCPError.transportError( + NSError( + domain: "mcp.transport.unix_socket", + code: Int(EADDRINUSE), + userInfo: [ + NSLocalizedDescriptionKey: + "Socket file already exists: \(socketPath)" + ] + )) + } + case .removeExisting: + // Always remove if exists + if access(socketPath, F_OK) == 0 { + unlink(socketPath) + logger.debug( + "Removed existing socket file", metadata: ["path": "\(socketPath)"]) + } + case .reuseIfPossible: + // Try to connect - if it fails, remove the stale socket + if access(socketPath, F_OK) == 0 { + let testSock = socket(AF_UNIX, SOCK_STREAM, 0) + if testSock >= 0 { + var testAddr = sockaddr_un() + testAddr.sun_family = sa_family_t(AF_UNIX) + let pathBytes = socketPath.utf8CString + _ = withUnsafeMutablePointer(to: &testAddr.sun_path) { ptr in + pathBytes.withUnsafeBufferPointer { buffer in + memcpy(ptr, buffer.baseAddress, min(buffer.count, UnixSocketTransport.socketPathMax)) + } + } + + let testResult = withUnsafePointer(to: &testAddr) { addrPtr in + addrPtr.withMemoryRebound(to: sockaddr.self, capacity: 1) { + sockaddrPtr in + #if canImport(Darwin) + Darwin.connect( + testSock, sockaddrPtr, + socklen_t(MemoryLayout.size)) + #elseif canImport(Glibc) + Glibc.connect( + testSock, sockaddrPtr, + socklen_t(MemoryLayout.size)) + #else + Musl.connect( + testSock, sockaddrPtr, + socklen_t(MemoryLayout.size)) + #endif + } + } + + #if canImport(Darwin) + Darwin.close(testSock) + #elseif canImport(Glibc) + Glibc.close(testSock) + #else + Musl.close(testSock) + #endif + + if testResult < 0 { + // Socket is stale, remove it + unlink(socketPath) + logger.debug( + "Removed stale socket file", metadata: ["path": "\(socketPath)"] + ) + } else { + // Socket is alive, fail + throw MCPError.transportError( + NSError( + domain: "mcp.transport.unix_socket", + code: Int(EADDRINUSE), + userInfo: [ + NSLocalizedDescriptionKey: + "Socket is already in use: \(socketPath)" + ] + )) + } + } + } + } + + // Create socket + let sockfd = socket(AF_UNIX, SOCK_STREAM, 0) + guard sockfd >= 0 else { + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + + let fd = FileDescriptor(rawValue: sockfd) + + // Build socket address + var addr = sockaddr_un() + addr.sun_family = sa_family_t(AF_UNIX) + + let pathBytes = socketPath.utf8CString + _ = withUnsafeMutablePointer(to: &addr.sun_path) { ptr in + pathBytes.withUnsafeBufferPointer { buffer in + memcpy(ptr, buffer.baseAddress, min(buffer.count, UnixSocketTransport.socketPathMax)) + } + } + + // Bind socket + let bindResult = withUnsafePointer(to: &addr) { addrPtr in + addrPtr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in + #if canImport(Darwin) + Darwin.bind(sockfd, sockaddrPtr, socklen_t(MemoryLayout.size)) + #elseif canImport(Glibc) + Glibc.bind(sockfd, sockaddrPtr, socklen_t(MemoryLayout.size)) + #else + Musl.bind(sockfd, sockaddrPtr, socklen_t(MemoryLayout.size)) + #endif + } + } + + guard bindResult >= 0 else { + try fd.close() + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + + // Listen for connections + let listenResult = listen(sockfd, 1) + guard listenResult >= 0 else { + try fd.close() + unlink(socketPath) + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + + // Set listening socket to non-blocking so accept() doesn't block + try setNonBlocking(fileDescriptor: fd) + + logger.debug("Server listening on Unix socket", metadata: ["path": "\(socketPath)"]) + + // Accept client connection (with retry loop for non-blocking) + var clientfd: Int32 = -1 + while clientfd < 0 { + clientfd = accept(sockfd, nil, nil) + if clientfd < 0 { + let error = Errno(rawValue: CInt(errno)) + if error == .resourceTemporarilyUnavailable { + // No client yet, sleep and retry + try? await Task.sleep(for: .milliseconds(10)) + continue + } else { + // Real error + try fd.close() + unlink(socketPath) + throw MCPError.transportError(error) + } + } + } + + let clientFd = FileDescriptor(rawValue: clientfd) + + // Set non-blocking mode on client descriptor + try setNonBlocking(fileDescriptor: clientFd) + + self.socketDescriptor = fd + self.clientDescriptor = clientFd + logger.debug("Server accepted client connection") + #else + throw MCPError.internalError("Unix sockets not supported on this platform") + #endif + } + + /// Configures a file descriptor for non-blocking I/O + /// + /// - Parameter fileDescriptor: The file descriptor to configure + /// - Throws: Error if the operation fails + private func setNonBlocking(fileDescriptor: FileDescriptor) throws { + #if canImport(Darwin) || canImport(Glibc) || canImport(Musl) + // Get current flags + let flags = fcntl(fileDescriptor.rawValue, F_GETFL) + guard flags >= 0 else { + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + + // Set non-blocking flag + let result = fcntl(fileDescriptor.rawValue, F_SETFL, flags | O_NONBLOCK) + guard result >= 0 else { + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + #else + // For platforms where non-blocking operations aren't supported + throw MCPError.internalError( + "Setting non-blocking mode not supported on this platform") + #endif + } + + /// Continuous loop that reads and processes incoming messages + /// + /// This method runs in the background while the transport is connected, + /// parsing complete messages delimited by newlines and yielding them + /// to the message stream. + private func readLoop() async { + let bufferSize = 4096 + var buffer = [UInt8](repeating: 0, count: bufferSize) + var pendingData = Data() + + // Read from client descriptor (server) or socket descriptor (client) + let readDescriptor: FileDescriptor? = + clientDescriptor != nil ? clientDescriptor : socketDescriptor + + guard let descriptor = readDescriptor else { + messageContinuation.finish() + return + } + + while isConnected && !Task.isCancelled { + do { + let bytesRead = try buffer.withUnsafeMutableBufferPointer { pointer in + try descriptor.read(into: UnsafeMutableRawBufferPointer(pointer)) + } + + if bytesRead == 0 { + logger.notice("EOF received") + break + } + + pendingData.append(Data(buffer[.. 0 { + remaining = remaining.dropFirst(written) + } + } catch let error where MCPError.isResourceTemporarilyUnavailable(error) { + try await Task.sleep(for: .milliseconds(10)) + continue + } catch { + throw MCPError.transportError(error) + } + } + } + + /// Receives messages from the transport. + /// + /// Messages may be individual JSON-RPC requests, notifications, responses, + /// or batches containing multiple requests/notifications encoded as JSON arrays. + /// Each message is guaranteed to be a complete JSON object or array. + /// + /// - Returns: An AsyncThrowingStream of Data objects representing JSON-RPC messages + public func receive() -> AsyncThrowingStream { + return messageStream + } + } +#endif diff --git a/Sources/MCP/Server/Server.swift b/Sources/MCP/Server/Server.swift index 9ad2cc31..257d35d2 100644 --- a/Sources/MCP/Server/Server.swift +++ b/Sources/MCP/Server/Server.swift @@ -836,6 +836,13 @@ public actor Server { "Processing notification", metadata: ["method": "\(message.method)"]) + // Handle transport-level connection notification (e.g., from Unix socket server) + // This resets session state to allow re-initialization for a new client + if message.method == "$/connection/didOpen" { + await resetSessionState() + return + } + if configuration.strict { // Check initialization state unless this is an initialized notification if message.method != InitializedNotification.name { @@ -954,6 +961,16 @@ public actor Server { self.isInitialized = true } + /// Resets session state to allow re-initialization. + /// Called when a new client connects to a multi-client transport (e.g., Unix socket server). + private func resetSessionState() async { + self.isInitialized = false + self.clientInfo = nil + self.clientCapabilities = nil + self.protocolVersion = nil + await self.logger?.debug("Session state reset for new client connection") + } + /// Cancel and remove a pending request task private func removePendingRequest(id: ID) -> Task, Error>? { pendingRequestTasks.removeValue(forKey: id) diff --git a/Tests/MCPTests/UnixSocketClientTransportTests.swift b/Tests/MCPTests/UnixSocketClientTransportTests.swift new file mode 100644 index 00000000..7f37f2e7 --- /dev/null +++ b/Tests/MCPTests/UnixSocketClientTransportTests.swift @@ -0,0 +1,904 @@ +import Foundation +import Testing + +@testable import MCP + +#if canImport(System) + import System +#else + @preconcurrency import SystemPackage +#endif + +#if canImport(Darwin) + import Darwin.POSIX +#elseif canImport(Glibc) + import Glibc +#elseif canImport(Musl) + import Musl +#endif + +#if canImport(Darwin) || canImport(Glibc) || canImport(Musl) + /// Simple test server for testing UnixSocketClientTransport + actor TestUnixSocketServer { + private var socketDescriptor: FileDescriptor? + private var clientDescriptor: FileDescriptor? + private let path: String + private var isRunning = false + private var messageStream: AsyncThrowingStream? + private var messageContinuation: AsyncThrowingStream.Continuation? + private var readTask: Task? + + init(path: String) { + self.path = path + } + + func start() async throws { + guard !isRunning else { return } + + // Clean up existing socket + unlink(path) + + let sockfd = socket(AF_UNIX, SOCK_STREAM, 0) + guard sockfd >= 0 else { + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + + let fd = FileDescriptor(rawValue: sockfd) + + // Set SO_REUSEADDR + var reuseAddr: Int32 = 1 + _ = setsockopt( + sockfd, SOL_SOCKET, SO_REUSEADDR, + &reuseAddr, socklen_t(MemoryLayout.size)) + + var addr = sockaddr_un() + addr.sun_family = sa_family_t(AF_UNIX) + let pathBytes = path.utf8CString + _ = withUnsafeMutablePointer(to: &addr.sun_path) { ptr in + pathBytes.withUnsafeBufferPointer { buffer in + memcpy(ptr, buffer.baseAddress, min(buffer.count, 103)) + } + } + + let bindResult = withUnsafePointer(to: &addr) { addrPtr in + addrPtr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPtr in + #if canImport(Darwin) + Darwin.bind( + sockfd, sockaddrPtr, + socklen_t(MemoryLayout.size)) + #elseif canImport(Glibc) + Glibc.bind( + sockfd, sockaddrPtr, + socklen_t(MemoryLayout.size)) + #else + Musl.bind( + sockfd, sockaddrPtr, + socklen_t(MemoryLayout.size)) + #endif + } + } + + guard bindResult >= 0 else { + try fd.close() + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + + let listenResult = listen(sockfd, 1) + guard listenResult >= 0 else { + try fd.close() + unlink(path) + throw MCPError.transportError(Errno(rawValue: CInt(errno))) + } + + // Set non-blocking + let flags = fcntl(sockfd, F_GETFL) + _ = fcntl(sockfd, F_SETFL, flags | O_NONBLOCK) + + self.socketDescriptor = fd + + // Accept one client + var clientfd: Int32 = -1 + while clientfd < 0 && !Task.isCancelled { + clientfd = accept(sockfd, nil, nil) + if clientfd >= 0 { + break + } + let error = Errno(rawValue: CInt(errno)) + if error == .resourceTemporarilyUnavailable || error == .wouldBlock { + try await Task.sleep(for: .milliseconds(10)) + continue + } else { + try fd.close() + unlink(path) + throw MCPError.transportError(error) + } + } + + guard clientfd >= 0 else { + try fd.close() + unlink(path) + throw MCPError.internalError("Accept cancelled") + } + + let clientFd = FileDescriptor(rawValue: clientfd) + let clientFlags = fcntl(clientfd, F_GETFL) + _ = fcntl(clientfd, F_SETFL, clientFlags | O_NONBLOCK) + + self.clientDescriptor = clientFd + + // Create stream + var continuation: AsyncThrowingStream.Continuation! + self.messageStream = AsyncThrowingStream { continuation = $0 } + self.messageContinuation = continuation + + isRunning = true + readTask = Task { await self.readLoop() } + } + + func stop() async { + guard isRunning else { return } + isRunning = false + + readTask?.cancel() + await readTask?.value + readTask = nil + + if let client = clientDescriptor { + try? client.close() + clientDescriptor = nil + } + + if let socket = socketDescriptor { + try? socket.close() + socketDescriptor = nil + } + + unlink(path) + messageContinuation?.finish() + messageContinuation = nil + messageStream = nil + } + + func send(_ data: Data) async throws { + guard isRunning, let client = clientDescriptor else { + throw MCPError.transportError(Errno(rawValue: ENOTCONN)) + } + + var messageWithNewline = data + messageWithNewline.append(UInt8(ascii: "\n")) + + var remaining = messageWithNewline + while !remaining.isEmpty { + do { + let written = try remaining.withUnsafeBytes { buffer in + try client.write(UnsafeRawBufferPointer(buffer)) + } + if written > 0 { + remaining = remaining.dropFirst(written) + } + } catch let error where MCPError.isResourceTemporarilyUnavailable(error) { + try await Task.sleep(for: .milliseconds(10)) + continue + } catch { + throw MCPError.transportError(error) + } + } + } + + func receive() -> AsyncThrowingStream { + guard let stream = messageStream else { + return AsyncThrowingStream { $0.finish() } + } + return stream + } + + private func readLoop() async { + let bufferSize = 4096 + var buffer = [UInt8](repeating: 0, count: bufferSize) + var pendingData = Data() + + guard let descriptor = clientDescriptor, let continuation = messageContinuation + else { + return + } + + while isRunning && !Task.isCancelled { + do { + let bytesRead = try buffer.withUnsafeMutableBufferPointer { pointer in + try descriptor.read(into: UnsafeMutableRawBufferPointer(pointer)) + } + + if bytesRead == 0 { + break + } + + pendingData.append(Data(buffer[.. String { + // Use a short path to avoid exceeding socket path limits + let uuid = UUID().uuidString.replacingOccurrences(of: "-", with: "") + let shortID = String(uuid.prefix(8)) + return "/tmp/mcp-\(shortID).sock" + } + + /// Cleanup helper to remove socket file + private func cleanup(_ path: String) { + unlink(path) + } + + @Test("Client Connect to Server") + func testClientConnectToServer() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer(path: path) + let client = UnixSocketClientTransport(path: path) + + // Start server in background + let serverTask = Task { + try await server.start() + } + + // Give server time to start listening + try await Task.sleep(for: .milliseconds(100)) + + // Connect client + try await client.connect() + + // Wait for server to accept + _ = try await serverTask.value + + // Verify connection succeeded + #expect(true) + + await client.disconnect() + await server.stop() + } + + @Test("Client Send Message to Server") + func testClientSendMessageToServer() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Send message from client + let message = #"{"jsonrpc":"2.0","method":"ping","id":1}"# + try await client.send(message.data(using: .utf8)!) + + // Receive on server + let stream = await server.receive() + var iterator = stream.makeAsyncIterator() + + let received = try await iterator.next() + #expect(received == message.data(using: .utf8)!) + + await client.disconnect() + await server.stop() + } + + @Test("Client Receive Message from Server") + func testClientReceiveMessageFromServer() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Send message from server + let message = #"{"jsonrpc":"2.0","result":"pong","id":1}"# + try await server.send(message.data(using: .utf8)!) + + // Receive on client + let stream = await client.receive() + var iterator = stream.makeAsyncIterator() + + let received = try await iterator.next() + #expect(received == message.data(using: .utf8)!) + + await client.disconnect() + await server.stop() + } + + @Test("Client Bidirectional Communication") + func testClientBidirectionalCommunication() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Client sends request + let request = #"{"jsonrpc":"2.0","method":"ping","id":1}"# + try await client.send(request.data(using: .utf8)!) + + // Server receives and sends response + let serverStream = await server.receive() + var serverIterator = serverStream.makeAsyncIterator() + + let receivedRequest = try await serverIterator.next() + #expect(receivedRequest == request.data(using: .utf8)!) + + let response = #"{"jsonrpc":"2.0","result":"pong","id":1}"# + try await server.send(response.data(using: .utf8)!) + + // Client receives response + let clientStream = await client.receive() + var clientIterator = clientStream.makeAsyncIterator() + + let receivedResponse = try await clientIterator.next() + #expect(receivedResponse == response.data(using: .utf8)!) + + await client.disconnect() + await server.stop() + } + + @Test("Client Multiple Messages") + func testClientMultipleMessages() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Send multiple messages + let messages = [ + #"{"jsonrpc":"2.0","method":"test1","id":1}"#, + #"{"jsonrpc":"2.0","method":"test2","id":2}"#, + #"{"jsonrpc":"2.0","method":"test3","id":3}"#, + ] + + for message in messages { + try await client.send(message.data(using: .utf8)!) + } + + // Receive all messages + let stream = await server.receive() + var iterator = stream.makeAsyncIterator() + + for expectedMessage in messages { + let received = try await iterator.next() + #expect(received == expectedMessage.data(using: .utf8)!) + } + + await client.disconnect() + await server.stop() + } + + @Test("Client Large Message") + func testClientLargeMessage() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Create a large message (20KB) + let largeData = String(repeating: "x", count: 20000) + let message = + #"{"jsonrpc":"2.0","method":"test","params":{"data":"\#(largeData)"},"id":1}"# + + try await client.send(message.data(using: .utf8)!) + + // Receive on server + let stream = await server.receive() + var iterator = stream.makeAsyncIterator() + + let received = try await iterator.next() + #expect(received == message.data(using: .utf8)!) + + await client.disconnect() + await server.stop() + } + + @Test("Client Invalid Socket Path (Too Long)") + func testClientInvalidSocketPathTooLong() async throws { + // Create a path that exceeds platform limits + let longPath = "/tmp/" + String(repeating: "x", count: 200) + ".sock" + let client = UnixSocketClientTransport(path: longPath) + + do { + try await client.connect() + #expect(Bool(false), "Expected connect to throw an error") + } catch { + #expect(error is MCPError) + if case .internalError(let msg) = error as? MCPError { + #expect(msg?.contains("Socket path too long") == true) + } + } + + await client.disconnect() + } + + @Test("Client Connection Failure (No Server)") + func testClientConnectionFailureNoServer() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let client = UnixSocketClientTransport(path: path) + + do { + try await client.connect() + #expect(Bool(false), "Expected connect to throw an error") + } catch { + // Expected to fail - no server listening + #expect(error is MCPError) + } + + await client.disconnect() + } + + @Test("Client Send Error (Disconnected)") + func testClientSendErrorDisconnected() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Connect and disconnect + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + await client.disconnect() + + // Try to send after disconnect - should fail + do { + try await client.send("test".data(using: .utf8)!) + #expect(Bool(false), "Expected send to throw an error") + } catch { + #expect(error is MCPError) + } + + await server.stop() + } + + @Test("Client Connection Lifecycle") + func testClientConnectionLifecycle() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Connect + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Send a message + try await client.send(#"{"test":"data"}"#.data(using: .utf8)!) + + // Disconnect + await client.disconnect() + await server.stop() + + // Verify we got here without errors + #expect(true) + } + + @Test("Client Reconnection") + func testClientReconnection() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + // First connection + let server1 = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + let serverTask1 = Task { + try await server1.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + try await serverTask1.value + + // Send a message + let message1 = #"{"test":"first"}"# + try await client.send(message1.data(using: .utf8)!) + + // Disconnect + await client.disconnect() + await server1.stop() + + // Wait a bit + try await Task.sleep(for: .milliseconds(100)) + + // Second connection (reconnect) + let server2 = TestUnixSocketServer( + path: path) + + let serverTask2 = Task { + try await server2.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask2.value + + // Send another message + let message2 = #"{"test":"second"}"# + try await client.send(message2.data(using: .utf8)!) + + // Receive on server + let stream = await server2.receive() + var iterator = stream.makeAsyncIterator() + + let received = try await iterator.next() + #expect(received == message2.data(using: .utf8)!) + + await client.disconnect() + await server2.stop() + } + + @Test("Client Multiple Connect Calls Are Idempotent") + func testClientMultipleConnectCallsIdempotent() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + + // Connect multiple times - should be idempotent + try await client.connect() + try await client.connect() // Should return early + try await client.connect() // Should return early + + _ = try await serverTask.value + + await client.disconnect() + await server.stop() + } + + @Test("Client Disconnect During Receive") + func testClientDisconnectDuringReceive() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Start a task to receive messages + let receiveTask = Task { + var count = 0 + for try await _ in await client.receive() { + count += 1 + if count > 10 { + // Prevent infinite loop in test + break + } + } + } + + // Let the receive loop start + try await Task.sleep(for: .milliseconds(100)) + + // Disconnect while receiving + await client.disconnect() + + // Wait for the receive task to complete + _ = await receiveTask.result + + await server.stop() + } + + @Test("Client Partial Message Reception") + func testClientPartialMessageReception() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Server sends a complete message + let message = #"{"jsonrpc":"2.0","result":"ok","id":1}"# + try await server.send(message.data(using: .utf8)!) + + // Client receives the message (readLoop handles partial reads internally) + let stream = await client.receive() + var iterator = stream.makeAsyncIterator() + + let received = try await iterator.next() + #expect(received == message.data(using: .utf8)!) + + await client.disconnect() + await server.stop() + } + + @Test("Client Receive After Disconnect Returns Empty Stream") + func testClientReceiveAfterDisconnect() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Connect + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Disconnect before receiving + await client.disconnect() + + // Create receive stream after disconnect + let messages = await client.receive() + var messageCount = 0 + + for try await _ in messages { + messageCount += 1 + } + + // Stream should complete immediately + #expect(messageCount == 0) + + await server.stop() + } + + @Test("Client Server Close Detection") + func testClientServerCloseDetection() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Start receiving on client + let receiveTask = Task { + var messageCount = 0 + for try await _ in await client.receive() { + messageCount += 1 + } + return messageCount + } + + // Give time for receive loop to start + try await Task.sleep(for: .milliseconds(50)) + + // Server closes connection + await server.stop() + + // Client should detect the close + let count = try await receiveTask.value + #expect(count == 0) + + await client.disconnect() + } + + @Test("Client Socket Path Max Constant") + func testClientSocketPathMaxConstant() { + // Verify the constant is reasonable + #if canImport(Darwin) + #expect(UnixSocketClientTransport.socketPathMax == 103) + #elseif canImport(Glibc) || canImport(Musl) + #expect(UnixSocketClientTransport.socketPathMax >= 107) + #endif + } + + @Test("Client Non-Blocking Socket") + func testClientNonBlockingSocket() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + + // Client connects with non-blocking socket + try await client.connect() + _ = try await serverTask.value + + // Send multiple messages rapidly + for i in 0..<10 { + try await client.send(#"{"id":\#(i)}"#.data(using: .utf8)!) + } + + // Verify non-blocking operation succeeded + #expect(true) + + await client.disconnect() + await server.stop() + } + + @Test("Client Empty Message Handling") + func testClientEmptyMessageHandling() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Send empty data + try await client.send(Data()) + + // Server should not receive anything (empty line is filtered out) + let stream = await server.receive() + var iterator = stream.makeAsyncIterator() + + // Send a real message to unblock the iterator + try await Task.sleep(for: .milliseconds(50)) + let message = #"{"test":"real"}"# + try await client.send(message.data(using: .utf8)!) + + let received = try await iterator.next() + #expect(received == message.data(using: .utf8)!) + + await client.disconnect() + await server.stop() + } + + @Test("Client Multiple Disconnects Are Safe") + func testClientMultipleDisconnectsSafe() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = TestUnixSocketServer( + path: path) + let client = UnixSocketClientTransport(path: path) + + // Connect + let serverTask = Task { + try await server.start() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Multiple disconnects should be safe + await client.disconnect() + await client.disconnect() // Should be safe + await client.disconnect() // Should be safe + + await server.stop() + } + } +#endif diff --git a/Tests/MCPTests/UnixSocketServerTransportTests.swift b/Tests/MCPTests/UnixSocketServerTransportTests.swift new file mode 100644 index 00000000..780df8e3 --- /dev/null +++ b/Tests/MCPTests/UnixSocketServerTransportTests.swift @@ -0,0 +1,706 @@ +import Foundation +import Testing + +@testable import MCP + +#if canImport(System) + import System +#else + @preconcurrency import SystemPackage +#endif + +#if canImport(Darwin) + import Darwin.POSIX +#elseif canImport(Glibc) + import Glibc +#elseif canImport(Musl) + import Musl +#endif + +#if canImport(Darwin) || canImport(Glibc) || canImport(Musl) + @Suite("Unix Socket Server Transport Tests") + struct UnixSocketServerTransportTests { + /// Generates a unique temporary socket path + private func tempSocketPath() -> String { + // Use a short path to avoid exceeding socket path limits + let uuid = UUID().uuidString.replacingOccurrences(of: "-", with: "") + let shortID = String(uuid.prefix(8)) + return "/tmp/srv-\(shortID).sock" + } + + /// Cleanup helper to remove socket file + private func cleanup(_ path: String) { + unlink(path) + } + + /// Checks if data is a connection notification (should be filtered out in tests) + private func isConnectionNotification(_ data: Data) -> Bool { + data == UnixSocketServerTransport.newConnectionNotification + } + + @Test("Server Accept Client Connection") + func testServerAcceptClientConnection() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client = UnixSocketClientTransport(path: path) + + // Start server in background + let serverTask = Task { + try await server.connect() + } + + // Give server time to start listening + try await Task.sleep(for: .milliseconds(100)) + + // Connect client + try await client.connect() + + // Wait for server to accept + _ = try await serverTask.value + + // Verify connection succeeded + #expect(Bool(true)) + + await client.disconnect() + await server.disconnect() + } + + @Test("Server Receive Message from Client") + func testServerReceiveMessageFromClient() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Send message from client + let message = #"{"jsonrpc":"2.0","method":"ping","id":1}"# + try await client.send(message.data(using: .utf8)!) + + // Receive on server (skip connection notification) + let stream = await server.receive() + var iterator = stream.makeAsyncIterator() + + var received = try await iterator.next() + if let data = received, isConnectionNotification(data) { + received = try await iterator.next() + } + #expect(received == message.data(using: .utf8)!) + + await client.disconnect() + await server.disconnect() + } + + @Test("Server Send Message to Client") + func testServerSendMessageToClient() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Give time for server to process the connection + try await Task.sleep(for: .milliseconds(50)) + + // Send message from server + let message = #"{"jsonrpc":"2.0","result":"pong","id":1}"# + try await server.send(message.data(using: .utf8)!) + + // Receive on client + let stream = await client.receive() + var iterator = stream.makeAsyncIterator() + + let received = try await iterator.next() + #expect(received == message.data(using: .utf8)!) + + await client.disconnect() + await server.disconnect() + } + + @Test("Server Bidirectional Communication") + func testServerBidirectionalCommunication() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Client sends request + let request = #"{"jsonrpc":"2.0","method":"ping","id":1}"# + try await client.send(request.data(using: .utf8)!) + + // Server receives and sends response (skip connection notification) + let serverStream = await server.receive() + var serverIterator = serverStream.makeAsyncIterator() + + var receivedRequest = try await serverIterator.next() + if let data = receivedRequest, isConnectionNotification(data) { + receivedRequest = try await serverIterator.next() + } + #expect(receivedRequest == request.data(using: .utf8)!) + + let response = #"{"jsonrpc":"2.0","result":"pong","id":1}"# + try await server.send(response.data(using: .utf8)!) + + // Client receives response + let clientStream = await client.receive() + var clientIterator = clientStream.makeAsyncIterator() + + let receivedResponse = try await clientIterator.next() + #expect(receivedResponse == response.data(using: .utf8)!) + + await client.disconnect() + await server.disconnect() + } + + @Test("Server Multiple Messages from Client") + func testServerMultipleMessagesFromClient() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Send multiple messages + let messages = [ + #"{"jsonrpc":"2.0","method":"test1","id":1}"#, + #"{"jsonrpc":"2.0","method":"test2","id":2}"#, + #"{"jsonrpc":"2.0","method":"test3","id":3}"#, + ] + + for message in messages { + try await client.send(message.data(using: .utf8)!) + } + + // Receive all messages (skip connection notification) + let stream = await server.receive() + var iterator = stream.makeAsyncIterator() + + // Skip connection notification + var first = try await iterator.next() + if let data = first, isConnectionNotification(data) { + first = try await iterator.next() + } + #expect(first == messages[0].data(using: .utf8)!) + + for expectedMessage in messages.dropFirst() { + let received = try await iterator.next() + #expect(received == expectedMessage.data(using: .utf8)!) + } + + await client.disconnect() + await server.disconnect() + } + + @Test("Server Large Message") + func testServerLargeMessage() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Create a large message (20KB) + let largeData = String(repeating: "x", count: 20000) + let message = + #"{"jsonrpc":"2.0","method":"test","params":{"data":"\#(largeData)"},"id":1}"# + + try await client.send(message.data(using: .utf8)!) + + // Receive on server (skip connection notification) + let stream = await server.receive() + var iterator = stream.makeAsyncIterator() + + var received = try await iterator.next() + if let data = received, isConnectionNotification(data) { + received = try await iterator.next() + } + #expect(received == message.data(using: .utf8)!) + + await client.disconnect() + await server.disconnect() + } + + @Test("Server Invalid Socket Path (Too Long)") + func testServerInvalidSocketPathTooLong() async throws { + // Create a path that exceeds platform limits + let longPath = "/tmp/" + String(repeating: "x", count: 200) + ".sock" + let server = UnixSocketServerTransport(path: longPath, cleanup: .removeExisting) + + do { + try await server.connect() + #expect(Bool(false), "Expected connect to throw an error") + } catch { + #expect(error is MCPError) + if case .internalError(let msg) = error as? MCPError { + #expect(msg?.contains("Socket path too long") == true) + } + } + + await server.disconnect() + } + + @Test("Server Socket Cleanup - Remove Existing") + func testServerSocketCleanupRemoveExisting() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + // Create first server and client + let server1 = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client1 = UnixSocketClientTransport(path: path) + + let serverTask1 = Task { try await server1.connect() } + try await Task.sleep(for: .milliseconds(100)) + try await client1.connect() + _ = try await serverTask1.value + + // Disconnect both + await client1.disconnect() + await server1.disconnect() + + // Wait a bit + try await Task.sleep(for: .milliseconds(50)) + + // Create second server with removeExisting - should succeed + let server2 = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client2 = UnixSocketClientTransport(path: path) + + let serverTask2 = Task { try await server2.connect() } + try await Task.sleep(for: .milliseconds(100)) + try await client2.connect() + _ = try await serverTask2.value + + await client2.disconnect() + await server2.disconnect() + } + + @Test("Server Socket Cleanup - Fail If Exists") + func testServerSocketCleanupFailIfExists() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + // Create a dummy file at the socket path + let fileData = Data("test".utf8) + #if canImport(Darwin) + try fileData.write(to: URL(fileURLWithPath: path)) + #else + let fd = open(path, O_CREAT | O_WRONLY, 0o644) + close(fd) + #endif + + // Try to create server with failIfExists - should fail + let server = UnixSocketServerTransport(path: path, cleanup: .failIfExists) + + do { + try await server.connect() + #expect(Bool(false), "Expected connect to throw an error") + } catch { + // Expected to fail + #expect(error is MCPError) + } + + await server.disconnect() + } + + @Test("Server Socket Cleanup - Reuse Stale Socket") + func testServerSocketCleanupReuseStale() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + // Create a stale socket file (not a real socket) + let fileData = Data("stale".utf8) + #if canImport(Darwin) + try fileData.write(to: URL(fileURLWithPath: path)) + #else + let fd = open(path, O_CREAT | O_WRONLY, 0o644) + close(fd) + #endif + + // Server with reuseIfPossible should remove stale file and succeed + let server = UnixSocketServerTransport( + path: path, cleanup: .reuseIfPossible) + let client = UnixSocketClientTransport(path: path) + + let serverTask = Task { try await server.connect() } + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + await client.disconnect() + await server.disconnect() + } + + @Test("Server Send Error (Disconnected)") + func testServerSendErrorDisconnected() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client = UnixSocketClientTransport(path: path) + + // Connect and disconnect + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + await server.disconnect() + + // Try to send after disconnect - should fail + do { + try await server.send("test".data(using: .utf8)!) + #expect(Bool(false), "Expected send to throw an error") + } catch { + #expect(error is MCPError) + } + + await client.disconnect() + } + + @Test("Server Connection Lifecycle") + func testServerConnectionLifecycle() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client = UnixSocketClientTransport(path: path) + + // Connect + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Send a message + try await client.send(#"{"test":"data"}"#.data(using: .utf8)!) + + // Disconnect + await server.disconnect() + await client.disconnect() + + // Verify socket file is cleaned up + #expect(access(path, F_OK) != 0) + } + + @Test("Server Accepts Multiple Sequential Clients") + func testServerAcceptsMultipleSequentialClients() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + _ = try await serverTask.value + + let msg1 = #"{"jsonrpc":"2.0","method":"test1","id":1}"# + let msg2 = #"{"jsonrpc":"2.0","method":"test2","id":2}"# + let connectionNotification = UnixSocketServerTransport.newConnectionNotification + + // Collect messages in background - returns array when done (filter out connection notifications) + let receiveTask = Task { () -> [Data] in + var messages: [Data] = [] + for try await data in await server.receive() { + // Skip connection notifications + if data != connectionNotification { + messages.append(data) + } + if messages.count >= 2 { + break + } + } + return messages + } + + // First client connects, sends message, disconnects + let client1 = UnixSocketClientTransport(path: path) + try await client1.connect() + try await Task.sleep(for: .milliseconds(50)) + try await client1.send(msg1.data(using: .utf8)!) + try await Task.sleep(for: .milliseconds(50)) + await client1.disconnect() + + // Wait a bit for server to process disconnect + try await Task.sleep(for: .milliseconds(100)) + + // Second client connects and sends message + let client2 = UnixSocketClientTransport(path: path) + try await client2.connect() + try await Task.sleep(for: .milliseconds(50)) + try await client2.send(msg2.data(using: .utf8)!) + + // Wait for messages to be received + let receivedMessages = try await receiveTask.value + + // Verify both messages were received + #expect(receivedMessages.count == 2) + #expect(receivedMessages[0] == msg1.data(using: .utf8)!) + #expect(receivedMessages[1] == msg2.data(using: .utf8)!) + + await client2.disconnect() + await server.disconnect() + } + + @Test("Server Three Sequential Clients With Requests") + func testServerThreeSequentialClientsWithRequests() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + + // Start server + try await server.connect() + + // Messages for each client (simulating initialize requests) + let requests = [ + #"{"jsonrpc":"2.0","method":"initialize","id":1,"params":{"protocolVersion":"2024-11-05"}}"#, + #"{"jsonrpc":"2.0","method":"initialize","id":1,"params":{"protocolVersion":"2024-11-05"}}"#, + #"{"jsonrpc":"2.0","method":"initialize","id":1,"params":{"protocolVersion":"2024-11-05"}}"#, + ] + + let connectionNotification = UnixSocketServerTransport.newConnectionNotification + + // Collect messages in background (filter out connection notifications) + let receiveTask = Task { () -> [Data] in + var messages: [Data] = [] + for try await data in await server.receive() { + // Skip connection notifications + if data != connectionNotification { + messages.append(data) + } + if messages.count >= 3 { + break + } + } + return messages + } + + // Process 3 sequential clients + for (index, request) in requests.enumerated() { + let client = UnixSocketClientTransport(path: path) + try await client.connect() + try await Task.sleep(for: .milliseconds(50)) + + // Send request + try await client.send(request.data(using: .utf8)!) + + // Simulate server sending response (no error) + let response = #"{"jsonrpc":"2.0","result":{"protocolVersion":"2024-11-05"},"id":1}"# + try await Task.sleep(for: .milliseconds(50)) + try await server.send(response.data(using: .utf8)!) + + // Client receives response + let clientStream = await client.receive() + var iterator = clientStream.makeAsyncIterator() + let receivedResponse = try await iterator.next() + + // Verify response has no error + #expect(receivedResponse != nil) + let responseString = String(data: receivedResponse!, encoding: .utf8)! + #expect(!responseString.contains("\"error\""), "Client \(index + 1) should not receive error") + #expect(responseString.contains("\"result\""), "Client \(index + 1) should receive result") + + await client.disconnect() + try await Task.sleep(for: .milliseconds(100)) + } + + // Verify all 3 requests were received + let receivedMessages = try await receiveTask.value + #expect(receivedMessages.count == 3) + + // Verify each request is valid JSON-RPC (no errors in what we received) + for (index, msg) in receivedMessages.enumerated() { + let msgString = String(data: msg, encoding: .utf8)! + #expect(msgString.contains("\"method\":\"initialize\""), "Message \(index + 1) should be initialize request") + #expect(!msgString.contains("\"error\""), "Message \(index + 1) should not contain error") + } + + await server.disconnect() + } + + @Test("Server Socket Path Max Constant") + func testServerSocketPathMaxConstant() { + // Verify the constant is reasonable + #if canImport(Darwin) + #expect(UnixSocketServerTransport.socketPathMax == 103) + #elseif canImport(Glibc) || canImport(Musl) + #expect(UnixSocketServerTransport.socketPathMax >= 107) + #endif + } + + @Test("Server Non-Blocking Socket") + func testServerNonBlockingSocket() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + + // Client connects with non-blocking socket + try await client.connect() + _ = try await serverTask.value + + // Send multiple messages rapidly + for i in 0..<10 { + try await client.send(#"{"id":\#(i)}"#.data(using: .utf8)!) + } + + // Verify non-blocking operation succeeded + #expect(Bool(true)) + + await client.disconnect() + await server.disconnect() + } + + @Test("Server Multiple Connect Calls Are Idempotent") + func testServerMultipleConnectCallsIdempotent() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + + // Connect client + try await client.connect() + _ = try await serverTask.value + + // Multiple connect calls should be idempotent (return early) + try await server.connect() + try await server.connect() + + await client.disconnect() + await server.disconnect() + } + + @Test("Server Empty Message Handling") + func testServerEmptyMessageHandling() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketServerTransport( + path: path, cleanup: .removeExisting) + let client = UnixSocketClientTransport(path: path) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + _ = try await serverTask.value + + // Send empty data + try await client.send(Data()) + + // Server should not receive anything (empty line is filtered out) + let stream = await server.receive() + var iterator = stream.makeAsyncIterator() + + // First message will be connection notification, skip it + var received = try await iterator.next() + if let data = received, isConnectionNotification(data) { + // Send a real message to unblock the iterator + try await Task.sleep(for: .milliseconds(50)) + let message = #"{"test":"real"}"# + try await client.send(message.data(using: .utf8)!) + + received = try await iterator.next() + #expect(received == message.data(using: .utf8)!) + } else { + // No connection notification, proceed as before + try await Task.sleep(for: .milliseconds(50)) + let message = #"{"test":"real"}"# + try await client.send(message.data(using: .utf8)!) + #expect(received == message.data(using: .utf8)!) + } + + await client.disconnect() + await server.disconnect() + } + } +#endif diff --git a/Tests/MCPTests/UnixSocketTransportTests.swift b/Tests/MCPTests/UnixSocketTransportTests.swift new file mode 100644 index 00000000..6ed8e0b7 --- /dev/null +++ b/Tests/MCPTests/UnixSocketTransportTests.swift @@ -0,0 +1,440 @@ +import Foundation +import Testing + +@testable import MCP + +#if canImport(System) + import System +#else + @preconcurrency import SystemPackage +#endif + +#if canImport(Darwin) + import Darwin.POSIX +#elseif canImport(Glibc) + import Glibc +#elseif canImport(Musl) + import Musl +#endif + +@Suite("Unix Socket Transport Tests") +struct UnixSocketTransportTests { + /// Generates a unique temporary socket path + private func tempSocketPath() -> String { + #if canImport(Darwin) + return NSTemporaryDirectory() + "mcp-test-\(UUID().uuidString).sock" + #else + return "/tmp/mcp-test-\(UUID().uuidString).sock" + #endif + } + + /// Cleanup helper to remove socket file + private func cleanup(_ path: String) { + unlink(path) + } + + @Test("Client-Server Connection") + func testClientServerConnection() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketTransport( + path: path, mode: .server(cleanup: .removeExisting)) + let client = UnixSocketTransport(path: path, mode: .client) + + // Start server in background + let serverTask = Task { + try await server.connect() + } + + // Give server time to start listening + try await Task.sleep(for: .milliseconds(100)) + + // Connect client + try await client.connect() + + // Wait for server to accept + try await serverTask.value + + // Verify both are connected + #expect(true) // If we got here, connection succeeded + + await server.disconnect() + await client.disconnect() + } + + @Test("Send Message Client to Server") + func testSendMessageClientToServer() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketTransport( + path: path, mode: .server(cleanup: .removeExisting)) + let client = UnixSocketTransport(path: path, mode: .client) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + try await serverTask.value + + // Send message from client + let message = #"{"jsonrpc":"2.0","method":"test","id":1}"# + try await client.send(message.data(using: .utf8)!) + + // Receive on server + let stream = await server.receive() + var iterator = stream.makeAsyncIterator() + + let received = try await iterator.next() + #expect(received == message.data(using: .utf8)!) + + await server.disconnect() + await client.disconnect() + } + + @Test("Send Message Server to Client") + func testSendMessageServerToClient() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketTransport( + path: path, mode: .server(cleanup: .removeExisting)) + let client = UnixSocketTransport(path: path, mode: .client) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + try await serverTask.value + + // Send message from server + let message = #"{"jsonrpc":"2.0","result":"ok","id":1}"# + try await server.send(message.data(using: .utf8)!) + + // Receive on client + let stream = await client.receive() + var iterator = stream.makeAsyncIterator() + + let received = try await iterator.next() + #expect(received == message.data(using: .utf8)!) + + await server.disconnect() + await client.disconnect() + } + + @Test("Bidirectional Communication") + func testBidirectionalCommunication() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketTransport( + path: path, mode: .server(cleanup: .removeExisting)) + let client = UnixSocketTransport(path: path, mode: .client) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + try await serverTask.value + + // Client sends request + let request = #"{"jsonrpc":"2.0","method":"ping","id":1}"# + try await client.send(request.data(using: .utf8)!) + + // Server receives and sends response + let serverStream = await server.receive() + var serverIterator = serverStream.makeAsyncIterator() + + let receivedRequest = try await serverIterator.next() + #expect(receivedRequest == request.data(using: .utf8)!) + + let response = #"{"jsonrpc":"2.0","result":"pong","id":1}"# + try await server.send(response.data(using: .utf8)!) + + // Client receives response + let clientStream = await client.receive() + var clientIterator = clientStream.makeAsyncIterator() + + let receivedResponse = try await clientIterator.next() + #expect(receivedResponse == response.data(using: .utf8)!) + + await server.disconnect() + await client.disconnect() + } + + @Test("Multiple Messages") + func testMultipleMessages() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketTransport( + path: path, mode: .server(cleanup: .removeExisting)) + let client = UnixSocketTransport(path: path, mode: .client) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + try await serverTask.value + + // Send multiple messages + let messages = [ + #"{"jsonrpc":"2.0","method":"test1","id":1}"#, + #"{"jsonrpc":"2.0","method":"test2","id":2}"#, + #"{"jsonrpc":"2.0","method":"test3","id":3}"#, + ] + + for message in messages { + try await client.send(message.data(using: .utf8)!) + } + + // Receive all messages + let stream = await server.receive() + var iterator = stream.makeAsyncIterator() + + for expectedMessage in messages { + let received = try await iterator.next() + #expect(received == expectedMessage.data(using: .utf8)!) + } + + await server.disconnect() + await client.disconnect() + } + + @Test("Large Message") + func testLargeMessage() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketTransport( + path: path, mode: .server(cleanup: .removeExisting)) + let client = UnixSocketTransport(path: path, mode: .client) + + // Start server + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + try await serverTask.value + + // Create a large message (20KB) + let largeData = String(repeating: "x", count: 20000) + let message = + #"{"jsonrpc":"2.0","method":"test","params":{"data":"\#(largeData)"},"id":1}"# + + try await client.send(message.data(using: .utf8)!) + + // Receive on server + let stream = await server.receive() + var iterator = stream.makeAsyncIterator() + + let received = try await iterator.next() + #expect(received == message.data(using: .utf8)!) + + await server.disconnect() + await client.disconnect() + } + + @Test("Socket Cleanup - Remove Existing") + func testSocketCleanupRemoveExisting() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + // Create first server and connect with a client + let server1 = try UnixSocketTransport( + path: path, mode: .server(cleanup: .removeExisting)) + let client1 = try UnixSocketTransport(path: path, mode: .client) + + let serverTask1 = Task { try await server1.connect() } + try await Task.sleep(for: .milliseconds(100)) + try await client1.connect() + try await serverTask1.value + + // Disconnect both + await client1.disconnect() + await server1.disconnect() + + // Socket file should not exist after disconnect + #expect(access(path, F_OK) != 0) + + // Create second server with removeExisting - should succeed even if file exists + let server2 = try UnixSocketTransport( + path: path, mode: .server(cleanup: .removeExisting)) + let client2 = try UnixSocketTransport(path: path, mode: .client) + + let serverTask2 = Task { try await server2.connect() } + try await Task.sleep(for: .milliseconds(100)) + try await client2.connect() + try await serverTask2.value + + await client2.disconnect() + await server2.disconnect() + } + + @Test("Socket Cleanup - Fail If Exists") + func testSocketCleanupFailIfExists() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + // Create a dummy file at the socket path + let fileData = Data("test".utf8) + #if canImport(Darwin) + try fileData.write(to: URL(fileURLWithPath: path)) + #else + let fd = open(path, O_CREAT | O_WRONLY, 0o644) + close(fd) + #endif + + // Try to create server with failIfExists - should fail + let server = UnixSocketTransport(path: path, mode: .server(cleanup: .failIfExists)) + + do { + try await server.connect() + #expect(Bool(false), "Expected connect to throw an error") + } catch { + // Expected to fail + #expect(error is MCPError) + } + + await server.disconnect() + } + + @Test("Socket Cleanup - Reuse If Possible (Stale Socket)") + func testSocketCleanupReuseStale() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + // Create a stale socket file (not a real socket) + let fileData = Data("stale".utf8) + #if canImport(Darwin) + try fileData.write(to: URL(fileURLWithPath: path)) + #else + let fd = open(path, O_CREAT | O_WRONLY, 0o644) + close(fd) + #endif + + // Server with reuseIfPossible should remove stale file and succeed + let server = UnixSocketTransport( + path: path, mode: .server(cleanup: .reuseIfPossible)) + let client = UnixSocketTransport(path: path, mode: .client) + + let serverTask = Task { try await server.connect() } + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + try await serverTask.value + + await client.disconnect() + await server.disconnect() + } + + @Test("Invalid Socket Path (Too Long)") + func testInvalidSocketPathTooLong() async throws { + // Create a path that exceeds platform limits (Darwin: 104, Linux: 108 bytes) + let longPath = "/tmp/" + String(repeating: "x", count: 200) + ".sock" + let transport = UnixSocketTransport(path: longPath, mode: .client) + + do { + try await transport.connect() + #expect(Bool(false), "Expected connect to throw an error") + } catch { + #expect(error is MCPError) + if case .internalError(let msg) = error as? MCPError { + #expect(msg?.contains("Socket path too long") == true) + } + } + + await transport.disconnect() + } + + @Test("Client Connection Failure (No Server)") + func testClientConnectionFailureNoServer() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let client = UnixSocketTransport(path: path, mode: .client) + + do { + try await client.connect() + #expect(Bool(false), "Expected connect to throw an error") + } catch { + // Expected to fail - no server listening + #expect(error is MCPError) + } + + await client.disconnect() + } + + @Test("Send Error (Disconnected)") + func testSendErrorDisconnected() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketTransport( + path: path, mode: .server(cleanup: .removeExisting)) + let client = UnixSocketTransport(path: path, mode: .client) + + // Connect and disconnect + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + try await serverTask.value + await client.disconnect() + + // Try to send after disconnect - should fail + do { + try await client.send("test".data(using: .utf8)!) + #expect(Bool(false), "Expected send to throw an error") + } catch { + #expect(error is MCPError) + } + + await server.disconnect() + } + + @Test("Connection Lifecycle") + func testConnectionLifecycle() async throws { + let path = tempSocketPath() + defer { cleanup(path) } + + let server = UnixSocketTransport( + path: path, mode: .server(cleanup: .removeExisting)) + let client = UnixSocketTransport(path: path, mode: .client) + + // Connect + let serverTask = Task { + try await server.connect() + } + + try await Task.sleep(for: .milliseconds(100)) + try await client.connect() + try await serverTask.value + + // Send a message + try await client.send(#"{"test":"data"}"#.data(using: .utf8)!) + + // Disconnect + await client.disconnect() + await server.disconnect() + + // Verify socket file is cleaned up + #expect(access(path, F_OK) != 0) + } +}