Skip to content

Introduce IOCP based AsyncIO for Windows #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions Sources/Subprocess/API.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ public func run<
output: try output.createPipe(),
error: try error.createPipe()
) { execution, inputIO, outputIO, errorIO in
var inputIOBox: TrackedPlatformDiskIO? = consume inputIO
var outputIOBox: TrackedPlatformDiskIO? = consume outputIO
var errorIOBox: TrackedPlatformDiskIO? = consume errorIO
var inputIOBox: IOChannel? = consume inputIO
var outputIOBox: IOChannel? = consume outputIO
var errorIOBox: IOChannel? = consume errorIO

// Write input, capture output and error in parallel
async let stdout = try output.captureOutput(from: outputIOBox.take())
Expand Down Expand Up @@ -177,12 +177,12 @@ public func run<Result, Input: InputProtocol, Output: OutputProtocol, Error: Out
output: try output.createPipe(),
error: try error.createPipe()
) { execution, inputIO, outputIO, errorIO in
var inputIOBox: TrackedPlatformDiskIO? = consume inputIO
var inputIOBox: IOChannel? = consume inputIO
return try await withThrowingTaskGroup(
of: Void.self,
returning: Result.self
) { group in
var inputIOContainer: TrackedPlatformDiskIO? = inputIOBox.take()
var inputIOContainer: IOChannel? = inputIOBox.take()
group.addTask {
if let inputIO = inputIOContainer.take() {
let writer = StandardInputWriter(diskIO: inputIO)
Expand Down Expand Up @@ -237,13 +237,13 @@ public func run<Result, Input: InputProtocol, Error: OutputProtocol>(
output: try output.createPipe(),
error: try error.createPipe()
) { execution, inputIO, outputIO, errorIO in
var inputIOBox: TrackedPlatformDiskIO? = consume inputIO
var outputIOBox: TrackedPlatformDiskIO? = consume outputIO
var inputIOBox: IOChannel? = consume inputIO
var outputIOBox: IOChannel? = consume outputIO
return try await withThrowingTaskGroup(
of: Void.self,
returning: Result.self
) { group in
var inputIOContainer: TrackedPlatformDiskIO? = inputIOBox.take()
var inputIOContainer: IOChannel? = inputIOBox.take()
group.addTask {
if let inputIO = inputIOContainer.take() {
let writer = StandardInputWriter(diskIO: inputIO)
Expand All @@ -253,7 +253,7 @@ public func run<Result, Input: InputProtocol, Error: OutputProtocol>(
}

// Body runs in the same isolation
let outputSequence = AsyncBufferSequence(diskIO: outputIOBox.take()!.consumeDiskIO())
let outputSequence = AsyncBufferSequence(diskIO: outputIOBox.take()!.consumeIOChannel())
let result = try await body(execution, outputSequence)
try await group.waitForAll()
return result
Expand Down Expand Up @@ -299,13 +299,13 @@ public func run<Result, Input: InputProtocol, Output: OutputProtocol>(
output: try output.createPipe(),
error: try error.createPipe()
) { execution, inputIO, outputIO, errorIO in
var inputIOBox: TrackedPlatformDiskIO? = consume inputIO
var errorIOBox: TrackedPlatformDiskIO? = consume errorIO
var inputIOBox: IOChannel? = consume inputIO
var errorIOBox: IOChannel? = consume errorIO
return try await withThrowingTaskGroup(
of: Void.self,
returning: Result.self
) { group in
var inputIOContainer: TrackedPlatformDiskIO? = inputIOBox.take()
var inputIOContainer: IOChannel? = inputIOBox.take()
group.addTask {
if let inputIO = inputIOContainer.take() {
let writer = StandardInputWriter(diskIO: inputIO)
Expand All @@ -315,7 +315,7 @@ public func run<Result, Input: InputProtocol, Output: OutputProtocol>(
}

// Body runs in the same isolation
let errorSequence = AsyncBufferSequence(diskIO: errorIOBox.take()!.consumeDiskIO())
let errorSequence = AsyncBufferSequence(diskIO: errorIOBox.take()!.consumeIOChannel())
let result = try await body(execution, errorSequence)
try await group.waitForAll()
return result
Expand Down Expand Up @@ -363,7 +363,7 @@ public func run<Result, Error: OutputProtocol>(
error: try error.createPipe()
) { execution, inputIO, outputIO, errorIO in
let writer = StandardInputWriter(diskIO: inputIO!)
let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeDiskIO())
let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeIOChannel())
return try await body(execution, writer, outputSequence)
}
}
Expand Down Expand Up @@ -408,7 +408,7 @@ public func run<Result, Output: OutputProtocol>(
error: try error.createPipe()
) { execution, inputIO, outputIO, errorIO in
let writer = StandardInputWriter(diskIO: inputIO!)
let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeDiskIO())
let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeIOChannel())
return try await body(execution, writer, errorSequence)
}
}
Expand Down Expand Up @@ -460,8 +460,8 @@ public func run<Result>(
error: try error.createPipe()
) { execution, inputIO, outputIO, errorIO in
let writer = StandardInputWriter(diskIO: inputIO!)
let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeDiskIO())
let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeDiskIO())
let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeIOChannel())
let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeIOChannel())
return try await body(execution, writer, outputSequence, errorSequence)
}
}
Expand Down Expand Up @@ -497,16 +497,16 @@ public func run<
error: try error.createPipe()
) { (execution, inputIO, outputIO, errorIO) -> RunResult in
// Write input, capture output and error in parallel
var inputIOBox: TrackedPlatformDiskIO? = consume inputIO
var outputIOBox: TrackedPlatformDiskIO? = consume outputIO
var errorIOBox: TrackedPlatformDiskIO? = consume errorIO
var inputIOBox: IOChannel? = consume inputIO
var outputIOBox: IOChannel? = consume outputIO
var errorIOBox: IOChannel? = consume errorIO
return try await withThrowingTaskGroup(
of: OutputCapturingState<Output.OutputType, Error.OutputType>?.self,
returning: RunResult.self
) { group in
var inputIOContainer: TrackedPlatformDiskIO? = inputIOBox.take()
var outputIOContainer: TrackedPlatformDiskIO? = outputIOBox.take()
var errorIOContainer: TrackedPlatformDiskIO? = errorIOBox.take()
var inputIOContainer: IOChannel? = inputIOBox.take()
var outputIOContainer: IOChannel? = outputIOBox.take()
var errorIOContainer: IOChannel? = errorIOBox.take()
group.addTask {
if let writeFd = inputIOContainer.take() {
let writer = StandardInputWriter(diskIO: writeFd)
Expand Down Expand Up @@ -580,8 +580,8 @@ public func run<Result>(
error: try error.createPipe()
) { execution, inputIO, outputIO, errorIO in
let writer = StandardInputWriter(diskIO: inputIO!)
let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeDiskIO())
let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeDiskIO())
let outputSequence = AsyncBufferSequence(diskIO: outputIO!.consumeIOChannel())
let errorSequence = AsyncBufferSequence(diskIO: errorIO!.consumeIOChannel())
return try await body(execution, writer, outputSequence, errorSequence)
}
}
Expand Down
45 changes: 23 additions & 22 deletions Sources/Subprocess/AsyncBufferSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
internal import Dispatch
#endif

public struct AsyncBufferSequence: AsyncSequence, Sendable {
public struct AsyncBufferSequence: AsyncSequence, @unchecked Sendable {
public typealias Failure = any Swift.Error
public typealias Element = Buffer

#if os(Windows)
internal typealias DiskIO = FileDescriptor
#else
#if canImport(Darwin)
internal typealias DiskIO = DispatchIO
#elseif canImport(WinSDK)
internal typealias DiskIO = HANDLE
#else
internal typealias DiskIO = FileDescriptor
#endif

@_nonSendable
Expand All @@ -47,15 +49,18 @@ public struct AsyncBufferSequence: AsyncSequence, Sendable {
return self.buffer.removeFirst()
}
// Read more data
let data = try await self.diskIO.read(
upToLength: readBufferSize
let data = try await AsyncIO.shared.read(
from: self.diskIO,
upTo: readBufferSize
)
guard let data else {
// We finished reading. Close the file descriptor now
#if os(Windows)
try self.diskIO.close()
#if canImport(Darwin)
try _safelyClose(.dispatchIO(self.diskIO))
#elseif canImport(WinSDK)
try _safelyClose(.handle(self.diskIO))
#else
self.diskIO.close()
try _safelyClose(.fileDescriptor(self.diskIO))
#endif
return nil
}
Expand Down Expand Up @@ -132,17 +137,7 @@ extension AsyncBufferSequence {
self.eofReached = true
return nil
}
#if os(Windows)
// Cast data to CodeUnit type
let result = buffer.withUnsafeBytes { ptr in
return Array(
UnsafeBufferPointer<Encoding.CodeUnit>(
start: ptr.bindMemory(to: Encoding.CodeUnit.self).baseAddress!,
count: ptr.count / MemoryLayout<Encoding.CodeUnit>.size
)
)
}
#else
#if canImport(Darwin)
// Unfortunately here we _have to_ copy the bytes out because
// DispatchIO (rightfully) reuses buffer, which means `buffer.data`
// has the same address on all iterations, therefore we can't directly
Expand All @@ -157,7 +152,13 @@ extension AsyncBufferSequence {
UnsafeBufferPointer(start: ptr.baseAddress?.assumingMemoryBound(to: Encoding.CodeUnit.self), count: elementCount)
)
}

#else
// Cast data to CodeUnitg type
let result = buffer.withUnsafeBytes { ptr in
return ptr.withMemoryRebound(to: Encoding.CodeUnit.self) { codeUnitPtr in
return Array(codeUnitPtr)
}
}
#endif
return result.isEmpty ? nil : result
}
Expand Down Expand Up @@ -340,7 +341,7 @@ private let _pageSize: Int = {
Int(_subprocess_vm_size())
}()
#elseif canImport(WinSDK)
import WinSDK
@preconcurrency import WinSDK
private let _pageSize: Int = {
var sysInfo: SYSTEM_INFO = SYSTEM_INFO()
GetSystemInfo(&sysInfo)
Expand Down
47 changes: 26 additions & 21 deletions Sources/Subprocess/Buffer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,8 @@
extension AsyncBufferSequence {
/// A immutable collection of bytes
public struct Buffer: Sendable {
#if os(Windows)
internal let data: [UInt8]

internal init(data: [UInt8]) {
self.data = data
}

internal static func createFrom(_ data: [UInt8]) -> [Buffer] {
return [.init(data: data)]
}
#else
// We need to keep the backingData alive while _ContiguousBufferView is alive
#if canImport(Darwin)
// We need to keep the backingData alive while Slice is alive
internal let backingData: DispatchData
internal let data: DispatchData._ContiguousBufferView

Expand All @@ -45,7 +35,17 @@ extension AsyncBufferSequence {
}
return slices.map{ .init(data: $0, backingData: data) }
}
#endif
#else
internal let data: [UInt8]

internal init(data: [UInt8]) {
self.data = data
}

internal static func createFrom(_ data: [UInt8]) -> [Buffer] {
return [.init(data: data)]
}
#endif // canImport(Darwin)
}
}

Expand Down Expand Up @@ -92,26 +92,23 @@ extension AsyncBufferSequence.Buffer {

// MARK: - Hashable, Equatable
extension AsyncBufferSequence.Buffer: Equatable, Hashable {
#if os(Windows)
// Compiler generated conformances
#else
#if canImport(Darwin)
public static func == (lhs: AsyncBufferSequence.Buffer, rhs: AsyncBufferSequence.Buffer) -> Bool {
return lhs.data.elementsEqual(rhs.data)
return lhs.data == rhs.data
}

public func hash(into hasher: inout Hasher) {
self.data.withUnsafeBytes { ptr in
hasher.combine(bytes: ptr)
}
hasher.combine(self.data)
}
#endif
// else Compiler generated conformances
}

// MARK: - DispatchData.Block
#if canImport(Darwin) || canImport(Glibc) || canImport(Android) || canImport(Musl)
extension DispatchData {
/// Unfortunately `DispatchData.Region` is not available on Linux, hence our own wrapper
internal struct _ContiguousBufferView: @unchecked Sendable, RandomAccessCollection {
internal struct _ContiguousBufferView: @unchecked Sendable, RandomAccessCollection, Hashable {
typealias Element = UInt8

internal let bytes: UnsafeBufferPointer<UInt8>
Expand All @@ -127,6 +124,14 @@ extension DispatchData {
return try body(UnsafeRawBufferPointer(self.bytes))
}

internal func hash(into hasher: inout Hasher) {
hasher.combine(bytes: UnsafeRawBufferPointer(self.bytes))
}

internal static func == (lhs: DispatchData._ContiguousBufferView, rhs: DispatchData._ContiguousBufferView) -> Bool {
return lhs.bytes.elementsEqual(rhs.bytes)
}

subscript(position: Int) -> UInt8 {
_read {
yield self.bytes[position]
Expand Down
1 change: 1 addition & 0 deletions Sources/Subprocess/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ target_sources(Subprocess PRIVATE
Result.swift
IO/Output.swift
IO/Input.swift
IO/AsyncIO.swift
Span+Subprocess.swift
AsyncBufferSequence.swift
API.swift
Expand Down
Loading