Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WriteAsync cancellation throws an error with the calls completed status if possible #2170

Merged
merged 2 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -53,6 +53,7 @@ public DefaultDeserializationContext DeserializationContext

public string? RequestGrpcEncoding { get; internal set; }

public abstract Task<Status> CallTask { get; }
public abstract CancellationToken CancellationToken { get; }
public abstract Type RequestType { get; }
public abstract Type ResponseType { get; }
Expand All @@ -64,6 +65,29 @@ protected GrpcCall(CallOptions options, GrpcChannel channel)
Logger = channel.LoggerFactory.CreateLogger(LoggerName);
}

public Exception CreateCanceledStatusException(Exception? ex = null)
{
var status = (CallTask.IsCompletedSuccessfully()) ? CallTask.Result : new Status(StatusCode.Cancelled, string.Empty, ex);
return CreateRpcException(status);
}

public CancellationToken GetCanceledToken(CancellationToken methodCancellationToken)
{
if (methodCancellationToken.IsCancellationRequested)
{
return methodCancellationToken;
}
else if (Options.CancellationToken.IsCancellationRequested)
{
return Options.CancellationToken;
}
else if (CancellationToken.IsCancellationRequested)
{
return CancellationToken;
}
return CancellationToken.None;
}

internal RpcException CreateRpcException(Status status)
{
// This code can be called from a background task.
Expand All @@ -84,6 +108,23 @@ internal RpcException CreateRpcException(Status status)
return new RpcException(status, trailers ?? Metadata.Empty);
}

public Exception CreateFailureStatusException(Status status)
{
if (Channel.ThrowOperationCanceledOnCancellation &&
(status.StatusCode == StatusCode.DeadlineExceeded || status.StatusCode == StatusCode.Cancelled))
{
// Convert status response of DeadlineExceeded to OperationCanceledException when
// ThrowOperationCanceledOnCancellation is true.
// This avoids a race between the client-side timer and the server status throwing different
// errors on deadline exceeded.
return new OperationCanceledException();
}
else
{
return CreateRpcException(status);
}
}

protected bool TryGetTrailers([NotNullWhen(true)] out Metadata? trailers)
{
if (Trailers == null)
Expand Down
42 changes: 1 addition & 41 deletions src/Grpc.Net.Client/Internal/GrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private void ValidateDeadline(DateTime? deadline)
}
}

public Task<Status> CallTask => _callTcs.Task;
public override Task<Status> CallTask => _callTcs.Task;

public override CancellationToken CancellationToken => _callCts.Token;

Expand Down Expand Up @@ -248,12 +248,6 @@ public void EnsureNotDisposed()
}
}

public Exception CreateCanceledStatusException(Exception? ex = null)
{
var status = (CallTask.IsCompletedSuccessfully()) ? CallTask.Result : new Status(StatusCode.Cancelled, string.Empty, ex);
return CreateRpcException(status);
}

private void FinishResponseAndCleanUp(Status status)
{
ResponseFinished = true;
Expand Down Expand Up @@ -760,23 +754,6 @@ public Exception EnsureUserCancellationTokenReported(Exception ex, CancellationT
return ex;
}

public CancellationToken GetCanceledToken(CancellationToken methodCancellationToken)
{
if (methodCancellationToken.IsCancellationRequested)
{
return methodCancellationToken;
}
else if (Options.CancellationToken.IsCancellationRequested)
{
return Options.CancellationToken;
}
else if (CancellationToken.IsCancellationRequested)
{
return CancellationToken;
}
return CancellationToken.None;
}

private void SetFailedResult(Status status)
{
CompatibilityHelpers.Assert(_responseTcs != null);
Expand All @@ -795,23 +772,6 @@ private void SetFailedResult(Status status)
}
}

public Exception CreateFailureStatusException(Status status)
{
if (Channel.ThrowOperationCanceledOnCancellation &&
(status.StatusCode == StatusCode.DeadlineExceeded || status.StatusCode == StatusCode.Cancelled))
{
// Convert status response of DeadlineExceeded to OperationCanceledException when
// ThrowOperationCanceledOnCancellation is true.
// This avoids a race between the client-side timer and the server status throwing different
// errors on deadline exceeded.
return new OperationCanceledException();
}
else
{
return CreateRpcException(status);
}
}

private (bool diagnosticSourceEnabled, Activity? activity) InitializeCall(HttpRequestMessage request, TimeSpan? timeout)
{
GrpcCallLog.StartingCall(Logger, Method.Type, request.RequestUri!);
Expand Down
10 changes: 8 additions & 2 deletions src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -72,4 +72,10 @@ protected override bool TryComputeLength(out long length)
// Hacky. ReadAsStreamAsync does not complete until SerializeToStreamAsync finishes.
// WARNING: Will run SerializeToStreamAsync again on .NET Framework.
internal Task PushComplete => ReadAsStreamAsync();
}

// Internal for testing.
internal Task SerializeToStreamAsync(Stream stream)
{
return SerializeToStreamAsync(stream, context: null);
}
}
47 changes: 32 additions & 15 deletions src/Grpc.Net.Client/Internal/StreamExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand All @@ -23,6 +23,7 @@
using System.Runtime.InteropServices;
using Grpc.Core;
using Grpc.Net.Compression;
using Grpc.Shared;
using Microsoft.Extensions.Logging;

#if NETSTANDARD2_0
Expand Down Expand Up @@ -318,10 +319,9 @@ public static async ValueTask WriteMessageAsync<TMessage>(
{
GrpcCallLog.ErrorSendingMessage(call.Logger, ex);

// Cancellation from disposing response while waiting for WriteAsync can throw ObjectDisposedException.
if (ex is ObjectDisposedException && call.CancellationToken.IsCancellationRequested)
if (TryCreateCallCompleteException(ex, call, out var statusException))
{
throw new OperationCanceledException();
throw statusException;
}

throw;
Expand All @@ -342,24 +342,41 @@ public static async ValueTask WriteMessageAsync(
{
GrpcCallLog.SendingMessage(call.Logger);

try
{
// Sending the header+content in a single WriteAsync call has significant performance benefits
// https://github.com/dotnet/runtime/issues/35184#issuecomment-626304981
await stream.WriteAsync(data, cancellationToken).ConfigureAwait(false);
}
catch (ObjectDisposedException) when (call.CancellationToken.IsCancellationRequested)
{
// Cancellation from disposing response while waiting for WriteAsync can throw ObjectDisposedException.
throw new OperationCanceledException();
}
// Sending the header+content in a single WriteAsync call has significant performance benefits
// https://github.com/dotnet/runtime/issues/35184#issuecomment-626304981
await stream.WriteAsync(data, cancellationToken).ConfigureAwait(false);

GrpcCallLog.MessageSent(call.Logger);
}
catch (Exception ex)
{
GrpcCallLog.ErrorSendingMessage(call.Logger, ex);

if (TryCreateCallCompleteException(ex, call, out var statusException))
{
throw statusException;
}

throw;
}
}

private static bool TryCreateCallCompleteException(Exception originalException, GrpcCall call, [NotNullWhen(true)] out Exception? exception)
{
// The call may have been completed while WriteAsync was running and caused WriteAsync to throw.
// In this situation, report the call's completed status.
//
// Replace exception with the status error if:
// 1. The original exception is one Stream.WriteAsync throws if the call was completed during a write, and
// 2. The call has already been successfully completed.
if (originalException is OperationCanceledException or ObjectDisposedException &&
call.CallTask.IsCompletedSuccessfully())
{
exception = call.CreateFailureStatusException(call.CallTask.Result);
return true;
}

exception = null;
return false;
}
}
2 changes: 1 addition & 1 deletion test/FunctionalTests/Client/StreamingTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down
103 changes: 101 additions & 2 deletions test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -107,7 +107,6 @@ public async Task AsyncClientStreamingCall_Success_RequestContentSent()
var responseTask = call.ResponseAsync;
Assert.IsFalse(responseTask.IsCompleted, "Response not returned until client stream is complete.");


await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout();
await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout();

Expand Down Expand Up @@ -268,6 +267,106 @@ public async Task ClientStreamWriter_WriteAfterResponseHasFinished_ErrorThrown()
Assert.AreEqual("Hello world", result.Message);
}

[Test]
public async Task AsyncClientStreamingCall_ErrorWhileWriting_StatusExceptionThrown()
{
// Arrange
PushStreamContent<HelloRequest, HelloReply>? content = null;

var responseTcs = new TaskCompletionSource<HttpResponseMessage>(TaskCreationOptions.RunContinuationsAsynchronously);
var httpClient = ClientTestHelpers.CreateTestClient(request =>
{
content = (PushStreamContent<HelloRequest, HelloReply>)request.Content!;
return responseTcs.Task;
});

var invoker = HttpClientCallInvokerFactory.Create(httpClient);

// Act

// Client starts call
var call = invoker.AsyncClientStreamingCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions());
// Client starts request stream write
var writeTask = call.RequestStream.WriteAsync(new HelloRequest());

// Simulate HttpClient starting to accept the write. Stream.WriteAsync is blocked.
var writeSyncPoint = new SyncPoint(runContinuationsAsynchronously: true);
var testStream = new TestStream(writeSyncPoint);
var serializeToStreamTask = content!.SerializeToStreamAsync(testStream);

// Server completes response.
await writeSyncPoint.WaitForSyncPoint().DefaultTimeout();
responseTcs.SetResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, new ByteArrayContent(Array.Empty<byte>()), grpcStatusCode: StatusCode.InvalidArgument));

await ExceptionAssert.ThrowsAsync<RpcException>(() => call.ResponseAsync).DefaultTimeout();
Assert.AreEqual(StatusCode.InvalidArgument, call.GetStatus().StatusCode);

// Unblock Stream.WriteAsync
writeSyncPoint.Continue();

// Get error thrown from write task. It should have the status returned by the server.
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => writeTask).DefaultTimeout();

// Assert
Assert.AreEqual(StatusCode.InvalidArgument, ex.StatusCode);
Assert.AreEqual(StatusCode.InvalidArgument, call.GetStatus().StatusCode);
Assert.AreEqual(string.Empty, call.GetStatus().Detail);
}

private sealed class TestStream : Stream
{
private readonly SyncPoint _writeSyncPoint;

public TestStream(SyncPoint writeSyncPoint)
{
_writeSyncPoint = writeSyncPoint;
}

public override bool CanRead { get; }
public override bool CanSeek { get; }
public override bool CanWrite { get; }
public override long Length { get; }
public override long Position { get; set; }

public override void Flush()
{
}

public override int Read(byte[] buffer, int offset, int count)
{
throw new NotImplementedException();
}

public override long Seek(long offset, SeekOrigin origin)
{
throw new NotImplementedException();
}

public override void SetLength(long value)
{
throw new NotImplementedException();
}

public override void Write(byte[] buffer, int offset, int count)
{
throw new NotImplementedException();
}

#if !NET472_OR_GREATER
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
await _writeSyncPoint.WaitToContinue();
throw new OperationCanceledException();
}
#else
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeSyncPoint.WaitToContinue();
throw new OperationCanceledException();
}
#endif
}

[Test]
public async Task ClientStreamWriter_CancelledBeforeCallStarts_ThrowsError()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -318,6 +318,7 @@ public TestGrpcCall(CallOptions options, GrpcChannel channel) : base(options, ch
public override Type RequestType { get; } = typeof(int);
public override Type ResponseType { get; } = typeof(string);
public override CancellationToken CancellationToken { get; }
public override Task<Status> CallTask => Task.FromResult(Status.DefaultCancelled);
}

private GrpcCallSerializationContext CreateSerializationContext(string? requestGrpcEncoding = null, int? maxSendMessageSize = null)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -65,5 +65,6 @@ public TestGrpcCall(CallOptions options, GrpcChannel channel, Type type) : base(
public override Type RequestType => _type;
public override Type ResponseType => _type;
public override CancellationToken CancellationToken { get; }
public override Task<Status> CallTask => Task.FromResult(Status.DefaultCancelled);
}
}