diff --git a/Sources/Subprocess/API.swift b/Sources/Subprocess/API.swift index b7788d6..6710d0a 100644 --- a/Sources/Subprocess/API.swift +++ b/Sources/Subprocess/API.swift @@ -105,9 +105,9 @@ public func run< output: try output.createPipe(), error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in - var inputIOBox: TrackedPlatformDiskIO? = consume inputIO - var outputIOBox: TrackedPlatformDiskIO? = consume outputIO - var errorIOBox: TrackedPlatformDiskIO? = consume errorIO + var inputIOBox: IOChannel? = consume inputIO + var outputIOBox: IOChannel? = consume outputIO + var errorIOBox: IOChannel? = consume errorIO // Write input, capture output and error in parallel async let stdout = try output.captureOutput(from: outputIOBox.take()) @@ -177,12 +177,12 @@ public func run( output: try output.createPipe(), error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in - var inputIOBox: TrackedPlatformDiskIO? = consume inputIO - var outputIOBox: TrackedPlatformDiskIO? = consume outputIO + var inputIOBox: IOChannel? = consume inputIO + var outputIOBox: IOChannel? = consume outputIO return try await withThrowingTaskGroup( of: Void.self, returning: Result.self ) { group in - var inputIOContainer: TrackedPlatformDiskIO? = inputIOBox.take() + var inputIOContainer: IOChannel? = inputIOBox.take() group.addTask { if let inputIO = inputIOContainer.take() { let writer = StandardInputWriter(diskIO: inputIO) @@ -253,7 +253,7 @@ public func run( } // Body runs in the same isolation - let outputSequence = AsyncBufferSequence(diskIO: outputIOBox.take()!.consumeDiskIO()) + let outputSequence = AsyncBufferSequence(diskIO: outputIOBox.take()!.consumeIOChannel()) let result = try await body(execution, outputSequence) try await group.waitForAll() return result @@ -299,13 +299,13 @@ public func run( output: try output.createPipe(), error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in - var inputIOBox: TrackedPlatformDiskIO? = consume inputIO - var errorIOBox: TrackedPlatformDiskIO? = consume errorIO + var inputIOBox: IOChannel? = consume inputIO + var errorIOBox: IOChannel? = consume errorIO return try await withThrowingTaskGroup( of: Void.self, returning: Result.self ) { group in - var inputIOContainer: TrackedPlatformDiskIO? = inputIOBox.take() + var inputIOContainer: IOChannel? = inputIOBox.take() group.addTask { if let inputIO = inputIOContainer.take() { let writer = StandardInputWriter(diskIO: inputIO) @@ -315,7 +315,7 @@ public func run( } // Body runs in the same isolation - let errorSequence = AsyncBufferSequence(diskIO: errorIOBox.take()!.consumeDiskIO()) + let errorSequence = AsyncBufferSequence(diskIO: errorIOBox.take()!.consumeIOChannel()) let result = try await body(execution, errorSequence) try await group.waitForAll() return result @@ -363,7 +363,7 @@ public func run( error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in let writer = StandardInputWriter(diskIO: inputIO!) - let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeDiskIO()) + let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeIOChannel()) return try await body(execution, writer, outputSequence) } } @@ -408,7 +408,7 @@ public func run( error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in let writer = StandardInputWriter(diskIO: inputIO!) - let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeDiskIO()) + let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeIOChannel()) return try await body(execution, writer, errorSequence) } } @@ -460,8 +460,8 @@ public func run( error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in let writer = StandardInputWriter(diskIO: inputIO!) - let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeDiskIO()) - let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeDiskIO()) + let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeIOChannel()) + let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeIOChannel()) return try await body(execution, writer, outputSequence, errorSequence) } } @@ -497,16 +497,16 @@ public func run< error: try error.createPipe() ) { (execution, inputIO, outputIO, errorIO) -> RunResult in // Write input, capture output and error in parallel - var inputIOBox: TrackedPlatformDiskIO? = consume inputIO - var outputIOBox: TrackedPlatformDiskIO? = consume outputIO - var errorIOBox: TrackedPlatformDiskIO? = consume errorIO + var inputIOBox: IOChannel? = consume inputIO + var outputIOBox: IOChannel? = consume outputIO + var errorIOBox: IOChannel? = consume errorIO return try await withThrowingTaskGroup( of: OutputCapturingState?.self, returning: RunResult.self ) { group in - var inputIOContainer: TrackedPlatformDiskIO? = inputIOBox.take() - var outputIOContainer: TrackedPlatformDiskIO? = outputIOBox.take() - var errorIOContainer: TrackedPlatformDiskIO? = errorIOBox.take() + var inputIOContainer: IOChannel? = inputIOBox.take() + var outputIOContainer: IOChannel? = outputIOBox.take() + var errorIOContainer: IOChannel? = errorIOBox.take() group.addTask { if let writeFd = inputIOContainer.take() { let writer = StandardInputWriter(diskIO: writeFd) @@ -580,8 +580,8 @@ public func run( error: try error.createPipe() ) { execution, inputIO, outputIO, errorIO in let writer = StandardInputWriter(diskIO: inputIO!) - let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeDiskIO()) - let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeDiskIO()) + let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeIOChannel()) + let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeIOChannel()) return try await body(execution, writer, outputSequence, errorSequence) } } diff --git a/Sources/Subprocess/AsyncBufferSequence.swift b/Sources/Subprocess/AsyncBufferSequence.swift index 39fb38b..0ca7e6e 100644 --- a/Sources/Subprocess/AsyncBufferSequence.swift +++ b/Sources/Subprocess/AsyncBufferSequence.swift @@ -19,14 +19,16 @@ internal import Dispatch #endif -public struct AsyncBufferSequence: AsyncSequence, Sendable { +public struct AsyncBufferSequence: AsyncSequence, @unchecked Sendable { public typealias Failure = any Swift.Error public typealias Element = Buffer - #if os(Windows) - internal typealias DiskIO = FileDescriptor - #else + #if canImport(Darwin) internal typealias DiskIO = DispatchIO + #elseif canImport(WinSDK) + internal typealias DiskIO = HANDLE + #else + internal typealias DiskIO = FileDescriptor #endif @_nonSendable @@ -47,15 +49,18 @@ public struct AsyncBufferSequence: AsyncSequence, Sendable { return self.buffer.removeFirst() } // Read more data - let data = try await self.diskIO.read( - upToLength: readBufferSize + let data = try await AsyncIO.shared.read( + from: self.diskIO, + upTo: readBufferSize ) guard let data else { // We finished reading. Close the file descriptor now - #if os(Windows) - try self.diskIO.close() + #if canImport(Darwin) + try _safelyClose(.dispatchIO(self.diskIO)) + #elseif canImport(WinSDK) + try _safelyClose(.handle(self.diskIO)) #else - self.diskIO.close() + try _safelyClose(.fileDescriptor(self.diskIO)) #endif return nil } @@ -132,17 +137,7 @@ extension AsyncBufferSequence { self.eofReached = true return nil } - #if os(Windows) - // Cast data to CodeUnit type - let result = buffer.withUnsafeBytes { ptr in - return Array( - UnsafeBufferPointer( - start: ptr.bindMemory(to: Encoding.CodeUnit.self).baseAddress!, - count: ptr.count / MemoryLayout.size - ) - ) - } - #else + #if canImport(Darwin) // Unfortunately here we _have to_ copy the bytes out because // DispatchIO (rightfully) reuses buffer, which means `buffer.data` // has the same address on all iterations, therefore we can't directly @@ -157,7 +152,13 @@ extension AsyncBufferSequence { UnsafeBufferPointer(start: ptr.baseAddress?.assumingMemoryBound(to: Encoding.CodeUnit.self), count: elementCount) ) } - + #else + // Cast data to CodeUnitg type + let result = buffer.withUnsafeBytes { ptr in + return ptr.withMemoryRebound(to: Encoding.CodeUnit.self) { codeUnitPtr in + return Array(codeUnitPtr) + } + } #endif return result.isEmpty ? nil : result } @@ -340,7 +341,7 @@ private let _pageSize: Int = { Int(_subprocess_vm_size()) }() #elseif canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK private let _pageSize: Int = { var sysInfo: SYSTEM_INFO = SYSTEM_INFO() GetSystemInfo(&sysInfo) diff --git a/Sources/Subprocess/Buffer.swift b/Sources/Subprocess/Buffer.swift index 94b8f52..292fac4 100644 --- a/Sources/Subprocess/Buffer.swift +++ b/Sources/Subprocess/Buffer.swift @@ -17,18 +17,8 @@ extension AsyncBufferSequence { /// A immutable collection of bytes public struct Buffer: Sendable { - #if os(Windows) - internal let data: [UInt8] - - internal init(data: [UInt8]) { - self.data = data - } - - internal static func createFrom(_ data: [UInt8]) -> [Buffer] { - return [.init(data: data)] - } - #else - // We need to keep the backingData alive while _ContiguousBufferView is alive + #if canImport(Darwin) + // We need to keep the backingData alive while Slice is alive internal let backingData: DispatchData internal let data: DispatchData._ContiguousBufferView @@ -45,7 +35,17 @@ extension AsyncBufferSequence { } return slices.map{ .init(data: $0, backingData: data) } } - #endif + #else + internal let data: [UInt8] + + internal init(data: [UInt8]) { + self.data = data + } + + internal static func createFrom(_ data: [UInt8]) -> [Buffer] { + return [.init(data: data)] + } + #endif // canImport(Darwin) } } @@ -92,26 +92,23 @@ extension AsyncBufferSequence.Buffer { // MARK: - Hashable, Equatable extension AsyncBufferSequence.Buffer: Equatable, Hashable { - #if os(Windows) - // Compiler generated conformances - #else + #if canImport(Darwin) public static func == (lhs: AsyncBufferSequence.Buffer, rhs: AsyncBufferSequence.Buffer) -> Bool { - return lhs.data.elementsEqual(rhs.data) + return lhs.data == rhs.data } public func hash(into hasher: inout Hasher) { - self.data.withUnsafeBytes { ptr in - hasher.combine(bytes: ptr) - } + hasher.combine(self.data) } #endif + // else Compiler generated conformances } // MARK: - DispatchData.Block #if canImport(Darwin) || canImport(Glibc) || canImport(Android) || canImport(Musl) extension DispatchData { /// Unfortunately `DispatchData.Region` is not available on Linux, hence our own wrapper - internal struct _ContiguousBufferView: @unchecked Sendable, RandomAccessCollection { + internal struct _ContiguousBufferView: @unchecked Sendable, RandomAccessCollection, Hashable { typealias Element = UInt8 internal let bytes: UnsafeBufferPointer @@ -127,6 +124,14 @@ extension DispatchData { return try body(UnsafeRawBufferPointer(self.bytes)) } + internal func hash(into hasher: inout Hasher) { + hasher.combine(bytes: UnsafeRawBufferPointer(self.bytes)) + } + + internal static func == (lhs: DispatchData._ContiguousBufferView, rhs: DispatchData._ContiguousBufferView) -> Bool { + return lhs.bytes.elementsEqual(rhs.bytes) + } + subscript(position: Int) -> UInt8 { _read { yield self.bytes[position] diff --git a/Sources/Subprocess/CMakeLists.txt b/Sources/Subprocess/CMakeLists.txt index ce78541..58bf205 100644 --- a/Sources/Subprocess/CMakeLists.txt +++ b/Sources/Subprocess/CMakeLists.txt @@ -17,6 +17,7 @@ target_sources(Subprocess PRIVATE Result.swift IO/Output.swift IO/Input.swift + IO/AsyncIO.swift Span+Subprocess.swift AsyncBufferSequence.swift API.swift diff --git a/Sources/Subprocess/Configuration.swift b/Sources/Subprocess/Configuration.swift index 5396506..5293c8d 100644 --- a/Sources/Subprocess/Configuration.swift +++ b/Sources/Subprocess/Configuration.swift @@ -24,7 +24,7 @@ import Glibc #elseif canImport(Musl) import Musl #elseif canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK #endif internal import Dispatch @@ -64,7 +64,7 @@ public struct Configuration: Sendable { output: consuming CreatedPipe, error: consuming CreatedPipe, isolation: isolated (any Actor)? = #isolation, - _ body: ((Execution, consuming TrackedPlatformDiskIO?, consuming TrackedPlatformDiskIO?, consuming TrackedPlatformDiskIO?) async throws -> Result) + _ body: ((Execution, consuming IOChannel?, consuming IOChannel?, consuming IOChannel?) async throws -> Result) ) async throws -> ExecutionResult { let spawnResults = try self.spawn( withInput: input, @@ -139,12 +139,12 @@ extension Configuration { /// Close each input individually, and throw the first error if there's multiple errors thrown @Sendable internal func safelyCloseMultiple( - inputRead: consuming TrackedFileDescriptor?, - inputWrite: consuming TrackedFileDescriptor?, - outputRead: consuming TrackedFileDescriptor?, - outputWrite: consuming TrackedFileDescriptor?, - errorRead: consuming TrackedFileDescriptor?, - errorWrite: consuming TrackedFileDescriptor? + inputRead: consuming IODescriptor?, + inputWrite: consuming IODescriptor?, + outputRead: consuming IODescriptor?, + outputWrite: consuming IODescriptor?, + errorRead: consuming IODescriptor?, + errorWrite: consuming IODescriptor? ) throws { var possibleError: (any Swift.Error)? = nil @@ -495,15 +495,15 @@ extension Configuration { /// via `SpawnResult` to perform actual reads internal struct SpawnResult: ~Copyable { let execution: Execution - var _inputWriteEnd: TrackedPlatformDiskIO? - var _outputReadEnd: TrackedPlatformDiskIO? - var _errorReadEnd: TrackedPlatformDiskIO? + var _inputWriteEnd: IOChannel? + var _outputReadEnd: IOChannel? + var _errorReadEnd: IOChannel? init( execution: Execution, - inputWriteEnd: consuming TrackedPlatformDiskIO?, - outputReadEnd: consuming TrackedPlatformDiskIO?, - errorReadEnd: consuming TrackedPlatformDiskIO? + inputWriteEnd: consuming IOChannel?, + outputReadEnd: consuming IOChannel?, + errorReadEnd: consuming IOChannel? ) { self.execution = execution self._inputWriteEnd = consume inputWriteEnd @@ -511,15 +511,15 @@ extension Configuration { self._errorReadEnd = consume errorReadEnd } - mutating func inputWriteEnd() -> TrackedPlatformDiskIO? { + mutating func inputWriteEnd() -> IOChannel? { return self._inputWriteEnd.take() } - mutating func outputReadEnd() -> TrackedPlatformDiskIO? { + mutating func outputReadEnd() -> IOChannel? { return self._outputReadEnd.take() } - mutating func errorReadEnd() -> TrackedPlatformDiskIO? { + mutating func errorReadEnd() -> IOChannel? { return self._errorReadEnd.take() } } @@ -581,36 +581,47 @@ internal enum StringOrRawBytes: Sendable, Hashable { } } -/// A wrapped `FileDescriptor` and whether it should be closed -/// automatically when done. -internal struct TrackedFileDescriptor: ~Copyable { - internal var closeWhenDone: Bool - internal let fileDescriptor: FileDescriptor - - internal init( - _ fileDescriptor: FileDescriptor, - closeWhenDone: Bool - ) { - self.fileDescriptor = fileDescriptor - self.closeWhenDone = closeWhenDone - } - - #if os(Windows) - consuming func consumeDiskIO() -> FileDescriptor { - let result = self.fileDescriptor - // Transfer the ownership out and therefor - // don't perform close on deinit - self.closeWhenDone = false - return result - } +internal enum _CloseTarget { + #if canImport(WinSDK) + case handle(HANDLE) #endif + case fileDescriptor(FileDescriptor) + case dispatchIO(DispatchIO) +} - internal mutating func safelyClose() throws { - guard self.closeWhenDone else { - return +internal func _safelyClose(_ target: _CloseTarget) throws { + switch target { + #if canImport(WinSDK) + case .handle(let handle): + /// Windows does not provide a “deregistration” API (the reverse of + /// `CreateIoCompletionPort`) for handles and it it reuses HANDLE + /// values once they are closed. Since we rely on the handle value + /// as the completion key for `CreateIoCompletionPort`, we should + /// remove the registration when the handle is closed to allow + /// new registration to proceed if the handle is reused. + AsyncIO.shared.removeRegistration(for: handle) + guard CloseHandle(handle) else { + let error = GetLastError() + // Getting `ERROR_INVALID_HANDLE` suggests that the file descriptor + // might have been closed unexpectedly. This can pose security risks + // if another part of the code inadvertently reuses the same file descriptor + // number. This problem is especially concerning on Unix systems due to POSIX’s + // guarantee of using the lowest available file descriptor number, making reuse + // more probable. We use `fatalError` upon receiving `.badFileDescriptor` + // to prevent accidentally closing a different file descriptor. + guard error != ERROR_INVALID_HANDLE else { + fatalError( + "HANDLE \(handle) is already closed" + ) + } + let subprocessError = SubprocessError( + code: .init(.asyncIOFailed("Failed to close HANDLE")), + underlyingError: .init(rawValue: error) + ) + throw subprocessError } - closeWhenDone = false - + #endif + case .fileDescriptor(let fileDescriptor): do { try fileDescriptor.close() } catch { @@ -632,6 +643,69 @@ internal struct TrackedFileDescriptor: ~Copyable { // Throw other kinds of errors to allow user to catch them throw error } + case .dispatchIO(let dispatchIO): + dispatchIO.close() + } +} + +/// `IODescriptor` wraps platform-specific `FileDescriptor`, +/// which is used to establish a connection to the standard input/output (IO) +/// system during the process of spawning a child process. Unlike `IODescriptor`, +/// the `IODescriptor` does not support data read/write operations; +/// its primary function is to facilitate the spawning of child processes +/// by providing a platform-specific file descriptor. +internal struct IODescriptor: ~Copyable { + #if canImport(WinSDK) + typealias Descriptor = HANDLE + #else + typealias Descriptor = FileDescriptor + #endif + + internal var closeWhenDone: Bool + internal let descriptor: Descriptor + + internal init( + _ descriptor: Descriptor, + closeWhenDone: Bool + ) { + self.descriptor = descriptor + self.closeWhenDone = closeWhenDone + } + + consuming func createIOChannel() -> IOChannel { + let shouldClose = self.closeWhenDone + self.closeWhenDone = false + #if canImport(Darwin) + // Transferring out the ownership of fileDescriptor means we don't have go close here + let closeFd = self.descriptor + let dispatchIO: DispatchIO = DispatchIO( + type: .stream, + fileDescriptor: self.platformDescriptor(), + queue: .global(), + cleanupHandler: { error in + // Close the file descriptor + if shouldClose { + try? closeFd.close() + } + } + ) + return IOChannel(dispatchIO, closeWhenDone: shouldClose) + #else + return IOChannel(self.descriptor, closeWhenDone: shouldClose) + #endif + } + + internal mutating func safelyClose() throws { + guard self.closeWhenDone else { + return + } + closeWhenDone = false + + #if canImport(WinSDK) + try _safelyClose(.handle(self.descriptor)) + #else + try _safelyClose(.fileDescriptor(self.descriptor)) + #endif } deinit { @@ -639,77 +713,178 @@ internal struct TrackedFileDescriptor: ~Copyable { return } - fatalError("FileDescriptor \(self.fileDescriptor.rawValue) was not closed") + fatalError("FileDescriptor \(self.descriptor) was not closed") } internal func platformDescriptor() -> PlatformFileDescriptor { - return self.fileDescriptor.platformDescriptor + #if canImport(WinSDK) + return self.descriptor + #else + return self.descriptor.platformDescriptor + #endif } } -#if !os(Windows) -/// A wrapped `DispatchIO` and whether it should be closed -/// automatically when done. -internal struct TrackedDispatchIO: ~Copyable { +internal struct IOChannel: ~Copyable, @unchecked Sendable { + #if canImport(WinSDK) + typealias Channel = HANDLE + #elseif canImport(Darwin) + typealias Channel = DispatchIO + #else + typealias Channel = FileDescriptor + #endif + internal var closeWhenDone: Bool - internal var dispatchIO: DispatchIO + internal let channel: Channel internal init( - _ dispatchIO: DispatchIO, + _ channel: Channel, closeWhenDone: Bool ) { - self.dispatchIO = dispatchIO + self.channel = channel self.closeWhenDone = closeWhenDone } - consuming func consumeDiskIO() -> DispatchIO { - let result = self.dispatchIO - // Transfer the ownership out and therefor - // don't perform close on deinit - self.closeWhenDone = false - return result - } - internal mutating func safelyClose() throws { guard self.closeWhenDone else { return } closeWhenDone = false - dispatchIO.close() - } - deinit { - guard self.closeWhenDone else { - return - } + #if canImport(WinSDK) + try _safelyClose(.handle(self.channel)) + #elseif canImport(Darwin) + try _safelyClose(.dispatchIO(self.channel)) + #else + try _safelyClose(.fileDescriptor(self.channel)) + #endif + } - fatalError("DispatchIO \(self.dispatchIO) was not closed") + internal consuming func consumeIOChannel() -> Channel { + let result = self.channel + // Transfer the ownership out and therefor + // don't perform close on deinit + self.closeWhenDone = false + return result } } -#endif internal struct CreatedPipe: ~Copyable { - internal var _readFileDescriptor: TrackedFileDescriptor? - internal var _writeFileDescriptor: TrackedFileDescriptor? + internal enum Purpose: CustomStringConvertible { + /// This pipe is used for standard input. This option maps to + /// `PIPE_ACCESS_OUTBOUND` on Windows where child only reads, + /// parent only writes. + case input + /// This pipe is used for standard output and standard error. + /// This option maps to `PIPE_ACCESS_INBOUND` on Windows where + /// child only writes, parent only reads. + case output + + var description: String { + switch self { + case .input: + return "input" + case .output: + return "output" + } + } + } + + internal var _readFileDescriptor: IODescriptor? + internal var _writeFileDescriptor: IODescriptor? internal init( - readFileDescriptor: consuming TrackedFileDescriptor?, - writeFileDescriptor: consuming TrackedFileDescriptor? + readFileDescriptor: consuming IODescriptor?, + writeFileDescriptor: consuming IODescriptor? ) { self._readFileDescriptor = readFileDescriptor self._writeFileDescriptor = writeFileDescriptor } - mutating func readFileDescriptor() -> TrackedFileDescriptor? { + mutating func readFileDescriptor() -> IODescriptor? { return self._readFileDescriptor.take() } - mutating func writeFileDescriptor() -> TrackedFileDescriptor? { + mutating func writeFileDescriptor() -> IODescriptor? { return self._writeFileDescriptor.take() } - internal init(closeWhenDone: Bool) throws { - let pipe = try FileDescriptor.ssp_pipe() + internal init(closeWhenDone: Bool, purpose: Purpose) throws { + #if canImport(WinSDK) + // On Windows, we need to create a named pipe + let pipeName = "\\\\.\\pipe\\subprocess-\(purpose)-\(Int.random(in: .min ..< .max))" + var saAttributes: SECURITY_ATTRIBUTES = SECURITY_ATTRIBUTES() + saAttributes.nLength = DWORD(MemoryLayout.size) + saAttributes.bInheritHandle = true + saAttributes.lpSecurityDescriptor = nil + + let parentEnd = pipeName.withCString( + encodedAs: UTF16.self + ) { pipeNameW in + // Use OVERLAPPED for async IO + var openMode: DWORD = DWORD(FILE_FLAG_OVERLAPPED) + switch purpose { + case .input: + openMode |= DWORD(PIPE_ACCESS_OUTBOUND) + case .output: + openMode |= DWORD(PIPE_ACCESS_INBOUND) + } + + return CreateNamedPipeW( + pipeNameW, + openMode, + DWORD(PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_WAIT), + 1, // Max instance, + DWORD(readBufferSize), + DWORD(readBufferSize), + 0, + &saAttributes + ) + } + guard let parentEnd, parentEnd != INVALID_HANDLE_VALUE else { + throw SubprocessError( + code: .init(.asyncIOFailed("CreateNamedPipeW failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + } + + let childEnd = pipeName.withCString( + encodedAs: UTF16.self + ) { pipeNameW in + var targetAccess: DWORD = 0 + switch purpose { + case .input: + targetAccess = DWORD(GENERIC_READ) + case .output: + targetAccess = DWORD(GENERIC_WRITE) + } + + return CreateFileW( + pipeNameW, + targetAccess, + 0, + &saAttributes, + DWORD(OPEN_EXISTING), + DWORD(FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED), + nil + ) + } + guard let childEnd, childEnd != INVALID_HANDLE_VALUE else { + throw SubprocessError( + code: .init(.asyncIOFailed("CreateFileW failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + } + switch purpose { + case .input: + self._readFileDescriptor = .init(childEnd, closeWhenDone: closeWhenDone) + self._writeFileDescriptor = .init(parentEnd, closeWhenDone: closeWhenDone) + case .output: + self._readFileDescriptor = .init(parentEnd, closeWhenDone: closeWhenDone) + self._writeFileDescriptor = .init(childEnd, closeWhenDone: closeWhenDone) + } + #else + let pipe = try FileDescriptor.pipe() self._readFileDescriptor = .init( pipe.readEnd, closeWhenDone: closeWhenDone @@ -718,6 +893,7 @@ internal struct CreatedPipe: ~Copyable { pipe.writeEnd, closeWhenDone: closeWhenDone ) + #endif } } diff --git a/Sources/Subprocess/Error.swift b/Sources/Subprocess/Error.swift index dde4468..5e4bd80 100644 --- a/Sources/Subprocess/Error.swift +++ b/Sources/Subprocess/Error.swift @@ -18,7 +18,7 @@ import Glibc #elseif canImport(Musl) import Musl #elseif canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK #endif /// Error thrown from Subprocess @@ -41,6 +41,7 @@ extension SubprocessError { case failedToWriteToSubprocess case failedToMonitorProcess case streamOutputExceedsLimit(Int) + case asyncIOFailed(String) // Signal case failedToSendSignal(Int32) // Windows Only @@ -67,18 +68,20 @@ extension SubprocessError { return 5 case .streamOutputExceedsLimit(_): return 6 - case .failedToSendSignal(_): + case .asyncIOFailed(_): return 7 - case .failedToTerminate: + case .failedToSendSignal(_): return 8 - case .failedToSuspend: + case .failedToTerminate: return 9 - case .failedToResume: + case .failedToSuspend: return 10 - case .failedToCreatePipe: + case .failedToResume: return 11 - case .invalidWindowsPath(_): + case .failedToCreatePipe: return 12 + case .invalidWindowsPath(_): + return 13 } } @@ -108,6 +111,8 @@ extension SubprocessError: CustomStringConvertible, CustomDebugStringConvertible return "Failed to monitor the state of child process with underlying error: \(self.underlyingError!)" case .streamOutputExceedsLimit(let limit): return "Failed to create output from current buffer because the output limit (\(limit)) was reached." + case .asyncIOFailed(let reason): + return "An error occurred within the AsyncIO subsystem: \(reason). Underlying error: \(self.underlyingError!)" case .failedToSendSignal(let signal): return "Failed to send signal \(signal) to the child process." case .failedToTerminate: diff --git a/Sources/Subprocess/Execution.swift b/Sources/Subprocess/Execution.swift index a21a170..66f8628 100644 --- a/Sources/Subprocess/Execution.swift +++ b/Sources/Subprocess/Execution.swift @@ -24,7 +24,7 @@ import Glibc #elseif canImport(Musl) import Musl #elseif canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK #endif /// An object that represents a subprocess that has been diff --git a/Sources/Subprocess/IO/AsyncIO.swift b/Sources/Subprocess/IO/AsyncIO.swift new file mode 100644 index 0000000..48fe4b3 --- /dev/null +++ b/Sources/Subprocess/IO/AsyncIO.swift @@ -0,0 +1,1058 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2025 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +#if canImport(System) +@preconcurrency import System +#else +@preconcurrency import SystemPackage +#endif + +/// Platform specific asynchronous read/write implementation + +// MARK: - Linux (epoll) +#if canImport(Glibc) || canImport(Android) || canImport(Musl) + +#if canImport(Glibc) +import Glibc +#elseif canImport(Android) +import Android +#elseif canImport(Musl) +import Musl +#endif + +import _SubprocessCShims +import Synchronization + +private typealias SignalStream = AsyncThrowingStream +private let _epollEventSize = 256 +private let _registration: Mutex< + [PlatformFileDescriptor : SignalStream.Continuation] +> = Mutex([:]) + +final class AsyncIO: Sendable { + + typealias OutputStream = AsyncThrowingStream + + private final class MonitorThreadContext { + let epollFileDescriptor: CInt + let shutdownFileDescriptor: CInt + + init( + epollFileDescriptor: CInt, + shutdownFileDescriptor: CInt + ) { + self.epollFileDescriptor = epollFileDescriptor + self.shutdownFileDescriptor = shutdownFileDescriptor + } + } + + private enum Event { + case read + case write + } + + private struct State { + let epollFileDescriptor: CInt + let shutdownFileDescriptor: CInt + let monitorThread: pthread_t + } + + static let shared: AsyncIO = AsyncIO() + + private let state: Result + + private init() { + // Create main epoll fd + let epollFileDescriptor = epoll_create1(CInt(EPOLL_CLOEXEC)) + guard epollFileDescriptor >= 0 else { + let error = SubprocessError( + code: .init(.asyncIOFailed("epoll_create1 failed")), + underlyingError: .init(rawValue: errno) + ) + self.state = .failure(error) + return + } + // Create shutdownFileDescriptor + let shutdownFileDescriptor = eventfd(0, CInt(EFD_NONBLOCK | EFD_CLOEXEC)) + guard shutdownFileDescriptor >= 0 else { + let error = SubprocessError( + code: .init(.asyncIOFailed("eventfd failed")), + underlyingError: .init(rawValue: errno) + ) + self.state = .failure(error) + return + } + + // Register shutdownFileDescriptor with epoll + var event = epoll_event( + events: EPOLLIN.rawValue, + data: epoll_data(fd: shutdownFileDescriptor) + ) + var rc = epoll_ctl( + epollFileDescriptor, + EPOLL_CTL_ADD, + shutdownFileDescriptor, + &event + ) + guard rc == 0 else { + let error = SubprocessError( + code: .init(.asyncIOFailed( + "failed to add shutdown fd \(shutdownFileDescriptor) to epoll list") + ), + underlyingError: .init(rawValue: errno) + ) + self.state = .failure(error) + return + } + + // Create thread data + let context = MonitorThreadContext( + epollFileDescriptor: epollFileDescriptor, + shutdownFileDescriptor: shutdownFileDescriptor + ) + let threadContext = Unmanaged.passRetained(context) + #if os(FreeBSD) || os(OpenBSD) + var thread: pthread_t? = nil + #else + var thread: pthread_t = pthread_t() + #endif + rc = pthread_create(&thread, nil, { args in + func reportError(_ error: SubprocessError) { + _registration.withLock { store in + for continuation in store.values { + continuation.finish(throwing: error) + } + } + } + + let unmanaged = Unmanaged.fromOpaque(args!) + let context = unmanaged.takeRetainedValue() + + var events: [epoll_event] = Array( + repeating: epoll_event(events: 0, data: epoll_data(fd: 0)), + count: _epollEventSize + ) + + // Enter the monitor loop + monitorLoop: while true { + let eventCount = epoll_wait( + context.epollFileDescriptor, + &events, + CInt(events.count), + -1 + ) + if eventCount < 0 { + if errno == EINTR || errno == EAGAIN { + continue // interrupted by signal; try again + } + // Report other errors + let error = SubprocessError( + code: .init(.asyncIOFailed( + "epoll_wait failed") + ), + underlyingError: .init(rawValue: errno) + ) + reportError(error) + break monitorLoop + } + + for index in 0 ..< Int(eventCount) { + let event = events[index] + let targetFileDescriptor = event.data.fd + // Breakout the monitor loop if we received shutdown + // from the shutdownFD + if targetFileDescriptor == context.shutdownFileDescriptor { + var buf: UInt64 = 0 + _ = _SubprocessCShims.read(context.shutdownFileDescriptor, &buf, MemoryLayout.size) + break monitorLoop + } + + // Notify the continuation + _registration.withLock { store in + if let continuation = store[targetFileDescriptor] { + continuation.yield(true) + } + } + } + } + + return nil + }, threadContext.toOpaque()) + guard rc == 0 else { + let error = SubprocessError( + code: .init(.asyncIOFailed("Failed to create monitor thread")), + underlyingError: .init(rawValue: rc) + ) + self.state = .failure(error) + return + } + + #if os(FreeBSD) || os(OpenBSD) + let monitorThread = thread! + #else + let monitorThread = thread + #endif + + let state = State( + epollFileDescriptor: epollFileDescriptor, + shutdownFileDescriptor: shutdownFileDescriptor, + monitorThread: monitorThread + ) + self.state = .success(state) + + atexit { + AsyncIO.shared.shutdown() + } + } + + private func shutdown() { + guard case .success(let currentState) = self.state else { + return + } + + var one: UInt64 = 1 + // Wake up the thread for shutdown + _ = _SubprocessCShims.write(currentState.shutdownFileDescriptor, &one, MemoryLayout.stride) + // Cleanup the monitor thread + pthread_join(currentState.monitorThread, nil) + } + + + private func registerFileDescriptor( + _ fileDescriptor: FileDescriptor, + for event: Event + ) -> SignalStream { + return SignalStream { continuation in + // If setup failed, nothing much we can do + switch self.state { + case .success(let state): + // Set file descriptor to be non blocking + let flags = fcntl(fileDescriptor.rawValue, F_GETFD) + guard flags != -1 else { + let error = SubprocessError( + code: .init(.asyncIOFailed( + "failed to get flags for \(fileDescriptor.rawValue)") + ), + underlyingError: .init(rawValue: errno) + ) + continuation.finish(throwing: error) + return + } + guard fcntl(fileDescriptor.rawValue, F_SETFL, flags | O_NONBLOCK) != -1 else { + let error = SubprocessError( + code: .init(.asyncIOFailed( + "failed to set \(fileDescriptor.rawValue) to be non-blocking") + ), + underlyingError: .init(rawValue: errno) + ) + continuation.finish(throwing: error) + return + } + // Register event + let targetEvent: EPOLL_EVENTS + switch event { + case .read: + targetEvent = EPOLLIN + case .write: + targetEvent = EPOLLOUT + } + + var event = epoll_event( + events: targetEvent.rawValue, + data: epoll_data(fd: fileDescriptor.rawValue) + ) + let rc = epoll_ctl( + state.epollFileDescriptor, + EPOLL_CTL_ADD, + fileDescriptor.rawValue, + &event + ) + if rc != 0 { + let error = SubprocessError( + code: .init(.asyncIOFailed( + "failed to add \(fileDescriptor.rawValue) to epoll list") + ), + underlyingError: .init(rawValue: errno) + ) + continuation.finish(throwing: error) + return + } + // Now save the continuation + _registration.withLock { storage in + storage[fileDescriptor.rawValue] = continuation + } + case .failure(let setupError): + continuation.finish(throwing: setupError) + return + } + } + } + + private func removeRegistration(for fileDescriptor: FileDescriptor) throws { + switch self.state { + case .success(let state): + let rc = epoll_ctl( + state.epollFileDescriptor, + EPOLL_CTL_DEL, + fileDescriptor.rawValue, + nil + ) + guard rc == 0 else { + throw SubprocessError( + code: .init(.asyncIOFailed( + "failed to remove \(fileDescriptor.rawValue) to epoll list") + ), + underlyingError: .init(rawValue: errno) + ) + } + _registration.withLock { store in + _ = store.removeValue(forKey: fileDescriptor.rawValue) + } + case .failure(let setupFailure): + throw setupFailure + } + } +} + +extension AsyncIO { + + protocol _ContiguousBytes { + var count: Int { get } + + func withUnsafeBytes( + _ body: (UnsafeRawBufferPointer) throws -> ResultType + ) rethrows -> ResultType + } + + func read( + from diskIO: borrowing IOChannel, + upTo maxLength: Int + ) async throws -> [UInt8]? { + return try await self.read(from: diskIO.channel, upTo: maxLength) + } + + func read( + from fileDescriptor: FileDescriptor, + upTo maxLength: Int + ) async throws -> [UInt8]? { + // If we are reading until EOF, start with readBufferSize + // and gradually increase buffer size + let bufferLength = maxLength == .max ? readBufferSize : maxLength + + var resultBuffer: [UInt8] = Array( + repeating: 0, count: bufferLength + ) + var readLength: Int = 0 + let signalStream = self.registerFileDescriptor(fileDescriptor, for: .read) + /// Outer loop: every iteration signals we are ready to read more data + for try await _ in signalStream { + /// Inner loop: repeatedly call `.read()` and read more data until: + /// 1. We reached EOF (read length is 0), in which case return the result + /// 2. We read `maxLength` bytes, in which case return the result + /// 3. `read()` returns -1 and sets `errno` to `EAGAIN` or `EWOULDBLOCK`. In + /// this case we `break` out of the inner loop and wait `.read()` to be + /// ready by `await`ing the next signal in the outer loop. + while true { + let bytesRead = resultBuffer.withUnsafeMutableBufferPointer { bufferPointer in + // Get a pointer to the memory at the specified offset + let targetCount = bufferPointer.count - readLength + + let offsetAddress = bufferPointer.baseAddress!.advanced(by: readLength) + + // Read directly into the buffer at the offset + return _SubprocessCShims.read(fileDescriptor.rawValue, offsetAddress, targetCount) + } + if bytesRead > 0 { + // Read some data + readLength += bytesRead + if maxLength == .max { + // Grow resultBuffer if needed + guard Double(readLength) > 0.8 * Double(resultBuffer.count) else { + continue + } + resultBuffer.append( + contentsOf: Array(repeating: 0, count: resultBuffer.count) + ) + } else if readLength >= maxLength { + // When we reached maxLength, return! + try self.removeRegistration(for: fileDescriptor) + return resultBuffer + } + } else if bytesRead == 0 { + // We reached EOF. Return whatever's left + try self.removeRegistration(for: fileDescriptor) + guard readLength > 0 else { + return nil + } + resultBuffer.removeLast(resultBuffer.count - readLength) + return resultBuffer + } else { + if errno == EAGAIN || errno == EWOULDBLOCK { + // No more data for now wait for the next signal + break + } else { + // Throw all other errors + try self.removeRegistration(for: fileDescriptor) + throw SubprocessError.UnderlyingError(rawValue: errno) + } + } + } + } + return resultBuffer + } + + func write( + _ array: [UInt8], + to diskIO: borrowing IOChannel + ) async throws -> Int { + return try await self._write(array, to: diskIO) + } + + func _write( + _ bytes: Bytes, + to diskIO: borrowing IOChannel + ) async throws -> Int { + let fileDescriptor = diskIO.channel + let signalStream = self.registerFileDescriptor(fileDescriptor, for: .write) + var writtenLength: Int = 0 + /// Outer loop: every iteration signals we are ready to read more data + for try await _ in signalStream { + /// Inner loop: repeatedly call `.write()` and write more data until: + /// 1. We've written bytes.count bytes. + /// 3. `.write()` returns -1 and sets `errno` to `EAGAIN` or `EWOULDBLOCK`. In + /// this case we `break` out of the inner loop and wait `.write()` to be + /// ready by `await`ing the next signal in the outer loop. + while true { + let written = bytes.withUnsafeBytes { ptr in + let remainingLength = ptr.count - writtenLength + let startPtr = ptr.baseAddress!.advanced(by: writtenLength) + return _SubprocessCShims.write(fileDescriptor.rawValue, startPtr, remainingLength) + } + if written > 0 { + writtenLength += written + if writtenLength >= bytes.count { + // Wrote all data + try self.removeRegistration(for: fileDescriptor) + return writtenLength + } + } else { + if errno == EAGAIN || errno == EWOULDBLOCK { + // No more data for now wait for the next signal + break + } else { + // Throw all other errors + try self.removeRegistration(for: fileDescriptor) + throw SubprocessError.UnderlyingError(rawValue: errno) + } + } + } + } + return 0 + } + + #if SubprocessSpan + func write( + _ span: borrowing RawSpan, + to diskIO: borrowing IOChannel + ) async throws -> Int { + let fileDescriptor = diskIO.channel + let signalStream = self.registerFileDescriptor(fileDescriptor, for: .write) + var writtenLength: Int = 0 + /// Outer loop: every iteration signals we are ready to read more data + for try await _ in signalStream { + /// Inner loop: repeatedly call `.write()` and write more data until: + /// 1. We've written bytes.count bytes. + /// 3. `.write()` returns -1 and sets `errno` to `EAGAIN` or `EWOULDBLOCK`. In + /// this case we `break` out of the inner loop and wait `.write()` to be + /// ready by `await`ing the next signal in the outer loop. + while true { + let written = span.withUnsafeBytes { ptr in + let remainingLength = ptr.count - writtenLength + let startPtr = ptr.baseAddress!.advanced(by: writtenLength) + return _SubprocessCShims.write(fileDescriptor.rawValue, startPtr, remainingLength) + } + if written > 0 { + writtenLength += written + if writtenLength >= span.byteCount { + // Wrote all data + try self.removeRegistration(for: fileDescriptor) + return writtenLength + } + } else { + if errno == EAGAIN || errno == EWOULDBLOCK { + // No more data for now wait for the next signal + break + } else { + // Throw all other errors + try self.removeRegistration(for: fileDescriptor) + throw SubprocessError.UnderlyingError(rawValue: errno) + } + } + } + } + return 0 + } + #endif +} + +extension Array : AsyncIO._ContiguousBytes where Element == UInt8 {} + +#endif // canImport(Glibc) || canImport(Android) || canImport(Musl) + +// MARK: - macOS (DispatchIO) +#if canImport(Darwin) + +internal import Dispatch + + +final class AsyncIO: Sendable { + static let shared: AsyncIO = AsyncIO() + + private init() {} + + internal func read( + from diskIO: borrowing IOChannel, + upTo maxLength: Int + ) async throws -> DispatchData? { + return try await self.read( + from: diskIO.channel, + upTo: maxLength, + ) + } + + internal func read( + from dispatchIO: DispatchIO, + upTo maxLength: Int + ) async throws -> DispatchData? { + return try await withCheckedThrowingContinuation { continuation in + var buffer: DispatchData = .empty + dispatchIO.read( + offset: 0, + length: maxLength, + queue: .global() + ) { done, data, error in + if error != 0 { + continuation.resume( + throwing: SubprocessError( + code: .init(.failedToReadFromSubprocess), + underlyingError: .init(rawValue: error) + ) + ) + return + } + if let data = data { + if buffer.isEmpty { + buffer = data + } else { + buffer.append(data) + } + } + if done { + if !buffer.isEmpty { + continuation.resume(returning: buffer) + } else { + continuation.resume(returning: nil) + } + } + } + } + } + + #if SubprocessSpan + internal func write( + _ span: borrowing RawSpan, + to diskIO: borrowing IOChannel + ) async throws -> Int { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let dispatchData = span.withUnsafeBytes { + return DispatchData( + bytesNoCopy: $0, + deallocator: .custom( + nil, + { + // noop + } + ) + ) + } + self.write(dispatchData, to: diskIO) { writtenLength, error in + if let error = error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: writtenLength) + } + } + } + } + #endif // SubprocessSpan + + internal func write( + _ array: [UInt8], + to diskIO: borrowing IOChannel + ) async throws -> Int { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let dispatchData = array.withUnsafeBytes { + return DispatchData( + bytesNoCopy: $0, + deallocator: .custom( + nil, + { + // noop + } + ) + ) + } + self.write(dispatchData, to: diskIO) { writtenLength, error in + if let error = error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: writtenLength) + } + } + } + } + + internal func write( + _ dispatchData: DispatchData, + to diskIO: borrowing IOChannel, + queue: DispatchQueue = .global(), + completion: @escaping (Int, Error?) -> Void + ) { + diskIO.channel.write( + offset: 0, + data: dispatchData, + queue: queue + ) { done, unwritten, error in + guard done else { + // Wait until we are done writing or encountered some error + return + } + + let unwrittenLength = unwritten?.count ?? 0 + let writtenLength = dispatchData.count - unwrittenLength + guard error != 0 else { + completion(writtenLength, nil) + return + } + completion( + writtenLength, + SubprocessError( + code: .init(.failedToWriteToSubprocess), + underlyingError: .init(rawValue: error) + ) + ) + } + } +} + +#endif + +// MARK: - Windows (I/O Completion Ports) +#if os(Windows) + +import Synchronization +internal import Dispatch +@preconcurrency import WinSDK + +private typealias SignalStream = AsyncThrowingStream +private let shutdownPort: UInt64 = .max +private let _registration: Mutex< + [UInt64 : SignalStream.Continuation] +> = Mutex([:]) + +final class AsyncIO: @unchecked Sendable { + + protocol _ContiguousBytes: Sendable { + var count: Int { get } + + func withUnsafeBytes( + _ body: (UnsafeRawBufferPointer + ) throws -> ResultType) rethrows -> ResultType + } + + private final class MonitorThreadContext { + let ioCompletionPort: HANDLE + + init(ioCompletionPort: HANDLE) { + self.ioCompletionPort = ioCompletionPort + } + } + + static let shared = AsyncIO() + + private let ioCompletionPort: Result + + private let monitorThread: Result + + private init() { + var maybeSetupError: SubprocessError? = nil + // Create the the completion port + guard let port = CreateIoCompletionPort( + INVALID_HANDLE_VALUE, nil, 0, 0 + ), port != INVALID_HANDLE_VALUE else { + let error = SubprocessError( + code: .init(.asyncIOFailed("CreateIoCompletionPort failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + self.ioCompletionPort = .failure(error) + self.monitorThread = .failure(error) + return + } + self.ioCompletionPort = .success(port) + // Create monitor thread + let threadContext = MonitorThreadContext(ioCompletionPort: port) + let threadContextPtr = Unmanaged.passRetained(threadContext) + let threadHandle = CreateThread(nil, 0, { args in + func reportError(_ error: SubprocessError) { + _registration.withLock { store in + for continuation in store.values { + continuation.finish(throwing: error) + } + } + } + + let unmanaged = Unmanaged.fromOpaque(args!) + let context = unmanaged.takeRetainedValue() + + // Monitor loop + while true { + var bytesTransferred: DWORD = 0 + var targetFileDescriptor: UInt64 = 0 + var overlapped: LPOVERLAPPED? = nil + + let monitorResult = GetQueuedCompletionStatus( + context.ioCompletionPort, + &bytesTransferred, + &targetFileDescriptor, + &overlapped, + INFINITE + ) + if !monitorResult { + let lastError = GetLastError() + if lastError == ERROR_BROKEN_PIPE { + // We finished reading the handle. Signal EOF by + // finishing the stream. + // NOTE: here we deliberately leave now unused continuation + // in the store. Windows does not offer an API to remove a + // HANDLE from an IOCP port, therefore we leave the registration + // to signify the HANDLE has already been resisted. + _registration.withLock { store in + if let continuation = store[targetFileDescriptor] { + continuation.finish() + } + } + continue + } else { + let error = SubprocessError( + code: .init(.asyncIOFailed("GetQueuedCompletionStatus failed")), + underlyingError: .init(rawValue: lastError) + ) + reportError(error) + break + } + } + + // Breakout the monitor loop if we received shutdown from the shutdownFD + if targetFileDescriptor == shutdownPort { + break + } + // Notify the continuations + _registration.withLock { store in + if let continuation = store[targetFileDescriptor] { + continuation.yield(bytesTransferred) + } + } + } + return 0 + }, threadContextPtr.toOpaque(), 0, nil) + guard let threadHandle = threadHandle else { + let error = SubprocessError( + code: .init(.asyncIOFailed("CreateThread failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + self.monitorThread = .failure(error) + return + } + self.monitorThread = .success(threadHandle) + + atexit { + AsyncIO.shared.shutdown() + } + } + + private func shutdown() { + // Post status to shutdown HANDLE + guard case .success(let ioPort) = ioCompletionPort, + case .success(let monitorThreadHandle) = monitorThread else { + return + } + PostQueuedCompletionStatus( + ioPort, + 0, + shutdownPort, + nil + ) + // Wait for monitor thread to exit + WaitForSingleObject(monitorThreadHandle, INFINITE); + CloseHandle(ioPort) + CloseHandle(monitorThreadHandle) + } + + private func registerHandle(_ handle: HANDLE) -> SignalStream { + return SignalStream { continuation in + switch self.ioCompletionPort { + case .success(let ioPort): + // Make sure thread setup also succeed + if case .failure(let error) = monitorThread { + continuation.finish(throwing: error) + return + } + let completionKey = UInt64(UInt(bitPattern: handle)) + // Windows does not offer an API to remove a handle + // from given ioCompletionPort. If this handle has already + // been registered we simply need to update the continuation + let registrationFound = _registration.withLock { storage in + if storage[completionKey] != nil { + // Old registration found. This means this handle has + // already been registered. We simply need to update + // the continuation saved + storage[completionKey] = continuation + return true + } else { + return false + } + } + if registrationFound { + return + } + + // Windows Documentation: The function returns the handle + // of the existing I/O completion port if successful + guard CreateIoCompletionPort( + handle, ioPort, completionKey, 0 + ) == ioPort else { + let error = SubprocessError( + code: .init(.asyncIOFailed("CreateIoCompletionPort failed")), + underlyingError: .init(rawValue: GetLastError()) + ) + continuation.finish(throwing: error) + return + } + // Now save the continuation + _registration.withLock { storage in + storage[completionKey] = continuation + } + case .failure(let error): + continuation.finish(throwing: error) + } + } + } + + internal func removeRegistration(for handle: HANDLE) { + let completionKey = UInt64(UInt(bitPattern: handle)) + _registration.withLock { storage in + storage.removeValue(forKey: completionKey) + } + } + + func read( + from diskIO: borrowing IOChannel, + upTo maxLength: Int + ) async throws -> [UInt8]? { + return try await self.read(from: diskIO.channel, upTo: maxLength) + } + + func read( + from handle: HANDLE, + upTo maxLength: Int + ) async throws -> [UInt8]? { + // If we are reading until EOF, start with readBufferSize + // and gradually increase buffer size + let bufferLength = maxLength == .max ? readBufferSize : maxLength + + var resultBuffer: [UInt8] = Array( + repeating: 0, count: bufferLength + ) + var readLength: Int = 0 + var signalStream = self.registerHandle(handle).makeAsyncIterator() + + while true { + var overlapped = _OVERLAPPED() + let succeed = try resultBuffer.withUnsafeMutableBufferPointer { bufferPointer in + // Get a pointer to the memory at the specified offset + // Windows ReadFile uses DWORD for target count, which means we can only + // read up to DWORD (aka UInt32) max. + let targetCount = min(bufferPointer.count - readLength, Int(UInt32.max)) + + let offsetAddress = bufferPointer.baseAddress!.advanced(by: readLength) + // Read directly into the buffer at the offset + return ReadFile( + handle, + offsetAddress, + DWORD(truncatingIfNeeded: targetCount), + nil, + &overlapped + ) + } + + if !succeed { + // It is expected `ReadFile` to return `false` in async mode. + // Make sure we only get `ERROR_IO_PENDING` or `ERROR_BROKEN_PIPE` + let lastError = GetLastError() + if lastError == ERROR_BROKEN_PIPE { + // We reached EOF + return nil + } + guard lastError == ERROR_IO_PENDING else { + let error = SubprocessError( + code: .init(.failedToReadFromSubprocess), + underlyingError: .init(rawValue: lastError) + ) + throw error + } + + } + // Now wait for read to finish + let bytesRead = try await signalStream.next() ?? 0 + + if bytesRead == 0 { + // We reached EOF. Return whatever's left + guard readLength > 0 else { + return nil + } + resultBuffer.removeLast(resultBuffer.count - readLength) + return resultBuffer + } else { + // Read some data + readLength += Int(bytesRead) + if maxLength == .max { + // Grow resultBuffer if needed + guard Double(readLength) > 0.8 * Double(resultBuffer.count) else { + continue + } + resultBuffer.append( + contentsOf: Array(repeating: 0, count: resultBuffer.count) + ) + } else if readLength >= maxLength { + // When we reached maxLength, return! + return resultBuffer + } + } + } + } + + func write( + _ array: [UInt8], + to diskIO: borrowing IOChannel + ) async throws -> Int { + return try await self._write(array, to: diskIO) + } + + #if SubprocessSpan + func write( + _ span: borrowing RawSpan, + to diskIO: borrowing IOChannel + ) async throws -> Int { + let handle = diskIO.channel + var signalStream = self.registerHandle(diskIO.channel).makeAsyncIterator() + var writtenLength: Int = 0 + while true { + var overlapped = _OVERLAPPED() + let succeed = try span.withUnsafeBytes { ptr in + // Windows WriteFile uses DWORD for target count + // which means we can only write up to DWORD max + let remainingLength = min( + ptr.count - writtenLength, Int(DWORD.max) + ) + let startPtr = ptr.baseAddress!.advanced(by: writtenLength) + return WriteFile( + handle, + startPtr, + DWORD(truncatingIfNeeded: remainingLength), + nil, + &overlapped + ) + } + if !succeed { + // It is expected `WriteFile` to return `false` in async mode. + // Make sure we only get `ERROR_IO_PENDING` + let lastError = GetLastError() + guard lastError == ERROR_IO_PENDING else { + let error = SubprocessError( + code: .init(.failedToWriteToSubprocess), + underlyingError: .init(rawValue: lastError) + ) + throw error + } + + } + // Now wait for read to finish + let bytesWritten: DWORD = try await signalStream.next() ?? 0 + + writtenLength += Int(bytesWritten) + if writtenLength >= span.byteCount { + return writtenLength + } + } + } + #endif // SubprocessSpan + + func _write( + _ bytes: Bytes, + to diskIO: borrowing IOChannel + ) async throws -> Int { + let handle = diskIO.channel + var signalStream = self.registerHandle(diskIO.channel).makeAsyncIterator() + var writtenLength: Int = 0 + while true { + var overlapped = _OVERLAPPED() + let succeed = try bytes.withUnsafeBytes { ptr in + // Windows WriteFile uses DWORD for target count + // which means we can only write up to DWORD max + let remainingLength = min( + ptr.count - writtenLength, Int(DWORD.max) + ) + let startPtr = ptr.baseAddress!.advanced(by: writtenLength) + return WriteFile( + handle, + startPtr, + DWORD(truncatingIfNeeded: remainingLength), + nil, + &overlapped + ) + } + + if !succeed { + // It is expected `WriteFile` to return `false` in async mode. + // Make sure we only get `ERROR_IO_PENDING` + let lastError = GetLastError() + guard lastError == ERROR_IO_PENDING else { + let error = SubprocessError( + code: .init(.failedToWriteToSubprocess), + underlyingError: .init(rawValue: lastError) + ) + throw error + } + } + // Now wait for read to finish + let bytesWritten: DWORD = try await signalStream.next() ?? 0 + writtenLength += Int(bytesWritten) + if writtenLength >= bytes.count { + return writtenLength + } + } + } +} + +extension Array : AsyncIO._ContiguousBytes where Element == UInt8 {} + +#endif + diff --git a/Sources/Subprocess/IO/Input.swift b/Sources/Subprocess/IO/Input.swift index 58bfe4d..715428e 100644 --- a/Sources/Subprocess/IO/Input.swift +++ b/Sources/Subprocess/IO/Input.swift @@ -15,6 +15,10 @@ @preconcurrency import SystemPackage #endif +#if canImport(WinSDK) +@preconcurrency import WinSDK +#endif + #if SubprocessFoundation #if canImport(Darwin) @@ -78,9 +82,14 @@ public struct FileDescriptorInput: InputProtocol { private let closeAfterSpawningProcess: Bool internal func createPipe() throws -> CreatedPipe { + #if canImport(WinSDK) + let readFd = HANDLE(bitPattern: _get_osfhandle(self.fileDescriptor.rawValue))! + #else + let readFd = self.fileDescriptor + #endif return CreatedPipe( readFileDescriptor: .init( - self.fileDescriptor, + readFd, closeWhenDone: self.closeAfterSpawningProcess ), writeFileDescriptor: nil @@ -203,7 +212,7 @@ extension InputProtocol { return try fdInput.createPipe() } // Base implementation - return try CreatedPipe(closeWhenDone: true) + return try CreatedPipe(closeWhenDone: true, purpose: .input) } } @@ -212,9 +221,9 @@ extension InputProtocol { /// A writer that writes to the standard input of the subprocess. public final actor StandardInputWriter: Sendable { - internal var diskIO: TrackedPlatformDiskIO + internal var diskIO: IOChannel - init(diskIO: consuming TrackedPlatformDiskIO) { + init(diskIO: consuming IOChannel) { self.diskIO = diskIO } @@ -224,7 +233,7 @@ public final actor StandardInputWriter: Sendable { public func write( _ array: [UInt8] ) async throws -> Int { - return try await self.diskIO.write(array) + return try await AsyncIO.shared.write(array, to: self.diskIO) } /// Write a `RawSpan` to the standard input of the subprocess. @@ -232,7 +241,7 @@ public final actor StandardInputWriter: Sendable { /// - Returns number of bytes written #if SubprocessSpan public func write(_ span: borrowing RawSpan) async throws -> Int { - return try await self.diskIO.write(span) + return try await AsyncIO.shared.write(span, to: self.diskIO) } #endif diff --git a/Sources/Subprocess/IO/Output.swift b/Sources/Subprocess/IO/Output.swift index 454eca4..223ccd6 100644 --- a/Sources/Subprocess/IO/Output.swift +++ b/Sources/Subprocess/IO/Output.swift @@ -14,6 +14,11 @@ #else @preconcurrency import SystemPackage #endif + +#if canImport(WinSDK) +@preconcurrency import WinSDK +#endif + internal import Dispatch // MARK: - Output @@ -85,10 +90,15 @@ public struct FileDescriptorOutput: OutputProtocol { private let fileDescriptor: FileDescriptor internal func createPipe() throws -> CreatedPipe { + #if canImport(WinSDK) + let writeFd = HANDLE(bitPattern: _get_osfhandle(self.fileDescriptor.rawValue))! + #else + let writeFd = self.fileDescriptor + #endif return CreatedPipe( readFileDescriptor: nil, writeFileDescriptor: .init( - self.fileDescriptor, + writeFd, closeWhenDone: self.closeAfterSpawningProcess ) ) @@ -140,16 +150,14 @@ public struct BytesOutput: OutputProtocol { public let maxSize: Int internal func captureOutput( - from diskIO: consuming TrackedPlatformDiskIO? + from diskIO: consuming IOChannel ) async throws -> [UInt8] { - #if os(Windows) - let result = try await diskIO?.fileDescriptor.read(upToLength: self.maxSize) ?? [] - try diskIO?.safelyClose() - return result - #else - let result = try await diskIO!.dispatchIO.read(upToLength: self.maxSize) - try diskIO?.safelyClose() + let result = try await AsyncIO.shared.read(from: diskIO, upTo: self.maxSize) + try diskIO.safelyClose() + #if canImport(Darwin) return result?.array() ?? [] + #else + return result ?? [] #endif } @@ -249,35 +257,44 @@ extension OutputProtocol { return try fdOutput.createPipe() } // Base pipe based implementation for everything else - return try CreatedPipe(closeWhenDone: true) + return try CreatedPipe(closeWhenDone: true, purpose: .output) } /// Capture the output from the subprocess up to maxSize @_disfavoredOverload internal func captureOutput( - from diskIO: consuming TrackedPlatformDiskIO? + from diskIO: consuming IOChannel? ) async throws -> OutputType { - if let bytesOutput = self as? BytesOutput { - return try await bytesOutput.captureOutput(from: diskIO) as! Self.OutputType - } - if OutputType.self == Void.self { return () as! OutputType } - #if os(Windows) - let result = try await diskIO?.fileDescriptor.read(upToLength: self.maxSize) - try diskIO?.safelyClose() - return try self.output(from: result ?? []) - #else - let result = try await diskIO!.dispatchIO.read(upToLength: self.maxSize) - try diskIO?.safelyClose() + // `diskIO` is only `nil` for any types that conform to `OutputProtocol` + // and have `Void` as ``OutputType` (i.e. `DiscardedOutput`). Since we + // made sure `OutputType` is not `Void` on the line above, `diskIO` + // must not be nil; otherwise, this is a programmer error. + guard var diskIO else { + fatalError( + "Internal Inconsistency Error: diskIO must not be nil when OutputType is not Void" + ) + } + + if let bytesOutput = self as? BytesOutput { + return try await bytesOutput.captureOutput(from: diskIO) as! Self.OutputType + } + // Force unwrap is safe here because only `OutputType.self == Void` would + // have nil `IOChannel` + let result = try await AsyncIO.shared.read(from: diskIO, upTo: self.maxSize) + try diskIO.safelyClose() + #if canImport(Darwin) return try self.output(from: result ?? .empty) + #else + return try self.output(from: result ?? []) #endif } } extension OutputProtocol where OutputType == Void { - internal func captureOutput(from fileDescriptor: consuming TrackedPlatformDiskIO?) async throws {} + internal func captureOutput(from fileDescriptor: consuming IOChannel?) async throws {} #if SubprocessSpan /// Convert the output from Data to expected output type @@ -293,34 +310,34 @@ extension OutputProtocol where OutputType == Void { #if SubprocessSpan extension OutputProtocol { - #if os(Windows) - internal func output(from data: [UInt8]) throws -> OutputType { + #if canImport(Darwin) + internal func output(from data: DispatchData) throws -> OutputType { guard !data.isEmpty else { let empty = UnsafeRawBufferPointer(start: nil, count: 0) let span = RawSpan(_unsafeBytes: empty) return try self.output(from: span) } - return try data.withUnsafeBufferPointer { ptr in - let span = RawSpan(_unsafeBytes: UnsafeRawBufferPointer(ptr)) + return try data.withUnsafeBytes { ptr in + let bufferPtr = UnsafeRawBufferPointer(start: ptr, count: data.count) + let span = RawSpan(_unsafeBytes: bufferPtr) return try self.output(from: span) } } #else - internal func output(from data: DispatchData) throws -> OutputType { + internal func output(from data: [UInt8]) throws -> OutputType { guard !data.isEmpty else { let empty = UnsafeRawBufferPointer(start: nil, count: 0) let span = RawSpan(_unsafeBytes: empty) return try self.output(from: span) } - return try data.withUnsafeBytes { ptr in - let bufferPtr = UnsafeRawBufferPointer(start: ptr, count: data.count) - let span = RawSpan(_unsafeBytes: bufferPtr) + return try data.withUnsafeBufferPointer { ptr in + let span = RawSpan(_unsafeBytes: UnsafeRawBufferPointer(ptr)) return try self.output(from: span) } } - #endif // os(Windows) + #endif // canImport(Darwin) } #endif diff --git a/Sources/Subprocess/Platforms/Subprocess+Darwin.swift b/Sources/Subprocess/Platforms/Subprocess+Darwin.swift index 7cd32c7..5309f1e 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Darwin.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Darwin.swift @@ -173,12 +173,12 @@ extension Configuration { var _outputPipe = outputPipeBox.take()! var _errorPipe = errorPipeBox.take()! - let inputReadFileDescriptor: TrackedFileDescriptor? = _inputPipe.readFileDescriptor() - let inputWriteFileDescriptor: TrackedFileDescriptor? = _inputPipe.writeFileDescriptor() - let outputReadFileDescriptor: TrackedFileDescriptor? = _outputPipe.readFileDescriptor() - let outputWriteFileDescriptor: TrackedFileDescriptor? = _outputPipe.writeFileDescriptor() - let errorReadFileDescriptor: TrackedFileDescriptor? = _errorPipe.readFileDescriptor() - let errorWriteFileDescriptor: TrackedFileDescriptor? = _errorPipe.writeFileDescriptor() + let inputReadFileDescriptor: IODescriptor? = _inputPipe.readFileDescriptor() + let inputWriteFileDescriptor: IODescriptor? = _inputPipe.writeFileDescriptor() + let outputReadFileDescriptor: IODescriptor? = _outputPipe.readFileDescriptor() + let outputWriteFileDescriptor: IODescriptor? = _outputPipe.writeFileDescriptor() + let errorReadFileDescriptor: IODescriptor? = _errorPipe.readFileDescriptor() + let errorWriteFileDescriptor: IODescriptor? = _errorPipe.writeFileDescriptor() for possibleExecutablePath in possiblePaths { var pid: pid_t = 0 @@ -442,9 +442,9 @@ extension Configuration { ) return SpawnResult( execution: execution, - inputWriteEnd: inputWriteFileDescriptor?.createPlatformDiskIO(), - outputReadEnd: outputReadFileDescriptor?.createPlatformDiskIO(), - errorReadEnd: errorReadFileDescriptor?.createPlatformDiskIO() + inputWriteEnd: inputWriteFileDescriptor?.createIOChannel(), + outputReadEnd: outputReadFileDescriptor?.createIOChannel(), + errorReadEnd: errorReadFileDescriptor?.createIOChannel() ) } diff --git a/Sources/Subprocess/Platforms/Subprocess+Linux.swift b/Sources/Subprocess/Platforms/Subprocess+Linux.swift index bf1e5b8..3b449ac 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Linux.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Linux.swift @@ -56,12 +56,12 @@ extension Configuration { var _outputPipe = outputPipeBox.take()! var _errorPipe = errorPipeBox.take()! - let inputReadFileDescriptor: TrackedFileDescriptor? = _inputPipe.readFileDescriptor() - let inputWriteFileDescriptor: TrackedFileDescriptor? = _inputPipe.writeFileDescriptor() - let outputReadFileDescriptor: TrackedFileDescriptor? = _outputPipe.readFileDescriptor() - let outputWriteFileDescriptor: TrackedFileDescriptor? = _outputPipe.writeFileDescriptor() - let errorReadFileDescriptor: TrackedFileDescriptor? = _errorPipe.readFileDescriptor() - let errorWriteFileDescriptor: TrackedFileDescriptor? = _errorPipe.writeFileDescriptor() + let inputReadFileDescriptor: IODescriptor? = _inputPipe.readFileDescriptor() + let inputWriteFileDescriptor: IODescriptor? = _inputPipe.writeFileDescriptor() + let outputReadFileDescriptor: IODescriptor? = _outputPipe.readFileDescriptor() + let outputWriteFileDescriptor: IODescriptor? = _outputPipe.writeFileDescriptor() + let errorReadFileDescriptor: IODescriptor? = _errorPipe.readFileDescriptor() + let errorWriteFileDescriptor: IODescriptor? = _errorPipe.writeFileDescriptor() for possibleExecutablePath in possiblePaths { var processGroupIDPtr: UnsafeMutablePointer? = nil @@ -154,9 +154,9 @@ extension Configuration { ) return SpawnResult( execution: execution, - inputWriteEnd: inputWriteFileDescriptor?.createPlatformDiskIO(), - outputReadEnd: outputReadFileDescriptor?.createPlatformDiskIO(), - errorReadEnd: errorReadFileDescriptor?.createPlatformDiskIO() + inputWriteEnd: inputWriteFileDescriptor?.createIOChannel(), + outputReadEnd: outputReadFileDescriptor?.createIOChannel(), + errorReadEnd: errorReadFileDescriptor?.createIOChannel() ) } diff --git a/Sources/Subprocess/Platforms/Subprocess+Unix.swift b/Sources/Subprocess/Platforms/Subprocess+Unix.swift index bb5da2f..122e5c5 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Unix.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Unix.swift @@ -375,146 +375,5 @@ extension FileDescriptor { } internal typealias PlatformFileDescriptor = CInt -internal typealias TrackedPlatformDiskIO = TrackedDispatchIO - -extension TrackedFileDescriptor { - internal consuming func createPlatformDiskIO() -> TrackedPlatformDiskIO { - let dispatchIO: DispatchIO = DispatchIO( - type: .stream, - fileDescriptor: self.platformDescriptor(), - queue: .global(), - cleanupHandler: { error in - // Close the file descriptor - if self.closeWhenDone { - try? self.safelyClose() - } - } - ) - return .init(dispatchIO, closeWhenDone: self.closeWhenDone) - } -} - -// MARK: - TrackedDispatchIO extensions -extension DispatchIO { - internal func read(upToLength maxLength: Int) async throws -> DispatchData? { - return try await withCheckedThrowingContinuation { continuation in - var buffer: DispatchData = .empty - self.read( - offset: 0, - length: maxLength, - queue: .global() - ) { done, data, error in - if error != 0 { - continuation.resume( - throwing: SubprocessError( - code: .init(.failedToReadFromSubprocess), - underlyingError: .init(rawValue: error) - ) - ) - return - } - if let data = data { - if buffer.isEmpty { - buffer = data - } else { - buffer.append(data) - } - } - if done { - if !buffer.isEmpty { - continuation.resume(returning: buffer) - } else { - continuation.resume(returning: nil) - } - } - } - } - } -} - -extension TrackedDispatchIO { - #if SubprocessSpan - internal func write( - _ span: borrowing RawSpan - ) async throws -> Int { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - let dispatchData = span.withUnsafeBytes { - return DispatchData( - bytesNoCopy: $0, - deallocator: .custom( - nil, - { - // noop - } - ) - ) - } - self.write(dispatchData) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) - } - } - } - } - #endif // SubprocessSpan - - internal func write( - _ array: [UInt8] - ) async throws -> Int { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - let dispatchData = array.withUnsafeBytes { - return DispatchData( - bytesNoCopy: $0, - deallocator: .custom( - nil, - { - // noop - } - ) - ) - } - self.write(dispatchData) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) - } - } - } - } - - internal func write( - _ dispatchData: DispatchData, - queue: DispatchQueue = .global(), - completion: @escaping (Int, Error?) -> Void - ) { - self.dispatchIO.write( - offset: 0, - data: dispatchData, - queue: queue - ) { done, unwritten, error in - guard done else { - // Wait until we are done writing or encountered some error - return - } - - let unwrittenLength = unwritten?.count ?? 0 - let writtenLength = dispatchData.count - unwrittenLength - guard error != 0 else { - completion(writtenLength, nil) - return - } - completion( - writtenLength, - SubprocessError( - code: .init(.failedToWriteToSubprocess), - underlyingError: .init(rawValue: error) - ) - ) - } - } -} #endif // canImport(Darwin) || canImport(Glibc) || canImport(Android) || canImport(Musl) diff --git a/Sources/Subprocess/Platforms/Subprocess+Windows.swift b/Sources/Subprocess/Platforms/Subprocess+Windows.swift index 3dfb00c..63229aa 100644 --- a/Sources/Subprocess/Platforms/Subprocess+Windows.swift +++ b/Sources/Subprocess/Platforms/Subprocess+Windows.swift @@ -11,7 +11,7 @@ #if canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK internal import Dispatch #if canImport(System) @preconcurrency import System @@ -48,37 +48,32 @@ extension Configuration { outputPipe: consuming CreatedPipe, errorPipe: consuming CreatedPipe ) throws -> SpawnResult { - var inputPipeBox: CreatedPipe? = consume inputPipe - var outputPipeBox: CreatedPipe? = consume outputPipe - var errorPipeBox: CreatedPipe? = consume errorPipe - - var _inputPipe = inputPipeBox.take()! - var _outputPipe = outputPipeBox.take()! - var _errorPipe = errorPipeBox.take()! - - let inputReadFileDescriptor: TrackedFileDescriptor? = _inputPipe.readFileDescriptor() - let inputWriteFileDescriptor: TrackedFileDescriptor? = _inputPipe.writeFileDescriptor() - let outputReadFileDescriptor: TrackedFileDescriptor? = _outputPipe.readFileDescriptor() - let outputWriteFileDescriptor: TrackedFileDescriptor? = _outputPipe.writeFileDescriptor() - let errorReadFileDescriptor: TrackedFileDescriptor? = _errorPipe.readFileDescriptor() - let errorWriteFileDescriptor: TrackedFileDescriptor? = _errorPipe.writeFileDescriptor() - - let ( - applicationName, - commandAndArgs, - environment, - intendedWorkingDir - ): (String?, String, String, String?) + var inputReadFileDescriptor: IODescriptor? = inputPipe.readFileDescriptor() + var inputWriteFileDescriptor: IODescriptor? = inputPipe.writeFileDescriptor() + var outputReadFileDescriptor: IODescriptor? = outputPipe.readFileDescriptor() + var outputWriteFileDescriptor: IODescriptor? = outputPipe.writeFileDescriptor() + var errorReadFileDescriptor: IODescriptor? = errorPipe.readFileDescriptor() + var errorWriteFileDescriptor: IODescriptor? = errorPipe.writeFileDescriptor() + + let applicationName: String? + let commandAndArgs: String + let environment: String + let intendedWorkingDir: String? do { - (applicationName, commandAndArgs, environment, intendedWorkingDir) = try self.preSpawn() + ( + applicationName, + commandAndArgs, + environment, + intendedWorkingDir + ) = try self.preSpawn() } catch { try self.safelyCloseMultiple( - inputRead: inputReadFileDescriptor, - inputWrite: inputWriteFileDescriptor, - outputRead: outputReadFileDescriptor, - outputWrite: outputWriteFileDescriptor, - errorRead: errorReadFileDescriptor, - errorWrite: errorWriteFileDescriptor + inputRead: inputReadFileDescriptor.take(), + inputWrite: inputWriteFileDescriptor.take(), + outputRead: outputReadFileDescriptor.take(), + outputWrite: outputWriteFileDescriptor.take(), + errorRead: errorReadFileDescriptor.take(), + errorWrite: errorWriteFileDescriptor.take() ) throw error } @@ -167,9 +162,9 @@ extension Configuration { return SpawnResult( execution: execution, - inputWriteEnd: inputWriteFileDescriptor?.createPlatformDiskIO(), - outputReadEnd: outputReadFileDescriptor?.createPlatformDiskIO(), - errorReadEnd: errorReadFileDescriptor?.createPlatformDiskIO() + inputWriteEnd: inputWriteFileDescriptor?.createIOChannel(), + outputReadEnd: outputReadFileDescriptor?.createIOChannel(), + errorReadEnd: errorReadFileDescriptor?.createIOChannel() ) } @@ -187,12 +182,12 @@ extension Configuration { var _outputPipe = outputPipeBox.take()! var _errorPipe = errorPipeBox.take()! - let inputReadFileDescriptor: TrackedFileDescriptor? = _inputPipe.readFileDescriptor() - let inputWriteFileDescriptor: TrackedFileDescriptor? = _inputPipe.writeFileDescriptor() - let outputReadFileDescriptor: TrackedFileDescriptor? = _outputPipe.readFileDescriptor() - let outputWriteFileDescriptor: TrackedFileDescriptor? = _outputPipe.writeFileDescriptor() - let errorReadFileDescriptor: TrackedFileDescriptor? = _errorPipe.readFileDescriptor() - let errorWriteFileDescriptor: TrackedFileDescriptor? = _errorPipe.writeFileDescriptor() + let inputReadFileDescriptor: IODescriptor? = _inputPipe.readFileDescriptor() + let inputWriteFileDescriptor: IODescriptor? = _inputPipe.writeFileDescriptor() + let outputReadFileDescriptor: IODescriptor? = _outputPipe.readFileDescriptor() + let outputWriteFileDescriptor: IODescriptor? = _outputPipe.writeFileDescriptor() + let errorReadFileDescriptor: IODescriptor? = _errorPipe.readFileDescriptor() + let errorWriteFileDescriptor: IODescriptor? = _errorPipe.writeFileDescriptor() let ( applicationName, @@ -311,9 +306,9 @@ extension Configuration { return SpawnResult( execution: execution, - inputWriteEnd: inputWriteFileDescriptor?.createPlatformDiskIO(), - outputReadEnd: outputReadFileDescriptor?.createPlatformDiskIO(), - errorReadEnd: errorReadFileDescriptor?.createPlatformDiskIO() + inputWriteEnd: inputWriteFileDescriptor?.createIOChannel(), + outputReadEnd: outputReadFileDescriptor?.createIOChannel(), + errorReadEnd: errorReadFileDescriptor?.createIOChannel() ) } } @@ -778,12 +773,12 @@ extension Configuration { } private func generateStartupInfo( - withInputRead inputReadFileDescriptor: borrowing TrackedFileDescriptor?, - inputWrite inputWriteFileDescriptor: borrowing TrackedFileDescriptor?, - outputRead outputReadFileDescriptor: borrowing TrackedFileDescriptor?, - outputWrite outputWriteFileDescriptor: borrowing TrackedFileDescriptor?, - errorRead errorReadFileDescriptor: borrowing TrackedFileDescriptor?, - errorWrite errorWriteFileDescriptor: borrowing TrackedFileDescriptor?, + withInputRead inputReadFileDescriptor: borrowing IODescriptor?, + inputWrite inputWriteFileDescriptor: borrowing IODescriptor?, + outputRead outputReadFileDescriptor: borrowing IODescriptor?, + outputWrite outputWriteFileDescriptor: borrowing IODescriptor?, + errorRead errorReadFileDescriptor: borrowing IODescriptor?, + errorWrite errorWriteFileDescriptor: borrowing IODescriptor?, ) throws -> STARTUPINFOW { var info: STARTUPINFOW = STARTUPINFOW() info.cb = DWORD(MemoryLayout.size) @@ -955,8 +950,6 @@ extension Configuration { // MARK: - Type alias internal typealias PlatformFileDescriptor = HANDLE -internal typealias TrackedPlatformDiskIO = TrackedFileDescriptor - // MARK: - Pipe Support extension FileDescriptor { // NOTE: Not the same as SwiftSystem's FileDescriptor.pipe, which has different behavior, @@ -1004,171 +997,6 @@ extension FileDescriptor { } } -extension FileDescriptor { - internal func read(upToLength maxLength: Int) async throws -> [UInt8]? { - return try await withCheckedThrowingContinuation { continuation in - self.readUntilEOF( - upToLength: maxLength - ) { result in - switch result { - case .failure(let error): - continuation.resume(throwing: error) - case .success(let bytes): - continuation.resume(returning: bytes.isEmpty ? nil : bytes) - } - } - } - } - - internal func readUntilEOF( - upToLength maxLength: Int, - resultHandler: @Sendable @escaping (Swift.Result<[UInt8], any (Error & Sendable)>) -> Void - ) { - DispatchQueue.global(qos: .userInitiated).async { - var totalBytesRead: Int = 0 - var lastError: DWORD? = nil - let values = [UInt8]( - unsafeUninitializedCapacity: maxLength - ) { buffer, initializedCount in - while true { - guard let baseAddress = buffer.baseAddress else { - initializedCount = 0 - break - } - let bufferPtr = baseAddress.advanced(by: totalBytesRead) - var bytesRead: DWORD = 0 - let readSucceed = ReadFile( - self.platformDescriptor, - UnsafeMutableRawPointer(mutating: bufferPtr), - DWORD(maxLength - totalBytesRead), - &bytesRead, - nil - ) - if !readSucceed { - // Windows throws ERROR_BROKEN_PIPE when the pipe is closed - let error = GetLastError() - if error == ERROR_BROKEN_PIPE { - // We are done reading - initializedCount = totalBytesRead - } else { - // We got some error - lastError = error - initializedCount = 0 - } - break - } else { - // We successfully read the current round - totalBytesRead += Int(bytesRead) - } - - if totalBytesRead >= maxLength { - initializedCount = min(maxLength, totalBytesRead) - break - } - } - } - if let lastError = lastError { - let windowsError = SubprocessError( - code: .init(.failedToReadFromSubprocess), - underlyingError: .init(rawValue: lastError) - ) - resultHandler(.failure(windowsError)) - } else { - resultHandler(.success(values)) - } - } - } -} - -extension TrackedFileDescriptor { - internal consuming func createPlatformDiskIO() -> TrackedPlatformDiskIO { - // TrackedPlatformDiskIO is a typealias of TrackedFileDescriptor on Windows (they're the same type) - // Just return the same object so we don't create a copy and try to double-close the fd. - return self - } - - internal func readUntilEOF( - upToLength maxLength: Int, - resultHandler: @Sendable @escaping (Swift.Result<[UInt8], any (Error & Sendable)>) -> Void - ) { - self.fileDescriptor.readUntilEOF( - upToLength: maxLength, - resultHandler: resultHandler - ) - } - -#if SubprocessSpan - internal func write( - _ span: borrowing RawSpan - ) async throws -> Int { - let fileDescriptor = self.fileDescriptor - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - span.withUnsafeBytes { ptr in - // TODO: Use WriteFileEx for asyc here - Self.write( - ptr, - to: fileDescriptor - ) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) - } - } - } - } - } -#endif - - internal func write( - _ array: [UInt8] - ) async throws -> Int { - try await withCheckedThrowingContinuation { continuation in - // TODO: Figure out a better way to asynchronously write - let fd = self.fileDescriptor - DispatchQueue.global(qos: .userInitiated).async { - array.withUnsafeBytes { - Self.write( - $0, - to: fd - ) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) - } - } - } - } - } - } - - internal static func write( - _ ptr: UnsafeRawBufferPointer, - to fileDescriptor: FileDescriptor, - completion: @escaping (Int, Swift.Error?) -> Void - ) { - let handle = HANDLE(bitPattern: _get_osfhandle(fileDescriptor.rawValue))! - var writtenBytes: DWORD = 0 - let writeSucceed = WriteFile( - handle, - ptr.baseAddress, - DWORD(ptr.count), - &writtenBytes, - nil - ) - if !writeSucceed { - let error = SubprocessError( - code: .init(.failedToWriteToSubprocess), - underlyingError: .init(rawValue: GetLastError()) - ) - completion(Int(writtenBytes), error) - } else { - completion(Int(writtenBytes), nil) - } - } -} - extension Optional where Wrapped == String { fileprivate func withOptionalCString( encodedAs targetEncoding: Encoding.Type, diff --git a/Sources/Subprocess/SubprocessFoundation/Input+Foundation.swift b/Sources/Subprocess/SubprocessFoundation/Input+Foundation.swift index a1db8ad..c82d2d3 100644 --- a/Sources/Subprocess/SubprocessFoundation/Input+Foundation.swift +++ b/Sources/Subprocess/SubprocessFoundation/Input+Foundation.swift @@ -111,7 +111,7 @@ extension StandardInputWriter { public func write( _ data: Data ) async throws -> Int { - return try await self.diskIO.write(data) + return try await AsyncIO.shared.write(data, to: self.diskIO) } /// Write a AsyncSequence of Data to the standard input of the subprocess. @@ -128,35 +128,12 @@ extension StandardInputWriter { } } -#if os(Windows) -extension TrackedFileDescriptor { - internal func write( - _ data: Data - ) async throws -> Int { - let fileDescriptor = self.fileDescriptor - return try await withCheckedThrowingContinuation { continuation in - // TODO: Figure out a better way to asynchronously write - DispatchQueue.global(qos: .userInitiated).async { - data.withUnsafeBytes { - Self.write( - $0, - to: fileDescriptor - ) { writtenLength, error in - if let error = error { - continuation.resume(throwing: error) - } else { - continuation.resume(returning: writtenLength) - } - } - } - } - } - } -} -#else -extension TrackedDispatchIO { + +#if canImport(Darwin) +extension AsyncIO { internal func write( - _ data: Data + _ data: Data, + to diskIO: borrowing IOChannel ) async throws -> Int { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in let dispatchData = data.withUnsafeBytes { @@ -170,7 +147,7 @@ extension TrackedDispatchIO { ) ) } - self.write(dispatchData) { writtenLength, error in + self.write(dispatchData, to: diskIO) { writtenLength, error in if let error = error { continuation.resume(throwing: error) } else { @@ -180,6 +157,17 @@ extension TrackedDispatchIO { } } } -#endif // os(Windows) +#else +extension Data : AsyncIO._ContiguousBytes { } + +extension AsyncIO { + internal func write( + _ data: Data, + to diskIO: borrowing IOChannel + ) async throws -> Int { + return try await self._write(data, to: diskIO) + } +} +#endif // canImport(Darwin) #endif // SubprocessFoundation diff --git a/Sources/Subprocess/Teardown.swift b/Sources/Subprocess/Teardown.swift index a4c0352..d5dd551 100644 --- a/Sources/Subprocess/Teardown.swift +++ b/Sources/Subprocess/Teardown.swift @@ -20,7 +20,7 @@ import Glibc #elseif canImport(Musl) import Musl #elseif canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK #endif /// A step in the graceful shutdown teardown sequence. diff --git a/Sources/_SubprocessCShims/include/process_shims.h b/Sources/_SubprocessCShims/include/process_shims.h index 35cbd2f..0ae4a5a 100644 --- a/Sources/_SubprocessCShims/include/process_shims.h +++ b/Sources/_SubprocessCShims/include/process_shims.h @@ -21,6 +21,11 @@ #include #endif +#if TARGET_OS_LINUX +#include +#include +#endif // TARGET_OS_LINUX + #if __has_include() vm_size_t _subprocess_vm_size(void); #endif diff --git a/Tests/SubprocessTests/SubprocessTests+Unix.swift b/Tests/SubprocessTests/SubprocessTests+Unix.swift index a50c4e8..25a464e 100644 --- a/Tests/SubprocessTests/SubprocessTests+Unix.swift +++ b/Tests/SubprocessTests/SubprocessTests+Unix.swift @@ -668,14 +668,11 @@ extension SubprocessUnixTests { var platformOptions = PlatformOptions() platformOptions.supplementaryGroups = Array(expectedGroups) let idResult = try await Subprocess.run( - .name("swift"), + .path("/usr/bin/swift"), arguments: [getgroupsSwift.string], platformOptions: platformOptions, - output: .string, - error: .string, + output: .string ) - let error = try #require(idResult.standardError) - try #require(error == "") #expect(idResult.terminationStatus.isSuccess) let ids = try #require( idResult.standardOutput diff --git a/Tests/SubprocessTests/SubprocessTests+Windows.swift b/Tests/SubprocessTests/SubprocessTests+Windows.swift index 3c615e9..3b9bed7 100644 --- a/Tests/SubprocessTests/SubprocessTests+Windows.swift +++ b/Tests/SubprocessTests/SubprocessTests+Windows.swift @@ -11,7 +11,7 @@ #if canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK import FoundationEssentials import Testing import Dispatch @@ -27,7 +27,7 @@ import TestResources @Suite(.serialized) struct SubprocessWindowsTests { - private let cmdExe: Subprocess.Executable = .path("C:\\Windows\\System32\\cmd.exe") + private let cmdExe: Subprocess.Executable = .name("cmd.exe") } // MARK: - Executable Tests @@ -87,7 +87,7 @@ extension SubprocessWindowsTests { Issue.record("Expected to throw POSIXError") } catch { guard let subprocessError = error as? SubprocessError, - let underlying = subprocessError.underlyingError + let underlying = subprocessError.underlyingError else { Issue.record("Expected CocoaError, got \(error)") return @@ -128,7 +128,6 @@ extension SubprocessWindowsTests { environment: .inherit, output: .string ) - #expect(result.terminationStatus.isSuccess) // As a sanity check, make sure there's // `C:\Windows\system32` in PATH // since we inherited the environment variables @@ -249,7 +248,6 @@ extension SubprocessWindowsTests { output: .data(limit: 2048 * 1024) ) - #expect(catResult.terminationStatus.isSuccess) // Make sure we read all bytes #expect( catResult.standardOutput == expected @@ -271,7 +269,6 @@ extension SubprocessWindowsTests { output: .data(limit: 2048 * 1024), error: .discarded ) - #expect(catResult.terminationStatus.isSuccess) // Make sure we read all bytes #expect( catResult.standardOutput == expected @@ -304,7 +301,6 @@ extension SubprocessWindowsTests { input: .sequence(stream), output: .data(limit: 2048 * 1024) ) - #expect(catResult.terminationStatus.isSuccess) #expect( catResult.standardOutput == expected ) @@ -327,7 +323,6 @@ extension SubprocessWindowsTests { } return buffer } - #expect(result.terminationStatus.isSuccess) #expect(result.value == expected) } @@ -364,7 +359,6 @@ extension SubprocessWindowsTests { } return buffer } - #expect(result.terminationStatus.isSuccess) #expect(result.value == expected) } } @@ -510,7 +504,7 @@ extension SubprocessWindowsTests { @Test func testPlatformOptionsCreateNewConsole() async throws { let parentConsole = GetConsoleWindow() let sameConsoleResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-File", windowsTester.string, "-mode", "get-console-window", @@ -529,7 +523,7 @@ extension SubprocessWindowsTests { var platformOptions: Subprocess.PlatformOptions = .init() platformOptions.consoleBehavior = .createNew let differentConsoleResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-File", windowsTester.string, "-mode", "get-console-window", @@ -551,7 +545,7 @@ extension SubprocessWindowsTests { var platformOptions: Subprocess.PlatformOptions = .init() platformOptions.consoleBehavior = .detach let detachConsoleResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-File", windowsTester.string, "-mode", "get-console-window", @@ -575,7 +569,7 @@ extension SubprocessWindowsTests { } let parentConsole = GetConsoleWindow() let newConsoleResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-File", windowsTester.string, "-mode", "get-console-window", @@ -606,7 +600,7 @@ extension SubprocessWindowsTests { } } let changeTitleResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-Command", "$consoleTitle = [console]::Title; Write-Host $consoleTitle", ], @@ -654,7 +648,7 @@ extension SubprocessWindowsTests { // Now check the to make sure the process is actually suspended // Why not spawn another process to do that? var checkResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-File", windowsTester.string, "-mode", "is-process-suspended", @@ -671,7 +665,7 @@ extension SubprocessWindowsTests { // Now resume the process try subprocess.resume() checkResult = try await Subprocess.run( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-File", windowsTester.string, "-mode", "is-process-suspended", @@ -700,12 +694,13 @@ extension SubprocessWindowsTests { 0 ) let pid = try Subprocess.runDetached( - .path("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"), + .name("powershell.exe"), arguments: [ "-Command", "Write-Host $PID", ], output: writeFd ) + try writeFd.close() // Wait for process to finish guard let processHandle = OpenProcess( @@ -722,7 +717,6 @@ extension SubprocessWindowsTests { WaitForSingleObject(processHandle, INFINITE) // Up to 10 characters because Windows process IDs are DWORDs (UInt32), whose max value is 10 digits. - try writeFd.close() let data = try await readFd.readUntilEOF(upToLength: 10) let resultPID = try #require( String(data: data, encoding: .utf8) diff --git a/Tests/TestResources/TestResources.swift b/Tests/TestResources/TestResources.swift index 6a09808..7030862 100644 --- a/Tests/TestResources/TestResources.swift +++ b/Tests/TestResources/TestResources.swift @@ -10,7 +10,7 @@ //===----------------------------------------------------------------------===// #if canImport(WinSDK) -import WinSDK +@preconcurrency import WinSDK #endif // Confitionally require Foundation due to `Bundle.module`