diff --git a/test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs b/test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs index 53aae9f99..12708e58a 100644 --- a/test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs +++ b/test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs @@ -24,6 +24,8 @@ using Grpc.Net.Client.Tests.Infrastructure; using Grpc.Shared; using Grpc.Tests.Shared; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using NUnit.Framework; namespace Grpc.Net.Client.Tests; @@ -177,4 +179,82 @@ await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync( Assert.IsTrue(moveNextTask4.IsCompleted); Assert.IsFalse(await moveNextTask3.DefaultTimeout()); } + + [Test] + public async Task AsyncDuplexStreamingCall_CancellationDisposeRace_Success() + { + // Arrange + var services = new ServiceCollection(); + services.AddNUnitLogger(); + var loggerFactory = services.BuildServiceProvider().GetRequiredService(); + var logger = loggerFactory.CreateLogger(GetType()); + + for (int i = 0; i < 20; i++) + { + // Let's mimic a real call first to get GrpcCall.RunCall where we need to for reproducing the deadlock. + var streamContent = new SyncPointMemoryStream(); + var requestContentTcs = new TaskCompletionSource>(TaskCreationOptions.RunContinuationsAsynchronously); + + PushStreamContent? content = null; + + var handler = TestHttpMessageHandler.Create(async request => + { + content = (PushStreamContent)request.Content!; + var streamTask = content.ReadAsStreamAsync(); + requestContentTcs.SetResult(streamTask); + + // Wait for RequestStream.CompleteAsync() + await streamTask; + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, new StreamContent(streamContent)); + }); + var channel = GrpcChannel.ForAddress("http://localhost", new GrpcChannelOptions + { + HttpHandler = handler, + LoggerFactory = loggerFactory + }); + var invoker = channel.CreateCallInvoker(); + + var cts = new CancellationTokenSource(); + + var call = invoker.AsyncDuplexStreamingCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(cancellationToken: cts.Token)); + await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout(); + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + // Let's read a response + var deserializationContext = new DefaultDeserializationContext(); + var requestContent = await await requestContentTcs.Task.DefaultTimeout(); + var requestMessage = await StreamSerializationHelper.ReadMessageAsync( + requestContent, + ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, + GrpcProtocolConstants.IdentityGrpcEncoding, + maximumMessageSize: null, + GrpcProtocolConstants.DefaultCompressionProviders, + singleMessage: false, + CancellationToken.None).DefaultTimeout(); + Assert.AreEqual("1", requestMessage!.Name); + + var actTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var cancellationTask = Task.Run(async () => + { + await actTcs.Task; + cts.Cancel(); + }); + var disposingTask = Task.Run(async () => + { + await actTcs.Task; + channel.Dispose(); + }); + + // Small pause to make sure we're waiting at the TCS everywhere. + await Task.Delay(50); + + // Act + actTcs.SetResult(true); + + // Assert + await Task.WhenAll(cancellationTask, disposingTask).DefaultTimeout(); + } + } }