Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce scatter/gather file I/O Windows API requirements et. al. #57424

Merged
merged 13 commits into from
Aug 31, 2021
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO.Strategies;
using System.Numerics;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -421,82 +423,185 @@ private static ValueTask<long> ReadScatterAtOffsetAsync(SafeFileHandle handle, I
return ScheduleSyncReadScatterAtOffsetAsync(handle, buffers, fileOffset, cancellationToken);
}

if (CanUseScatterGatherWindowsAPIs(handle))
switch (buffers.Count)
{
long totalBytes = 0;
int buffersCount = buffers.Count;
for (int i = 0; i < buffersCount; i++)
case 0:
return CastValueTask(ReadAtOffsetAsync(handle, Memory<byte>.Empty, fileOffset, cancellationToken));
case 1:
return CastValueTask(ReadAtOffsetAsync(handle, buffers[0], fileOffset, cancellationToken));
}

if (CanUseScatterGatherWindowsAPIs(handle)
&& TryPrepareScatterGatherBuffers(buffers, default(MemoryHandler), out MemoryHandle[]? pinnedBuffers, out int totalBytes))
{
try
{
totalBytes += buffers[i].Length;
return ReadScatterAtOffsetSingleSyscallAsync(handle, pinnedBuffers, fileOffset, totalBytes,
cancellationToken);
}

if (totalBytes <= int.MaxValue) // the ReadFileScatter API uses int, not long
catch
{
return ReadScatterAtOffsetSingleSyscallAsync(handle, buffers, fileOffset, (int)totalBytes, cancellationToken);
foreach (MemoryHandle memoryHandle in pinnedBuffers)
{
memoryHandle.Dispose();
}

throw;
}
}

return ReadScatterAtOffsetMultipleSyscallsAsync(handle, buffers, fileOffset, cancellationToken);

static async ValueTask<long> CastValueTask(ValueTask<int> task) =>
// we have to await it because we can't cast a VT<int> to VT<long>
await task.ConfigureAwait(false);
}

// Abstracts away the type signature incompatibility between Memory and ReadOnlyMemory.
private interface IMemoryHandler<T>
teo-tsirpanis marked this conversation as resolved.
Show resolved Hide resolved
{
int GetLength(in T memory);
MemoryHandle Pin(in T memory);
}

private struct MemoryHandler : IMemoryHandler<Memory<byte>>
{
public int GetLength(in Memory<byte> memory) => memory.Length;
public MemoryHandle Pin(in Memory<byte> memory) => memory.Pin();
}

private struct ReadOnlyMemoryHandler : IMemoryHandler<ReadOnlyMemory<byte>>
{
public int GetLength(in ReadOnlyMemory<byte> memory) => memory.Length;
public MemoryHandle Pin(in ReadOnlyMemory<byte> memory) => memory.Pin();
}

// From https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-readfilescatter:
// "The file handle must be created with the GENERIC_READ right, and the FILE_FLAG_OVERLAPPED and FILE_FLAG_NO_BUFFERING flags."
// "The file handle must be created with [...] the FILE_FLAG_OVERLAPPED and FILE_FLAG_NO_BUFFERING flags."
private static bool CanUseScatterGatherWindowsAPIs(SafeFileHandle handle)
=> handle.IsAsync && ((handle.GetFileOptions() & SafeFileHandle.NoBuffering) != 0);

private static async ValueTask<long> ReadScatterAtOffsetSingleSyscallAsync(SafeFileHandle handle, IReadOnlyList<Memory<byte>> buffers, long fileOffset, int totalBytes, CancellationToken cancellationToken)
// From the same source:
// "Each buffer must be at least the size of a system memory page and must be aligned on a system
// memory page size boundary. The system reads/writes one system memory page of data into/from each buffer."
// This method returns true if the buffers can be used by
// the Windows scatter/gather API, which happens when they are:
// 1. aligned at page size boundaries
// 2. exactly one page long each (our own requirement to prevent partial reads)
// 3. not bigger than 2^32 - 1 in total
// This function is also responsible for pinning the buffers if they
// are suitable and they must be unpinned after the I/O operation completes.
// The total size of the buffers is also returned.
private static bool TryPrepareScatterGatherBuffers<T, THandler>(IReadOnlyList<T> buffers,
THandler handler, [NotNullWhen(true)] out MemoryHandle[]? pinnedBuffers, out int totalBytes)
where THandler: struct, IMemoryHandler<T>
{
int pageSize = Environment.SystemPageSize;
teo-tsirpanis marked this conversation as resolved.
Show resolved Hide resolved
Debug.Assert(BitOperations.IsPow2(pageSize), "Page size is not a power of two.");
// We take advantage of the fact that the page size is
// a power of two to avoid an expensive modulo operation.
teo-tsirpanis marked this conversation as resolved.
Show resolved Hide resolved
long alignedAtPageSizeMask = pageSize - 1;
int buffersCount = buffers.Count;
if (buffersCount == 1)
pinnedBuffers = new MemoryHandle[buffersCount];

try
{
// we have to await it because we can't cast a VT<int> to VT<long>
return await ReadAtOffsetAsync(handle, buffers[0], fileOffset, cancellationToken).ConfigureAwait(false);
long totalBytes64 = 0;
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
for (int i = 0; i < buffersCount; i++)
{
T buffer = buffers[i];
int length = handler.GetLength(in buffer);
totalBytes64 += length;
if (length != pageSize || totalBytes64 > int.MaxValue)
{
goto Failure;
}

MemoryHandle handle = pinnedBuffers[i] = handler.Pin(in buffer);
unsafe
{
if (((long)handle.Pointer & alignedAtPageSizeMask) != 0)
{
goto Failure;
}
}
}

totalBytes = (int)totalBytes64;
return true;

Failure:
foreach (MemoryHandle handle in pinnedBuffers)
teo-tsirpanis marked this conversation as resolved.
Show resolved Hide resolved
{
handle.Dispose();
}

pinnedBuffers = null;
totalBytes = 0;
return false;
}
catch
{
if (pinnedBuffers != null)
{
foreach (MemoryHandle handle in pinnedBuffers)
{
handle.Dispose();
}
}

throw;
}
}

private static async ValueTask<long> ReadScatterAtOffsetSingleSyscallAsync(SafeFileHandle handle, MemoryHandle[] pinnedBuffers, long fileOffset, int totalBytes, CancellationToken cancellationToken)
{
int buffersCount = pinnedBuffers.Length;
// "The array must contain enough elements to store nNumberOfBytesToWrite bytes of data, and one element for the terminating NULL. "
long[] fileSegments = new long[buffersCount + 1];
fileSegments[buffersCount] = 0;

MemoryHandle[] memoryHandles = new MemoryHandle[buffersCount];
MemoryHandle pinnedSegments = fileSegments.AsMemory().Pin();
GCHandle pinnedSegments = default;

try
{
pinnedSegments = GCHandle.Alloc(fileSegments, GCHandleType.Pinned);

for (int i = 0; i < buffersCount; i++)
{
Memory<byte> buffer = buffers[i];
MemoryHandle memoryHandle = buffer.Pin();
memoryHandles[i] = memoryHandle;

unsafe // awaits can't be in an unsafe context
{
fileSegments[i] = new IntPtr(memoryHandle.Pointer).ToInt64();
fileSegments[i] = new IntPtr(pinnedBuffers[i].Pointer).ToInt64();
}
}

return await ReadFileScatterAsync(handle, pinnedSegments, totalBytes, fileOffset, cancellationToken).ConfigureAwait(false);
return await ReadFileScatterAsync(handle, pinnedSegments.AddrOfPinnedObject(), totalBytes, fileOffset, cancellationToken).ConfigureAwait(false);
}
finally
{
foreach (MemoryHandle memoryHandle in memoryHandles)
foreach (MemoryHandle memoryHandle in pinnedBuffers)
{
memoryHandle.Dispose();
}
pinnedSegments.Dispose();

if (pinnedSegments.IsAllocated)
{
pinnedSegments.Free();
}
}
}

private static unsafe ValueTask<int> ReadFileScatterAsync(SafeFileHandle handle, MemoryHandle pinnedSegments, int bytesToRead, long fileOffset, CancellationToken cancellationToken)
private static unsafe ValueTask<int> ReadFileScatterAsync(SafeFileHandle handle, IntPtr segmentsPtr, int bytesToRead, long fileOffset, CancellationToken cancellationToken)
{
handle.EnsureThreadPoolBindingInitialized();

SafeFileHandle.OverlappedValueTaskSource vts = handle.GetOverlappedValueTaskSource();
try
{
NativeOverlapped* nativeOverlapped = vts.PrepareForOperation(Memory<byte>.Empty, fileOffset);
Debug.Assert(pinnedSegments.Pointer != null);
Debug.Assert(segmentsPtr != IntPtr.Zero);

if (Interop.Kernel32.ReadFileScatter(handle, (long*)pinnedSegments.Pointer, bytesToRead, IntPtr.Zero, nativeOverlapped) == 0)
if (Interop.Kernel32.ReadFileScatter(handle, (long*)segmentsPtr, bytesToRead, IntPtr.Zero, nativeOverlapped) == 0)
{
// The operation failed, or it's pending.
int errorCode = FileStreamHelpers.GetLastWin32ErrorAndDisposeHandleIfInvalid(handle);
Expand Down Expand Up @@ -562,17 +667,30 @@ private static ValueTask WriteGatherAtOffsetAsync(SafeFileHandle handle, IReadOn
return ScheduleSyncWriteGatherAtOffsetAsync(handle, buffers, fileOffset, cancellationToken);
}

if (CanUseScatterGatherWindowsAPIs(handle))
switch (buffers.Count)
{
long totalBytes = 0;
for (int i = 0; i < buffers.Count; i++)
case 0:
return WriteAtOffsetAsync(handle, ReadOnlyMemory<byte>.Empty, fileOffset, cancellationToken);
case 1:
return WriteAtOffsetAsync(handle, buffers[0], fileOffset, cancellationToken);
}

if (CanUseScatterGatherWindowsAPIs(handle)
&& TryPrepareScatterGatherBuffers(buffers, default(ReadOnlyMemoryHandler), out MemoryHandle[]? pinnedBuffers, out int totalBytes))
{
try
{
totalBytes += buffers[i].Length;
return WriteGatherAtOffsetSingleSyscallAsync(handle, pinnedBuffers, fileOffset, totalBytes,
cancellationToken);
}

if (totalBytes <= int.MaxValue) // the ReadFileScatter API uses int, not long
catch
{
return WriteGatherAtOffsetSingleSyscallAsync(handle, buffers, fileOffset, (int)totalBytes, cancellationToken);
foreach (MemoryHandle memoryHandle in pinnedBuffers)
{
memoryHandle.Dispose();
}

throw;
}
}

Expand All @@ -591,64 +709,55 @@ private static async ValueTask WriteGatherAtOffsetMultipleSyscallsAsync(SafeFile
}
}

private static ValueTask WriteGatherAtOffsetSingleSyscallAsync(SafeFileHandle handle, IReadOnlyList<ReadOnlyMemory<byte>> buffers, long fileOffset, int totalBytes, CancellationToken cancellationToken)
private static async ValueTask WriteGatherAtOffsetSingleSyscallAsync(SafeFileHandle handle, MemoryHandle[] pinnedBuffers, long fileOffset, int totalBytes, CancellationToken cancellationToken)
{
if (buffers.Count == 1)
{
return WriteAtOffsetAsync(handle, buffers[0], fileOffset, cancellationToken);
}
// "The array must contain enough elements to store nNumberOfBytesToWrite bytes of data, and one element for the terminating NULL. "
int buffersCount = pinnedBuffers.Length;
long[] fileSegments = new long[buffersCount + 1];
fileSegments[buffersCount] = 0;

return Core(handle, buffers, fileOffset, totalBytes, cancellationToken);
GCHandle pinnedSegments = default;
teo-tsirpanis marked this conversation as resolved.
Show resolved Hide resolved

static async ValueTask Core(SafeFileHandle handle, IReadOnlyList<ReadOnlyMemory<byte>> buffers, long fileOffset, int totalBytes, CancellationToken cancellationToken)
try
{
// "The array must contain enough elements to store nNumberOfBytesToWrite bytes of data, and one element for the terminating NULL. "
int buffersCount = buffers.Count;
long[] fileSegments = new long[buffersCount + 1];
fileSegments[buffersCount] = 0;

MemoryHandle[] memoryHandles = new MemoryHandle[buffersCount];
MemoryHandle pinnedSegments = fileSegments.AsMemory().Pin();
pinnedSegments = GCHandle.Alloc(fileSegments, GCHandleType.Pinned);

try
for (int i = 0; i < buffersCount; i++)
{
for (int i = 0; i < buffersCount; i++)
unsafe // awaits can't be in an unsafe context
{
ReadOnlyMemory<byte> buffer = buffers[i];
MemoryHandle memoryHandle = buffer.Pin();
memoryHandles[i] = memoryHandle;

unsafe // awaits can't be in an unsafe context
{
fileSegments[i] = new IntPtr(memoryHandle.Pointer).ToInt64();
}
fileSegments[i] = new IntPtr(pinnedBuffers[i].Pointer).ToInt64();
}
}

await WriteFileGatherAsync(handle, pinnedSegments, totalBytes, fileOffset, cancellationToken).ConfigureAwait(false);
await WriteFileGatherAsync(handle, pinnedSegments.AddrOfPinnedObject(), totalBytes, fileOffset, cancellationToken).ConfigureAwait(false);
}
finally
{
foreach (MemoryHandle memoryHandle in pinnedBuffers)
{
memoryHandle.Dispose();
}
finally

if (pinnedSegments.IsAllocated)
{
foreach (MemoryHandle memoryHandle in memoryHandles)
{
memoryHandle.Dispose();
}
pinnedSegments.Dispose();
pinnedSegments.Free();
}
}
}

private static unsafe ValueTask WriteFileGatherAsync(SafeFileHandle handle, MemoryHandle pinnedSegments, int bytesToWrite, long fileOffset, CancellationToken cancellationToken)
private static unsafe ValueTask WriteFileGatherAsync(SafeFileHandle handle, IntPtr segmentsPtr, int bytesToWrite, long fileOffset, CancellationToken cancellationToken)
{
handle.EnsureThreadPoolBindingInitialized();

SafeFileHandle.OverlappedValueTaskSource vts = handle.GetOverlappedValueTaskSource();
try
{
NativeOverlapped* nativeOverlapped = vts.PrepareForOperation(ReadOnlyMemory<byte>.Empty, fileOffset);
Debug.Assert(pinnedSegments.Pointer != null);
Debug.Assert(segmentsPtr != IntPtr.Zero);

// Queue an async WriteFile operation.
if (Interop.Kernel32.WriteFileGather(handle, (long*)pinnedSegments.Pointer, bytesToWrite, IntPtr.Zero, nativeOverlapped) == 0)
if (Interop.Kernel32.WriteFileGather(handle, (long*)segmentsPtr, bytesToWrite, IntPtr.Zero, nativeOverlapped) == 0)
{
// The operation failed, or it's pending.
int errorCode = FileStreamHelpers.GetLastWin32ErrorAndDisposeHandleIfInvalid(handle);
Expand Down