Skip to content

Commit

Permalink
Add | Adding disposable stack temp ref struct and use (#1818)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wraith2 committed Nov 3, 2022
1 parent f684185 commit 760510c
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@
<Compile Include="..\..\src\Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs">
<Link>Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs">
<Link>Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\EnclaveDelegate.cs">
<Link>Microsoft\Data\SqlClient\EnclaveDelegate.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4409,6 +4409,7 @@ private void AssertReaderState(bool requireData, bool permitAsync, int? columnIn
public override Task<bool> NextResultAsync(CancellationToken cancellationToken)
{
using (TryEventScope.Create("SqlDataReader.NextResultAsync | API | Object Id {0}", ObjectID))
using (var registrationHolder = new DisposableTemporaryOnStack<CancellationTokenRegistration>())
{
TaskCompletionSource<bool> source = new TaskCompletionSource<bool>();

Expand All @@ -4418,15 +4419,14 @@ public override Task<bool> NextResultAsync(CancellationToken cancellationToken)
return source.Task;
}

CancellationTokenRegistration registration = default;
if (cancellationToken.CanBeCanceled)
{
if (cancellationToken.IsCancellationRequested)
{
source.SetCanceled();
return source.Task;
}
registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
}

Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
Expand All @@ -4444,7 +4444,7 @@ public override Task<bool> NextResultAsync(CancellationToken cancellationToken)
return source.Task;
}

return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registration));
return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registrationHolder.Take()));
}
}

Expand Down Expand Up @@ -4739,17 +4739,17 @@ out bytesRead
public override Task<bool> ReadAsync(CancellationToken cancellationToken)
{
using (TryEventScope.Create("SqlDataReader.ReadAsync | API | Object Id {0}", ObjectID))
using (var registrationHolder = new DisposableTemporaryOnStack<CancellationTokenRegistration>())
{
if (IsClosed)
{
return Task.FromException<bool>(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed()));
}

// Register first to catch any already expired tokens to be able to trigger cancellation event.
CancellationTokenRegistration registration = default;
if (cancellationToken.CanBeCanceled)
{
registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
}

// If user's token is canceled, return a canceled task
Expand Down Expand Up @@ -4862,7 +4862,7 @@ public override Task<bool> ReadAsync(CancellationToken cancellationToken)

Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ReadAsyncCallContext was not properly disposed");

context.Set(this, source, registration);
context.Set(this, source, registrationHolder.Take());
context._hasMoreData = more;
context._hasReadRowToken = rowTokenRead;

Expand Down Expand Up @@ -5000,49 +5000,51 @@ override public Task<bool> IsDBNullAsync(int i, CancellationToken cancellationTo
return Task.FromException<bool>(ex);
}

// Setup and check for pending task
TaskCompletionSource<bool> source = new TaskCompletionSource<bool>();
Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
if (original != null)
using (var registrationHolder = new DisposableTemporaryOnStack<CancellationTokenRegistration>())
{
source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
return source.Task;
}
// Setup and check for pending task
TaskCompletionSource<bool> source = new TaskCompletionSource<bool>();
Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
if (original != null)
{
source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
return source.Task;
}

// Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
if (_cancelAsyncOnCloseToken.IsCancellationRequested)
{
source.SetCanceled();
_currentTask = null;
return source.Task;
}
// Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
if (_cancelAsyncOnCloseToken.IsCancellationRequested)
{
source.SetCanceled();
_currentTask = null;
return source.Task;
}

// Setup cancellations
CancellationTokenRegistration registration = default;
if (cancellationToken.CanBeCanceled)
{
registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
}
// Setup cancellations
if (cancellationToken.CanBeCanceled)
{
registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
}

IsDBNullAsyncCallContext context = null;
if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection)
{
context = Interlocked.Exchange(ref sqlInternalConnection.CachedDataReaderIsDBNullContext, null);
}
if (context is null)
{
context = new IsDBNullAsyncCallContext();
}
IsDBNullAsyncCallContext context = null;
if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection)
{
context = Interlocked.Exchange(ref sqlInternalConnection.CachedDataReaderIsDBNullContext, null);
}
if (context is null)
{
context = new IsDBNullAsyncCallContext();
}

Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ISDBNullAsync context not properly disposed");
Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ISDBNullAsync context not properly disposed");

context.Set(this, source, registration);
context._columnIndex = i;
context.Set(this, source, registrationHolder.Take());
context._columnIndex = i;

// Setup async
PrepareAsyncInvocation(useSnapshot: true);
// Setup async
PrepareAsyncInvocation(useSnapshot: true);

return InvokeAsyncCall(context);
return InvokeAsyncCall(context);
}
}
}

Expand Down Expand Up @@ -5147,37 +5149,39 @@ override public Task<T> GetFieldValueAsync<T>(int i, CancellationToken cancellat
return Task.FromException<T>(ex);
}

// Setup and check for pending task
TaskCompletionSource<T> source = new TaskCompletionSource<T>();
Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
if (original != null)
using (var registrationHolder = new DisposableTemporaryOnStack<CancellationTokenRegistration>())
{
source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
return source.Task;
}
// Setup and check for pending task
TaskCompletionSource<T> source = new TaskCompletionSource<T>();
Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
if (original != null)
{
source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
return source.Task;
}

// Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
if (_cancelAsyncOnCloseToken.IsCancellationRequested)
{
source.SetCanceled();
_currentTask = null;
return source.Task;
}
// Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
if (_cancelAsyncOnCloseToken.IsCancellationRequested)
{
source.SetCanceled();
_currentTask = null;
return source.Task;
}

// Setup cancellations
CancellationTokenRegistration registration = default;
if (cancellationToken.CanBeCanceled)
{
registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
}
// Setup cancellations
if (cancellationToken.CanBeCanceled)
{
registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command));
}

// Setup async
PrepareAsyncInvocation(useSnapshot: true);
// Setup async
PrepareAsyncInvocation(useSnapshot: true);

GetFieldValueAsyncCallContext<T> context = new GetFieldValueAsyncCallContext<T>(this, source, registration);
context._columnIndex = i;
GetFieldValueAsyncCallContext<T> context = new GetFieldValueAsyncCallContext<T>(this, source, registrationHolder.Take());
context._columnIndex = i;

return InvokeAsyncCall(context);
return InvokeAsyncCall(context);
}
}

private static Task<T> GetFieldValueAsyncExecute<T>(Task task, object state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@
<Compile Include="..\..\src\Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs">
<Link>Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs">
<Link>Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\EnclaveDelegate.cs">
<Link>Microsoft\Data\SqlClient\EnclaveDelegate.cs</Link>
</Compile>
Expand Down
Loading

0 comments on commit 760510c

Please sign in to comment.