From 349c91fe9780a7925eb0de76788944e2194f9770 Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Tue, 2 Feb 2021 22:48:02 +1300 Subject: [PATCH] Add gRPC retries to client --- Directory.Build.props | 2 +- examples/Retrier/Client/Client.csproj | 20 + examples/Retrier/Client/Program.cs | 195 +++++ examples/Retrier/Proto/retry.proto | 32 + examples/Retrier/Retrier.sln | 43 + examples/Retrier/Server/Program.cs | 38 + examples/Retrier/Server/Server.csproj | 13 + .../Retrier/Server/Services/RetrierService.cs | 75 ++ examples/Retrier/Server/Startup.cs | 48 ++ .../Server/appsettings.Development.json | 10 + examples/Retrier/Server/appsettings.json | 13 + .../InterceptorRegistration.cs | 2 +- .../Internal/GrpcWebProtocolHelpers.cs | 2 +- .../Configuration/ConfigObject.cs | 68 ++ .../Configuration/HedgingPolicy.cs | 84 ++ .../Configuration/MethodConfig.cs | 96 +++ .../Configuration/MethodName.cs | 79 ++ .../Configuration/RetryPolicy.cs | 102 +++ .../Configuration/RetryThrottlingPolicy.cs | 60 ++ .../Configuration/ServiceConfig.cs | 66 ++ src/Grpc.Net.Client/GrpcChannel.cs | 161 +++- src/Grpc.Net.Client/GrpcChannelOptions.cs | 53 ++ .../Internal/Configuration/ConfigProperty.cs | 60 ++ .../Internal/Configuration/ConvertHelpers.cs | 106 +++ .../Internal/Configuration/IConfigValue.cs | 25 + .../Internal/Configuration/Values.cs | 97 +++ .../DefaultChannelCredentialsConfigurator.cs | 62 ++ .../Internal/GrpcCall.NonGeneric.cs | 80 ++ src/Grpc.Net.Client/Internal/GrpcCall.cs | 249 +++--- .../Internal/GrpcMethodInfo.cs | 5 +- .../Internal/GrpcProtocolConstants.cs | 10 +- .../Internal/GrpcProtocolHelpers.cs | 11 + .../Internal/Http/PushStreamContent.cs | 22 +- .../Internal/Http/PushUnaryContent.cs | 23 +- ...UnaryContent.cs => WinHttpUnaryContent.cs} | 74 +- .../Internal/HttpClientCallInvoker.cs | 54 +- .../Internal/HttpContentClientStreamReader.cs | 2 +- .../Internal/HttpContentClientStreamWriter.cs | 24 +- src/Grpc.Net.Client/Internal/IGrpcCall.cs | 49 ++ .../Internal/Retry/ChannelRetryThrottling.cs | 78 ++ .../Internal/Retry/CommitReason.cs | 34 + .../Internal/Retry/DeadlineGrpcCall.cs | 135 +++ .../Internal/Retry/HedgingCall.cs | 410 +++++++++ .../Internal/Retry/RetryCall.cs | 357 ++++++++ .../Internal/Retry/RetryCallBase.Log.cs | 120 +++ .../Internal/Retry/RetryCallBase.cs | 483 +++++++++++ .../Retry/RetryCallBaseClientStreamReader.cs | 46 ++ .../Retry/RetryCallBaseClientStreamWriter.cs | 51 ++ .../Internal/StreamExtensions.cs | 5 +- src/Shared/CommonGrpcProtocolHelpers.cs | 4 +- .../Client/CancellationTests.cs | 2 +- .../Client/EventSourceTests.cs | 4 +- test/FunctionalTests/Client/HedgingTests.cs | 548 ++++++++++++ test/FunctionalTests/Client/RetryTests.cs | 562 +++++++++++++ test/FunctionalTests/FunctionalTestBase.cs | 19 +- .../Grpc.AspNetCore.FunctionalTests.csproj | 1 + .../Server/ClientStreamingMethodTests.cs | 2 +- test/FunctionalTests/Server/DeadlineTests.cs | 11 +- .../Web/Base64PipeReaderTests.cs | 2 +- .../AsyncUnaryCallTests.cs | 42 +- .../Grpc.Net.Client.Tests.csproj | 2 + .../Grpc.Net.Client.Tests/GrpcChannelTests.cs | 21 + .../HttpContentClientStreamReaderTests.cs | 5 +- .../HttpClientCallInvokerFactory.cs | 7 +- .../Infrastructure/WinHttpHandler.cs | 28 + .../Retry/ChannelRetryThrottlingTests.cs | 49 ++ .../Retry/HedgingCallTests.cs | 371 +++++++++ .../Retry/HedgingTests.cs | 654 +++++++++++++++ .../Grpc.Net.Client.Tests/Retry/RetryTests.cs | 781 ++++++++++++++++++ .../ServiceConfigTests.cs | 129 +++ .../Base64ResponseStreamTests.cs | 4 +- test/Shared/ClientTestHelpers.cs | 9 +- test/Shared/ExceptionAssert.cs | 5 + test/Shared/ServiceConfigHelpers.cs | 108 +++ test/Shared/TestHelpers.cs | 4 +- .../InteropTestsWebsite/TestServiceImpl.cs | 1 + 76 files changed, 7017 insertions(+), 292 deletions(-) create mode 100644 examples/Retrier/Client/Client.csproj create mode 100644 examples/Retrier/Client/Program.cs create mode 100644 examples/Retrier/Proto/retry.proto create mode 100644 examples/Retrier/Retrier.sln create mode 100644 examples/Retrier/Server/Program.cs create mode 100644 examples/Retrier/Server/Server.csproj create mode 100644 examples/Retrier/Server/Services/RetrierService.cs create mode 100644 examples/Retrier/Server/Startup.cs create mode 100644 examples/Retrier/Server/appsettings.Development.json create mode 100644 examples/Retrier/Server/appsettings.json create mode 100644 src/Grpc.Net.Client/Configuration/ConfigObject.cs create mode 100644 src/Grpc.Net.Client/Configuration/HedgingPolicy.cs create mode 100644 src/Grpc.Net.Client/Configuration/MethodConfig.cs create mode 100644 src/Grpc.Net.Client/Configuration/MethodName.cs create mode 100644 src/Grpc.Net.Client/Configuration/RetryPolicy.cs create mode 100644 src/Grpc.Net.Client/Configuration/RetryThrottlingPolicy.cs create mode 100644 src/Grpc.Net.Client/Configuration/ServiceConfig.cs create mode 100644 src/Grpc.Net.Client/Internal/Configuration/ConfigProperty.cs create mode 100644 src/Grpc.Net.Client/Internal/Configuration/ConvertHelpers.cs create mode 100644 src/Grpc.Net.Client/Internal/Configuration/IConfigValue.cs create mode 100644 src/Grpc.Net.Client/Internal/Configuration/Values.cs create mode 100644 src/Grpc.Net.Client/Internal/DefaultChannelCredentialsConfigurator.cs rename src/Grpc.Net.Client/Internal/Http/{LengthUnaryContent.cs => WinHttpUnaryContent.cs} (60%) create mode 100644 src/Grpc.Net.Client/Internal/IGrpcCall.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/ChannelRetryThrottling.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/CommitReason.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/DeadlineGrpcCall.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/RetryCall.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/RetryCallBase.Log.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamReader.cs create mode 100644 src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamWriter.cs create mode 100644 test/FunctionalTests/Client/HedgingTests.cs create mode 100644 test/FunctionalTests/Client/RetryTests.cs create mode 100644 test/Grpc.Net.Client.Tests/Infrastructure/WinHttpHandler.cs create mode 100644 test/Grpc.Net.Client.Tests/Retry/ChannelRetryThrottlingTests.cs create mode 100644 test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs create mode 100644 test/Grpc.Net.Client.Tests/Retry/HedgingTests.cs create mode 100644 test/Grpc.Net.Client.Tests/Retry/RetryTests.cs create mode 100644 test/Grpc.Net.Client.Tests/ServiceConfigTests.cs create mode 100644 test/Shared/ServiceConfigHelpers.cs diff --git a/Directory.Build.props b/Directory.Build.props index 0fb0ae78c..0434644c6 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -18,7 +18,7 @@ $(WarningsNotAsErrors);CS1591 true - 8.0 + 9.0 enable diff --git a/examples/Retrier/Client/Client.csproj b/examples/Retrier/Client/Client.csproj new file mode 100644 index 000000000..0d1dea897 --- /dev/null +++ b/examples/Retrier/Client/Client.csproj @@ -0,0 +1,20 @@ + + + + Exe + net5.0 + + + + + + + + + + + + + + + diff --git a/examples/Retrier/Client/Program.cs b/examples/Retrier/Client/Program.cs new file mode 100644 index 000000000..2477ee3f9 --- /dev/null +++ b/examples/Retrier/Client/Program.cs @@ -0,0 +1,195 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Google.Protobuf.WellKnownTypes; +using Grpc.Core; +using Grpc.Net.Client; +using Grpc.Net.Client.Configuration; +using Retry; + +namespace Client +{ + public class Program + { + static async Task Main(string[] args) + { + using var channel = CreateChannel(); + var client = new Retrier.RetrierClient(channel); + + //await UnaryRetry(client); + await StreamingRetry(client); + + Console.WriteLine("Shutting down"); + Console.WriteLine("Press any key to exit..."); + Console.ReadKey(); + } + + private static async Task UnaryRetry(Retrier.RetrierClient client) + { + Console.WriteLine("Delivering packages..."); + foreach (var product in Products) + { + try + { + var package = new Package { Name = product }; + var call = client.DeliverPackageAsync(package); + var response = await call; + + #region Print success + Console.ForegroundColor = ConsoleColor.Green; + Console.Write(response.Message); + Console.ResetColor(); + Console.Write(" " + await GetRetryCount(call.ResponseHeadersAsync)); + Console.WriteLine(); + #endregion + } + catch (RpcException ex) + { + #region Print failure + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine(ex.Status.Detail); + Console.ResetColor(); + #endregion + } + + await Task.Delay(TimeSpan.FromSeconds(0.2)); + } + } + + private static async Task StreamingRetry(Retrier.RetrierClient client) + { + var call = client.MessageUpload(); + + try + { + var lines = ImportantMessage.Split(Environment.NewLine); + for (var i = 0; i < lines.Length; i++) + { + #region Print percentage + if (i % 2 == 0) + { + Console.WriteLine((int)((i + 1D) / lines.Length * 100) + "% uploaded"); + } + #endregion + + await call.RequestStream.WriteAsync(new StringValue { Value = lines[i] }); + await Task.Delay(TimeSpan.FromSeconds(0.1)); + } + await call.RequestStream.CompleteAsync(); + + Console.WriteLine("Upload complete"); + Console.WriteLine("Press any key to download important message..."); + Console.ReadKey(); + + #region Print success + var count = 0; + await foreach (var line in call.ResponseStream.ReadAllAsync()) + { + Console.ForegroundColor = (ConsoleColor)(count % 7) + 9; + Console.WriteLine(line.Value); + await Task.Delay(TimeSpan.FromSeconds(0.2)); + count++; + } + Console.ResetColor(); + Console.Write(" " + await GetRetryCount(call.ResponseHeadersAsync)); + Console.WriteLine(); + #endregion + } + catch (RpcException ex) + { + #region Print failure + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine(ex.Status.Detail); + Console.ResetColor(); + #endregion + } + } + + private static GrpcChannel CreateChannel() + { + var options = new GrpcChannelOptions + { + ServiceConfig = new ServiceConfig + { + MethodConfigs = + { + new MethodConfig + { + Names = { MethodName.Default }, + } + } + } + }; + return GrpcChannel.ForAddress("http://localhost:5000", options); + } + + private static async Task GetRetryCount(Task responseHeadersTask) + { + var headers = await responseHeadersTask; + var previousAttemptCount = headers.GetValue("grpc-previous-rpc-attempts"); + return previousAttemptCount != null ? $"(retry count: {previousAttemptCount})" : string.Empty; + } + + private static readonly IList Products = new List + { + "Secrets of Silicon Valley", + "The Busy Executive's Database Guide", + "Emotional Security: A New Algorithm", + "Prolonged Data Deprivation: Four Case Studies", + "Cooking with Computers: Surreptitious Balance Sheets", + "Silicon Valley Gastronomic Treats", + "Sushi, Anyone?", + "Fifty Years in Buckingham Palace Kitchens", + "But Is It User Friendly?", + "You Can Combat Computer Stress!", + "Is Anger the Enemy?", + "Life Without Fear", + "The Gourmet Microwave", + "Onions, Leeks, and Garlic: Cooking Secrets of the Mediterranean", + "The Psychology of Computer Cooking", + "Straight Talk About Computers", + "Computer Phobic AND Non-Phobic Individuals: Behavior Variations", + "Net Etiquette" + }; + + private static readonly string ImportantMessage =@" + _____ _____ _____ + | __ \| __ \ / ____| + __ _| |__) | |__) | | + / _` | _ /| ___/| | + | (_| | | \ \| | | |____ + \__, |_| \_\_| \_____| + __/ | + |___/ + _ + (_) + _ ___ + | / __| + | \__ \ _ + |_|___/ | | + ___ ___ ___ | | + / __/ _ \ / _ \| | + | (_| (_) | (_) | | + \___\___/ \___/|_| + + "; + } +} diff --git a/examples/Retrier/Proto/retry.proto b/examples/Retrier/Proto/retry.proto new file mode 100644 index 000000000..2ba421457 --- /dev/null +++ b/examples/Retrier/Proto/retry.proto @@ -0,0 +1,32 @@ +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +import "google/protobuf/wrappers.proto"; + +package retry; + +service Retrier { + rpc DeliverPackage (Package) returns (Response); + rpc MessageUpload (stream google.protobuf.StringValue) returns (stream google.protobuf.StringValue); +} + +message Package { + string name = 1; +} + +message Response { + string message = 1; +} diff --git a/examples/Retrier/Retrier.sln b/examples/Retrier/Retrier.sln new file mode 100644 index 000000000..fd620020c --- /dev/null +++ b/examples/Retrier/Retrier.sln @@ -0,0 +1,43 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 16 +VisualStudioVersion = 16.0.29230.61 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Server", "Server\Server.csproj", "{534AC5F8-2DF2-40BD-87A5-B3D8310118C4}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Client", "Client\Client.csproj", "{48A1D3BC-A14B-436A-8822-6DE2BEF8B747}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Grpc.Net.Client", "..\..\src\Grpc.Net.Client\Grpc.Net.Client.csproj", "{F001F7FD-21F7-42E5-BFB6-D0136ACA8869}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Grpc.Net.Common", "..\..\src\Grpc.Net.Common\Grpc.Net.Common.csproj", "{EB47A4E0-1AED-4D44-8BF6-BC7AE00D4058}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {534AC5F8-2DF2-40BD-87A5-B3D8310118C4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {534AC5F8-2DF2-40BD-87A5-B3D8310118C4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {534AC5F8-2DF2-40BD-87A5-B3D8310118C4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {534AC5F8-2DF2-40BD-87A5-B3D8310118C4}.Release|Any CPU.Build.0 = Release|Any CPU + {48A1D3BC-A14B-436A-8822-6DE2BEF8B747}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {48A1D3BC-A14B-436A-8822-6DE2BEF8B747}.Debug|Any CPU.Build.0 = Debug|Any CPU + {48A1D3BC-A14B-436A-8822-6DE2BEF8B747}.Release|Any CPU.ActiveCfg = Release|Any CPU + {48A1D3BC-A14B-436A-8822-6DE2BEF8B747}.Release|Any CPU.Build.0 = Release|Any CPU + {F001F7FD-21F7-42E5-BFB6-D0136ACA8869}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F001F7FD-21F7-42E5-BFB6-D0136ACA8869}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F001F7FD-21F7-42E5-BFB6-D0136ACA8869}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F001F7FD-21F7-42E5-BFB6-D0136ACA8869}.Release|Any CPU.Build.0 = Release|Any CPU + {EB47A4E0-1AED-4D44-8BF6-BC7AE00D4058}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {EB47A4E0-1AED-4D44-8BF6-BC7AE00D4058}.Debug|Any CPU.Build.0 = Debug|Any CPU + {EB47A4E0-1AED-4D44-8BF6-BC7AE00D4058}.Release|Any CPU.ActiveCfg = Release|Any CPU + {EB47A4E0-1AED-4D44-8BF6-BC7AE00D4058}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {D22B3129-3BFB-41FA-9FCE-E45EBEF8C2DD} + EndGlobalSection +EndGlobal diff --git a/examples/Retrier/Server/Program.cs b/examples/Retrier/Server/Program.cs new file mode 100644 index 000000000..8ec497bbf --- /dev/null +++ b/examples/Retrier/Server/Program.cs @@ -0,0 +1,38 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.Hosting; + +namespace Server +{ + public class Program + { + public static void Main(string[] args) + { + CreateHostBuilder(args).Build().Run(); + } + + public static IHostBuilder CreateHostBuilder(string[] args) => + Host.CreateDefaultBuilder(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + }); + } +} diff --git a/examples/Retrier/Server/Server.csproj b/examples/Retrier/Server/Server.csproj new file mode 100644 index 000000000..63e9d3525 --- /dev/null +++ b/examples/Retrier/Server/Server.csproj @@ -0,0 +1,13 @@ + + + + net5.0 + + + + + + + + + diff --git a/examples/Retrier/Server/Services/RetrierService.cs b/examples/Retrier/Server/Services/RetrierService.cs new file mode 100644 index 000000000..bb34a8011 --- /dev/null +++ b/examples/Retrier/Server/Services/RetrierService.cs @@ -0,0 +1,75 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Google.Protobuf.WellKnownTypes; +using Grpc.Core; +using Retry; + +namespace Server +{ + public class RetrierService : Retrier.RetrierBase + { + private readonly Random _random = new Random(); + + public override Task DeliverPackage(Package request, ServerCallContext context) + { + const double deliveryChance = 0.5; + if (_random.NextDouble() > deliveryChance) + { + throw new RpcException(new Status(StatusCode.Unavailable, $"- {request.Name}")); + } + + return Task.FromResult(new Response + { + Message = $"+ {request.Name}" + }); + } + + public override async Task MessageUpload( + IAsyncStreamReader requestStream, + IServerStreamWriter responseStream, + ServerCallContext context) + { + const double deliveryChance = 0.95; + + // Receive chunks + var chunks = new List(); + await foreach (var chunk in requestStream.ReadAllAsync()) + { + if (_random.NextDouble() > deliveryChance) + { + throw new RpcException(new Status(StatusCode.Unavailable, $"Message chunk not delivered.")); + } + + chunks.Add(chunk.Value); + } + + // Write chunks + foreach (var chunk in chunks) + { + await responseStream.WriteAsync(new StringValue + { + Value = chunk + }); + } + } + } +} diff --git a/examples/Retrier/Server/Startup.cs b/examples/Retrier/Server/Startup.cs new file mode 100644 index 000000000..739f2fbf0 --- /dev/null +++ b/examples/Retrier/Server/Startup.cs @@ -0,0 +1,48 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace Server +{ + public class Startup + { + public void ConfigureServices(IServiceCollection services) + { + services.AddGrpc(); + } + + public void Configure(IApplicationBuilder app, IWebHostEnvironment env) + { + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + app.UseRouting(); + + app.UseEndpoints(endpoints => + { + endpoints.MapGrpcService(); + }); + } + } +} diff --git a/examples/Retrier/Server/appsettings.Development.json b/examples/Retrier/Server/appsettings.Development.json new file mode 100644 index 000000000..fe20c40cc --- /dev/null +++ b/examples/Retrier/Server/appsettings.Development.json @@ -0,0 +1,10 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Debug", + "System": "Information", + "Grpc": "Information", + "Microsoft": "Information" + } + } +} diff --git a/examples/Retrier/Server/appsettings.json b/examples/Retrier/Server/appsettings.json new file mode 100644 index 000000000..f5f63744b --- /dev/null +++ b/examples/Retrier/Server/appsettings.json @@ -0,0 +1,13 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information" + } + }, + "AllowedHosts": "*", + "Kestrel": { + "EndpointDefaults": { + "Protocols": "Http2" + } + } +} diff --git a/src/Grpc.AspNetCore.Server/InterceptorRegistration.cs b/src/Grpc.AspNetCore.Server/InterceptorRegistration.cs index 828c032a2..51ee404d9 100644 --- a/src/Grpc.AspNetCore.Server/InterceptorRegistration.cs +++ b/src/Grpc.AspNetCore.Server/InterceptorRegistration.cs @@ -50,7 +50,7 @@ internal InterceptorRegistration( { throw new ArgumentNullException(nameof(arguments)); } - for (int i = 0; i < arguments.Length; i++) + for (var i = 0; i < arguments.Length; i++) { if (arguments[i] == null) { diff --git a/src/Grpc.AspNetCore.Web/Internal/GrpcWebProtocolHelpers.cs b/src/Grpc.AspNetCore.Web/Internal/GrpcWebProtocolHelpers.cs index 1950c5469..00d28233c 100644 --- a/src/Grpc.AspNetCore.Web/Internal/GrpcWebProtocolHelpers.cs +++ b/src/Grpc.AspNetCore.Web/Internal/GrpcWebProtocolHelpers.cs @@ -121,7 +121,7 @@ private static void WriteTrailersContent(Span buffer, IHeaderDictionary tr // gRPC-Web protocol says that names should be lower-case and grpc-web JS client // will check for 'grpc-status' and 'grpc-message' in trailers with lower-case key. // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-WEB.md#protocol-differences-vs-grpc-over-http2 - for (int i = 0; i < kv.Key.Length; i++) + for (var i = 0; i < kv.Key.Length; i++) { char c = kv.Key[i]; currentBuffer[i] = (byte)((uint)(c - 'A') <= ('Z' - 'A') ? c | 0x20 : c); diff --git a/src/Grpc.Net.Client/Configuration/ConfigObject.cs b/src/Grpc.Net.Client/Configuration/ConfigObject.cs new file mode 100644 index 000000000..297494b67 --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/ConfigObject.cs @@ -0,0 +1,68 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Collections; +using System.Collections.Generic; +using Grpc.Net.Client.Internal.Configuration; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// Represents a configuration object. Implementations provide strongly typed wrappers over + /// collections of untyped values. + /// + public abstract class ConfigObject : IConfigValue + { + /// + /// Gets the underlying configuration values. + /// + public IDictionary Inner { get; } + + internal ConfigObject() : this(new Dictionary()) + { + } + + internal ConfigObject(IDictionary inner) + { + Inner = inner; + } + + object IConfigValue.Inner => Inner; + + internal T? GetValue(string key) + { + if (Inner.TryGetValue(key, out var value)) + { + return (T?)value; + } + return default; + } + + internal void SetValue(string key, T? value) + { + if (value == null) + { + Inner.Remove(key); + } + else + { + Inner[key] = value; + } + } + } +} diff --git a/src/Grpc.Net.Client/Configuration/HedgingPolicy.cs b/src/Grpc.Net.Client/Configuration/HedgingPolicy.cs new file mode 100644 index 000000000..7832dda16 --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/HedgingPolicy.cs @@ -0,0 +1,84 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using Grpc.Core; +using Grpc.Net.Client.Internal.Configuration; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// The hedging policy for outgoing calls. Hedged calls may execute more than + /// once on the server, so only idempotent methods should specify a hedging + /// policy. + /// + public sealed class HedgingPolicy : ConfigObject + { + internal const string MaxAttemptsPropertyName = "maxAttempts"; + internal const string HedgingDelayPropertyName = "hedgingDelay"; + internal const string NonFatalStatusCodesPropertyName = "nonFatalStatusCodes"; + + private ConfigProperty, IList> _nonFatalStatusCodes = + new(i => new Values(i ?? new List(), s => ConvertHelpers.ConvertStatusCode(s), s => ConvertHelpers.ConvertStatusCode(s.ToString()!)), NonFatalStatusCodesPropertyName); + + /// + /// Initializes a new instance of the class. + /// + public HedgingPolicy() { } + internal HedgingPolicy(IDictionary inner) : base(inner) { } + + /// + /// The hedging policy will send up to this number of calls. + /// This number represents the total number of all attempts, including + /// the original attempt. + /// + /// This field is required and must be greater than 1. + /// This value is limited by . + /// + public int? MaxAttempts + { + get => GetValue(MaxAttemptsPropertyName); + set => SetValue(MaxAttemptsPropertyName, value); + } + + /// + /// The first call will be sent immediately, but the subsequent + /// hedged call will be sent at intervals of every delay. Set this + /// to 0 to immediately send all hedged calls. + /// + public TimeSpan? HedgingDelay + { + get => ConvertHelpers.ConvertDurationText(GetValue(HedgingDelayPropertyName)); + set => SetValue(HedgingDelayPropertyName, ConvertHelpers.ToDurationText(value)); + } + + /// + /// The set of status codes which indicate other hedged calls may still + /// succeed. If a non-fatal status code is returned by the server, hedged + /// calls will continue. Otherwise, outstanding requests will be canceled and + /// the error returned to the client application layer. + /// + /// Specifying status codes is optional. + /// + public IList NonFatalStatusCodes + { + get => _nonFatalStatusCodes.GetValue(this)!; + } + } +} diff --git a/src/Grpc.Net.Client/Configuration/MethodConfig.cs b/src/Grpc.Net.Client/Configuration/MethodConfig.cs new file mode 100644 index 000000000..9a85fa8dd --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/MethodConfig.cs @@ -0,0 +1,96 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Collections.Generic; +using Grpc.Net.Client.Internal.Configuration; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// Configuration for a method. + /// The collection is used to determine which methods this configuration applies to. + /// + public sealed class MethodConfig : ConfigObject + { + private const string NamePropertyName = "name"; + private const string RetryPolicyPropertyName = "retryPolicy"; + private const string HedgingPolicyPropertyName = "hedgingPolicy"; + + private ConfigProperty, IList> _names = + new(i => new Values(i ?? new List(), s => s.Inner, s => new MethodName((IDictionary)s)), NamePropertyName); + + private ConfigProperty> _retryPolicy = + new(i => i != null ? new RetryPolicy(i) : null, RetryPolicyPropertyName); + + private ConfigProperty> _hedgingPolicy = + new(i => i != null ? new HedgingPolicy(i) : null, HedgingPolicyPropertyName); + + /// + /// Initializes a new instance of the class. + /// + public MethodConfig() { } + internal MethodConfig(IDictionary inner) : base(inner) { } + + /// + /// Gets or sets the retry policy for outgoing calls. + /// A retry policy can't be combined with . + /// + public RetryPolicy? RetryPolicy + { + get => _retryPolicy.GetValue(this); + set => _retryPolicy.SetValue(this, value); + } + + /// + /// Gets or sets the hedging policy for outgoing calls. Hedged calls may execute + /// more than once on the server, so only idempotent methods should specify a hedging + /// policy. A hedging policy can't be combined with . + /// + public HedgingPolicy? HedgingPolicy + { + get => _hedgingPolicy.GetValue(this); + set => _hedgingPolicy.SetValue(this, value); + } + + /// + /// Gets a collection of names which determine the calls the method config will apply to. + /// A without names won't be used. Each name must be unique + /// across an entire . + /// + /// + /// + /// If a name's property isn't set then the method config is the default + /// for all methods for the specified service. + /// + /// + /// If a name's property isn't set then must also be unset, + /// and the method config is the default for all methods on all services. + /// represents this global default name. + /// + /// + /// When determining which method config to use for a given RPC, the most specific match wins. A method config + /// with a configured that exactly matches a call's method and service will be used + /// instead of a service or global default method config. + /// + /// + public IList Names + { + get => _names.GetValue(this)!; + } + } +} diff --git a/src/Grpc.Net.Client/Configuration/MethodName.cs b/src/Grpc.Net.Client/Configuration/MethodName.cs new file mode 100644 index 000000000..1479d20e9 --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/MethodName.cs @@ -0,0 +1,79 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Collections.Generic; +using System.Collections.ObjectModel; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// The name of a method. Used to configure what calls a applies to using + /// the collection. + /// + /// + /// + /// If a name's property isn't set then the method config is the default + /// for all methods for the specified service. + /// + /// + /// If a name's property isn't set then must also be unset, + /// and the method config is the default for all methods on all services. + /// represents this global default name. + /// + /// + /// When determining which method config to use for a given RPC, the most specific match wins. A method config + /// with a configured that exactly matches a call's method and service will be used + /// instead of a service or global default method config. + /// + /// + public sealed class MethodName + : ConfigObject + { + /// + /// A global default name. + /// + public static readonly MethodName Default = new MethodName(new ReadOnlyDictionary(new Dictionary())); + + private const string ServicePropertyName = "service"; + private const string MethodPropertyName = "method"; + + /// + /// Initializes a new instance of the class. + /// + public MethodName() { } + internal MethodName(IDictionary inner) : base(inner) { } + + /// + /// Gets or sets the service name. + /// + public string? Service + { + get => GetValue(ServicePropertyName); + set => SetValue(ServicePropertyName, value); + } + + /// + /// Gets or sets the method name. + /// + public string? Method + { + get => GetValue(MethodPropertyName); + set => SetValue(MethodPropertyName, value); + } + } +} diff --git a/src/Grpc.Net.Client/Configuration/RetryPolicy.cs b/src/Grpc.Net.Client/Configuration/RetryPolicy.cs new file mode 100644 index 000000000..0655e63f6 --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/RetryPolicy.cs @@ -0,0 +1,102 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using Grpc.Core; +using Grpc.Net.Client.Internal.Configuration; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// The retry policy for outgoing calls. + /// + public sealed class RetryPolicy : ConfigObject + { + internal const string MaxAttemptsPropertyName = "maxAttempts"; + internal const string InitialBackoffPropertyName = "initialBackoff"; + internal const string MaxBackoffPropertyName = "maxBackoff"; + internal const string BackoffMultiplierPropertyName = "backoffMultiplier"; + internal const string RetryableStatusCodesPropertyName = "retryableStatusCodes"; + + private ConfigProperty, IList> _retryableStatusCodes = + new(i => new Values(i ?? new List(), s => ConvertHelpers.ConvertStatusCode(s), s => ConvertHelpers.ConvertStatusCode(s.ToString()!)), RetryableStatusCodesPropertyName); + + /// + /// Initializes a new instance of the class. + /// + public RetryPolicy() { } + internal RetryPolicy(IDictionary inner) : base(inner) { } + + /// + /// Gets or sets the maximum number of call attempts. This value includes the original attempt. + /// This property is required and must be greater than 1. + /// This value is limited by . + /// + public int? MaxAttempts + { + get => GetValue(MaxAttemptsPropertyName); + set => SetValue(MaxAttemptsPropertyName, value); + } + + /// + /// Gets or sets the initial backoff. + /// A randomized delay between 0 and the current backoff value will determine when the next + /// retry attempt is made. + /// + /// The backoff will be multiplied by after each retry + /// attempt and will increase exponentially when the multiplier is greater than 1. + /// + /// + public TimeSpan? InitialBackoff + { + get => ConvertHelpers.ConvertDurationText(GetValue(InitialBackoffPropertyName)); + set => SetValue(InitialBackoffPropertyName, ConvertHelpers.ToDurationText(value)); + } + + /// + /// Gets or sets the maximum backoff. + /// The maximum backoff places an upper limit on exponential backoff growth. + /// + public TimeSpan? MaxBackoff + { + get => ConvertHelpers.ConvertDurationText(GetValue(MaxBackoffPropertyName)); + set => SetValue(MaxBackoffPropertyName, ConvertHelpers.ToDurationText(value)); + } + + /// + /// Gets or sets the backoff multiplier. + /// The backoff will be multiplied by after each retry + /// attempt and will increase exponentially when the multiplier is greater than 1. + /// + public double? BackoffMultiplier + { + get => GetValue(BackoffMultiplierPropertyName); + set => SetValue(BackoffMultiplierPropertyName, value); + } + + /// + /// Gets a collection of status codes which may be retried. + /// At least one status code is required. + /// + public IList RetryableStatusCodes + { + get => _retryableStatusCodes.GetValue(this)!; + } + } +} diff --git a/src/Grpc.Net.Client/Configuration/RetryThrottlingPolicy.cs b/src/Grpc.Net.Client/Configuration/RetryThrottlingPolicy.cs new file mode 100644 index 000000000..291ec3ccf --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/RetryThrottlingPolicy.cs @@ -0,0 +1,60 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Collections.Generic; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// The retry throttling policy for a server. + /// + public sealed class RetryThrottlingPolicy : ConfigObject + { + internal const string MaxTokensPropertyName = "maxTokens"; + internal const string TokenRatioPropertyName = "tokenRatio"; + + /// + /// Initializes a new instance of the class. + /// + public RetryThrottlingPolicy() { } + internal RetryThrottlingPolicy(IDictionary inner) : base(inner) { } + + /// + /// Gets or sets the maximum number of tokens. + /// The number of tokens starts at and the token count will + /// always be between 0 and . + /// This property is required and must be greater than zero. + /// + public int? MaxTokens + { + get => GetValue(MaxTokensPropertyName); + set => SetValue(MaxTokensPropertyName, value); + } + + /// + /// Gets or sets the amount of tokens to add on each successful call. Typically this will + /// be some number between 0 and 1, e.g., 0.1. + /// This property is required and must be greater than zero. Up to 3 decimal places are supported. + /// + public double? TokenRatio + { + get => GetValue(TokenRatioPropertyName); + set => SetValue(TokenRatioPropertyName, value); + } + } +} diff --git a/src/Grpc.Net.Client/Configuration/ServiceConfig.cs b/src/Grpc.Net.Client/Configuration/ServiceConfig.cs new file mode 100644 index 000000000..93f7afa4d --- /dev/null +++ b/src/Grpc.Net.Client/Configuration/ServiceConfig.cs @@ -0,0 +1,66 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Collections.Generic; +using Grpc.Net.Client.Internal.Configuration; + +namespace Grpc.Net.Client.Configuration +{ + /// + /// A represents information about a service. + /// + public sealed class ServiceConfig : ConfigObject + { + private const string MethodConfigPropertyName = "method_config"; + private const string RetryThrottlingPropertyName = "retryThrottling"; + + private ConfigProperty, IList> _methods = + new(i => new Values(i ?? new List(), s => s.Inner, s => new MethodConfig((IDictionary)s)), MethodConfigPropertyName); + + private ConfigProperty> _retryThrottling = + new(i => i != null ? new RetryThrottlingPolicy(i) : null, RetryThrottlingPropertyName); + + /// + /// Initializes a new instance of the class. + /// + public ServiceConfig() { } + internal ServiceConfig(IDictionary inner) : base(inner) { } + + /// + /// Gets a collection of instances. This collection is used to specify + /// configuration on a per-method basis. determines which calls + /// a method config applies to. + /// + public IList MethodConfigs + { + get => _methods.GetValue(this)!; + } + + /// + /// Gets or sets the retry throttling policy. + /// If a is provided, gRPC will automatically throttle + /// retry attempts and hedged RPCs when the client's ratio of failures to + /// successes exceeds a threshold. + /// + public RetryThrottlingPolicy? RetryThrottling + { + get => _retryThrottling.GetValue(this); + set => _retryThrottling.SetValue(this, value); + } + } +} diff --git a/src/Grpc.Net.Client/GrpcChannel.cs b/src/Grpc.Net.Client/GrpcChannel.cs index 77ddc0cf5..7e9a686d9 100644 --- a/src/Grpc.Net.Client/GrpcChannel.cs +++ b/src/Grpc.Net.Client/GrpcChannel.cs @@ -23,10 +23,15 @@ using System.Net.Http; using Grpc.Core; using Grpc.Net.Client.Internal; +using Grpc.Net.Client.Configuration; +using GrpcServiceConfig = Grpc.Net.Client.Configuration.ServiceConfig; using Grpc.Net.Compression; using Grpc.Shared; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Grpc.Net.Client.Internal.Retry; +using System.Threading; +using System.Diagnostics; namespace Grpc.Net.Client { @@ -38,9 +43,14 @@ namespace Grpc.Net.Client public sealed class GrpcChannel : ChannelBase, IDisposable { internal const int DefaultMaxReceiveMessageSize = 1024 * 1024 * 4; // 4 MB + internal const int DefaultMaxRetryAttempts = 5; + internal const long DefaultMaxRetryBufferSize = 1024 * 1024 * 16; // 16 MB + internal const long DefaultMaxRetryBufferPerCallSize = 1024 * 1024; // 1 MB private readonly ConcurrentDictionary _methodInfoCache; private readonly Func _createMethodInfoFunc; + private readonly Dictionary? _serviceConfigMethods; + private readonly object? _retryBufferLock; // Internal for testing internal readonly HashSet ActiveCalls; @@ -49,6 +59,9 @@ public sealed class GrpcChannel : ChannelBase, IDisposable internal bool IsWinHttp { get; } internal int? SendMaxMessageSize { get; } internal int? ReceiveMaxMessageSize { get; } + internal int? MaxRetryAttempts { get; } + internal long? MaxRetryBufferSize { get; } + internal long? MaxRetryBufferPerCallSize { get; } internal ILoggerFactory LoggerFactory { get; } internal bool ThrowOperationCanceledOnCancellation { get; } internal bool? IsSecure { get; } @@ -56,6 +69,11 @@ public sealed class GrpcChannel : ChannelBase, IDisposable internal Dictionary CompressionProviders { get; } internal string MessageAcceptEncoding { get; } internal bool Disposed { get; private set; } + internal GrpcServiceConfig? ServiceConfig { get; } + + // Stateful + internal ChannelRetryThrottling? RetryThrottling { get; } + internal long CurrentRetryBufferSize; // Options that are set in unit tests internal ISystemClock Clock = SystemClock.Instance; @@ -80,12 +98,20 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr IsWinHttp = channelOptions.HttpHandler != null ? HttpHandlerFactory.HasHttpHandlerType(channelOptions.HttpHandler, "System.Net.Http.WinHttpHandler") : false; SendMaxMessageSize = channelOptions.MaxSendMessageSize; ReceiveMaxMessageSize = channelOptions.MaxReceiveMessageSize; + MaxRetryAttempts = channelOptions.MaxRetryAttempts; + MaxRetryBufferSize = channelOptions.MaxRetryBufferSize; + MaxRetryBufferPerCallSize = channelOptions.MaxRetryBufferPerCallSize; CompressionProviders = ResolveCompressionProviders(channelOptions.CompressionProviders); MessageAcceptEncoding = GrpcProtocolHelpers.GetMessageAcceptEncoding(CompressionProviders); LoggerFactory = channelOptions.LoggerFactory ?? NullLoggerFactory.Instance; ThrowOperationCanceledOnCancellation = channelOptions.ThrowOperationCanceledOnCancellation; _createMethodInfoFunc = CreateMethodInfo; ActiveCalls = new HashSet(); + // TODO(JamesNK): Underlying service config data is not copied + ServiceConfig = channelOptions.ServiceConfig; + RetryThrottling = ServiceConfig?.RetryThrottling != null ? new ChannelRetryThrottling(ServiceConfig.RetryThrottling) : null; + _serviceConfigMethods = (ServiceConfig != null) ? CreateServiceConfigMethods(ServiceConfig) : null; + _retryBufferLock = (ServiceConfig != null) ? new object() : null; if (channelOptions.Credentials != null) { @@ -99,6 +125,27 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr } } + private static Dictionary CreateServiceConfigMethods(GrpcServiceConfig serviceConfig) + { + var configs = new Dictionary(); + for (var i = 0; i < serviceConfig.MethodConfigs.Count; i++) + { + var methodConfig = serviceConfig.MethodConfigs[i]; + for (var j = 0; j < methodConfig.Names.Count; j++) + { + var name = methodConfig.Names[j]; + var methodKey = new MethodKey(name.Service, name.Method); + if (configs.ContainsKey(methodKey)) + { + throw new InvalidOperationException($"Duplicate method config found. Service: '{name.Service}', method: '{name.Method}'."); + } + configs[methodKey] = methodConfig; + } + } + + return configs; + } + private static HttpMessageInvoker CreateInternalHttpInvoker(HttpMessageHandler? handler) { // HttpMessageInvoker should always dispose handler if Disposed is called on it. @@ -144,8 +191,31 @@ private GrpcMethodInfo CreateMethodInfo(IMethod method) { var uri = new Uri(method.FullName, UriKind.Relative); var scope = new GrpcCallScope(method.Type, uri); + var methodConfig = ResolveMethodConfig(method); + + return new GrpcMethodInfo(scope, new Uri(Address, uri), methodConfig); + } + + private MethodConfig? ResolveMethodConfig(IMethod method) + { + if (_serviceConfigMethods != null) + { + MethodConfig? methodConfig; + if (_serviceConfigMethods.TryGetValue(new MethodKey(method.ServiceName, method.Name), out methodConfig)) + { + return methodConfig; + } + if (_serviceConfigMethods.TryGetValue(new MethodKey(method.ServiceName, null), out methodConfig)) + { + return methodConfig; + } + if (_serviceConfigMethods.TryGetValue(new MethodKey(null, null), out methodConfig)) + { + return methodConfig; + } + } - return new GrpcMethodInfo(scope, new Uri(Address, uri)); + return null; } private static Dictionary ResolveCompressionProviders(IList? compressionProviders) @@ -156,7 +226,7 @@ private static Dictionary ResolveCompressionProvid } var resolvedCompressionProviders = new Dictionary(StringComparer.Ordinal); - for (int i = 0; i < compressionProviders.Count; i++) + for (var i = 0; i < compressionProviders.Count; i++) { var compressionProvider = compressionProviders[i]; if (!resolvedCompressionProviders.ContainsKey(compressionProvider.EncodingName)) @@ -199,47 +269,6 @@ public override CallInvoker CreateCallInvoker() return invoker; } - private class DefaultChannelCredentialsConfigurator : ChannelCredentialsConfiguratorBase - { - public bool? IsSecure { get; private set; } - public List? CallCredentials { get; private set; } - - public override void SetCompositeCredentials(object state, ChannelCredentials channelCredentials, CallCredentials callCredentials) - { - channelCredentials.InternalPopulateConfiguration(this, null); - - if (callCredentials != null) - { - if (CallCredentials == null) - { - CallCredentials = new List(); - } - - CallCredentials.Add(callCredentials); - } - } - - public override void SetInsecureCredentials(object state) - { - IsSecure = false; - } - - public override void SetSslCredentials(object state, string rootCertificates, KeyCertificatePair keyCertificatePair, VerifyPeerCallback verifyPeerCallback) - { - if (!string.IsNullOrEmpty(rootCertificates) || - keyCertificatePair != null || - verifyPeerCallback != null) - { - throw new InvalidOperationException( - $"{nameof(SslCredentials)} with non-null arguments is not supported by {nameof(GrpcChannel)}. " + - $"{nameof(GrpcChannel)} uses HttpClient to make gRPC calls and HttpClient automatically loads root certificates from the operating system certificate store. " + - $"Client certificates should be configured on HttpClient. See https://aka.ms/AA6we64 for details."); - } - - IsSecure = true; - } - } - /// /// Creates a for the specified address. /// @@ -330,5 +359,51 @@ public void Dispose() } Disposed = true; } + + internal bool TryAddToRetryBuffer(long messageSize) + { + CompatibilityExtensions.Assert(_retryBufferLock != null); + + lock (_retryBufferLock) + { + if (CurrentRetryBufferSize + messageSize > MaxRetryBufferSize) + { + return false; + } + + CurrentRetryBufferSize += messageSize; + return true; + } + } + + internal void RemoveFromRetryBuffer(long messageSize) + { + CompatibilityExtensions.Assert(_retryBufferLock != null); + + lock (_retryBufferLock) + { + CurrentRetryBufferSize -= messageSize; + } + } + + private struct MethodKey : IEquatable + { + public MethodKey(string? service, string? method) + { + Service = service; + Method = method; + } + + public string? Service { get; } + public string? Method { get; } + + public override bool Equals(object? obj) => obj is MethodKey n ? Equals(n) : false; + + public bool Equals(MethodKey other) => other.Service == Service && other.Method == Method; + + public override int GetHashCode() => + (Service != null ? StringComparer.Ordinal.GetHashCode(Service) : 0) ^ + (Method != null ? StringComparer.Ordinal.GetHashCode(Method) : 0); + } } } diff --git a/src/Grpc.Net.Client/GrpcChannelOptions.cs b/src/Grpc.Net.Client/GrpcChannelOptions.cs index 4cb89c2c5..ce581331d 100644 --- a/src/Grpc.Net.Client/GrpcChannelOptions.cs +++ b/src/Grpc.Net.Client/GrpcChannelOptions.cs @@ -20,6 +20,7 @@ using System.Collections.Generic; using System.Net.Http; using Grpc.Core; +using Grpc.Net.Client.Configuration; using Grpc.Net.Compression; using Microsoft.Extensions.Logging; @@ -65,6 +66,47 @@ public sealed class GrpcChannelOptions /// public int? MaxReceiveMessageSize { get; set; } + /// + /// Gets or sets the maximum retry attempts. This value limits any retry and hedging attempt values specified in + /// the service config. + /// + /// Setting this value alone doesn't enable retries. Retries are enabled in the service config, which can be done + /// using . + /// + /// + /// A null value removes the maximum retry attempts limit. Defaults to 5. + /// + /// + public int? MaxRetryAttempts { get; set; } + + /// + /// Gets or sets the maximum buffer size in bytes that can be used to store sent messages when retrying + /// or hedging calls. If the buffer limit is exceeded then no more retry attempts are made and all + /// hedging calls but one will be canceled. This limit is applied across all calls made using the channel. + /// + /// Setting this value alone doesn't enable retries. Retries are enabled in the service config, which can be done + /// using . + /// + /// + /// A null value removes the maximum retry buffer size limit. Defaults to 16,777,216 (16 MB). + /// + /// + public long? MaxRetryBufferSize { get; set; } + + /// + /// Gets or sets the maximum buffer size in bytes that can be used to store sent messages when retrying + /// or hedging calls. If the buffer limit is exceeded then no more retry attempts are made and all + /// hedging calls but one will be canceled. This limit is applied to one call. + /// + /// Setting this value alone doesn't enable retries. Retries are enabled in the service config, which can be done + /// using . + /// + /// + /// A null value removes the maximum retry buffer size limit per call. Defaults to 1,048,576 (1 MB). + /// + /// + public long? MaxRetryBufferPerCallSize { get; set; } + /// /// Gets or sets a collection of compression providers. /// @@ -127,12 +169,23 @@ public sealed class GrpcChannelOptions /// public bool ThrowOperationCanceledOnCancellation { get; set; } + /// + /// Gets or sets the service config for a gRPC channel. A service config allows service owners to publish parameters + /// to be automatically used by all clients of their service. A service config can also be specified by a client + /// using this property. + /// Note: experimental API that can change or be removed without any prior notice. + /// + public ServiceConfig? ServiceConfig { get; set; } + /// /// Initializes a new instance of the class. /// public GrpcChannelOptions() { MaxReceiveMessageSize = GrpcChannel.DefaultMaxReceiveMessageSize; + MaxRetryAttempts = GrpcChannel.DefaultMaxRetryAttempts; + MaxRetryBufferSize = GrpcChannel.DefaultMaxRetryBufferSize; + MaxRetryBufferPerCallSize = GrpcChannel.DefaultMaxRetryBufferPerCallSize; } } } diff --git a/src/Grpc.Net.Client/Internal/Configuration/ConfigProperty.cs b/src/Grpc.Net.Client/Internal/Configuration/ConfigProperty.cs new file mode 100644 index 000000000..56ad27474 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Configuration/ConfigProperty.cs @@ -0,0 +1,60 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using Grpc.Net.Client.Configuration; + +namespace Grpc.Net.Client.Internal.Configuration +{ + internal struct ConfigProperty where TValue : IConfigValue + { + private TValue? _value; + private readonly Func _valueFactory; + private readonly string _key; + + public ConfigProperty(Func valueFactory, string key) + { + _value = default; + _valueFactory = valueFactory; + _key = key; + } + + public TValue? GetValue(ConfigObject inner) + { + if (_value == null) + { + var innerValue = inner.GetValue(_key); + _value = _valueFactory(innerValue); + + if (_value != null && innerValue == null) + { + // Set newly created value + SetValue(inner, _value); + } + } + + return _value; + } + + public void SetValue(ConfigObject inner, TValue? value) + { + _value = value; + inner.SetValue(_key, _value?.Inner); + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Configuration/ConvertHelpers.cs b/src/Grpc.Net.Client/Internal/Configuration/ConvertHelpers.cs new file mode 100644 index 000000000..998dcda1f --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Configuration/ConvertHelpers.cs @@ -0,0 +1,106 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Globalization; +using Grpc.Core; + +namespace Grpc.Net.Client.Internal.Configuration +{ + internal static class ConvertHelpers + { + public static string ConvertStatusCode(StatusCode statusCode) + { + return statusCode switch + { + StatusCode.OK => "OK", + StatusCode.Cancelled => "CANCELLED", + StatusCode.Unknown => "UNKNOWN", + StatusCode.InvalidArgument => "INVALID_ARGUMENT", + StatusCode.DeadlineExceeded => "DEADLINE_EXCEEDED", + StatusCode.NotFound => "NOT_FOUND", + StatusCode.AlreadyExists => "ALREADY_EXISTS", + StatusCode.PermissionDenied => "PERMISSION_DENIED", + StatusCode.Unauthenticated => "UNAUTHENTICATED", + StatusCode.ResourceExhausted => "RESOURCE_EXHAUSTED", + StatusCode.FailedPrecondition => "FAILED_PRECONDITION", + StatusCode.Aborted => "ABORTED", + StatusCode.OutOfRange => "OUT_OF_RANGE", + StatusCode.Unimplemented => "UNIMPLEMENTED", + StatusCode.Internal => "INTERNAL", + StatusCode.Unavailable => "UNAVAILABLE", + StatusCode.DataLoss => "DATA_LOSS", + _ => throw new InvalidOperationException($"Unexpected status code: {statusCode}") + }; + } + + public static StatusCode ConvertStatusCode(string statusCode) + { + return statusCode.ToUpperInvariant() switch + { + "OK" => StatusCode.OK, + "CANCELLED" => StatusCode.Cancelled, + "UNKNOWN" => StatusCode.Unknown, + "INVALID_ARGUMENT" => StatusCode.InvalidArgument, + "DEADLINE_EXCEEDED" => StatusCode.DeadlineExceeded, + "NOT_FOUND" => StatusCode.NotFound, + "ALREADY_EXISTS" => StatusCode.AlreadyExists, + "PERMISSION_DENIED" => StatusCode.PermissionDenied, + "UNAUTHENTICATED" => StatusCode.Unauthenticated, + "RESOURCE_EXHAUSTED" => StatusCode.ResourceExhausted, + "FAILED_PRECONDITION" => StatusCode.FailedPrecondition, + "ABORTED" => StatusCode.Aborted, + "OUT_OF_RANGE" => StatusCode.OutOfRange, + "UNIMPLEMENTED" => StatusCode.Unimplemented, + "INTERNAL" => StatusCode.Internal, + "UNAVAILABLE" => StatusCode.Unavailable, + "DATA_LOSS" => StatusCode.DataLoss, + _ => int.TryParse(statusCode, out var number) + ? (StatusCode)number + : throw new InvalidOperationException($"Unexpected status code: {statusCode}") + }; + } + + public static TimeSpan? ConvertDurationText(string? text) + { + if (text == null) + { + return null; + } + + if (text[text.Length - 1] == 's') + { + return TimeSpan.FromSeconds(Convert.ToDouble(text.Substring(0, text.Length - 1), CultureInfo.InvariantCulture)); + } + else + { + throw new FormatException($"'{text}' isn't a valid duration."); + } + } + + public static string? ToDurationText(TimeSpan? value) + { + if (value == null) + { + return null; + } + + return value.GetValueOrDefault().TotalSeconds + "s"; + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Configuration/IConfigValue.cs b/src/Grpc.Net.Client/Internal/Configuration/IConfigValue.cs new file mode 100644 index 000000000..e3a06371d --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Configuration/IConfigValue.cs @@ -0,0 +1,25 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +namespace Grpc.Net.Client.Internal.Configuration +{ + internal interface IConfigValue + { + object Inner { get; } + } +} diff --git a/src/Grpc.Net.Client/Internal/Configuration/Values.cs b/src/Grpc.Net.Client/Internal/Configuration/Values.cs new file mode 100644 index 000000000..825f96356 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Configuration/Values.cs @@ -0,0 +1,97 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Grpc.Net.Client.Internal.Configuration +{ + internal class Values : IList, IConfigValue + { + internal readonly IList Inner; + + private readonly IList _values; + internal readonly Func _convertTo; + internal readonly Func _convertFrom; + + public Values(IList inner, Func convertTo, Func convertFrom) + { + Inner = inner; + _values = new List(); + _convertTo = convertTo; + _convertFrom = convertFrom; + + foreach (var item in Inner) + { + _values.Add(_convertFrom(item)); + } + } + + public T this[int index] + { + get => _values[index]; + set + { + _values[index] = value; + Inner[index] = _convertTo(value); + } + } + + public int Count => Inner.Count; + public bool IsReadOnly => Inner.IsReadOnly; + + object IConfigValue.Inner => Inner; + + public void Add(T item) + { + _values.Add(item); + Inner.Add(_convertTo(item)); + } + + public void Clear() + { + _values.Clear(); + Inner.Clear(); + } + + public bool Contains(T item) => _values.Contains(item); + + public void CopyTo(T[] array, int arrayIndex) => _values.CopyTo(array, arrayIndex); + + public IEnumerator GetEnumerator() => _values.GetEnumerator(); + + public int IndexOf(T item) => _values.IndexOf(item); + + public void Insert(int index, T item) + { + _values.Insert(index, item); + Inner.Insert(index, _convertTo(item)); + } + + public bool Remove(T item) => _values.Remove(item) && Inner.Remove(_convertTo(item)); + + public void RemoveAt(int index) + { + _values.RemoveAt(index); + Inner.RemoveAt(index); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } +} diff --git a/src/Grpc.Net.Client/Internal/DefaultChannelCredentialsConfigurator.cs b/src/Grpc.Net.Client/Internal/DefaultChannelCredentialsConfigurator.cs new file mode 100644 index 000000000..d31fc1840 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/DefaultChannelCredentialsConfigurator.cs @@ -0,0 +1,62 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using Grpc.Core; + +namespace Grpc.Net.Client.Internal +{ + internal class DefaultChannelCredentialsConfigurator : ChannelCredentialsConfiguratorBase + { + public bool? IsSecure { get; private set; } + public List? CallCredentials { get; private set; } + + public override void SetCompositeCredentials(object state, ChannelCredentials channelCredentials, CallCredentials callCredentials) + { + channelCredentials.InternalPopulateConfiguration(this, null); + + if (callCredentials != null) + { + if (CallCredentials == null) + { + CallCredentials = new List(); + } + + CallCredentials.Add(callCredentials); + } + } + + public override void SetInsecureCredentials(object state) => IsSecure = false; + + public override void SetSslCredentials(object state, string rootCertificates, KeyCertificatePair keyCertificatePair, VerifyPeerCallback verifyPeerCallback) + { + if (!string.IsNullOrEmpty(rootCertificates) || + keyCertificatePair != null || + verifyPeerCallback != null) + { + throw new InvalidOperationException( + $"{nameof(SslCredentials)} with non-null arguments is not supported by {nameof(GrpcChannel)}. " + + $"{nameof(GrpcChannel)} uses HttpClient to make gRPC calls and HttpClient automatically loads root certificates from the operating system certificate store. " + + $"Client certificates should be configured on HttpClient. See https://aka.ms/AA6we64 for details."); + } + + IsSecure = true; + } + } +} diff --git a/src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs b/src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs index 23750c374..cdf6c37c3 100644 --- a/src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs +++ b/src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs @@ -19,6 +19,7 @@ using System; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Net; using System.Net.Http; using Grpc.Core; using Grpc.Shared; @@ -90,5 +91,84 @@ protected bool TryGetTrailers([NotNullWhen(true)] out Metadata? trailers) trailers = Trailers; return true; } + + internal static Status? ValidateHeaders(HttpResponseMessage httpResponse, out Metadata? trailers) + { + // gRPC status can be returned in the header when there is no message (e.g. unimplemented status) + // An explicitly specified status header has priority over other failing statuses + if (GrpcProtocolHelpers.TryGetStatusCore(httpResponse.Headers, out var status)) + { + // Trailers are in the header because there is no message. + // Note that some default headers will end up in the trailers (e.g. Date, Server). + trailers = GrpcProtocolHelpers.BuildMetadata(httpResponse.Headers); + return status; + } + + trailers = null; + + // ALPN negotiation is sending HTTP/1.1 and HTTP/2. + // Check that the response wasn't downgraded to HTTP/1.1. + if (httpResponse.Version < GrpcProtocolConstants.Http2Version) + { + return new Status(StatusCode.Internal, $"Bad gRPC response. Response protocol downgraded to HTTP/{httpResponse.Version.ToString(2)}."); + } + + if (httpResponse.StatusCode != HttpStatusCode.OK) + { + var statusCode = MapHttpStatusToGrpcCode(httpResponse.StatusCode); + return new Status(statusCode, "Bad gRPC response. HTTP status code: " + (int)httpResponse.StatusCode); + } + + if (httpResponse.Content?.Headers.ContentType == null) + { + return new Status(StatusCode.Cancelled, "Bad gRPC response. Response did not have a content-type header."); + } + + var grpcEncoding = httpResponse.Content.Headers.ContentType; + if (!CommonGrpcProtocolHelpers.IsContentType(GrpcProtocolConstants.GrpcContentType, grpcEncoding?.MediaType)) + { + return new Status(StatusCode.Cancelled, "Bad gRPC response. Invalid content-type value: " + grpcEncoding); + } + + // Call is still in progress + return null; + } + + private static StatusCode MapHttpStatusToGrpcCode(HttpStatusCode httpStatusCode) + { + switch (httpStatusCode) + { + case HttpStatusCode.BadRequest: // 400 +#if !NETSTANDARD2_0 + case HttpStatusCode.RequestHeaderFieldsTooLarge: // 431 +#else + case (HttpStatusCode)431: +#endif + return StatusCode.Internal; + case HttpStatusCode.Unauthorized: // 401 + return StatusCode.Unauthenticated; + case HttpStatusCode.Forbidden: // 403 + return StatusCode.PermissionDenied; + case HttpStatusCode.NotFound: // 404 + return StatusCode.Unimplemented; +#if !NETSTANDARD2_0 + case HttpStatusCode.TooManyRequests: // 429 +#else + case (HttpStatusCode)429: +#endif + case HttpStatusCode.BadGateway: // 502 + case HttpStatusCode.ServiceUnavailable: // 503 + case HttpStatusCode.GatewayTimeout: // 504 + return StatusCode.Unavailable; + default: + if ((int)httpStatusCode >= 100 && (int)httpStatusCode < 200) + { + // 1xx. These headers should have been ignored. + return StatusCode.Internal; + } + + return StatusCode.Unknown; + } + } } } diff --git a/src/Grpc.Net.Client/Internal/GrpcCall.cs b/src/Grpc.Net.Client/Internal/GrpcCall.cs index 3bf17e799..30e96b851 100644 --- a/src/Grpc.Net.Client/Internal/GrpcCall.cs +++ b/src/Grpc.Net.Client/Internal/GrpcCall.cs @@ -18,8 +18,10 @@ using System; using System.Buffers; +using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.IO; using System.Net; using System.Net.Http; @@ -27,6 +29,7 @@ using System.Threading.Tasks; using Grpc.Core; using Grpc.Net.Client.Internal.Http; +using Grpc.Net.Client.Configuration; using Grpc.Shared; using Microsoft.Extensions.Logging; @@ -36,18 +39,19 @@ namespace Grpc.Net.Client.Internal { - internal sealed partial class GrpcCall : GrpcCall, IDisposable + internal sealed partial class GrpcCall : GrpcCall, IGrpcCall where TRequest : class where TResponse : class { - private const string ErrorStartingCallMessage = "Error starting gRPC call."; + internal const string ErrorStartingCallMessage = "Error starting gRPC call."; private readonly CancellationTokenSource _callCts; private readonly TaskCompletionSource _callTcs; private readonly DateTime _deadline; private readonly GrpcMethodInfo _grpcMethodInfo; + private readonly int _previousAttempts; - private Task? _httpResponseTask; + internal Task? _httpResponseTask; private Task? _responseHeadersTask; private Timer? _deadlineTimer; private CancellationTokenRegistration? _ctsRegistration; @@ -57,10 +61,12 @@ internal sealed partial class GrpcCall : GrpcCall, IDisposa // These are set depending on the type of gRPC call private TaskCompletionSource? _responseTcs; + + public int MessagesWritten { get; private set; } public HttpContentClientStreamWriter? ClientStreamWriter { get; private set; } public HttpContentClientStreamReader? ClientStreamReader { get; private set; } - public GrpcCall(Method method, GrpcMethodInfo grpcMethodInfo, CallOptions options, GrpcChannel channel) + public GrpcCall(Method method, GrpcMethodInfo grpcMethodInfo, CallOptions options, GrpcChannel channel, int previousAttempts) : base(options, channel) { // Validate deadline before creating any objects that require cleanup @@ -72,10 +78,13 @@ public GrpcCall(Method method, GrpcMethodInfo grpcMethodInf Method = method; _grpcMethodInfo = grpcMethodInfo; _deadline = options.Deadline ?? DateTime.MaxValue; + _previousAttempts = previousAttempts; Channel.RegisterActiveCall(this); } + public MethodConfig? MethodConfig => _grpcMethodInfo.MethodConfig; + private void ValidateDeadline(DateTime? deadline) { if (deadline != null && deadline != DateTime.MaxValue && deadline != DateTime.MinValue && deadline.Value.Kind != DateTimeKind.Utc) @@ -94,40 +103,75 @@ public CancellationToken CancellationToken public override Type RequestType => typeof(TRequest); public override Type ResponseType => typeof(TResponse); - public void StartUnary(TRequest request) + IClientStreamWriter? IGrpcCall.ClientStreamWriter => ClientStreamWriter; + IAsyncStreamReader? IGrpcCall.ClientStreamReader => ClientStreamReader; + + public void StartUnary(TRequest request) => StartUnaryCore(CreatePushUnaryContent(request)); + + public void StartClientStreaming() + { + var clientStreamWriter = new HttpContentClientStreamWriter(this); + var content = new PushStreamContent(clientStreamWriter); + + StartClientStreamingCore(clientStreamWriter, content); + } + + public void StartServerStreaming(TRequest request) => StartServerStreamingCore(CreatePushUnaryContent(request)); + + private HttpContent CreatePushUnaryContent(TRequest request) + { + return !Channel.IsWinHttp + ? new PushUnaryContent(request, WriteAsync) + : new WinHttpUnaryContent(request, WriteAsync, this); + + ValueTask WriteAsync(TRequest request, Stream stream) + { + return WriteMessageAsync(stream, request, Options); + } + } + + public void StartDuplexStreaming() + { + var clientStreamWriter = new HttpContentClientStreamWriter(this); + var content = new PushStreamContent(clientStreamWriter); + + StartDuplexStreamingCore(clientStreamWriter, content); + } + + internal void StartUnaryCore(HttpContent content) { _responseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var timeout = GetTimeout(); var message = CreateHttpRequestMessage(timeout); - SetMessageContent(request, message); + SetMessageContent(content, message); _ = RunCall(message, timeout); } - public void StartClientStreaming() + internal void StartServerStreamingCore(HttpContent content) { - _responseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var timeout = GetTimeout(); var message = CreateHttpRequestMessage(timeout); - CreateWriter(message); + SetMessageContent(content, message); + ClientStreamReader = new HttpContentClientStreamReader(this); _ = RunCall(message, timeout); } - public void StartServerStreaming(TRequest request) + internal void StartClientStreamingCore(HttpContentClientStreamWriter clientStreamWriter, HttpContent content) { + _responseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var timeout = GetTimeout(); var message = CreateHttpRequestMessage(timeout); - SetMessageContent(request, message); - ClientStreamReader = new HttpContentClientStreamReader(this); + SetWriter(message, clientStreamWriter, content); _ = RunCall(message, timeout); } - public void StartDuplexStreaming() + public void StartDuplexStreamingCore(HttpContentClientStreamWriter clientStreamWriter, HttpContent content) { var timeout = GetTimeout(); var message = CreateHttpRequestMessage(timeout); - CreateWriter(message); + SetWriter(message, clientStreamWriter, content); ClientStreamReader = new HttpContentClientStreamReader(this); _ = RunCall(message, timeout); } @@ -138,14 +182,14 @@ public void Dispose() { Disposed = true; - Cleanup(new Status(StatusCode.Cancelled, "gRPC call disposed.")); + Cleanup(GrpcProtocolConstants.DisposeCanceledStatus); } } /// /// Clean up can be called by: /// 1. The user. AsyncUnaryCall.Dispose et al will call this on Dispose - /// 2. will call dispose if errors fail validation + /// 2. will call dispose if errors fail validation /// 3. will call dispose /// private void Cleanup(Status status) @@ -254,7 +298,15 @@ private async Task GetResponseHeadersCoreAsync() await CallTask.ConfigureAwait(false); } - return GrpcProtocolHelpers.BuildMetadata(httpResponse.Headers); + var metadata = GrpcProtocolHelpers.BuildMetadata(httpResponse.Headers); + + // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#exposed-retry-metadata + if (_previousAttempts > 0) + { + metadata.Add(GrpcProtocolConstants.RetryPreviousAttemptsHeader, _previousAttempts.ToString(CultureInfo.InvariantCulture)); + } + + return metadata; } catch (Exception ex) when (ResolveException(ErrorStartingCallMessage, ex, out _, out var resolvedException)) { @@ -281,85 +333,6 @@ public Task GetResponseAsync() return _responseTcs.Task; } - private Status? ValidateHeaders(HttpResponseMessage httpResponse) - { - GrpcCallLog.ResponseHeadersReceived(Logger); - - // gRPC status can be returned in the header when there is no message (e.g. unimplemented status) - // An explicitly specified status header has priority over other failing statuses - if (GrpcProtocolHelpers.TryGetStatusCore(httpResponse.Headers, out var status)) - { - // Trailers are in the header because there is no message. - // Note that some default headers will end up in the trailers (e.g. Date, Server). - Trailers = GrpcProtocolHelpers.BuildMetadata(httpResponse.Headers); - return status; - } - - // ALPN negotiation is sending HTTP/1.1 and HTTP/2. - // Check that the response wasn't downgraded to HTTP/1.1. - if (httpResponse.Version < GrpcProtocolConstants.Http2Version) - { - return new Status(StatusCode.Internal, $"Bad gRPC response. Response protocol downgraded to HTTP/{httpResponse.Version.ToString(2)}."); - } - - if (httpResponse.StatusCode != HttpStatusCode.OK) - { - var statusCode = MapHttpStatusToGrpcCode(httpResponse.StatusCode); - return new Status(statusCode, "Bad gRPC response. HTTP status code: " + (int)httpResponse.StatusCode); - } - - if (httpResponse.Content?.Headers.ContentType == null) - { - return new Status(StatusCode.Cancelled, "Bad gRPC response. Response did not have a content-type header."); - } - - var grpcEncoding = httpResponse.Content.Headers.ContentType; - if (!CommonGrpcProtocolHelpers.IsContentType(GrpcProtocolConstants.GrpcContentType, grpcEncoding?.MediaType)) - { - return new Status(StatusCode.Cancelled, "Bad gRPC response. Invalid content-type value: " + grpcEncoding); - } - - // Call is still in progress - return null; - } - - private static StatusCode MapHttpStatusToGrpcCode(HttpStatusCode httpStatusCode) - { - switch (httpStatusCode) - { - case HttpStatusCode.BadRequest: // 400 -#if !NETSTANDARD2_0 - case HttpStatusCode.RequestHeaderFieldsTooLarge: // 431 -#else - case (HttpStatusCode)431: -#endif - return StatusCode.Internal; - case HttpStatusCode.Unauthorized: // 401 - return StatusCode.Unauthenticated; - case HttpStatusCode.Forbidden: // 403 - return StatusCode.PermissionDenied; - case HttpStatusCode.NotFound: // 404 - return StatusCode.Unimplemented; -#if !NETSTANDARD2_0 - case HttpStatusCode.TooManyRequests: // 429 -#else - case (HttpStatusCode)429: -#endif - case HttpStatusCode.BadGateway: // 502 - case HttpStatusCode.ServiceUnavailable: // 503 - case HttpStatusCode.GatewayTimeout: // 504 - return StatusCode.Unavailable; - default: - if ((int)httpStatusCode >= 100 && (int)httpStatusCode < 200) - { - // 1xx. These headers should have been ignored. - return StatusCode.Internal; - } - - return StatusCode.Unknown; - } - } - public Metadata GetTrailers() { using (StartScope()) @@ -375,32 +348,17 @@ public Metadata GetTrailers() } } - private void SetMessageContent(TRequest request, HttpRequestMessage message) + private void SetMessageContent(HttpContent content, HttpRequestMessage message) { RequestGrpcEncoding = GrpcProtocolHelpers.GetRequestEncoding(message.Headers); - - if (!Channel.IsWinHttp) - { - message.Content = new PushUnaryContent( - request, - this, - GrpcProtocolConstants.GrpcContentTypeHeaderValue); - } - else - { - // WinHttp doesn't support streaming request data so a length needs to be specified. - message.Content = new LengthUnaryContent( - request, - this, - GrpcProtocolConstants.GrpcContentTypeHeaderValue); - } + message.Content = content; } public void CancelCallFromCancellationToken() { using (StartScope()) { - CancelCall(new Status(StatusCode.Cancelled, "Call canceled by the client.")); + CancelCall(GrpcProtocolConstants.ClientCanceledStatus); } } @@ -478,10 +436,11 @@ private async Task RunCall(HttpRequestMessage request, TimeSpan? timeout) } catch (Exception ex) { - // Don't log OperationCanceledException if deadline has exceeded. + // Don't log OperationCanceledException if deadline has exceeded + // or the call has been canceled. if (ex is OperationCanceledException && _callTcs.Task.IsCompletedSuccessfully() && - _callTcs.Task.Result.StatusCode == StatusCode.DeadlineExceeded) + (_callTcs.Task.Result.StatusCode == StatusCode.DeadlineExceeded || _callTcs.Task.Result.StatusCode == StatusCode.Cancelled)) { throw; } @@ -492,7 +451,12 @@ private async Task RunCall(HttpRequestMessage request, TimeSpan? timeout) } } - status = ValidateHeaders(HttpResponse); + GrpcCallLog.ResponseHeadersReceived(Logger); + status = ValidateHeaders(HttpResponse, out var trailers); + if (trailers != null) + { + Trailers = trailers; + } // A status means either the call has failed or grpc-status was returned in the response header if (status != null) @@ -651,10 +615,7 @@ internal bool ResolveException(string summary, Exception ex, [NotNull] out Statu } else { - var exceptionMessage = CommonGrpcProtocolHelpers.ConvertToRpcExceptionMessage(ex); - var statusCode = GrpcProtocolHelpers.ResolveRpcExceptionStatusCode(ex); - - status = new Status(statusCode, summary + " " + exceptionMessage, ex); + status = GrpcProtocolHelpers.CreateStatusFromException(summary, ex); resolvedException = CreateRpcException(status.Value); return true; } @@ -816,12 +777,11 @@ private async Task ReadCredentials(HttpRequestMessage request) } } - private void CreateWriter(HttpRequestMessage message) + private void SetWriter(HttpRequestMessage message, HttpContentClientStreamWriter clientStreamWriter, HttpContent content) { RequestGrpcEncoding = GrpcProtocolHelpers.GetRequestEncoding(message.Headers); - ClientStreamWriter = new HttpContentClientStreamWriter(this); - - message.Content = new PushStreamContent(ClientStreamWriter, GrpcProtocolConstants.GrpcContentTypeHeaderValue); + ClientStreamWriter = clientStreamWriter; + message.Content = content; } private HttpRequestMessage CreateHttpRequestMessage(TimeSpan? timeout) @@ -842,6 +802,12 @@ private HttpRequestMessage CreateHttpRequestMessage(TimeSpan? timeout) headers.TryAddWithoutValidation(GrpcProtocolConstants.TEHeader, GrpcProtocolConstants.TEHeaderValue); headers.TryAddWithoutValidation(GrpcProtocolConstants.MessageAcceptEncodingHeader, Channel.MessageAcceptEncoding); + // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#exposed-retry-metadata + if (_previousAttempts > 0) + { + headers.TryAddWithoutValidation(GrpcProtocolConstants.RetryPreviousAttemptsHeader, _previousAttempts.ToString(CultureInfo.InvariantCulture)); + } + if (Options.Headers != null && Options.Headers.Count > 0) { foreach (var entry in Options.Headers) @@ -929,38 +895,26 @@ private void DeadlineExceeded() internal ValueTask WriteMessageAsync( Stream stream, - TRequest message, - Action contextualSerializer, - CallOptions callOptions) - { - return stream.WriteMessageAsync( - this, - message, - contextualSerializer, - callOptions); - } - - internal ValueTask WriteMessageAsync( - Stream stream, - TRequest message, - Action contextualSerializer, - CallOptions callOptions) where TSerializationContext : SerializationContext, IMemoryOwner + ReadOnlyMemory message, + CancellationToken cancellationToken) { + MessagesWritten++; return stream.WriteMessageAsync( this, message, - contextualSerializer, - callOptions); + cancellationToken); } internal ValueTask WriteMessageAsync( Stream stream, - ReadOnlyMemory message, + TRequest message, CallOptions callOptions) { + MessagesWritten++; return stream.WriteMessageAsync( this, message, + Method.RequestMarshaller.ContextualSerializer, callOptions); } @@ -981,5 +935,10 @@ internal ValueTask WriteMessageAsync( singleMessage, cancellationToken); } + + public Task WriteClientStreamAsync(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state) + { + return ClientStreamWriter!.WriteAsync(writeFunc, state); + } } } diff --git a/src/Grpc.Net.Client/Internal/GrpcMethodInfo.cs b/src/Grpc.Net.Client/Internal/GrpcMethodInfo.cs index d495c4a7b..edcb9381f 100644 --- a/src/Grpc.Net.Client/Internal/GrpcMethodInfo.cs +++ b/src/Grpc.Net.Client/Internal/GrpcMethodInfo.cs @@ -18,6 +18,7 @@ using System; using Grpc.Core; +using Grpc.Net.Client.Configuration; namespace Grpc.Net.Client.Internal { @@ -26,13 +27,15 @@ namespace Grpc.Net.Client.Internal /// internal class GrpcMethodInfo { - public GrpcMethodInfo(GrpcCallScope logScope, Uri callUri) + public GrpcMethodInfo(GrpcCallScope logScope, Uri callUri, MethodConfig? methodConfig) { LogScope = logScope; CallUri = callUri; + MethodConfig = methodConfig; } public GrpcCallScope LogScope { get; } public Uri CallUri { get; } + public MethodConfig? MethodConfig { get; } } } diff --git a/src/Grpc.Net.Client/Internal/GrpcProtocolConstants.cs b/src/Grpc.Net.Client/Internal/GrpcProtocolConstants.cs index f96934dea..e19cb9869 100644 --- a/src/Grpc.Net.Client/Internal/GrpcProtocolConstants.cs +++ b/src/Grpc.Net.Client/Internal/GrpcProtocolConstants.cs @@ -22,6 +22,7 @@ using System.Linq; using System.Net.Http.Headers; using System.Reflection; +using Grpc.Core; using Grpc.Net.Compression; namespace Grpc.Net.Client.Internal @@ -46,9 +47,11 @@ internal static class GrpcProtocolConstants internal const string IdentityGrpcEncoding = "identity"; internal const string MessageAcceptEncodingHeader = "grpc-accept-encoding"; - internal const string CompressionRequestAlgorithmHeader = "grpc-internal-encoding-request"; + internal const string RetryPushbackHeader = "grpc-retry-pushback-ms"; + internal const string RetryPreviousAttemptsHeader = "grpc-previous-rpc-attempts"; + internal static readonly Dictionary DefaultCompressionProviders = new Dictionary(StringComparer.Ordinal) { ["gzip"] = new GzipCompressionProvider(System.IO.Compression.CompressionLevel.Fastest), @@ -65,6 +68,11 @@ internal static class GrpcProtocolConstants internal static readonly string TEHeader; internal static readonly string TEHeaderValue; + internal static readonly Status DeadlineExceededStatus = new Status(StatusCode.DeadlineExceeded, string.Empty); + internal static readonly Status ThrottledStatus = new Status(StatusCode.Cancelled, "Retries stopped because retry throttling is active."); + internal static readonly Status ClientCanceledStatus = new Status(StatusCode.Cancelled, "Call canceled by the client."); + internal static readonly Status DisposeCanceledStatus = new Status(StatusCode.Cancelled, "gRPC call disposed."); + internal static string GetMessageAcceptEncoding(Dictionary compressionProviders) { return IdentityGrpcEncoding + "," + diff --git a/src/Grpc.Net.Client/Internal/GrpcProtocolHelpers.cs b/src/Grpc.Net.Client/Internal/GrpcProtocolHelpers.cs index 445327574..eca30280f 100644 --- a/src/Grpc.Net.Client/Internal/GrpcProtocolHelpers.cs +++ b/src/Grpc.Net.Client/Internal/GrpcProtocolHelpers.cs @@ -402,6 +402,9 @@ public static StatusCode ResolveRpcExceptionStatusCode(Exception ex) } else if (current is IOException) { + // TODO(JamesNK): IOException is also returned for aborted requests. + // Need to think about what is the best status for aborted requests. + // IOException happens if there is a protocol mismatch. return StatusCode.Unavailable; } @@ -409,5 +412,13 @@ public static StatusCode ResolveRpcExceptionStatusCode(Exception ex) return StatusCode.Internal; } + + public static Status CreateStatusFromException(string summary, Exception ex) + { + var exceptionMessage = CommonGrpcProtocolHelpers.ConvertToRpcExceptionMessage(ex); + var statusCode = ResolveRpcExceptionStatusCode(ex); + + return new Status(statusCode, summary + " " + exceptionMessage, ex); + } } } diff --git a/src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs b/src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs index e8ed68ad0..d2aee84e3 100644 --- a/src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs +++ b/src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs @@ -20,9 +20,12 @@ using System.IO; using System.Net; using System.Net.Http; -using System.Net.Http.Headers; using System.Threading.Tasks; +#if NETSTANDARD2_0 +using ValueTask = System.Threading.Tasks.Task; +#endif + namespace Grpc.Net.Client.Internal.Http { internal class PushStreamContent : HttpContent @@ -30,19 +33,32 @@ internal class PushStreamContent : HttpContent where TResponse : class { private readonly HttpContentClientStreamWriter _streamWriter; + private readonly Func? _startCallback; - public PushStreamContent(HttpContentClientStreamWriter streamWriter, MediaTypeHeaderValue mediaType) + public PushStreamContent(HttpContentClientStreamWriter streamWriter) { - Headers.ContentType = mediaType; + Headers.ContentType = GrpcProtocolConstants.GrpcContentTypeHeaderValue; _streamWriter = streamWriter; } + public PushStreamContent( + HttpContentClientStreamWriter streamWriter, + Func? startCallback) : this(streamWriter) + { + _startCallback = startCallback; + } + protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context) { // Immediately flush request stream to send headers // https://github.com/dotnet/corefx/issues/39586#issuecomment-516210081 await stream.FlushAsync().ConfigureAwait(false); + if (_startCallback != null) + { + await _startCallback(stream).ConfigureAwait(false); + } + // Pass request stream to writer _streamWriter.WriteStreamTcs.TrySetResult(stream); diff --git a/src/Grpc.Net.Client/Internal/Http/PushUnaryContent.cs b/src/Grpc.Net.Client/Internal/Http/PushUnaryContent.cs index acedd9ce8..01a9ac761 100644 --- a/src/Grpc.Net.Client/Internal/Http/PushUnaryContent.cs +++ b/src/Grpc.Net.Client/Internal/Http/PushUnaryContent.cs @@ -16,40 +16,37 @@ #endregion +using System; using System.IO; using System.Net; using System.Net.Http; -using System.Net.Http.Headers; using System.Threading.Tasks; #if NETSTANDARD2_0 using ValueTask = System.Threading.Tasks.Task; #endif -namespace Grpc.Net.Client.Internal.Http +namespace Grpc.Net.Client.Internal { + // TODO: Still need generic args? internal class PushUnaryContent : HttpContent where TRequest : class where TResponse : class { - private readonly TRequest _content; - private readonly GrpcCall _call; + private readonly TRequest _request; + private readonly Func _startCallback; - public PushUnaryContent(TRequest content, GrpcCall call, MediaTypeHeaderValue mediaType) + public PushUnaryContent(TRequest request, Func startCallback) { - _content = content; - _call = call; - Headers.ContentType = mediaType; + _request = request; + _startCallback = startCallback; + Headers.ContentType = GrpcProtocolConstants.GrpcContentTypeHeaderValue; } protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context) { #pragma warning disable CA2012 // Use ValueTasks correctly - var writeMessageTask = _call.WriteMessageAsync( - stream, - _content, - _call.Method.RequestMarshaller.ContextualSerializer, - _call.Options); + var writeMessageTask = _startCallback(_request, stream); #pragma warning restore CA2012 // Use ValueTasks correctly if (writeMessageTask.IsCompletedSuccessfully()) { diff --git a/src/Grpc.Net.Client/Internal/Http/LengthUnaryContent.cs b/src/Grpc.Net.Client/Internal/Http/WinHttpUnaryContent.cs similarity index 60% rename from src/Grpc.Net.Client/Internal/Http/LengthUnaryContent.cs rename to src/Grpc.Net.Client/Internal/Http/WinHttpUnaryContent.cs index 924dced98..945b6c2aa 100644 --- a/src/Grpc.Net.Client/Internal/Http/LengthUnaryContent.cs +++ b/src/Grpc.Net.Client/Internal/Http/WinHttpUnaryContent.cs @@ -37,64 +37,66 @@ namespace Grpc.Net.Client.Internal.Http /// The payload is then written directly to the request using specialized context /// and serializer method. /// - internal class LengthUnaryContent : HttpContent + internal class WinHttpUnaryContent : HttpContent where TRequest : class where TResponse : class { - private readonly TRequest _content; + private readonly TRequest _request; + private readonly Func _startCallback; private readonly GrpcCall _call; - private byte[]? _payload; - public LengthUnaryContent(TRequest content, GrpcCall call, MediaTypeHeaderValue mediaType) + public WinHttpUnaryContent(TRequest request, Func startCallback, GrpcCall call) { - _content = content; + _request = request; + _startCallback = startCallback; _call = call; - Headers.ContentType = mediaType; + Headers.ContentType = GrpcProtocolConstants.GrpcContentTypeHeaderValue; } - // Serialize message. Need to know size to prefix the length in the header. - private byte[] SerializePayload() + protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context) { - var serializationContext = _call.SerializationContext; - serializationContext.CallOptions = _call.Options; - serializationContext.Initialize(); - - try - { - _call.Method.RequestMarshaller.ContextualSerializer(_content, serializationContext); - - return serializationContext.GetWrittenPayload().ToArray(); - } - finally +#pragma warning disable CA2012 // Use ValueTasks correctly + var writeMessageTask = _startCallback(_request, stream); +#pragma warning restore CA2012 // Use ValueTasks correctly + if (writeMessageTask.IsCompletedSuccessfully()) { - serializationContext.Reset(); + GrpcEventSource.Log.MessageSent(); + return Task.CompletedTask; } + + return WriteMessageCore(writeMessageTask); } - protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context) + private static async Task WriteMessageCore(ValueTask writeMessageTask) { - if (_payload == null) - { - _payload = SerializePayload(); - } - - await _call.WriteMessageAsync( - stream, - _payload, - _call.Options).ConfigureAwait(false); - + await writeMessageTask.ConfigureAwait(false); GrpcEventSource.Log.MessageSent(); } protected override bool TryComputeLength(out long length) { - if (_payload == null) + // This will serialize the request message again. + // Consider caching serialized content if it is a problem. + length = GetPayloadLength(); + return true; + } + + private int GetPayloadLength() + { + var serializationContext = _call.SerializationContext; + serializationContext.CallOptions = _call.Options; + serializationContext.Initialize(); + + try { - _payload = SerializePayload(); - } + _call.Method.RequestMarshaller.ContextualSerializer(_request, serializationContext); - length = _payload.Length; - return true; + return serializationContext.GetWrittenPayload().Length; + } + finally + { + serializationContext.Reset(); + } } } } diff --git a/src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs b/src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs index 2cec54229..ece8c2b5b 100644 --- a/src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs +++ b/src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs @@ -19,6 +19,8 @@ using System; using System.Threading.Tasks; using Grpc.Core; +using Grpc.Net.Client.Internal; +using Grpc.Net.Client.Internal.Retry; namespace Grpc.Net.Client.Internal { @@ -40,7 +42,7 @@ public HttpClientCallInvoker(GrpcChannel channel) /// public override AsyncClientStreamingCall AsyncClientStreamingCall(Method method, string host, CallOptions options) { - var call = CreateGrpcCall(method, options); + var call = CreateRootGrpcCall(Channel, method, options); call.StartClientStreaming(); return new AsyncClientStreamingCall( @@ -60,7 +62,7 @@ public override AsyncClientStreamingCall AsyncClientStreami /// public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall(Method method, string host, CallOptions options) { - var call = CreateGrpcCall(method, options); + var call = CreateRootGrpcCall(Channel, method, options); call.StartDuplexStreaming(); return new AsyncDuplexStreamingCall( @@ -79,7 +81,7 @@ public override AsyncDuplexStreamingCall AsyncDuplexStreami /// public override AsyncServerStreamingCall AsyncServerStreamingCall(Method method, string host, CallOptions options, TRequest request) { - var call = CreateGrpcCall(method, options); + var call = CreateRootGrpcCall(Channel, method, options); call.StartServerStreaming(request); return new AsyncServerStreamingCall( @@ -96,7 +98,7 @@ public override AsyncServerStreamingCall AsyncServerStreamingCall public override AsyncUnaryCall AsyncUnaryCall(Method method, string host, CallOptions options, TRequest request) { - var call = CreateGrpcCall(method, options); + var call = CreateRootGrpcCall(Channel, method, options); call.StartUnary(request); return new AsyncUnaryCall( @@ -117,19 +119,47 @@ public override TResponse BlockingUnaryCall(Method CreateGrpcCall( + private static IGrpcCall CreateRootGrpcCall( + GrpcChannel channel, Method method, CallOptions options) where TRequest : class where TResponse : class { - if (Channel.Disposed) + var methodInfo = channel.GetCachedGrpcMethodInfo(method); + var retryPolicy = methodInfo.MethodConfig?.RetryPolicy; + var hedgingPolicy = methodInfo.MethodConfig?.HedgingPolicy; + + if (retryPolicy != null) + { + return new RetryCall(retryPolicy, channel, method, options); + } + else if (hedgingPolicy != null) + { + return new HedgingCall(hedgingPolicy, channel, method, options); + } + else + { + // No retry/hedge policy configured. Fast path! + return CreateGrpcCall(channel, method, options, previousAttempts: 0); + } + } + + public static GrpcCall CreateGrpcCall( + GrpcChannel channel, + Method method, + CallOptions options, + int previousAttempts) + where TRequest : class + where TResponse : class + { + if (channel.Disposed) { throw new ObjectDisposedException(nameof(GrpcChannel)); } - var methodInfo = Channel.GetCachedGrpcMethodInfo(method); - var call = new GrpcCall(method, methodInfo, options, Channel); + var methodInfo = channel.GetCachedGrpcMethodInfo(method); + var call = new GrpcCall(method, methodInfo, options, channel, previousAttempts); return call; } @@ -139,10 +169,10 @@ private static class Callbacks where TRequest : class where TResponse : class { - internal static readonly Func> GetResponseHeadersAsync = state => ((GrpcCall)state).GetResponseHeadersAsync(); - internal static readonly Func GetStatus = state => ((GrpcCall)state).GetStatus(); - internal static readonly Func GetTrailers = state => ((GrpcCall)state).GetTrailers(); - internal static readonly Action Dispose = state => ((GrpcCall)state).Dispose(); + internal static readonly Func> GetResponseHeadersAsync = state => ((IGrpcCall)state).GetResponseHeadersAsync(); + internal static readonly Func GetStatus = state => ((IGrpcCall)state).GetStatus(); + internal static readonly Func GetTrailers = state => ((IGrpcCall)state).GetTrailers(); + internal static readonly Action Dispose = state => ((IGrpcCall)state).Dispose(); } } } diff --git a/src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs b/src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs index fa040ccf0..736ac08f8 100644 --- a/src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs +++ b/src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs @@ -31,7 +31,7 @@ internal class HttpContentClientStreamReader : IAsyncStream where TRequest : class where TResponse : class { - // Getting logger name from generic type is slow + // Getting logger name from generic type is slow. Cached copy. private const string LoggerName = "Grpc.Net.Client.Internal.HttpContentClientStreamReader"; private static readonly Task FinishedTask = Task.FromResult(false); diff --git a/src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs b/src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs index 7b26faf23..0e5f30ecf 100644 --- a/src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs +++ b/src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs @@ -23,6 +23,10 @@ using Grpc.Core; using Microsoft.Extensions.Logging; +#if NETSTANDARD2_0 +using ValueTask = System.Threading.Tasks.Task; +#endif + namespace Grpc.Net.Client.Internal { internal class HttpContentClientStreamWriter : IClientStreamWriter @@ -88,6 +92,16 @@ public Task WriteAsync(TRequest message) throw new ArgumentNullException(nameof(message)); } + return WriteAsync(WriteMessageToStream, message); + + static ValueTask WriteMessageToStream(GrpcCall call, Stream writeStream, CallOptions callOptions, TRequest message) + { + return call.WriteMessageAsync(writeStream, message, callOptions); + } + } + + public Task WriteAsync(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state) + { _call.EnsureNotDisposed(); lock (_writeLock) @@ -122,7 +136,7 @@ public Task WriteAsync(TRequest message) } // Save write task to track whether it is complete. Must be set inside lock. - _writeTask = WriteAsyncCore(message); + _writeTask = WriteAsyncCore(writeFunc, state); } } @@ -142,7 +156,7 @@ public void Dispose() public GrpcCall Call => _call; - private async Task WriteAsyncCore(TRequest message) + public async Task WriteAsyncCore(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state) { try { @@ -157,11 +171,7 @@ private async Task WriteAsyncCore(TRequest message) callOptions = callOptions.WithWriteOptions(WriteOptions); } - await _call.WriteMessageAsync( - writeStream, - message, - _call.Method.RequestMarshaller.ContextualSerializer, - callOptions).ConfigureAwait(false); + await writeFunc(_call, writeStream, callOptions, state).ConfigureAwait(false); // Flush stream to ensure messages are sent immediately await writeStream.FlushAsync(callOptions.CancellationToken).ConfigureAwait(false); diff --git a/src/Grpc.Net.Client/Internal/IGrpcCall.cs b/src/Grpc.Net.Client/Internal/IGrpcCall.cs new file mode 100644 index 000000000..d796dd034 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/IGrpcCall.cs @@ -0,0 +1,49 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.IO; +using System.Threading.Tasks; +using Grpc.Core; + +#if NETSTANDARD2_0 +using ValueTask = System.Threading.Tasks.Task; +#endif + +namespace Grpc.Net.Client.Internal +{ + internal interface IGrpcCall : IDisposable + where TRequest : class + where TResponse : class + { + Task GetResponseAsync(); + Task GetResponseHeadersAsync(); + Status GetStatus(); + Metadata GetTrailers(); + + IClientStreamWriter? ClientStreamWriter { get; } + IAsyncStreamReader? ClientStreamReader { get; } + + void StartUnary(TRequest request); + void StartClientStreaming(); + void StartServerStreaming(TRequest request); + void StartDuplexStreaming(); + + Task WriteClientStreamAsync(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state); + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/ChannelRetryThrottling.cs b/src/Grpc.Net.Client/Internal/Retry/ChannelRetryThrottling.cs new file mode 100644 index 000000000..40b820e10 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/ChannelRetryThrottling.cs @@ -0,0 +1,78 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Threading; +using Grpc.Net.Client.Configuration; + +namespace Grpc.Net.Client.Internal.Retry +{ + internal class ChannelRetryThrottling + { + private readonly object _lock = new object(); + private readonly double _tokenRatio; + private readonly int _maxTokens; + private double _tokenCount; + private double _tokenThreshold; + + public ChannelRetryThrottling(RetryThrottlingPolicy retryThrottling) + { + if (retryThrottling.MaxTokens == null) + { + throw CreateException(RetryThrottlingPolicy.MaxTokensPropertyName); + } + if (retryThrottling.TokenRatio == null) + { + throw CreateException(RetryThrottlingPolicy.TokenRatioPropertyName); + } + + // Truncate token ratio to 3 decimal places + _tokenRatio = Math.Truncate(retryThrottling.TokenRatio.GetValueOrDefault() * 1000) / 1000; + + _maxTokens = retryThrottling.MaxTokens.GetValueOrDefault(); + _tokenCount = retryThrottling.MaxTokens.GetValueOrDefault(); + _tokenThreshold = _tokenCount / 2; + + static InvalidOperationException CreateException(string propertyName) + { + return new InvalidOperationException($"Retry throttling missing required property '{propertyName}'."); + } + } + + public bool IsRetryThrottlingActive() + { + return Volatile.Read(ref _tokenCount) <= _tokenThreshold; + } + + public void CallSuccess() + { + lock (_lock) + { + _tokenCount = Math.Min(_tokenCount + _tokenRatio, _maxTokens); + } + } + + public void CallFailure() + { + lock (_lock) + { + _tokenCount = Math.Max(_tokenCount - 1, 0); + } + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/CommitReason.cs b/src/Grpc.Net.Client/Internal/Retry/CommitReason.cs new file mode 100644 index 000000000..166e7032d --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/CommitReason.cs @@ -0,0 +1,34 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + + +namespace Grpc.Net.Client.Internal.Retry +{ + internal enum CommitReason + { + ResponseHeadersReceived, + FatalStatusCode, + ExceededAttemptCount, + DeadlineExceeded, + Throttled, + BufferExceeded, + PushbackStop, + UnexpectedError, + Canceled + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/DeadlineGrpcCall.cs b/src/Grpc.Net.Client/Internal/Retry/DeadlineGrpcCall.cs new file mode 100644 index 000000000..e2728cec4 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/DeadlineGrpcCall.cs @@ -0,0 +1,135 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; + +#if NETSTANDARD2_0 +using ValueTask = System.Threading.Tasks.Task; +#endif + +namespace Grpc.Net.Client.Internal.Retry +{ + internal sealed class StatusGrpcCall : IGrpcCall + where TRequest : class + where TResponse : class + { + private readonly Status _status; + private IClientStreamWriter? _clientStreamWriter; + private IAsyncStreamReader? _clientStreamReader; + + public IClientStreamWriter? ClientStreamWriter => _clientStreamWriter ??= new StatusClientStreamWriter(_status); + public IAsyncStreamReader? ClientStreamReader => _clientStreamReader ??= new StatusStreamReader(_status); + + public StatusGrpcCall(Status status) + { + _status = status; + } + + public void Dispose() + { + } + + public Task GetResponseAsync() + { + return Task.FromException(new RpcException(_status)); + } + + public Task GetResponseHeadersAsync() + { + return Task.FromException(new RpcException(_status)); + } + + public Status GetStatus() + { + return _status; + } + + public Metadata GetTrailers() + { + throw new InvalidOperationException("Can't get the call trailers because the call has not completed successfully."); + } + + public void StartClientStreaming() + { + throw new NotSupportedException(); + } + + public void StartDuplexStreaming() + { + throw new NotSupportedException(); + } + + public void StartServerStreaming(TRequest request) + { + throw new NotSupportedException(); + } + + public void StartUnary(TRequest request) + { + throw new NotSupportedException(); + } + + public Task WriteClientStreamAsync(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state) + { + return Task.FromException(new RpcException(_status)); + } + + private sealed class StatusClientStreamWriter : IClientStreamWriter + { + private readonly Status _status; + + public WriteOptions? WriteOptions { get; set; } + + public StatusClientStreamWriter(Status status) + { + _status = status; + } + + public Task CompleteAsync() + { + return Task.FromException(new RpcException(_status)); + } + + public Task WriteAsync(TRequest message) + { + return Task.FromException(new RpcException(_status)); + } + } + + private sealed class StatusStreamReader : IAsyncStreamReader + { + private readonly Status _status; + + public TResponse Current { get; set; } = default!; + + public StatusStreamReader(Status status) + { + _status = status; + } + + public Task MoveNext(CancellationToken cancellationToken) + { + return Task.FromException(new RpcException(_status)); + } + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs b/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs new file mode 100644 index 000000000..d4bedbb94 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs @@ -0,0 +1,410 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; +using Grpc.Net.Client.Configuration; +using Microsoft.Extensions.Logging; + +namespace Grpc.Net.Client.Internal.Retry +{ + internal sealed partial class HedgingCall : RetryCallBase + where TRequest : class + where TResponse : class + { + // Getting logger name from generic type is slow. Cached copy. + private const string LoggerName = "Grpc.Net.Client.Internal.HedgingCall"; + + private readonly HedgingPolicy _hedgingPolicy; + + private int _callsAttempted; + + private TaskCompletionSource? _pushbackReceivedTcs; + private TimeSpan? _pushbackDelay; + + // Internal for testing + internal List> _activeCalls { get; } + internal Task? CreateHedgingCallsTask { get; set; } + + public HedgingCall(HedgingPolicy hedgingPolicy, GrpcChannel channel, Method method, CallOptions options) + : base(channel, method, options, LoggerName, hedgingPolicy.MaxAttempts.GetValueOrDefault()) + { + _hedgingPolicy = hedgingPolicy; + _activeCalls = new List>(); + + if (_hedgingPolicy.HedgingDelay > TimeSpan.Zero) + { + _pushbackReceivedTcs = new TaskCompletionSource(TaskCreationOptions.None); + } + + ValidatePolicy(hedgingPolicy); + } + + private void ValidatePolicy(HedgingPolicy hedgingPolicy) + { + //if (retryThrottlingPolicy.MaxAttempts == null) + //{ + // throw CreateException(_method, RetryPolicy.MaxAttemptsPropertyName); + //} + //if (retryThrottlingPolicy.InitialBackoff == null) + //{ + // throw CreateException(_method, RetryPolicy.InitialBackoffPropertyName); + //} + //if (retryThrottlingPolicy.MaxBackoff == null) + //{ + // throw CreateException(_method, RetryPolicy.MaxBackoffPropertyName); + //} + //if (retryThrottlingPolicy.BackoffMultiplier == null) + //{ + // throw CreateException(_method, RetryPolicy.BackoffMultiplierPropertyName); + //} + //if (retryThrottlingPolicy.RetryableStatusCodes.Count == 0) + //{ + // throw new InvalidOperationException($"Retry policy for '{_method.FullName}' must have property '{RetryPolicy.RetryableStatusCodesPropertyName}' and must be non-empty."); + //} + + //static InvalidOperationException CreateException(IMethod method, string propertyName) + //{ + // return new InvalidOperationException($"Retry policy for '{method.FullName}' is missing required property '{propertyName}'."); + //} + } + + private async Task StartCall(Action> startCallFunc) + { + GrpcCall call; + lock (Lock) + { + if (CommitedCallTask.IsCompletedSuccessfully()) + { + // Call has already been commited. This could happen if written messages exceed + // buffer limits, which causes the call to immediately become commited and to clear buffers. + return; + } + + call = HttpClientCallInvoker.CreateGrpcCall(Channel, Method, Options, _callsAttempted); + _activeCalls.Add(call); + _callsAttempted++; + + startCallFunc(call); + + SetNewActiveCall(call); + } + + Status? responseStatus; + + try + { + call.CancellationToken.ThrowIfCancellationRequested(); + + CompatibilityExtensions.Assert(call._httpResponseTask != null, "Request should have be made if call is not preemptively cancelled."); + var httpResponse = await call._httpResponseTask.ConfigureAwait(false); + + responseStatus = GrpcCall.ValidateHeaders(httpResponse, out _); + } + catch (Exception ex) + { + call.ResolveException(GrpcCall.ErrorStartingCallMessage, ex, out responseStatus, out _); + } + + // Check to see the response returned from the server makes the call commited + // Null status code indicates the headers were valid and a "Response-Headers" response + // was received from the server. + // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#when-retries-are-valid + if (responseStatus == null) + { + // Headers were returned. We're commited. + CommitCall(call, CommitReason.ResponseHeadersReceived); + + // Wait until the call has finished and then check its status code + // to update retry throttling tokens. + var status = await call.CallTask.ConfigureAwait(false); + if (status.StatusCode == StatusCode.OK) + { + // Success. Exit retry loop. + Channel.RetryThrottling?.CallSuccess(); + } + } + else + { + var status = responseStatus.Value; + + var retryPushbackMS = GetRetryPushback(call); + + if (retryPushbackMS < 0) + { + Channel.RetryThrottling?.CallFailure(); + } + else if (_hedgingPolicy.NonFatalStatusCodes.Contains(status.StatusCode)) + { + // Pushback doesn't do anything if we started with no delay and all calls + // have already been made when hedging starting. + if (retryPushbackMS >= 0 && _pushbackReceivedTcs != null) + { + lock (Lock) + { + _pushbackDelay = TimeSpan.FromMilliseconds(retryPushbackMS.GetValueOrDefault()); + _pushbackReceivedTcs.TrySetResult(null); + } + } + Channel.RetryThrottling?.CallFailure(); + } + else + { + CommitCall(call, CommitReason.FatalStatusCode); + } + } + + lock (Lock) + { + if (IsDeadlineExceeded()) + { + // Deadline has been exceeded so immediately commit call. + CommitCall(call, CommitReason.DeadlineExceeded); + } + else if (_activeCalls.Count == 1 && _callsAttempted >= MaxRetryAttempts) + { + // This is the last active call and no more will be made. + CommitCall(call, CommitReason.ExceededAttemptCount); + } + else if (_activeCalls.Count == 1 && (Channel.RetryThrottling?.IsRetryThrottlingActive() ?? false)) + { + // This is the last active call and throttling is active. + CommitCall(call, CommitReason.Throttled); + } + else + { + // Call isn't used and can be cancelled. + // Note that the call could have already been removed and disposed if the + // hedging call has been finalized or disposed. + if (_activeCalls.Remove(call)) + { + call.Dispose(); + } + } + } + } + + protected override void OnCommitCall(IGrpcCall call) + { + _activeCalls.Remove(call); + + CleanUpUnsynchronized(); + } + + private void CleanUpUnsynchronized() + { + while (_activeCalls.Count > 0) + { + _activeCalls[_activeCalls.Count - 1].Dispose(); + _activeCalls.RemoveAt(_activeCalls.Count - 1); + } + } + + protected override void StartCore(Action> startCallFunc) + { + var hedgingDelay = _hedgingPolicy.HedgingDelay.GetValueOrDefault(); + if (hedgingDelay == TimeSpan.Zero) + { + // If there is no delay then start all call immediately + while (_callsAttempted < MaxRetryAttempts) + { + _ = StartCall(startCallFunc); + + // Don't send additional calls if retry throttling is active. + if (Channel.RetryThrottling?.IsRetryThrottlingActive() ?? false) + { + Log.AdditionalCallsBlockedByRetryThrottling(Logger); + break; + } + + lock (Lock) + { + // Don't send additional calls if call has been commited. + if (CommitedCallTask.IsCompletedSuccessfully()) + { + break; + } + } + } + } + else + { + CreateHedgingCallsTask = CreateHedgingCalls(startCallFunc); + } + } + + private async Task CreateHedgingCalls(Action> startCallFunc) + { + Log.StartingRetryWorker(Logger); + + try + { + var hedgingDelay = _hedgingPolicy.HedgingDelay.GetValueOrDefault(); + + while (_callsAttempted < MaxRetryAttempts) + { + _ = StartCall(startCallFunc); + + await HedgingDelayAsync(hedgingDelay).ConfigureAwait(false); + + if (IsDeadlineExceeded()) + { + CommitCall(new StatusGrpcCall(new Status(StatusCode.DeadlineExceeded, string.Empty)), CommitReason.DeadlineExceeded); + break; + } + else + { + lock (Lock) + { + if (Channel.RetryThrottling?.IsRetryThrottlingActive() ?? false) + { + if (_activeCalls.Count == 0) + { + CommitCall(CreateStatusCall(GrpcProtocolConstants.ThrottledStatus), CommitReason.Throttled); + } + else + { + Log.AdditionalCallsBlockedByRetryThrottling(Logger); + } + break; + } + + // Don't send additional calls if call has been commited. + if (CommitedCallTask.IsCompletedSuccessfully()) + { + break; + } + } + } + } + } + catch (Exception ex) + { + HandleUnexpectedError(ex); + } + finally + { + Log.StoppingRetryWorker(Logger); + } + } + + private async Task HedgingDelayAsync(TimeSpan hedgingDelay) + { + while (true) + { + var tcs = _pushbackReceivedTcs; + if (tcs != null) + { + var completedTask = await Task.WhenAny(Task.Delay(hedgingDelay, CancellationTokenSource.Token), tcs.Task).ConfigureAwait(false); + if (completedTask != tcs.Task) + { + // Task.Delay won. Check CTS to see if it won because of cancellation. + CancellationTokenSource.Token.ThrowIfCancellationRequested(); + return; + } + + lock (Lock) + { + Debug.Assert(_pushbackDelay != null); + + // Use pushback value and delay again + hedgingDelay = _pushbackDelay.GetValueOrDefault(); + + _pushbackDelay = null; + _pushbackReceivedTcs = new TaskCompletionSource(TaskCreationOptions.None); + } + } + else + { + await Task.Delay(hedgingDelay).ConfigureAwait(false); + return; + } + } + } + + protected override void Dispose(bool disposing) + { + lock (Lock) + { + base.Dispose(disposing); + + CleanUpUnsynchronized(); + } + } + + public override Task ClientStreamCompleteAsync() + { + ClientStreamComplete = true; + + return DoClientStreamActionAsync(calls => + { + var completeTasks = new Task[calls.Count]; + for (var i = 0; i < calls.Count; i++) + { + completeTasks[i] = calls[i].ClientStreamWriter!.CompleteAsync(); + } + + //Console.WriteLine($"Completing"); + return Task.WhenAll(completeTasks); + }); + } + + public override async Task ClientStreamWriteAsync(TRequest message) + { + // TODO(JamesNK) - Not safe for multi-threading + await DoClientStreamActionAsync(calls => + { + var writeTasks = new Task[calls.Count]; + for (var i = 0; i < calls.Count; i++) + { + writeTasks[i] = calls[i].WriteClientStreamAsync(WriteNewMessage, message); + } + + //Console.WriteLine($"Writing client stream message to {writeTasks.Length} calls."); + return Task.WhenAll(writeTasks); + }).ConfigureAwait(false); + BufferedCurrentMessage = false; + } + + private Task DoClientStreamActionAsync(Func>, Task> action) + { + lock (Lock) + { + if (_activeCalls.Count > 0) + { + return action(_activeCalls); + } + else + { + return WaitForCallUnsynchronizedAsync(action); + } + } + + async Task WaitForCallUnsynchronizedAsync(Func>, Task> action) + { + var call = await GetActiveCallUnsynchronizedAsync(previousCall: null).ConfigureAwait(false); + await action(new[] { call! }).ConfigureAwait(false); + } + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs new file mode 100644 index 000000000..376138618 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs @@ -0,0 +1,357 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; +using Grpc.Net.Client.Configuration; +using Grpc.Shared; + +namespace Grpc.Net.Client.Internal.Retry +{ + internal sealed class RetryCall : RetryCallBase + where TRequest : class + where TResponse : class + { + // Getting logger name from generic type is slow. Cached copy. + private const string LoggerName = "Grpc.Net.Client.Internal.RetryCall"; + + private readonly RetryPolicy _retryPolicy; + + private readonly Random _random; + + private int _attemptCount; + private int _nextRetryDelayMilliseconds; + + private GrpcCall? _activeCall; + + public RetryCall(RetryPolicy retryPolicy, GrpcChannel channel, Method method, CallOptions options) + : base(channel, method, options, LoggerName, retryPolicy.MaxAttempts.GetValueOrDefault()) + { + _retryPolicy = retryPolicy; + + _random = new Random(); + + ValidatePolicy(retryPolicy); + + _nextRetryDelayMilliseconds = Convert.ToInt32(retryPolicy.InitialBackoff.GetValueOrDefault().TotalMilliseconds); + } + + private void ValidatePolicy(RetryPolicy retryPolicy) + { + if (retryPolicy.MaxAttempts == null) + { + throw CreateException(Method, RetryPolicy.MaxAttemptsPropertyName); + } + if (retryPolicy.InitialBackoff == null) + { + throw CreateException(Method, RetryPolicy.InitialBackoffPropertyName); + } + if (retryPolicy.MaxBackoff == null) + { + throw CreateException(Method, RetryPolicy.MaxBackoffPropertyName); + } + if (retryPolicy.BackoffMultiplier == null) + { + throw CreateException(Method, RetryPolicy.BackoffMultiplierPropertyName); + } + if (retryPolicy.RetryableStatusCodes.Count == 0) + { + throw new InvalidOperationException($"Retry policy for '{Method.FullName}' must have property '{RetryPolicy.RetryableStatusCodesPropertyName}' and must be non-empty."); + } + + static InvalidOperationException CreateException(IMethod method, string propertyName) + { + return new InvalidOperationException($"Retry policy for '{method.FullName}' is missing required property '{propertyName}'."); + } + } + + private int CalculateNextRetryDelay() + { + var nextMilliseconds = _nextRetryDelayMilliseconds * _retryPolicy.BackoffMultiplier.GetValueOrDefault(); + nextMilliseconds = Math.Min(nextMilliseconds, _retryPolicy.MaxBackoff.GetValueOrDefault().TotalMilliseconds); + + return Convert.ToInt32(nextMilliseconds); + } + + private CommitReason? EvaluateRetry(Status status, int? retryPushbackMilliseconds) + { + if (IsDeadlineExceeded()) + { + return CommitReason.DeadlineExceeded; + } + + if (Channel.RetryThrottling?.IsRetryThrottlingActive() ?? false) + { + return CommitReason.Throttled; + } + + if (_attemptCount >= MaxRetryAttempts) + { + return CommitReason.ExceededAttemptCount; + } + + if (retryPushbackMilliseconds != null) + { + if (retryPushbackMilliseconds >= 0) + { + return null; + } + else + { + return CommitReason.PushbackStop; + } + } + + if (!_retryPolicy.RetryableStatusCodes.Contains(status.StatusCode)) + { + return CommitReason.FatalStatusCode; + } + + return null; + } + + private async Task StartRetry(Action> startCallFunc) + { + Log.StartingRetryWorker(Logger); + + try + { + // This is the main retry loop. It will: + // 1. Check the result of the active call was successful. + // 2. If it was unsuccessful then evaluate if the call can be retried. + // 3. If it can be retried then start a new active call and begin again. + while (true) + { + GrpcCall currentCall; + lock (Lock) + { + // Start new call. + currentCall = _activeCall = HttpClientCallInvoker.CreateGrpcCall(Channel, Method, Options, _attemptCount); + startCallFunc(currentCall); + + if (CommitedCallTask.IsCompletedSuccessfully()) + { + // Call has already been commited. This could happen if written messages exceed + // buffer limits, which causes the call to immediately become commited and to clear buffers. + return; + } + + _attemptCount++; + + SetNewActiveCall(currentCall); + } + + Status? responseStatus; + + try + { + currentCall.CancellationToken.ThrowIfCancellationRequested(); + + CompatibilityExtensions.Assert(currentCall._httpResponseTask != null, "Request should have be made if call is not preemptively cancelled."); + var httpResponse = await currentCall._httpResponseTask.ConfigureAwait(false); + + responseStatus = GrpcCall.ValidateHeaders(httpResponse, out _); + } + catch (Exception ex) + { + currentCall.ResolveException(GrpcCall.ErrorStartingCallMessage, ex, out responseStatus, out _); + } + + // Check to see the response returned from the server makes the call commited + // Null status code indicates the headers were valid and a "Response-Headers" response + // was received from the server. + // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#when-retries-are-valid + if (responseStatus == null) + { + // Headers were returned. We're commited. + CommitCall(currentCall, CommitReason.ResponseHeadersReceived); + + responseStatus = await currentCall.CallTask.ConfigureAwait(false); + if (responseStatus.GetValueOrDefault().StatusCode == StatusCode.OK) + { + // Success. Exit retry loop. + Channel.RetryThrottling?.CallSuccess(); + } + return; + } + + if (CommitedCallTask.IsCompletedSuccessfully()) + { + // Call has already been commited. This could happen if written messages exceed + // buffer limits, which causes the call to immediately become commited and to clear buffers. + return; + } + + Status status = responseStatus.Value; + + var retryPushbackMS = GetRetryPushback(currentCall); + + // Failures only could towards retry throttling if they have a known, retriable status. + // This stops non-transient statuses, e.g. INVALID_ARGUMENT, from triggering throttling. + if (_retryPolicy.RetryableStatusCodes.Contains(status.StatusCode) || + retryPushbackMS < 0) + { + Channel.RetryThrottling?.CallFailure(); + } + + var result = EvaluateRetry(status, retryPushbackMS); + Log.RetryEvaluated(Logger, status.StatusCode, _attemptCount, result == null); + + if (result == null) + { + TimeSpan delayDuration; + if (retryPushbackMS != null) + { + delayDuration = TimeSpan.FromMilliseconds(retryPushbackMS.GetValueOrDefault()); + _nextRetryDelayMilliseconds = retryPushbackMS.GetValueOrDefault(); + } + else + { + delayDuration = TimeSpan.FromMilliseconds(_random.Next(0, Convert.ToInt32(_nextRetryDelayMilliseconds))); + } + + Log.StartingRetryDelay(Logger, delayDuration); + await Task.Delay(delayDuration, CancellationTokenSource.Token).ConfigureAwait(false); + + _nextRetryDelayMilliseconds = CalculateNextRetryDelay(); + + // Check if dispose was called on call. + CancellationTokenSource.Token.ThrowIfCancellationRequested(); + + // Clean up the failed call. + currentCall.Dispose(); + } + else + { + // Handle the situation where the call failed with a non-deadline status, but retry + // didn't happen because of deadline exceeded. + IGrpcCall resolvedCall = (IsDeadlineExceeded() && !(currentCall.CallTask.IsCompletedSuccessfully() && currentCall.CallTask.Result.StatusCode == StatusCode.DeadlineExceeded)) + ? CreateStatusCall(GrpcProtocolConstants.DeadlineExceededStatus) + : currentCall; + + // Can't retry. + // Signal public API exceptions that they should finish throwing and then exit the retry loop. + CommitCall(resolvedCall, result.GetValueOrDefault()); + return; + } + } + } + catch (Exception ex) + { + HandleUnexpectedError(ex); + } + finally + { + Log.StoppingRetryWorker(Logger); + } + } + + protected override void OnCommitCall(IGrpcCall call) + { + _activeCall = null; + } + + protected override void Dispose(bool disposing) + { + lock (Lock) + { + base.Dispose(disposing); + + _activeCall?.Dispose(); + } + } + + protected override void StartCore(Action> startCallFunc) + { + _ = StartRetry(startCallFunc); + } + + public override Task ClientStreamCompleteAsync() + { + ClientStreamComplete = true; + + return DoClientStreamActionAsync(async call => + { + await call.ClientStreamWriter!.CompleteAsync().ConfigureAwait(false); + }); + } + + public override Task ClientStreamWriteAsync(TRequest message) + { + return DoClientStreamActionAsync(async call => + { + CompatibilityExtensions.Assert(call.ClientStreamWriter != null); + + if (ClientStreamWriteOptions != null) + { + call.ClientStreamWriter.WriteOptions = ClientStreamWriteOptions; + } + + await call.WriteClientStreamAsync(WriteNewMessage, message).ConfigureAwait(false); + BufferedCurrentMessage = false; + + if (ClientStreamComplete) + { + await call.ClientStreamWriter.CompleteAsync().ConfigureAwait(false); + } + }); + } + + private async Task DoClientStreamActionAsync(Func, Task> action) + { + var call = await GetActiveCallAsync(previousCall: null).ConfigureAwait(false); + while (true) + { + try + { + await action(call!).ConfigureAwait(false); + return; + } + catch + { + call = await GetActiveCallAsync(previousCall: call).ConfigureAwait(false); + if (call == null) + { + throw; + } + } + } + } + + private Task?> GetActiveCallAsync(IGrpcCall? previousCall) + { + Debug.Assert(NewActiveCallTcs != null); + + lock (Lock) + { + // Return currently active call if there is one, and its not the previous call. + if (_activeCall != null && previousCall != _activeCall) + { + return Task.FromResult?>(_activeCall); + } + + // Wait to see whether new call will be made + return GetActiveCallUnsynchronizedAsync(previousCall); + } + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.Log.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.Log.cs new file mode 100644 index 000000000..19aa1a459 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.Log.cs @@ -0,0 +1,120 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using Grpc.Core; +using Microsoft.Extensions.Logging; + +namespace Grpc.Net.Client.Internal.Retry +{ + internal partial class RetryCallBase : IGrpcCall + where TRequest : class + where TResponse : class + { + protected static class Log + { + private static readonly Action _retryEvaluated = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, "RetryEvaluated"), "Evaluated retry for failed gRPC call. Status code: '{StatusCode}', Attempt: {AttemptCount}, Retry: {WillRetry}"); + + private static readonly Action _retryPushbackReceived = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, "RetryPushbackReceived"), "Retry pushback of '{RetryPushback}' received from the failed gRPC call."); + + private static readonly Action _startingRetryDelay = + LoggerMessage.Define(LogLevel.Trace, new EventId(3, "StartingRetryDelay"), "Starting retry delay of {DelayDuration}."); + + private static readonly Action _errorRetryingCall = + LoggerMessage.Define(LogLevel.Error, new EventId(4, "ErrorRetryingCall"), "Error retrying gRPC call."); + + private static readonly Action _sendingBufferedMessages = + LoggerMessage.Define(LogLevel.Trace, new EventId(5, "SendingBufferedMessages"), "Sending {MessageCount} buffered messages from previous failed gRPC calls."); + + private static readonly Action _messageAddedToBuffer = + LoggerMessage.Define(LogLevel.Trace, new EventId(6, "MessageAddedToBuffer"), "Message with {MessageSize} bytes added to the buffer. There are {CallBufferSize} bytes buffered for this call."); + + private static readonly Action _callCommited = + LoggerMessage.Define(LogLevel.Debug, new EventId(7, "CallCommited"), "Call commited. Reason: {CommitReason}"); + + private static readonly Action _startingRetryWorker = + LoggerMessage.Define(LogLevel.Trace, new EventId(8, "StartingRetryWorker"), "Starting retry worker."); + + private static readonly Action _stoppingRetryWorker = + LoggerMessage.Define(LogLevel.Trace, new EventId(9, "StoppingRetryWorker"), "Stopping retry worker."); + + private static readonly Action _maxAttemptsLimited = + LoggerMessage.Define(LogLevel.Debug, new EventId(10, "MaxAttemptsLimited"), "The method has {ServiceConfigMaxAttempts} attempts specified in the service config. The number of attempts has been limited by channel configuration to {ChannelMaxAttempts}."); + + private static readonly Action _additionalCallsBlockedByRetryThrottling = + LoggerMessage.Define(LogLevel.Debug, new EventId(11, "AdditionalCallsBlockedByRetryThrottling"), "Additional calls blocked by retry throttling."); + + internal static void RetryEvaluated(ILogger logger, StatusCode statusCode, int attemptCount, bool willRetry) + { + _retryEvaluated(logger, statusCode, attemptCount, willRetry, null); + } + + internal static void RetryPushbackReceived(ILogger logger, string retryPushback) + { + _retryPushbackReceived(logger, retryPushback, null); + } + + internal static void StartingRetryDelay(ILogger logger, TimeSpan delayDuration) + { + _startingRetryDelay(logger, delayDuration, null); + } + + internal static void ErrorRetryingCall(ILogger logger, Exception ex) + { + _errorRetryingCall(logger, ex); + } + + internal static void SendingBufferedMessages(ILogger logger, int messageCount) + { + _sendingBufferedMessages(logger, messageCount, null); + } + + internal static void MessageAddedToBuffer(ILogger logger, int messageSize, long callBufferSize) + { + _messageAddedToBuffer(logger, messageSize, callBufferSize, null); + } + + internal static void CallCommited(ILogger logger, CommitReason commitReason) + { + _callCommited(logger, commitReason, null); + } + + internal static void StartingRetryWorker(ILogger logger) + { + _startingRetryWorker(logger, null); + } + + internal static void StoppingRetryWorker(ILogger logger) + { + _stoppingRetryWorker(logger, null); + } + + internal static void MaxAttemptsLimited(ILogger logger, int serviceConfigMaxAttempts, int channelMaxAttempts) + { + _maxAttemptsLimited(logger, serviceConfigMaxAttempts, channelMaxAttempts, null); + } + + internal static void AdditionalCallsBlockedByRetryThrottling(ILogger logger) + { + _additionalCallsBlockedByRetryThrottling(logger, null); + } + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs new file mode 100644 index 000000000..97000254e --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs @@ -0,0 +1,483 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; +using Grpc.Net.Client.Internal.Http; +using Grpc.Shared; +using Microsoft.Extensions.Logging; + +#if NETSTANDARD2_0 +using ValueTask = System.Threading.Tasks.Task; +#endif + +namespace Grpc.Net.Client.Internal.Retry +{ + internal abstract partial class RetryCallBase : IGrpcCall + where TRequest : class + where TResponse : class + { + private readonly TaskCompletionSource> _commitedCallTcs; + private RetryCallBaseClientStreamReader? _retryBaseClientStreamReader; + private RetryCallBaseClientStreamWriter? _retryBaseClientStreamWriter; + private CancellationTokenRegistration? _ctsRegistration; + + protected object Lock { get; } = new object(); + protected ILogger Logger { get; } + protected GrpcChannel Channel { get; } + protected Method Method { get; } + protected CallOptions Options { get; } + protected int MaxRetryAttempts { get; } + protected CancellationTokenSource CancellationTokenSource { get; } + protected TaskCompletionSource?>? NewActiveCallTcs { get; set; } + protected bool Disposed { get; private set; } + + public Task> CommitedCallTask => _commitedCallTcs.Task; + public IAsyncStreamReader? ClientStreamReader => _retryBaseClientStreamReader ??= new RetryCallBaseClientStreamReader(this); + public IClientStreamWriter? ClientStreamWriter => _retryBaseClientStreamWriter ??= new RetryCallBaseClientStreamWriter(this); + public WriteOptions? ClientStreamWriteOptions { get; internal set; } + + protected bool ClientStreamComplete { get; set; } + + protected List> BufferedMessages { get; } + protected long CurrentCallBufferSize { get; set; } + protected bool BufferedCurrentMessage { get; set; } + + protected RetryCallBase(GrpcChannel channel, Method method, CallOptions options, string loggerName, int retryAttempts) + { + Logger = channel.LoggerFactory.CreateLogger(loggerName); + Channel = channel; + Method = method; + Options = options; + _commitedCallTcs = new TaskCompletionSource>(TaskCreationOptions.RunContinuationsAsynchronously); + BufferedMessages = new List>(); + + if (options.CancellationToken.CanBeCanceled) + { + _ctsRegistration = options.CancellationToken.Register(state => ((RetryCallBase)state!).CancellationTokenSource.Cancel(), this); + } + + CancellationTokenSource = new CancellationTokenSource(); + + // TODO(JamesNK) - Check that large deadlines are supported. Might need to use a Timer here instead. + var deadline = Options.Deadline.GetValueOrDefault(DateTime.MaxValue); + if (deadline != DateTime.MaxValue) + { + var timeout = CommonGrpcProtocolHelpers.GetTimerDueTime(deadline - Channel.Clock.UtcNow, Channel.MaxTimerDueTime); + CancellationTokenSource.CancelAfter(TimeSpan.FromMilliseconds(timeout)); + } + + if (HasClientStream()) + { + NewActiveCallTcs = new TaskCompletionSource?>(TaskCreationOptions.None); + } + + if (retryAttempts > Channel.MaxRetryAttempts) + { + Log.MaxAttemptsLimited(Logger, retryAttempts, Channel.MaxRetryAttempts.GetValueOrDefault()); + MaxRetryAttempts = Channel.MaxRetryAttempts.GetValueOrDefault(); + } + else + { + MaxRetryAttempts = retryAttempts; + } + } + + public async Task GetResponseAsync() + { + var call = await CommitedCallTask.ConfigureAwait(false); + return await call.GetResponseAsync().ConfigureAwait(false); + } + + public async Task GetResponseHeadersAsync() + { + var call = await CommitedCallTask.ConfigureAwait(false); + return await call.GetResponseHeadersAsync().ConfigureAwait(false); + } + + public Status GetStatus() + { + if (CommitedCallTask.IsCompletedSuccessfully()) + { + return CommitedCallTask.Result.GetStatus(); + } + + throw new InvalidOperationException("Unable to get the status because the call is not complete."); + } + + public Metadata GetTrailers() + { + if (CommitedCallTask.IsCompletedSuccessfully()) + { + return CommitedCallTask.Result.GetTrailers(); + } + + throw new InvalidOperationException("Can't get the call trailers because the call has not completed successfully."); + } + + public void Dispose() => Dispose(true); + + public void StartUnary(TRequest request) + { + StartCore(call => call.StartUnaryCore(CreatePushUnaryContent(request, call))); + } + + public void StartClientStreaming() + { + StartCore(call => + { + var clientStreamWriter = new HttpContentClientStreamWriter(call); + var content = CreatePushStreamContent(call, clientStreamWriter); + call.StartClientStreamingCore(clientStreamWriter, content); + }); + } + + public void StartServerStreaming(TRequest request) + { + StartCore(call => call.StartServerStreamingCore(CreatePushUnaryContent(request, call))); + } + + public void StartDuplexStreaming() + { + StartCore(call => + { + var clientStreamWriter = new HttpContentClientStreamWriter(call); + var content = CreatePushStreamContent(call, clientStreamWriter); + call.StartDuplexStreamingCore(clientStreamWriter, content); + }); + } + + private HttpContent CreatePushUnaryContent(TRequest request, GrpcCall call) + { + return !Channel.IsWinHttp + ? new PushUnaryContent(request, WriteAsync) + : new WinHttpUnaryContent(request, WriteAsync, call); + + ValueTask WriteAsync(TRequest request, Stream stream) + { + return WriteNewMessage(call, stream, call.Options, request); + } + } + + private PushStreamContent CreatePushStreamContent(GrpcCall call, HttpContentClientStreamWriter clientStreamWriter) + { + return new PushStreamContent(clientStreamWriter, async requestStream => + { + ValueTask writeTask; + lock (Lock) + { + Log.SendingBufferedMessages(Logger, BufferedMessages.Count); + + if (BufferedMessages.Count == 0) + { +#if NETSTANDARD2_0 + writeTask = Task.CompletedTask; +#else + writeTask = default; +#endif + } + else if (BufferedMessages.Count == 1) + { + writeTask = call.WriteMessageAsync(requestStream, BufferedMessages[0], call.CancellationToken); + } + else + { + // Copy messages to a new collection in lock for thread-safety. + var bufferedMessageCopy = BufferedMessages.ToArray(); + writeTask = WriteBufferedMessages(call, requestStream, bufferedMessageCopy); + } + } + + await writeTask.ConfigureAwait(false); + + if (ClientStreamComplete) + { + await call.ClientStreamWriter!.CompleteAsync().ConfigureAwait(false); + } + }); + + static async ValueTask WriteBufferedMessages(GrpcCall call, Stream requestStream, ReadOnlyMemory[] bufferedMessages) + { + foreach (var writtenMessage in bufferedMessages) + { + await call.WriteMessageAsync(requestStream, writtenMessage, call.CancellationToken).ConfigureAwait(false); + } + } + } + + protected abstract void StartCore(Action> startCallFunc); + + public abstract Task ClientStreamCompleteAsync(); + + public abstract Task ClientStreamWriteAsync(TRequest message); + + protected bool IsDeadlineExceeded() + { + return Options.Deadline != null && Options.Deadline <= Channel.Clock.UtcNow; + } + + protected int? GetRetryPushback(GrpcCall call) + { + // https://github.com/grpc/proposal/blob/master/A6-client-retries.md#pushback + if (call.HttpResponse != null) + { + if (call.HttpResponse.Headers.TryGetValues(GrpcProtocolConstants.RetryPushbackHeader, out var values)) + { + var headerValue = values.Single(); + Log.RetryPushbackReceived(Logger, headerValue); + + // A non-integer value means the server wants retries to stop. + // Resolve non-integer value to a negative integer which also means stop. + return int.TryParse(headerValue, out var value) ? value : -1; + } + } + + return null; + } + + protected byte[] SerializePayload(GrpcCall call, CallOptions callOptions, TRequest request) + { + var serializationContext = call.SerializationContext; + serializationContext.CallOptions = callOptions; + serializationContext.Initialize(); + + try + { + call.Method.RequestMarshaller.ContextualSerializer(request, serializationContext); + + // Need to take a copy because the serialization context will returned a rented buffer. + return serializationContext.GetWrittenPayload().ToArray(); + } + finally + { + serializationContext.Reset(); + } + } + + protected async ValueTask WriteNewMessage(GrpcCall call, Stream writeStream, CallOptions callOptions, TRequest message) + { + // Serialize current message and add to the buffer. + ReadOnlyMemory messageData; + + lock (Lock) + { + if (!BufferedCurrentMessage) + { + messageData = SerializePayload(call, callOptions, message); + + // Don't buffer message data if the call has been commited. + if (!CommitedCallTask.IsCompletedSuccessfully()) + { + if (!TryAddToRetryBuffer(messageData)) + { + CommitCall(call, CommitReason.BufferExceeded); + } + else + { + BufferedCurrentMessage = true; + + Log.MessageAddedToBuffer(Logger, messageData.Length, CurrentCallBufferSize); + } + } + } + else + { + // There is a race between: + // 1. A client stream starting for a new call. It will write all buffered messages, and + // 2. Writing a new message here. The message may already have been buffered when the client + // stream started so we don't want to write it again. + // + // Check the client stream write count against he buffer message count to ensure all buffered + // messages haven't already been written. + if (call.MessagesWritten == BufferedMessages.Count) + { + return; + } + + messageData = BufferedMessages[BufferedMessages.Count - 1]; + } + } + + await call.WriteMessageAsync(writeStream, messageData, callOptions.CancellationToken).ConfigureAwait(false); + } + + protected void CommitCall(IGrpcCall call, CommitReason commitReason) + { + lock (Lock) + { + if (!CommitedCallTask.IsCompletedSuccessfully()) + { + OnCommitCall(call); + + // Log before committing for unit tests. + Log.CallCommited(Logger, commitReason); + + NewActiveCallTcs?.SetResult(null); + _commitedCallTcs.SetResult(call); + + ClearRetryBuffer(); + } + } + } + + protected abstract void OnCommitCall(IGrpcCall call); + + protected bool HasClientStream() + { + return Method.Type == MethodType.ClientStreaming || Method.Type == MethodType.DuplexStreaming; + } + + protected void SetNewActiveCall(IGrpcCall call) + { + Debug.Assert(!CommitedCallTask.IsCompletedSuccessfully()); + + if (NewActiveCallTcs != null) + { + // Run continuation synchronously so awaiters execute inside the lock + NewActiveCallTcs.SetResult(call); + NewActiveCallTcs = new TaskCompletionSource?>(TaskCreationOptions.None); + } + } + + Task IGrpcCall.WriteClientStreamAsync(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state) + { + throw new NotSupportedException(); + } + + protected async Task?> GetActiveCallUnsynchronizedAsync(IGrpcCall? previousCall) + { + CompatibilityExtensions.Assert(NewActiveCallTcs != null); + + var call = await NewActiveCallTcs.Task.ConfigureAwait(false); + if (call == null) + { + call = await CommitedCallTask.ConfigureAwait(false); + } + + // Avoid infinite loop. + if (call == previousCall) + { + return null; + } + + return call; + } + + protected virtual void Dispose(bool disposing) + { + if (Disposed) + { + return; + } + + Disposed = true; + + if (disposing) + { + _ctsRegistration?.Dispose(); + CancellationTokenSource.Cancel(); + + if (CommitedCallTask.IsCompletedSuccessfully()) + { + CommitedCallTask.Result.Dispose(); + } + + ClearRetryBuffer(); + } + } + + internal bool TryAddToRetryBuffer(ReadOnlyMemory message) + { + lock (Lock) + { + var messageSize = message.Length; + if (CurrentCallBufferSize + messageSize > Channel.MaxRetryBufferPerCallSize) + { + return false; + } + if (!Channel.TryAddToRetryBuffer(messageSize)) + { + return false; + } + + CurrentCallBufferSize += messageSize; + BufferedMessages.Add(message); + return true; + } + } + + internal void ClearRetryBuffer() + { + lock (Lock) + { + if (BufferedMessages.Count > 0) + { + BufferedMessages.Clear(); + Channel.RemoveFromRetryBuffer(CurrentCallBufferSize); + CurrentCallBufferSize = 0; + } + } + } + + protected StatusGrpcCall CreateStatusCall(Status status) + { + return new StatusGrpcCall(status); + } + + protected void HandleUnexpectedError(Exception ex) + { + IGrpcCall resolvedCall; + CommitReason commitReason; + + // Cancellation token triggered by dispose could throw here. + if (ex is OperationCanceledException && CancellationTokenSource.IsCancellationRequested) + { + // Cancellation could have been caused by an exceeded deadline. + if (IsDeadlineExceeded()) + { + commitReason = CommitReason.DeadlineExceeded; + // An exceeded deadline inbetween calls means there is no active call. + // Create a fake call that returns exceeded deadline status to the app. + resolvedCall = CreateStatusCall(GrpcProtocolConstants.DeadlineExceededStatus); + } + else + { + commitReason = CommitReason.Canceled; + resolvedCall = CreateStatusCall(Disposed ? GrpcProtocolConstants.DisposeCanceledStatus : GrpcProtocolConstants.ClientCanceledStatus); + } + } + else + { + commitReason = CommitReason.UnexpectedError; + resolvedCall = CreateStatusCall(GrpcProtocolHelpers.CreateStatusFromException("Unexpected error during retry.", ex)); + + // Only log unexpected errors. + Log.ErrorRetryingCall(Logger, ex); + } + + CommitCall(resolvedCall, commitReason); + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamReader.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamReader.cs new file mode 100644 index 000000000..46885f80f --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamReader.cs @@ -0,0 +1,46 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; + +namespace Grpc.Net.Client.Internal.Retry +{ + internal class RetryCallBaseClientStreamReader : IAsyncStreamReader + where TRequest : class + where TResponse : class + { + private readonly RetryCallBase _retryCallBase; + + public RetryCallBaseClientStreamReader(RetryCallBase retryCallBase) + { + _retryCallBase = retryCallBase; + } + + public TResponse Current => _retryCallBase.CommitedCallTask.IsCompletedSuccessfully() + ? _retryCallBase.CommitedCallTask.Result.ClientStreamReader!.Current + : default!; + + public async Task MoveNext(CancellationToken cancellationToken) + { + var call = await _retryCallBase.CommitedCallTask.ConfigureAwait(false); + return await call.ClientStreamReader!.MoveNext(cancellationToken).ConfigureAwait(false); + } + } +} diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamWriter.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamWriter.cs new file mode 100644 index 000000000..1052f6cfb --- /dev/null +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCallBaseClientStreamWriter.cs @@ -0,0 +1,51 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Threading.Tasks; +using Grpc.Core; + +namespace Grpc.Net.Client.Internal.Retry +{ + internal class RetryCallBaseClientStreamWriter : IClientStreamWriter + where TRequest : class + where TResponse : class + { + private readonly RetryCallBase _retryCallBase; + + public RetryCallBaseClientStreamWriter(RetryCallBase retryCallBase) + { + _retryCallBase = retryCallBase; + } + + public WriteOptions? WriteOptions + { + get => _retryCallBase.ClientStreamWriteOptions; + set => _retryCallBase.ClientStreamWriteOptions = value; + } + + public Task CompleteAsync() + { + return _retryCallBase.ClientStreamCompleteAsync(); + } + + public Task WriteAsync(TRequest message) + { + return _retryCallBase.ClientStreamWriteAsync(message); + } + } +} diff --git a/src/Grpc.Net.Client/Internal/StreamExtensions.cs b/src/Grpc.Net.Client/Internal/StreamExtensions.cs index cb5279abd..8444b9291 100644 --- a/src/Grpc.Net.Client/Internal/StreamExtensions.cs +++ b/src/Grpc.Net.Client/Internal/StreamExtensions.cs @@ -319,16 +319,15 @@ public static async ValueTask WriteMessageAsync( this Stream stream, GrpcCall call, ReadOnlyMemory data, - CallOptions callOptions) + CancellationToken cancellationToken) { - // Sync relevant changes here with other WriteMessageAsync try { GrpcCallLog.SendingMessage(call.Logger); // Sending the header+content in a single WriteAsync call has significant performance benefits // https://github.com/dotnet/runtime/issues/35184#issuecomment-626304981 - await stream.WriteAsync(data, callOptions.CancellationToken).ConfigureAwait(false); + await stream.WriteAsync(data, cancellationToken).ConfigureAwait(false); GrpcCallLog.MessageSent(call.Logger); } diff --git a/src/Shared/CommonGrpcProtocolHelpers.cs b/src/Shared/CommonGrpcProtocolHelpers.cs index 787e17f1c..b5cce6059 100644 --- a/src/Shared/CommonGrpcProtocolHelpers.cs +++ b/src/Shared/CommonGrpcProtocolHelpers.cs @@ -31,7 +31,7 @@ internal static class CommonGrpcProtocolHelpers // - The timer is rescheduled to run in 0.5ms. // - The deadline callback is raised again and there is now 0.4ms until deadline. // - The timer is rescheduled to run in 0.4ms, etc. - private static readonly int TimerEpsilonMilliseconds = 4; + private static readonly int TimerEpsilonMilliseconds = 7; public static long GetTimerDueTime(TimeSpan timeout, long maxTimerDueTime) { @@ -41,7 +41,7 @@ public static long GetTimerDueTime(TimeSpan timeout, long maxTimerDueTime) // Add epislon to take into account Timer precision. // This will avoid rescheduling the timer multiple times, but means deadline - // might run for some extra milliseconds + // might run slightly longer than requested. dueTimeMilliseconds += TimerEpsilonMilliseconds; dueTimeMilliseconds = Math.Min(dueTimeMilliseconds, maxTimerDueTime); diff --git a/test/FunctionalTests/Client/CancellationTests.cs b/test/FunctionalTests/Client/CancellationTests.cs index 5f49663e5..ceb4319f8 100644 --- a/test/FunctionalTests/Client/CancellationTests.cs +++ b/test/FunctionalTests/Client/CancellationTests.cs @@ -107,7 +107,7 @@ await TestHelpers.RunParallel(tasks, async taskIndex => { try { - for (int i = 0; i < interations; i++) + for (var i = 0; i < interations; i++) { Logger.LogInformation($"Staring {taskIndex}-{i}"); diff --git a/test/FunctionalTests/Client/EventSourceTests.cs b/test/FunctionalTests/Client/EventSourceTests.cs index 6127e6e53..c40f197ab 100644 --- a/test/FunctionalTests/Client/EventSourceTests.cs +++ b/test/FunctionalTests/Client/EventSourceTests.cs @@ -172,7 +172,7 @@ async Task UnaryError(HelloRequest request, ServerCallContext contex public async Task UnaryMethod_DeadlineExceededCall_PollingCountersUpdatedCorrectly() { // Loop to ensure test is resilent across multiple runs - for (int i = 1; i < 3; i++) + for (var i = 1; i < 3; i++) { var syncPoint = new SyncPoint(); @@ -248,7 +248,7 @@ async Task UnaryDeadlineExceeded(HelloRequest request, ServerCallCon public async Task UnaryMethod_CancelCall_PollingCountersUpdatedCorrectly() { // Loop to ensure test is resilent across multiple runs - for (int i = 1; i < 3; i++) + for (var i = 1; i < 3; i++) { var syncPoint = new SyncPoint(); var cts = new CancellationTokenSource(); diff --git a/test/FunctionalTests/Client/HedgingTests.cs b/test/FunctionalTests/Client/HedgingTests.cs new file mode 100644 index 000000000..ac96014e5 --- /dev/null +++ b/test/FunctionalTests/Client/HedgingTests.cs @@ -0,0 +1,548 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using Grpc.AspNetCore.FunctionalTests.Infrastructure; +using Grpc.Core; +using Grpc.Net.Client; +using Grpc.Net.Client.Configuration; +using Grpc.Tests.Shared; +using Microsoft.Extensions.Logging; +using NUnit.Framework; +using Streaming; + +namespace Grpc.AspNetCore.FunctionalTests.Client +{ + [TestFixture] + public class HedgingTests : FunctionalTestBase + { + [TestCase(0)] + [TestCase(20)] + public async Task Unary_ExceedAttempts_Failure(int hedgingDelay) + { + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Ignore errors + SetExpectedErrorsFilter(writeContext => + { + return true; + }); + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 5, hedgingDelay: TimeSpan.FromMilliseconds(hedgingDelay))); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: ExceededAttemptCount"); + } + + [Test] + public async Task Duplex_ManyParallelRequests_MessageRoundTripped() + { + const string ImportantMessage = +@" _____ _____ _____ + | __ \| __ \ / ____| + __ _| |__) | |__) | | + / _` | _ /| ___/| | + | (_| | | \ \| | | |____ + \__, |_| \_\_| \_____| + __/ | + |___/ + _ + (_) + _ ___ + | / __| + | \__ \ _ + |_|___/ | | + ___ ___ ___ | | + / __/ _ \ / _ \| | + | (_| (_) | (_) | | + \___\___/ \___/|_| + + "; + + var attempts = 100; + var allUploads = new List(); + var allCompletedTasks = new List(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + async Task MessageUpload( + IAsyncStreamReader requestStream, + IServerStreamWriter responseStream, + ServerCallContext context) + { + // Receive chunks + var chunks = new List(); + await foreach (var chunk in requestStream.ReadAllAsync()) + { + chunks.Add(chunk.Value); + } + + Task completeTask; + lock (allUploads) + { + allUploads.Add(string.Join(Environment.NewLine, chunks)); + if (allUploads.Count < attempts) + { + // Check that unused calls are canceled. + completeTask = Task.Run(async () => + { + await tcs.Task; + + var cancellationTcs = new TaskCompletionSource(); + context.CancellationToken.Register(s => ((TaskCompletionSource)s!).SetResult(true), cancellationTcs); + await cancellationTcs.Task; + }); + } + else + { + // Write response in used call. + completeTask = Task.Run(async () => + { + // Write chunks + foreach (var chunk in chunks) + { + await responseStream.WriteAsync(new StringValue + { + Value = chunk + }); + } + }); + } + } + + await completeTask; + } + + var method = Fixture.DynamicGrpc.AddDuplexStreamingMethod(MessageUpload); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 100, hedgingDelay: TimeSpan.Zero), maxRetryAttempts: 100); + + var client = TestClientFactory.Create(channel, method); + + using var call = client.DuplexStreamingCall(); + + var lines = ImportantMessage.Split(Environment.NewLine); + for (var i = 0; i < lines.Length; i++) + { + await call.RequestStream.WriteAsync(new StringValue { Value = lines[i] }).DefaultTimeout(); + await Task.Delay(TimeSpan.FromSeconds(0.01)).DefaultTimeout(); + } + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + await TestHelpers.AssertIsTrueRetryAsync(() => allUploads.Count == 100, "Wait for all calls to reach server.").DefaultTimeout(); + tcs.SetResult(null); + + var receivedLines = new List(); + await foreach (var line in call.ResponseStream.ReadAllAsync().DefaultTimeout()) + { + receivedLines.Add(line.Value); + } + + Assert.AreEqual(ImportantMessage, string.Join(Environment.NewLine, receivedLines)); + + foreach (var upload in allUploads) + { + Assert.AreEqual(ImportantMessage, upload); + } + + await Task.WhenAll(allCompletedTasks).DefaultTimeout(); + } + + [TestCase(1)] + [TestCase(2)] + public async Task Unary_DeadlineExceedAfterServerCall_Failure(int exceptedServerCallCount) + { + var callCount = 0; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + + if (callCount < exceptedServerCallCount) + { + return Task.FromException(new RpcException(new Status(StatusCode.DeadlineExceeded, ""))); + } + + return tcs.Task; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(nonFatalStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(200))); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + + Assert.IsFalse(Logs.Any(l => l.EventId.Name == "DeadlineTimerRescheduled")); + + tcs.SetResult(new DataMessage()); + } + + [Test] + public async Task Unary_DeadlineExceedDuringDelay_Failure() + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + + return Task.FromException(new RpcException(new Status(StatusCode.DeadlineExceeded, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig( + hedgingDelay: TimeSpan.FromSeconds(10), + nonFatalStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(300))); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(1, callCount); + + Assert.IsFalse(Logs.Any(l => l.EventId.Name == "DeadlineTimerRescheduled")); + + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: DeadlineExceeded"); + } + + [Test] + public async Task Duplex_DeadlineExceedDuringDelay_Failure() + { + var callCount = 0; + Task DuplexDeadlineExceeded(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) + { + callCount++; + + return Task.FromException(new RpcException(new Status(StatusCode.DeadlineExceeded, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddDuplexStreamingMethod(DuplexDeadlineExceeded); + + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig( + hedgingDelay: TimeSpan.FromSeconds(10), + nonFatalStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.DuplexStreamingCall(new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(300))); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseStream.MoveNext(CancellationToken.None)).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + + ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new DataMessage())).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(1, callCount); + + Assert.IsFalse(Logs.Any(l => l.EventId.Name == "DeadlineTimerRescheduled")); + } + + [Test] + public async Task Unary_DeadlineExceedBeforeServerCall_Failure() + { + var callCount = 0; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + return tcs.Task; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(nonFatalStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow)); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(0, callCount); + + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: DeadlineExceeded"); + + tcs.SetResult(new DataMessage()); + } + + [Test] + public async Task Unary_TriggerRetryThrottling_Failure() + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig( + hedgingDelay: TimeSpan.FromMilliseconds(100), + retryThrottling: new RetryThrottlingPolicy + { + MaxTokens = 5, + TokenRatio = 0.1 + })); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(3, callCount); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: Throttled"); + } + + [TestCase(0)] + [TestCase(100)] + public async Task Unary_RetryThrottlingAlreadyActive_Failure(int hedgingDelayMilliseconds) + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig( + hedgingDelay: TimeSpan.FromMilliseconds(hedgingDelayMilliseconds), + retryThrottling: new RetryThrottlingPolicy + { + MaxTokens = 5, + TokenRatio = 0.1 + })); + + // Manually trigger retry throttling + Debug.Assert(channel.RetryThrottling != null); + channel.RetryThrottling.CallFailure(); + channel.RetryThrottling.CallFailure(); + channel.RetryThrottling.CallFailure(); + Debug.Assert(channel.RetryThrottling.IsRetryThrottlingActive()); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(1, callCount); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: Throttled"); + } + + [Test] + public async Task Unary_RetryThrottlingBecomesActive_HasDelay_Failure() + { + var callCount = 0; + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + async Task UnaryFailure(DataMessage request, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + await syncPoint.WaitToContinue(); + return request; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig( + hedgingDelay: TimeSpan.FromMilliseconds(100), + retryThrottling: new RetryThrottlingPolicy + { + MaxTokens = 5, + TokenRatio = 0.1 + })); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + await syncPoint.WaitForSyncPoint().DefaultTimeout(); + + // Manually trigger retry throttling + Debug.Assert(channel.RetryThrottling != null); + channel.RetryThrottling.CallFailure(); + channel.RetryThrottling.CallFailure(); + channel.RetryThrottling.CallFailure(); + Debug.Assert(channel.RetryThrottling.IsRetryThrottlingActive()); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => HasLog(LogLevel.Debug, "AdditionalCallsBlockedByRetryThrottling", "Additional calls blocked by retry throttling."), "Check for expected log."); + + Assert.AreEqual(1, callCount); + syncPoint.Continue(); + + await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode); + + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: ResponseHeadersReceived"); + } + + [TestCase(0)] + [TestCase(20)] + public async Task Unary_AttemptsGreaterThanDefaultClientLimit_LimitedAttemptsMade(int hedgingDelay) + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 10, hedgingDelay: TimeSpan.FromMilliseconds(hedgingDelay))); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(5, callCount); + + AssertHasLog(LogLevel.Debug, "MaxAttemptsLimited", "The method has 10 attempts specified in the service config. The number of attempts has been limited by channel configuration to 5."); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: ExceededAttemptCount"); + } + + [TestCase(0, false, 0)] + [TestCase(0, false, 1)] + [TestCase(GrpcChannel.DefaultMaxRetryBufferPerCallSize - 10, false, 0)] // Final message size is bigger because of header + Protobuf field + [TestCase(GrpcChannel.DefaultMaxRetryBufferPerCallSize - 10, false, 1)] // Final message size is bigger because of header + Protobuf field + [TestCase(GrpcChannel.DefaultMaxRetryBufferPerCallSize + 10, true, 0)] + [TestCase(GrpcChannel.DefaultMaxRetryBufferPerCallSize + 10, true, 1)] + public async Task Unary_LargeMessages_ExceedPerCallBufferSize(long payloadSize, bool exceedBufferLimit, int hedgingDelayMilliseconds) + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Ignore errors + SetExpectedErrorsFilter(writeContext => + { + if (writeContext.EventId.Name == "ErrorSendingMessage" || + writeContext.EventId.Name == "ErrorExecutingServiceMethod") + { + return true; + } + + return false; + }); + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromMilliseconds(hedgingDelayMilliseconds))); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage + { + Data = ByteString.CopyFrom(new byte[payloadSize]) + }); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + if (!exceedBufferLimit) + { + Assert.AreEqual(5, callCount); + } + else + { + Assert.AreEqual(1, callCount); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: BufferExceeded"); + } + + Assert.AreEqual(0, channel.CurrentRetryBufferSize); + } + } +} diff --git a/test/FunctionalTests/Client/RetryTests.cs b/test/FunctionalTests/Client/RetryTests.cs new file mode 100644 index 000000000..e7ef76bfa --- /dev/null +++ b/test/FunctionalTests/Client/RetryTests.cs @@ -0,0 +1,562 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Google.Protobuf; +using Grpc.AspNetCore.FunctionalTests.Infrastructure; +using Grpc.Core; +using Grpc.Net.Client; +using Grpc.Net.Client.Configuration; +using Grpc.Tests.Shared; +using Microsoft.Extensions.Logging; +using NUnit.Framework; +using Streaming; + +namespace Grpc.AspNetCore.FunctionalTests.Client +{ + [TestFixture] + public class RetryTests : FunctionalTestBase + { + [Test] + public async Task ClientStreaming_MultipleWritesAndRetries_Failure() + { + var nextFailure = 1; + + async Task ClientStreamingWithReadFailures(IAsyncStreamReader requestStream, ServerCallContext context) + { + List bytes = new List(); + await foreach (var message in requestStream.ReadAllAsync()) + { + if (bytes.Count >= nextFailure) + { + nextFailure = nextFailure * 2; + throw new RpcException(new Status(StatusCode.Unavailable, "")); + } + + bytes.Add(message.Data[0]); + } + + return new DataMessage + { + Data = ByteString.CopyFrom(bytes.ToArray()) + }; + } + + SetExpectedErrorsFilter(writeContext => + { + return true; + }); + + // Arrange + var method = Fixture.DynamicGrpc.AddClientStreamingMethod(ClientStreamingWithReadFailures); + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig(maxAttempts: 10), maxRetryAttempts: 10); + var client = TestClientFactory.Create(channel, method); + var sentData = new List(); + + // Act + var call = client.ClientStreamingCall(); + + for (var i = 0; i < 20; i++) + { + sentData.Add((byte)i); + + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { (byte)i }) }).DefaultTimeout(); + await Task.Delay(1); + } + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var result = await call.ResponseAsync.DefaultTimeout(); + + // Assert + Assert.IsTrue(result.Data.Span.SequenceEqual(sentData.ToArray())); + } + + [Test] + public async Task Unary_ExceedRetryAttempts_Failure() + { + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + var metadata = new Metadata(); + metadata.Add("grpc-retry-pushback-ms", "5"); + + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""), metadata)); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig()); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + AssertHasLog(LogLevel.Debug, "RetryPushbackReceived", "Retry pushback of '5' received from the failed gRPC call."); + AssertHasLog(LogLevel.Debug, "RetryEvaluated", "Evaluated retry for failed gRPC call. Status code: 'Unavailable', Attempt: 1, Retry: True"); + AssertHasLog(LogLevel.Trace, "StartingRetryDelay", "Starting retry delay of 00:00:00.0050000."); + AssertHasLog(LogLevel.Debug, "RetryEvaluated", "Evaluated retry for failed gRPC call. Status code: 'Unavailable', Attempt: 5, Retry: False"); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: ExceededAttemptCount"); + } + + [Test] + public async Task Unary_TriggerRetryThrottling_Failure() + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig( + retryThrottling: new RetryThrottlingPolicy + { + MaxTokens = 5, + TokenRatio = 0.1 + })); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + AssertHasLog(LogLevel.Debug, "RetryEvaluated", "Evaluated retry for failed gRPC call. Status code: 'Unavailable', Attempt: 3, Retry: False"); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: Throttled"); + } + + [TestCase(1)] + [TestCase(2)] + public async Task Unary_DeadlineExceedAfterServerCall_Failure(int exceptedServerCallCount) + { + var callCount = 0; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + + if (callCount < exceptedServerCallCount) + { + return Task.FromException(new RpcException(new Status(StatusCode.DeadlineExceeded, ""))); + } + + return tcs.Task; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(retryableStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(200))); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(exceptedServerCallCount, callCount); + + Assert.IsFalse(Logs.Any(l => l.EventId.Name == "DeadlineTimerRescheduled")); + + AssertHasLog(LogLevel.Debug, "RetryEvaluated", $"Evaluated retry for failed gRPC call. Status code: 'DeadlineExceeded', Attempt: {exceptedServerCallCount}, Retry: False"); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: DeadlineExceeded"); + + tcs.SetResult(new DataMessage()); + } + + [Test] + public async Task Unary_DeadlineExceedDuringBackoff_Failure() + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""), new Metadata + { + new Metadata.Entry("grpc-retry-pushback-ms", TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()) + })); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig( + initialBackoff: TimeSpan.FromSeconds(10), + maxBackoff: TimeSpan.FromSeconds(10), + retryableStatusCodes: new List { StatusCode.Unavailable }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(500))); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(1, callCount); + + Assert.IsFalse(Logs.Any(l => l.EventId.Name == "DeadlineTimerRescheduled")); + } + + [Test] + public async Task Duplex_DeadlineExceedDuringBackoff_Failure() + { + var callCount = 0; + Task DuplexDeadlineExceeded(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) + { + callCount++; + + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""), new Metadata + { + new Metadata.Entry("grpc-retry-pushback-ms", TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()) + })); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddDuplexStreamingMethod(DuplexDeadlineExceeded); + + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig( + initialBackoff: TimeSpan.FromSeconds(10), + maxBackoff: TimeSpan.FromSeconds(10), + retryableStatusCodes: new List { StatusCode.Unavailable }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.DuplexStreamingCall(new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(300))); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseStream.MoveNext(CancellationToken.None)).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + + ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new DataMessage())).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(1, callCount); + + Assert.IsFalse(Logs.Any(l => l.EventId.Name == "DeadlineTimerRescheduled")); + } + + [Test] + public async Task Unary_DeadlineExceedBeforeServerCall_Failure() + { + var callCount = 0; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + callCount++; + return tcs.Task; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(retryableStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow)); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + Assert.AreEqual(0, callCount); + + AssertHasLog(LogLevel.Debug, "RetryEvaluated", "Evaluated retry for failed gRPC call. Status code: 'DeadlineExceeded', Attempt: 1, Retry: False"); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: DeadlineExceeded"); + + tcs.SetResult(new DataMessage()); + } + + [TestCase(0)] + [TestCase(20)] + public async Task Unary_AttemptsGreaterThanDefaultClientLimit_LimitedAttemptsMade(int hedgingDelay) + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig(maxAttempts: 10, initialBackoff: TimeSpan.FromMilliseconds(hedgingDelay))); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage()); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(5, callCount); + + AssertHasLog(LogLevel.Debug, "MaxAttemptsLimited", "The method has 10 attempts specified in the service config. The number of attempts has been limited by channel configuration to 5."); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: ExceededAttemptCount"); + } + + [TestCase(0, false)] + [TestCase(GrpcChannel.DefaultMaxRetryBufferPerCallSize - 10, false)] // Final message size is bigger because of header + Protobuf field + [TestCase(GrpcChannel.DefaultMaxRetryBufferPerCallSize + 10, true)] + public async Task Unary_LargeMessages_ExceedPerCallBufferSize(long payloadSize, bool exceedBufferLimit) + { + var callCount = 0; + Task UnaryFailure(DataMessage request, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + return Task.FromException(new RpcException(new Status(StatusCode.Unavailable, ""))); + } + + // Arrange + var method = Fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var channel = CreateChannel(serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig()); + + var client = TestClientFactory.Create(channel, method); + + // Act + var call = client.UnaryCall(new DataMessage + { + Data = ByteString.CopyFrom(new byte[payloadSize]) + }); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + if (!exceedBufferLimit) + { + Assert.AreEqual(5, callCount); + } + else + { + Assert.AreEqual(1, callCount); + AssertHasLog(LogLevel.Debug, "CallCommited", "Call commited. Reason: BufferExceeded"); + } + + Assert.AreEqual(0, channel.CurrentRetryBufferSize); + } + + [Test] + public async Task Unary_MultipleLargeMessages_ExceedChannelMaxBufferSize() + { + // Arrange + var sp1 = new SyncPoint(runContinuationsAsynchronously: true); + var sp2 = new SyncPoint(runContinuationsAsynchronously: true); + var sp3 = new SyncPoint(runContinuationsAsynchronously: true); + var channel = CreateChannel( + serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig(), + maxRetryBufferSize: 200, + maxRetryBufferPerCallSize: 100); + + var request = new DataMessage { Data = ByteString.CopyFrom(new byte[90]) }; + + // Act + var call1Task = MakeCall(Fixture, channel, request, sp1); + await sp1.WaitForSyncPoint(); + + var call2Task = MakeCall(Fixture, channel, request, sp2); + await sp2.WaitForSyncPoint(); + + // Will exceed channel buffer limit and won't retry + var call3Task = MakeCall(Fixture, channel, request, sp3); + await sp3.WaitForSyncPoint(); + + // Assert + Assert.AreEqual(194, channel.CurrentRetryBufferSize); + + sp1.Continue(); + sp2.Continue(); + sp3.Continue(); + + var response = await call1Task.DefaultTimeout(); + Assert.AreEqual(90, response.Data.Length); + + response = await call2Task.DefaultTimeout(); + Assert.AreEqual(90, response.Data.Length); + + // Can't retry because buffer size exceeded. + var ex = await ExceptionAssert.ThrowsAsync(() => call3Task).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + + Assert.AreEqual(0, channel.CurrentRetryBufferSize); + + static Task MakeCall(GrpcTestFixture fixture, GrpcChannel channel, DataMessage request, SyncPoint syncPoint) + { + var callCount = 0; + async Task UnaryFailure(DataMessage request, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + if (callCount == 1) + { + await syncPoint.WaitToContinue(); + throw new RpcException(new Status(StatusCode.Unavailable, "")); + } + else + { + return request; + } + } + + // Arrange + var method = fixture.DynamicGrpc.AddUnaryMethod(UnaryFailure); + + var client = TestClientFactory.Create(channel, method); + + var call = client.UnaryCall(request); + + return call.ResponseAsync; + } + } + + [Test] + public async Task ClientStreaming_MultipleWritesExceedPerCallLimit_Failure() + { + var nextFailure = 2; + var callCount = 0; + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + + async Task ClientStreamingWithReadFailures(IAsyncStreamReader requestStream, ServerCallContext context) + { + Interlocked.Increment(ref callCount); + + List bytes = new List(); + await foreach (var message in requestStream.ReadAllAsync()) + { + bytes.Add(message.Data[0]); + + Logger.LogInformation($"Current count: {bytes.Count}, next failure: {nextFailure}."); + + if (bytes.Count >= nextFailure) + { + await syncPoint.WaitToContinue(); + throw new RpcException(new Status(StatusCode.Unavailable, "")); + } + } + + return new DataMessage + { + Data = ByteString.CopyFrom(bytes.ToArray()) + }; + } + + SetExpectedErrorsFilter(writeContext => + { + return true; + }); + + // Arrange + var method = Fixture.DynamicGrpc.AddClientStreamingMethod(ClientStreamingWithReadFailures); + var channel = CreateChannel( + serviceConfig: ServiceConfigHelpers.CreateRetryServiceConfig(maxAttempts: 10), + maxRetryAttempts: 10, + maxRetryBufferPerCallSize: 100); + var client = TestClientFactory.Create(channel, method); + var sentData = new List(); + + // Act + var call = client.ClientStreamingCall(); + + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + + await syncPoint.WaitForSyncPoint(); + + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + + var s = syncPoint; + syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + nextFailure = 15; + s.Continue(); + + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + + Assert.AreEqual(96, channel.CurrentRetryBufferSize); + + await TestHelpers.AssertIsTrueRetryAsync(() => callCount == 2, "Wait for server to have second call.").DefaultTimeout(); + + // This message exceeds the buffer size. Call is commited here. + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + Assert.AreEqual(0, channel.CurrentRetryBufferSize); + + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + + await syncPoint.WaitForSyncPoint(); + + await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(new byte[] { 1 }) }).DefaultTimeout(); + + s = syncPoint; + syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + nextFailure = int.MaxValue; + s.Continue(); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(2, callCount); + Assert.AreEqual(0, channel.CurrentRetryBufferSize); + } + } +} diff --git a/test/FunctionalTests/FunctionalTestBase.cs b/test/FunctionalTests/FunctionalTestBase.cs index ebe17a2e9..88c5e189e 100644 --- a/test/FunctionalTests/FunctionalTestBase.cs +++ b/test/FunctionalTests/FunctionalTestBase.cs @@ -22,6 +22,7 @@ using Grpc.AspNetCore.FunctionalTests.Infrastructure; using Grpc.Core; using Grpc.Net.Client; +using Grpc.Net.Client.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using NUnit.Framework; @@ -41,12 +42,26 @@ public class FunctionalTestBase protected GrpcChannel Channel => _channel ??= CreateChannel(); - protected GrpcChannel CreateChannel(bool useHandler = false) + protected GrpcChannel CreateChannel(bool useHandler = false, ServiceConfig? serviceConfig = null, int? maxRetryAttempts = null, long? maxRetryBufferSize = null, long? maxRetryBufferPerCallSize = null) { var options = new GrpcChannelOptions { - LoggerFactory = LoggerFactory + LoggerFactory = LoggerFactory, + ServiceConfig = serviceConfig }; + // Don't overwrite defaults + if (maxRetryAttempts != null) + { + options.MaxRetryAttempts = maxRetryAttempts; + } + if (maxRetryBufferSize != null) + { + options.MaxRetryBufferSize = maxRetryBufferSize; + } + if (maxRetryBufferPerCallSize != null) + { + options.MaxRetryBufferPerCallSize = maxRetryBufferPerCallSize; + } if (useHandler) { options.HttpHandler = Fixture.Handler; diff --git a/test/FunctionalTests/Grpc.AspNetCore.FunctionalTests.csproj b/test/FunctionalTests/Grpc.AspNetCore.FunctionalTests.csproj index dab69c71f..78326bf26 100644 --- a/test/FunctionalTests/Grpc.AspNetCore.FunctionalTests.csproj +++ b/test/FunctionalTests/Grpc.AspNetCore.FunctionalTests.csproj @@ -7,6 +7,7 @@ + diff --git a/test/FunctionalTests/Server/ClientStreamingMethodTests.cs b/test/FunctionalTests/Server/ClientStreamingMethodTests.cs index 43ff8a5f6..3e6b2fdbc 100644 --- a/test/FunctionalTests/Server/ClientStreamingMethodTests.cs +++ b/test/FunctionalTests/Server/ClientStreamingMethodTests.cs @@ -201,7 +201,7 @@ static async Task AccumulateCount(IAsyncStreamReader { - for (int i = 0; i < 10; i++) + for (var i = 0; i < 10; i++) { await s.WriteAsync(ms.ToArray()).AsTask().DefaultTimeout(); await s.FlushAsync().DefaultTimeout(); diff --git a/test/FunctionalTests/Server/DeadlineTests.cs b/test/FunctionalTests/Server/DeadlineTests.cs index feffdfa16..4ea56b347 100644 --- a/test/FunctionalTests/Server/DeadlineTests.cs +++ b/test/FunctionalTests/Server/DeadlineTests.cs @@ -251,15 +251,18 @@ public async Task WriteMessageAfterDeadline() { static async Task WriteUntilError(HelloRequest request, IServerStreamWriter responseStream, ServerCallContext context) { - var i = 0; - while (true) + for (var i = 0; i < 5; i++) { var message = $"How are you {request.Name}? {i}"; await responseStream.WriteAsync(new HelloReply { Message = message }).DefaultTimeout(); - i++; - await Task.Delay(10); } + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + context.CancellationToken.Register(s => ((TaskCompletionSource)s!).SetResult(true), tcs); + await tcs.Task; + + await responseStream.WriteAsync(new HelloReply { Message = "Write after deadline" }).DefaultTimeout(); } // Arrange diff --git a/test/Grpc.AspNetCore.Server.Tests/Web/Base64PipeReaderTests.cs b/test/Grpc.AspNetCore.Server.Tests/Web/Base64PipeReaderTests.cs index 9674f3ac0..4cd7d7c95 100644 --- a/test/Grpc.AspNetCore.Server.Tests/Web/Base64PipeReaderTests.cs +++ b/test/Grpc.AspNetCore.Server.Tests/Web/Base64PipeReaderTests.cs @@ -130,7 +130,7 @@ public async Task ReadAsync_ByteAtATime_Success() Assert.IsFalse(resultTask.IsCompleted); - for (int i = 0; i < base64Data.Length; i++) + for (var i = 0; i < base64Data.Length; i++) { await testPipe.Writer.WriteAsync(base64Data.AsMemory(i, 1)); await Task.Delay(10); diff --git a/test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs b/test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs index 342cc6c09..a1b72feb7 100644 --- a/test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs +++ b/test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs @@ -43,10 +43,12 @@ public async Task AsyncUnaryCall_Success_HttpRequestMessagePopulated() { // Arrange HttpRequestMessage? httpRequestMessage = null; + long? requestContentLength = null; var httpClient = ClientTestHelpers.CreateTestClient(async request => { httpRequestMessage = request; + requestContentLength = httpRequestMessage!.Content!.Headers!.ContentLength; HelloReply reply = new HelloReply { @@ -72,6 +74,7 @@ public async Task AsyncUnaryCall_Success_HttpRequestMessagePopulated() Assert.AreEqual(new MediaTypeHeaderValue("application/grpc"), httpRequestMessage.Content?.Headers?.ContentType); Assert.AreEqual(GrpcProtocolConstants.TEHeaderValue, httpRequestMessage.Headers.TE.Single().Value); Assert.AreEqual("identity,gzip", httpRequestMessage.Headers.GetValues(GrpcProtocolConstants.MessageAcceptEncodingHeader).Single()); + Assert.AreEqual(null, requestContentLength); var userAgent = httpRequestMessage.Headers.UserAgent.Single()!; Assert.AreEqual("grpc-dotnet", userAgent.Product?.Name); @@ -83,6 +86,41 @@ public async Task AsyncUnaryCall_Success_HttpRequestMessagePopulated() Assert.IsTrue(userAgent.Product!.Version!.Length <= 10); } + [Test] + public async Task AsyncUnaryCall_HasWinHttpHandler_ContentLengthOnHttpRequestMessagePopulated() + { + // Arrange + HttpRequestMessage? httpRequestMessage = null; + long? requestContentLength = null; + + var handler = TestHttpMessageHandler.Create(async request => + { + httpRequestMessage = request; + requestContentLength = httpRequestMessage!.Content!.Headers!.ContentLength; + + HelloReply reply = new HelloReply + { + Message = "Hello world" + }; + + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + // Just need to have a type called WinHttpHandler to activate new behavior. + var winHttpHandler = new WinHttpHandler(handler); + var invoker = HttpClientCallInvokerFactory.Create(winHttpHandler, "https://localhost"); + + // Act + var rs = await invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "Hello world" }); + + // Assert + Assert.AreEqual("Hello world", rs.Message); + + Assert.IsNotNull(httpRequestMessage); + Assert.AreEqual(18, requestContentLength); + } + [Test] public async Task AsyncUnaryCall_Success_RequestContentSent() { @@ -126,7 +164,7 @@ public async Task AsyncUnaryCall_Success_RequestContentSent() } [Test] - public async Task AsyncUnaryCall_NonOkStatusTrailer_AccessResponse_ReturnHeaders() + public async Task AsyncUnaryCall_NonOkStatusTrailer_AccessResponse_ThrowRpcError() { // Arrange var httpClient = ClientTestHelpers.CreateTestClient(request => @@ -144,7 +182,7 @@ public async Task AsyncUnaryCall_NonOkStatusTrailer_AccessResponse_ReturnHeaders } [Test] - public async Task AsyncUnaryCall_NonOkStatusTrailer_AccessHeaders_ThrowRpcError() + public async Task AsyncUnaryCall_NonOkStatusTrailer_AccessHeaders_ReturnHeaders() { // Arrange var httpClient = ClientTestHelpers.CreateTestClient(request => diff --git a/test/Grpc.Net.Client.Tests/Grpc.Net.Client.Tests.csproj b/test/Grpc.Net.Client.Tests/Grpc.Net.Client.Tests.csproj index 7cc960a6e..91082a74a 100644 --- a/test/Grpc.Net.Client.Tests/Grpc.Net.Client.Tests.csproj +++ b/test/Grpc.Net.Client.Tests/Grpc.Net.Client.Tests.csproj @@ -12,12 +12,14 @@ + + diff --git a/test/Grpc.Net.Client.Tests/GrpcChannelTests.cs b/test/Grpc.Net.Client.Tests/GrpcChannelTests.cs index a952a5333..bc357d3fa 100644 --- a/test/Grpc.Net.Client.Tests/GrpcChannelTests.cs +++ b/test/Grpc.Net.Client.Tests/GrpcChannelTests.cs @@ -23,6 +23,7 @@ using Greet; using Grpc.Core; using Grpc.Net.Client.Tests.Infrastructure; +using Grpc.Net.Client.Configuration; using Grpc.Tests.Shared; using NUnit.Framework; @@ -168,6 +169,26 @@ public void Build_NoHttpProviderOnNetFx_Throw() } #endif + [Test] + public void Build_ServiceConfigDuplicateMethodConfigNames_Error() + { + // Arrange + var options = CreateGrpcChannelOptions(o => o.ServiceConfig = new ServiceConfig + { + MethodConfigs = + { + new MethodConfig { Names = { MethodName.Default } }, + new MethodConfig { Names = { MethodName.Default } } + } + }); + + // Act + var ex = Assert.Throws(() => GrpcChannel.ForAddress("https://localhost", options)); + + // Assert + Assert.AreEqual("Duplicate method config found. Service: '', method: ''.", ex.Message); + } + [Test] public void Dispose_NotCalled_NotDisposed() { diff --git a/test/Grpc.Net.Client.Tests/HttpContentClientStreamReaderTests.cs b/test/Grpc.Net.Client.Tests/HttpContentClientStreamReaderTests.cs index 20fe713d2..91bc70c42 100644 --- a/test/Grpc.Net.Client.Tests/HttpContentClientStreamReaderTests.cs +++ b/test/Grpc.Net.Client.Tests/HttpContentClientStreamReaderTests.cs @@ -234,9 +234,10 @@ private static GrpcCall CreateGrpcCall(GrpcChannel cha return new GrpcCall( ClientTestHelpers.ServiceMethod, - new GrpcMethodInfo(new GrpcCallScope(ClientTestHelpers.ServiceMethod.Type, uri), uri), + new GrpcMethodInfo(new GrpcCallScope(ClientTestHelpers.ServiceMethod.Type, uri), uri, methodConfig: null), new CallOptions(), - channel); + channel, + previousAttempts: 0); } private static GrpcChannel CreateChannel(HttpClient httpClient, ILoggerFactory? loggerFactory = null, bool? throwOperationCanceledOnCancellation = null) diff --git a/test/Grpc.Net.Client.Tests/Infrastructure/HttpClientCallInvokerFactory.cs b/test/Grpc.Net.Client.Tests/Infrastructure/HttpClientCallInvokerFactory.cs index 8276fd622..4808e7bbf 100644 --- a/test/Grpc.Net.Client.Tests/Infrastructure/HttpClientCallInvokerFactory.cs +++ b/test/Grpc.Net.Client.Tests/Infrastructure/HttpClientCallInvokerFactory.cs @@ -18,6 +18,7 @@ using System; using System.Net.Http; +using Grpc.Net.Client.Configuration; using Grpc.Net.Client.Internal; using Microsoft.Extensions.Logging; @@ -32,12 +33,14 @@ public static HttpClientCallInvoker Create( Action? configure = null, bool? disableClientDeadline = null, long? maxTimerPeriod = null, - IOperatingSystem? operatingSystem = null) + IOperatingSystem? operatingSystem = null, + ServiceConfig? serviceConfig = null) { var channelOptions = new GrpcChannelOptions { LoggerFactory = loggerFactory, - HttpClient = httpClient + HttpClient = httpClient, + ServiceConfig = serviceConfig }; configure?.Invoke(channelOptions); diff --git a/test/Grpc.Net.Client.Tests/Infrastructure/WinHttpHandler.cs b/test/Grpc.Net.Client.Tests/Infrastructure/WinHttpHandler.cs new file mode 100644 index 000000000..06f6a1c66 --- /dev/null +++ b/test/Grpc.Net.Client.Tests/Infrastructure/WinHttpHandler.cs @@ -0,0 +1,28 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +// Namespace and class name needs to resolve to System.Net.Http.WinHttpHandler. +namespace System.Net.Http +{ + public class WinHttpHandler : DelegatingHandler + { + public WinHttpHandler(HttpMessageHandler innerHandler) : base(innerHandler) + { + } + } +} diff --git a/test/Grpc.Net.Client.Tests/Retry/ChannelRetryThrottlingTests.cs b/test/Grpc.Net.Client.Tests/Retry/ChannelRetryThrottlingTests.cs new file mode 100644 index 000000000..caeb7b62f --- /dev/null +++ b/test/Grpc.Net.Client.Tests/Retry/ChannelRetryThrottlingTests.cs @@ -0,0 +1,49 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using Grpc.Net.Client.Configuration; +using Grpc.Net.Client.Internal.Retry; +using NUnit.Framework; + +namespace Grpc.Net.Client.Tests.Retry +{ + [TestFixture] + public class ChannelRetryThrottlingTests + { + [Test] + public void IsRetryThrottlingActive_FailedAndSuccessCalls_ActivatedChanges() + { + var channelRetryThrottling = new ChannelRetryThrottling(new RetryThrottlingPolicy + { + MaxTokens = 3, + TokenRatio = 1 + }); + + Assert.AreEqual(false, channelRetryThrottling.IsRetryThrottlingActive()); + + channelRetryThrottling.CallFailure(); + Assert.AreEqual(false, channelRetryThrottling.IsRetryThrottlingActive()); + + channelRetryThrottling.CallFailure(); + Assert.AreEqual(true, channelRetryThrottling.IsRetryThrottlingActive()); + + channelRetryThrottling.CallSuccess(); + Assert.AreEqual(false, channelRetryThrottling.IsRetryThrottlingActive()); + } + } +} diff --git a/test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs b/test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs new file mode 100644 index 000000000..8b0a55adf --- /dev/null +++ b/test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs @@ -0,0 +1,371 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Greet; +using Grpc.Core; +using Grpc.Net.Client.Configuration; +using Grpc.Net.Client.Internal; +using Grpc.Net.Client.Internal.Http; +using Grpc.Net.Client.Internal.Retry; +using Grpc.Net.Client.Tests.Infrastructure; +using Grpc.Tests.Shared; +using NUnit.Framework; + +namespace Grpc.Net.Client.Tests.Retry +{ + [TestFixture] + public class HedgingCallTests + { + [Test] + public async Task Dispose_ActiveCalls_CleansUpActiveCalls() + { + // Arrange + var allCallsOnServerTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var waitUntilFinishedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + // All calls are in-progress at once. + Interlocked.Increment(ref callCount); + if (callCount == 5) + { + allCallsOnServerTcs.SetResult(null); + } + await waitUntilFinishedTcs.Task; + + await request.Content!.CopyToAsync(new MemoryStream()); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 5, hedgingDelay: TimeSpan.FromMilliseconds(20)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(serviceConfig.MethodConfigs[0].HedgingPolicy!, invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions()); + + // Act + hedgingCall.StartUnary(new HelloRequest { Name = "World" }); + Assert.IsFalse(hedgingCall.CreateHedgingCallsTask!.IsCompleted); + + // Assert + Assert.AreEqual(1, hedgingCall._activeCalls.Count); + + await allCallsOnServerTcs.Task.DefaultTimeout(); + + Assert.AreEqual(5, callCount); + Assert.AreEqual(5, hedgingCall._activeCalls.Count); + + hedgingCall.Dispose(); + Assert.AreEqual(0, hedgingCall._activeCalls.Count); + await hedgingCall.CreateHedgingCallsTask!.DefaultTimeout(); + + waitUntilFinishedTcs.SetResult(null); + } + + [Test] + public async Task ActiveCalls_FatalStatusCode_CleansUpActiveCalls() + { + // Arrange + var allCallsOnServerSyncPoint = new SyncPoint(runContinuationsAsynchronously: true); + var waitUntilFinishedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var callLock = new object(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + await request.Content!.CopyToAsync(new MemoryStream()); + + // All calls are in-progress at once. + bool allCallsOnServer = false; + lock (callLock) + { + callCount++; + if (callCount == 5) + { + allCallsOnServer = true; + } + } + if (allCallsOnServer) + { + await allCallsOnServerSyncPoint.WaitToContinue(); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.InvalidArgument); + } + await waitUntilFinishedTcs.Task; + + throw new InvalidOperationException("Should never reach here."); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 5, hedgingDelay: TimeSpan.FromMilliseconds(20)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(serviceConfig.MethodConfigs[0].HedgingPolicy!, invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions()); + + // Act + hedgingCall.StartUnary(new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(1, hedgingCall._activeCalls.Count); + Assert.IsFalse(hedgingCall.CreateHedgingCallsTask!.IsCompleted); + + await allCallsOnServerSyncPoint.WaitForSyncPoint().DefaultTimeout(); + + Assert.AreEqual(5, callCount); + Assert.AreEqual(5, hedgingCall._activeCalls.Count); + + allCallsOnServerSyncPoint.Continue(); + + var ex = await ExceptionAssert.ThrowsAsync(() => hedgingCall.GetResponseAsync()).DefaultTimeout(); + Assert.AreEqual(StatusCode.InvalidArgument, ex.StatusCode); + + // Fatal status code will cancel other calls + Assert.AreEqual(0, hedgingCall._activeCalls.Count); + await hedgingCall.CreateHedgingCallsTask!.DefaultTimeout(); + + waitUntilFinishedTcs.SetResult(null); + } + + [Test] + public async Task ClientStreamWriteAsync_NoActiveCalls_WaitsForNextCall() + { + // Arrange + var allCallsOnServerSyncPoint = new SyncPoint(runContinuationsAsynchronously: true); + var callLock = new object(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + var content = (PushStreamContent)request.Content!; + _ = content.ReadAsStreamAsync(); + + // All calls are in-progress at once. + bool firstCallsOnServer = false; + lock (callLock) + { + callCount++; + if (callCount == 1) + { + firstCallsOnServer = true; + } + } + if (firstCallsOnServer) + { + await allCallsOnServerSyncPoint.WaitToContinue(); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + await content.PushComplete.DefaultTimeout(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 5, hedgingDelay: TimeSpan.FromMilliseconds(200)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(serviceConfig.MethodConfigs[0].HedgingPolicy!, invoker.Channel, ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), new CallOptions()); + + // Act + hedgingCall.StartClientStreaming(); + await hedgingCall.ClientStreamWriter!.WriteAsync(new HelloRequest { Name = "Name 1" }).DefaultTimeout(); + + // Assert + Assert.AreEqual(1, hedgingCall._activeCalls.Count); + Assert.IsFalse(hedgingCall.CreateHedgingCallsTask!.IsCompleted); + + await allCallsOnServerSyncPoint.WaitForSyncPoint().DefaultTimeout(); + allCallsOnServerSyncPoint.Continue(); + + await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 0, "Call should finish and then wait until next call."); + + // This call will wait until next hedging call starts + await hedgingCall.ClientStreamWriter!.WriteAsync(new HelloRequest { Name = "Name 2" }).DefaultTimeout(); + Assert.AreEqual(1, hedgingCall._activeCalls.Count); + + await hedgingCall.ClientStreamWriter!.CompleteAsync().DefaultTimeout(); + + var responseMessage = await hedgingCall.GetResponseAsync().DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + Assert.AreEqual(0, hedgingCall._activeCalls.Count); + await hedgingCall.CreateHedgingCallsTask!.DefaultTimeout(); + } + + [Test] + public async Task ResponseAsync_PushbackStop_SuccessAfterPushbackStop() + { + // Arrange + var allCallsOnServerTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var returnSuccessTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + // All calls are in-progress at once. + Interlocked.Increment(ref callCount); + if (callCount == 2) + { + allCallsOnServerTcs.TrySetResult(null); + } + await allCallsOnServerTcs.Task; + + await request.Content!.CopyToAsync(new MemoryStream()); + + if (request.Headers.TryGetValues(GrpcProtocolConstants.RetryPreviousAttemptsHeader, out var headerValues) && + headerValues.Single() == "1") + { + await returnSuccessTcs.Task; + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + } + else + { + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, customHeaders: new Dictionary + { + [GrpcProtocolConstants.RetryPushbackHeader] = "-1" + }); + } + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 2); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(serviceConfig.MethodConfigs[0].HedgingPolicy!, invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions()); + + // Act + hedgingCall.StartUnary(new HelloRequest { Name = "World" }); + + // Wait for both calls to be on the server + await allCallsOnServerTcs.Task; + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 1, "Wait for pushback to be returned."); + returnSuccessTcs.SetResult(null); + + var rs = await hedgingCall.GetResponseAsync().DefaultTimeout(); + Assert.AreEqual("Hello world", rs.Message); + Assert.AreEqual(StatusCode.OK, hedgingCall.GetStatus().StatusCode); + Assert.AreEqual(2, callCount); + Assert.AreEqual(0, hedgingCall._activeCalls.Count); + } + + [Test] + public async Task RetryThrottling_BecomesActiveDuringDelay_CancelFailure() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig( + hedgingDelay: TimeSpan.FromMilliseconds(200), + retryThrottling: new RetryThrottlingPolicy + { + MaxTokens = 5, + TokenRatio = 0.1 + }); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(serviceConfig.MethodConfigs[0].HedgingPolicy!, invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions()); + + // Act + hedgingCall.StartUnary(new HelloRequest()); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 0, "Wait for all calls to fail.").DefaultTimeout(); + + CompatibilityExtensions.Assert(invoker.Channel.RetryThrottling != null); + invoker.Channel.RetryThrottling.CallFailure(); + invoker.Channel.RetryThrottling.CallFailure(); + CompatibilityExtensions.Assert(invoker.Channel.RetryThrottling.IsRetryThrottlingActive()); + + var ex = await ExceptionAssert.ThrowsAsync(() => hedgingCall.GetResponseAsync()).DefaultTimeout(); + Assert.AreEqual(1, callCount); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual(StatusCode.Cancelled, hedgingCall.GetStatus().StatusCode); + Assert.AreEqual("Retries stopped because retry throttling is active.", hedgingCall.GetStatus().Detail); + } + + [Test] + public async Task AsyncUnaryCall_CancellationDuringBackoff_CanceledStatus() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()); + }); + var cts = new CancellationTokenSource(); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromSeconds(10)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(serviceConfig.MethodConfigs[0].HedgingPolicy!, invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions(cancellationToken: cts.Token)); + + // Act + hedgingCall.StartUnary(new HelloRequest()); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 0, "Wait for all calls to fail.").DefaultTimeout(); + + cts.Cancel(); + + var ex = await ExceptionAssert.ThrowsAsync(() => hedgingCall.GetResponseAsync()).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual("Call canceled by the client.", ex.Status.Detail); + } + + [Test] + public async Task AsyncUnaryCall_DisposeDuringBackoff_CanceledStatus() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromSeconds(10)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(serviceConfig.MethodConfigs[0].HedgingPolicy!, invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions()); + + // Act + hedgingCall.StartUnary(new HelloRequest()); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 0, "Wait for all calls to fail.").DefaultTimeout(); + + hedgingCall.Dispose(); + + var ex = await ExceptionAssert.ThrowsAsync(() => hedgingCall.GetResponseAsync()).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual("gRPC call disposed.", ex.Status.Detail); + } + } +} diff --git a/test/Grpc.Net.Client.Tests/Retry/HedgingTests.cs b/test/Grpc.Net.Client.Tests/Retry/HedgingTests.cs new file mode 100644 index 000000000..118409bae --- /dev/null +++ b/test/Grpc.Net.Client.Tests/Retry/HedgingTests.cs @@ -0,0 +1,654 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Google.Protobuf; +using Greet; +using Grpc.Core; +using Grpc.Net.Client.Internal; +using Grpc.Net.Client.Internal.Http; +using Grpc.Net.Client.Tests.Infrastructure; +using Grpc.Tests.Shared; +using NUnit.Framework; + +namespace Grpc.Net.Client.Tests.Retry +{ + [TestFixture] + public class HedgingTests + { + [TestCase(1)] + [TestCase(10)] + [TestCase(100)] + public async Task AsyncUnaryCall_OneAttempt_Success(int maxAttempts) + { + // Arrange + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + await tcs.Task; + + await request.Content!.CopyToAsync(new MemoryStream()); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: maxAttempts); + var invoker = HttpClientCallInvokerFactory.Create( + httpClient, + serviceConfig: serviceConfig, + configure: o => o.MaxRetryAttempts = maxAttempts); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => callCount == maxAttempts, "All calls made at once."); + tcs.SetResult(null); + + var rs = await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual("Hello world", rs.Message); + Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode); + } + + [Test] + public async Task AsyncClientStreamingCall_ManyParallelCalls_ReadDirectlyToRequestStream() + { + // Arrange + var requestStreams = new List(); + var attempts = 100; + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + WriterTestStream writerTestStream; + lock (requestStreams) + { + Interlocked.Increment(ref callCount); + writerTestStream = new WriterTestStream(); + requestStreams.Add(writerTestStream); + } + await request.Content!.CopyToAsync(writerTestStream); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: attempts); + var invoker = HttpClientCallInvokerFactory.Create( + httpClient, + serviceConfig: serviceConfig, + configure: o => o.MaxRetryAttempts = attempts); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions()); + var writeAsyncTask = call.RequestStream.WriteAsync(new HelloRequest { Name = "World" }); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => callCount == attempts, "All calls made at once."); + + var firstMessages = await Task.WhenAll(requestStreams.Select(s => s.WaitForDataAsync())).DefaultTimeout(); + await writeAsyncTask.DefaultTimeout(); + + foreach (var message in firstMessages) + { + Assert.IsTrue(firstMessages[0].Span.SequenceEqual(message.Span)); + } + + writeAsyncTask = call.RequestStream.WriteAsync(new HelloRequest { Name = "World 2" }); + var secondMessages = await Task.WhenAll(requestStreams.Select(s => s.WaitForDataAsync())).DefaultTimeout(); + await writeAsyncTask.DefaultTimeout(); + + foreach (var message in secondMessages) + { + Assert.IsTrue(secondMessages[0].Span.SequenceEqual(message.Span)); + } + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var rs = await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual("Hello world", rs.Message); + Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode); + } + + private class WriterTestStream : Stream + { + public TaskCompletionSource> WriteAsyncTcs = new TaskCompletionSource>(TaskCreationOptions.RunContinuationsAsynchronously); + + public override bool CanRead { get; } + public override bool CanSeek { get; } + public override bool CanWrite { get; } + public override long Length { get; } + public override long Position { get; set; } + + private SyncPoint _syncPoint; + private Func _awaiter; + private ReadOnlyMemory _currentWriteData; + + public WriterTestStream() + { + _awaiter = SyncPoint.Create(out _syncPoint, runContinuationsAsynchronously: true); + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + +#if NET472 + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) +#else + public override async ValueTask WriteAsync(ReadOnlyMemory data, CancellationToken cancellationToken = default) +#endif + { +#if NET472 + var data = buffer.AsMemory(offset, count); +#endif + _currentWriteData = data.ToArray(); + + await _awaiter(); + // Wait until data is read by WaitForDataAsync + //await _syncPoint.WaitForSyncPoint(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public async Task> WaitForDataAsync() + { + await _syncPoint.WaitForSyncPoint(); + + ResetSyncPointAndContinuePrevious(); + + //await _awaiter(); + return _currentWriteData; + } + + private void ResetSyncPointAndContinuePrevious() + { + // We have read all data + // Signal AddDataAndWait to continue + // Reset sync point for next read + var syncPoint = _syncPoint; + + ResetSyncPoint(); + + syncPoint.Continue(); + } + + private void ResetSyncPoint() + { + _awaiter = SyncPoint.Create(out _syncPoint, runContinuationsAsynchronously: true); + } + } + + [Test] + public async Task AsyncUnaryCall_ExceedAttempts_Failure() + { + // Arrange + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var requestMessages = new List(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + // All calls are in-progress at once. + Interlocked.Increment(ref callCount); + if (callCount == 5) + { + tcs.TrySetResult(null); + } + await tcs.Task; + + var requestContent = await request.Content!.ReadAsStreamAsync(); + var requestMessage = await ReadRequestMessage(requestContent); + lock (requestMessages) + { + requestMessages.Add(requestMessage!); + } + + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(5, callCount); + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(5, requestMessages.Count); + foreach (var requestMessage in requestMessages) + { + Assert.AreEqual("World", requestMessage.Name); + } + } + + [Test] + public async Task AsyncUnaryCall_ExceedDeadlineWithActiveCalls_Failure() + { + // Arrange + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(request => + { + Interlocked.Increment(ref callCount); + return tcs.Task; + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromMilliseconds(200)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(100)), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(1, callCount); + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode); + Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode); + } + + [Test] + public async Task AsyncUnaryCall_ManyAttemptsNoDelay_MarshallerCalledOnce() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + var marshallerCount = 0; + var requestMarshaller = Marshallers.Create( + r => + { + Interlocked.Increment(ref marshallerCount); + return r.ToByteArray(); + }, + data => HelloRequest.Parser.ParseFrom(data)); + var method = ClientTestHelpers.GetServiceMethod(requestMarshaller: requestMarshaller); + + // Act + var call = invoker.AsyncUnaryCall(method, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + Assert.AreEqual(5, callCount); + Assert.AreEqual(1, marshallerCount); + } + + [Test] + public async Task AsyncUnaryCall_ExceedAttempts_HedgeDelay_Failure() + { + // Arrange + var stopwatch = new Stopwatch(); + var callIntervals = new List(); + var hedgeDelay = TimeSpan.FromMilliseconds(100); + const int timerResolutionMs = 15 * 2; // Timer has a precision of about 15ms. Double it, just to be safe + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callIntervals.Add(stopwatch.ElapsedMilliseconds); + stopwatch.Restart(); + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 2, hedgingDelay: hedgeDelay); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + stopwatch.Start(); + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(2, callCount); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + + // First call should happen immediately + Assert.LessOrEqual(callIntervals[0], hedgeDelay.TotalMilliseconds); + + // Second call should happen after delay + Console.WriteLine(callIntervals[0]); + Console.WriteLine(callIntervals[1]); + Assert.GreaterOrEqual(callIntervals[1], hedgeDelay.TotalMilliseconds - timerResolutionMs); + } + + [Test] + public async Task AsyncUnaryCall_PushbackDelay_PushbackDelayUpdatesNextCallDelay() + { + // Arrange + var stopwatch = new Stopwatch(); + var callIntervals = new List(); + var hedgingDelay = TimeSpan.FromSeconds(10); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callIntervals.Add(stopwatch.ElapsedMilliseconds); + stopwatch.Restart(); + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + string? hedgingPushback = null; + if (callCount == 1) + { + hedgingPushback = "0"; + } + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: hedgingPushback); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 5, hedgingDelay: hedgingDelay); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + stopwatch.Start(); + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + await TestHelpers.AssertIsTrueRetryAsync(() => callIntervals.Count == 2, "Only two calls should be made.").DefaultTimeout(); + + // First call should happen immediately + Assert.LessOrEqual(callIntervals[0], 100); + + // Second call should happen after delay + Assert.LessOrEqual(callIntervals[1], hedgingDelay.TotalMilliseconds); + } + + [Test] + public async Task AsyncUnaryCall_FatalStatusCode_HedgeDelay_Failure() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, (callCount == 1) ? StatusCode.Unavailable : StatusCode.InvalidArgument); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromMilliseconds(50)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.InvalidArgument, ex.StatusCode); + Assert.AreEqual(StatusCode.InvalidArgument, call.GetStatus().StatusCode); + Assert.AreEqual(2, callCount); + } + + [Test] + public async Task AsyncServerStreamingCall_SuccessAfterRetry_RequestContentSent() + { + // Arrange + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + MemoryStream? requestContent = null; + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + var s = await request.Content!.ReadAsStreamAsync(); + var ms = new MemoryStream(); + await s.CopyToAsync(ms); + + if (callCount == 1) + { + await syncPoint.WaitForSyncPoint(); + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + syncPoint.Continue(); + + requestContent = ms; + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 2, hedgingDelay: TimeSpan.FromMilliseconds(50)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncServerStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ServerStreaming), string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + var moveNextTask = call.ResponseStream.MoveNext(CancellationToken.None); + + // Wait until the first call has failed and the second is on the server + await syncPoint.WaitToContinue().DefaultTimeout(); + + // Assert + Assert.IsTrue(await moveNextTask); + Assert.AreEqual("Hello world", call.ResponseStream.Current.Message); + + requestContent!.Seek(0, SeekOrigin.Begin); + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.AreEqual("World", requestMessage!.Name); + } + + [TestCase(0)] + [TestCase(1)] + [TestCase(100)] + public async Task AsyncClientStreamingCall_SuccessAfterRetry_RequestContentSent(int hedgingDelayMS) + { + // Arrange + var callLock = new object(); + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + var firstCall = false; + lock (callLock) + { + callCount++; + if (callCount == 1) + { + firstCall = true; + } + } + if (firstCall) + { + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + var content = (PushStreamContent)request.Content!; + await content.PushComplete.DefaultTimeout(); + + await request.Content!.CopyToAsync(requestContent); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(maxAttempts: 2, hedgingDelay: TimeSpan.FromMilliseconds(hedgingDelayMS)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + Assert.IsNotNull(call); + + var responseTask = call.ResponseAsync; + Assert.IsFalse(responseTask.IsCompleted, "Response not returned until client stream is complete."); + + + await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout(); + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var responseMessage = await responseTask.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + requestContent.Seek(0, SeekOrigin.Begin); + + var requests = new List(); + while (true) + { + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + if (requestMessage == null) + { + break; + } + + requests.Add(requestMessage); + } + + Assert.AreEqual(2, requests.Count); + Assert.AreEqual("1", requests[0].Name); + Assert.AreEqual("2", requests[1].Name); + } + + [Test] + public async Task AsyncClientStreamingCall_CompleteAndWriteAfterResult_Error() + { + // Arrange + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + _ = request.Content!.ReadAsStreamAsync(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + var responseMessage = await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + requestContent.Seek(0, SeekOrigin.Begin); + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout(); + Assert.AreEqual("Request stream has already been completed.", ex.Message); + } + + [Test] + public async Task AsyncClientStreamingCall_WriteAfterResult_Error() + { + // Arrange + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + _ = request.Content!.ReadAsStreamAsync(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + var responseMessage = await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout(); + Assert.AreEqual(StatusCode.OK, ex.StatusCode); + } + + private static Task ReadRequestMessage(Stream requestContent) + { + return StreamSerializationHelper.ReadMessageAsync( + requestContent, + ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, + GrpcProtocolConstants.IdentityGrpcEncoding, + maximumMessageSize: null, + GrpcProtocolConstants.DefaultCompressionProviders, + singleMessage: false, + CancellationToken.None); + } + } +} diff --git a/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs b/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs new file mode 100644 index 000000000..6eea619d7 --- /dev/null +++ b/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs @@ -0,0 +1,781 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Greet; +using Grpc.Core; +using Grpc.Net.Client.Internal; +using Grpc.Net.Client.Internal.Http; +using Grpc.Net.Client.Tests.Infrastructure; +using Grpc.Tests.Shared; +using Microsoft.Extensions.Logging; +using NUnit.Framework; + +namespace Grpc.Net.Client.Tests.Retry +{ + [TestFixture] + public class RetryTests + { + [Test] + public async Task AsyncUnaryCall_SuccessAfterRetry_RequestContentSent() + { + // Arrange + HttpContent? content = null; + + bool? firstRequestPreviousAttemptsHeader = null; + string? secondRequestPreviousAttemptsHeaderValue = null; + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + + content = request.Content!; + await content.CopyToAsync(requestContent); + requestContent.Seek(0, SeekOrigin.Begin); + + if (callCount == 1) + { + firstRequestPreviousAttemptsHeader = request.Headers.TryGetValues(GrpcProtocolConstants.RetryPreviousAttemptsHeader, out _); + + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + if (request.Headers.TryGetValues(GrpcProtocolConstants.RetryPreviousAttemptsHeader, out var retryAttemptCountValue)) + { + secondRequestPreviousAttemptsHeaderValue = retryAttemptCountValue.Single(); + } + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent, customTrailers: new Dictionary + { + ["custom-trailer"] = "Value!" + }); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(2, callCount); + Assert.AreEqual("Hello world", (await call.ResponseAsync.DefaultTimeout()).Message); + Assert.AreEqual("1", (await call.ResponseHeadersAsync.DefaultTimeout()).GetValue(GrpcProtocolConstants.RetryPreviousAttemptsHeader)); + + Assert.IsNotNull(content); + + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + + Assert.AreEqual("World", requestMessage!.Name); + + Assert.IsFalse(firstRequestPreviousAttemptsHeader); + Assert.AreEqual("1", secondRequestPreviousAttemptsHeaderValue); + + var trailers = call.GetTrailers(); + Assert.AreEqual("Value!", trailers.GetValue("custom-trailer")); + } + + [Test] + public async Task AsyncUnaryCall_SuccessAfterRetry_AccessResponseHeaders_SuccessfullyResponseHeadersReturned() + { + // Arrange + HttpContent? content = null; + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + content = request.Content!; + + if (callCount == 1) + { + await content.CopyToAsync(new MemoryStream()); + + await syncPoint.WaitForSyncPoint(); + + return ResponseUtils.CreateHeadersOnlyResponse( + HttpStatusCode.OK, + StatusCode.Unavailable, + customHeaders: new Dictionary { ["call-count"] = callCount.ToString() }); + } + + syncPoint.Continue(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse( + HttpStatusCode.OK, + streamContent, + customHeaders: new Dictionary { ["call-count"] = callCount.ToString() }); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + var headersTask = call.ResponseHeadersAsync; + + // Wait until the first call has failed and the second is on the server + await syncPoint.WaitToContinue().DefaultTimeout(); + + // Assert + Assert.AreEqual(2, callCount); + Assert.AreEqual("Hello world", (await call.ResponseAsync.DefaultTimeout()).Message); + + var headers = await headersTask.DefaultTimeout(); + Assert.AreEqual("2", headers.GetValue("call-count")); + } + + [Test] + public async Task AsyncUnaryCall_ExceedRetryAttempts_Failure() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(maxAttempts: 3); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(3, callCount); + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + } + + [Test] + public async Task AsyncUnaryCall_FailureWithLongDelay_Dispose_CallImmediatelyDisposed() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + }); + // Very long delay + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(initialBackoff: TimeSpan.FromSeconds(30), maxBackoff: TimeSpan.FromSeconds(30)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + var resultTask = call.ResponseAsync; + + // Test will timeout if dispose doesn't kill the timer. + call.Dispose(); + + // Assert + var ex = await ExceptionAssert.ThrowsAsync(() => resultTask).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual("gRPC call disposed.", ex.Status.Detail); + } + + [TestCase("")] + [TestCase("-1")] + [TestCase("stop")] + public async Task AsyncUnaryCall_PushbackStop_Failure(string header) + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, new StringContent(""), StatusCode.Unavailable, retryPushbackHeader: header); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(1, callCount); + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + Assert.AreEqual(StatusCode.Unavailable, call.GetStatus().StatusCode); + } + + [Test] + public async Task AsyncUnaryCall_PushbackExpicitDelay_DelayForSpecifiedDuration() + { + // Arrange + Task? delayTask = null; + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + if (callCount == 1) + { + await request.Content!.CopyToAsync(new MemoryStream()); + delayTask = Task.Delay(100); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: "200"); + } + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(backoffMultiplier: 1); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Delay of 100ms will finish before second record which has a pushback delay of 200ms + var completedTask = await Task.WhenAny(call.ResponseAsync, delayTask!).DefaultTimeout(); + var rs = await call.ResponseAsync.DefaultTimeout(); + + // Assert + Assert.AreEqual(delayTask, completedTask); // Response task should finish after + Assert.AreEqual(2, callCount); + Assert.AreEqual("Hello world", rs.Message); + } + + [Test] + public async Task AsyncUnaryCall_CancellationDuringBackoff_CanceledStatus() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var cts = new CancellationTokenSource(); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(cancellationToken: cts.Token), new HelloRequest { Name = "World" }); + + var delayTask = Task.Delay(100); + var completedTask = await Task.WhenAny(call.ResponseAsync, delayTask); + + // Assert + Assert.AreEqual(delayTask, completedTask); // Ensure that we're waiting for retry + + cts.Cancel(); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual("Call canceled by the client.", ex.Status.Detail); + } + + [Test] + public async Task AsyncUnaryCall_DisposeDuringBackoff_CanceledStatus() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: TimeSpan.FromSeconds(10).TotalMilliseconds.ToString()); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var cts = new CancellationTokenSource(); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(cancellationToken: cts.Token), new HelloRequest { Name = "World" }); + + var delayTask = Task.Delay(100); + var completedTask = await Task.WhenAny(call.ResponseAsync, delayTask); + + // Assert + Assert.AreEqual(delayTask, completedTask); // Ensure that we're waiting for retry + + call.Dispose(); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + Assert.AreEqual("gRPC call disposed.", ex.Status.Detail); + } + + [Test] + public async Task AsyncUnaryCall_PushbackExplicitDelayExceedAttempts_Failure() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: "0"); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(maxAttempts: 5); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(5, callCount); + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode); + } + + [Test] + public async Task AsyncUnaryCall_UnsupportedStatusCode_Failure() + { + // Arrange + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + await request.Content!.CopyToAsync(new MemoryStream()); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, new StringContent(""), StatusCode.InvalidArgument); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(1, callCount); + var ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.InvalidArgument, ex.StatusCode); + Assert.AreEqual(StatusCode.InvalidArgument, call.GetStatus().StatusCode); + } + + [Test] + public async Task AsyncUnaryCall_Success_RequestContentSent() + { + // Arrange + HttpContent? content = null; + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + content = request.Content; + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncUnaryCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + + // Assert + Assert.AreEqual(1, callCount); + Assert.AreEqual("Hello world", (await call.ResponseAsync.DefaultTimeout()).Message); + } + + [Test] + public async Task AsyncClientStreamingCall_SuccessAfterRetry_RequestContentSent() + { + // Arrange + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + var currentContent = new MemoryStream(); + await request.Content!.CopyToAsync(currentContent); + + if (callCount == 1) + { + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + currentContent.Seek(0, SeekOrigin.Begin); + await currentContent.CopyToAsync(requestContent); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + Assert.IsNotNull(call); + + var responseTask = call.ResponseAsync; + Assert.IsFalse(responseTask.IsCompleted, "Response not returned until client stream is complete."); + + await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout(); + await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout(); + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var responseMessage = await responseTask.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + requestContent.Seek(0, SeekOrigin.Begin); + + var requests = new List(); + while (true) + { + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + if (requestMessage == null) + { + break; + } + + requests.Add(requestMessage); + } + + Assert.AreEqual(2, requests.Count); + Assert.AreEqual("1", requests[0].Name); + Assert.AreEqual("2", requests[1].Name); + + call.Dispose(); + } + + [Test] + public async Task AsyncClientStreamingCall_CompleteAndWriteAfterResult_Error() + { + // Arrange + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + _ = request.Content!.ReadAsStreamAsync(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + var responseMessage = await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + requestContent.Seek(0, SeekOrigin.Begin); + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout(); + Assert.AreEqual("Request stream has already been completed.", ex.Message); + } + + [Test] + public async Task AsyncClientStreamingCall_WriteAfterResult_Error() + { + // Arrange + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + Interlocked.Increment(ref callCount); + + _ = request.Content!.ReadAsStreamAsync(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + var responseMessage = await call.ResponseAsync.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + var ex = await ExceptionAssert.ThrowsAsync(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout(); + Assert.AreEqual(StatusCode.OK, ex.StatusCode); + } + + [Test] + public async Task AsyncClientStreamingCall_OneMessageSentThenRetryThenAnotherMessage_RequestContentSent() + { + // Arrange + var requestContent = new MemoryStream(); + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + var content = (PushStreamContent)request.Content!; + + if (callCount == 1) + { + _ = content.CopyToAsync(new MemoryStream()); + + await syncPoint.WaitForSyncPoint(); + + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + syncPoint.Continue(); + + await content.PushComplete.DefaultTimeout(); + await content.CopyToAsync(requestContent); + requestContent.Seek(0, SeekOrigin.Begin); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ClientStreaming), string.Empty, new CallOptions()); + + // Assert + Assert.IsNotNull(call); + + var responseTask = call.ResponseAsync; + Assert.IsFalse(responseTask.IsCompleted, "Response not returned until client stream is complete."); + + await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout(); + + // Wait until the first call has failed and the second is on the server + await syncPoint.WaitToContinue().DefaultTimeout(); + + await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout(); + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + var responseMessage = await responseTask.DefaultTimeout(); + Assert.AreEqual("Hello world", responseMessage.Message); + + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.AreEqual("1", requestMessage!.Name); + requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.AreEqual("2", requestMessage!.Name); + requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.IsNull(requestMessage); + } + + [Test] + public async Task AsyncServerStreamingCall_SuccessAfterRetry_RequestContentSent() + { + // Arrange + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + var requestContent = new MemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + + var content = request.Content!; + await content.CopyToAsync(requestContent); + requestContent.Seek(0, SeekOrigin.Begin); + + if (callCount == 1) + { + await syncPoint.WaitForSyncPoint(); + + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + syncPoint.Continue(); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncServerStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.ServerStreaming), string.Empty, new CallOptions(), new HelloRequest { Name = "World" }); + var moveNextTask = call.ResponseStream.MoveNext(CancellationToken.None); + + // Wait until the first call has failed and the second is on the server + await syncPoint.WaitToContinue().DefaultTimeout(); + + // Assert + Assert.IsTrue(await moveNextTask); + Assert.AreEqual("Hello world", call.ResponseStream.Current.Message); + + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.AreEqual("World", requestMessage!.Name); + requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.IsNull(requestMessage); + } + + [Test] + public async Task AsyncServerStreamingCall_FailureAfterReadingResponseMessage_Failure() + { + // Arrange + var streamContent = new SyncPointMemoryStream(); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(request => + { + callCount++; + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, new StreamContent(streamContent))); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncServerStreamingCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest()); + + var responseStream = call.ResponseStream; + + // Assert + Assert.IsNull(responseStream.Current); + + var moveNextTask1 = responseStream.MoveNext(CancellationToken.None); + Assert.IsFalse(moveNextTask1.IsCompleted); + + await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync(new HelloReply + { + Message = "Hello world 1" + }).DefaultTimeout()).DefaultTimeout(); + + Assert.IsTrue(await moveNextTask1.DefaultTimeout()); + Assert.IsNotNull(responseStream.Current); + Assert.AreEqual("Hello world 1", responseStream.Current.Message); + + var moveNextTask2 = responseStream.MoveNext(CancellationToken.None); + Assert.IsFalse(moveNextTask2.IsCompleted); + + await streamContent.AddExceptionAndWait(new Exception("Exception!")).DefaultTimeout(); + + var ex = await ExceptionAssert.ThrowsAsync(() => moveNextTask2).DefaultTimeout(); + Assert.AreEqual(StatusCode.Internal, ex.StatusCode); + Assert.AreEqual(StatusCode.Internal, call.GetStatus().StatusCode); + Assert.AreEqual("Error reading next message. Exception: Exception!", call.GetStatus().Detail); + } + + [Test] + public async Task AsyncDuplexStreamingCall_SuccessAfterRetry_RequestContentSent() + { + // Arrange + var requestContent = new MemoryStream(); + var syncPoint = new SyncPoint(runContinuationsAsynchronously: true); + + var callCount = 0; + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + callCount++; + var content = (PushStreamContent)request.Content!; + + if (callCount == 1) + { + _ = content.CopyToAsync(new MemoryStream()); + + await syncPoint.WaitForSyncPoint(); + + return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable); + } + + syncPoint.Continue(); + + await content.PushComplete.DefaultTimeout(); + await content.CopyToAsync(requestContent); + requestContent.Seek(0, SeekOrigin.Begin); + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + + // Act + var call = invoker.AsyncDuplexStreamingCall(ClientTestHelpers.GetServiceMethod(MethodType.DuplexStreaming), string.Empty, new CallOptions()); + var moveNextTask = call.ResponseStream.MoveNext(CancellationToken.None); + + await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout(); + + // Wait until the first call has failed and the second is on the server + await syncPoint.WaitToContinue().DefaultTimeout(); + + await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout(); + + await call.RequestStream.CompleteAsync().DefaultTimeout(); + + // Assert + Assert.IsTrue(await moveNextTask.DefaultTimeout()); + Assert.AreEqual("Hello world", call.ResponseStream.Current.Message); + + var requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.AreEqual("1", requestMessage!.Name); + requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.AreEqual("2", requestMessage!.Name); + requestMessage = await ReadRequestMessage(requestContent).DefaultTimeout(); + Assert.IsNull(requestMessage); + } + + private static Task ReadRequestMessage(Stream requestContent) + { + return StreamSerializationHelper.ReadMessageAsync( + requestContent, + ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, + GrpcProtocolConstants.IdentityGrpcEncoding, + maximumMessageSize: null, + GrpcProtocolConstants.DefaultCompressionProviders, + singleMessage: false, + CancellationToken.None); + } + } +} diff --git a/test/Grpc.Net.Client.Tests/ServiceConfigTests.cs b/test/Grpc.Net.Client.Tests/ServiceConfigTests.cs new file mode 100644 index 000000000..bf5317a03 --- /dev/null +++ b/test/Grpc.Net.Client.Tests/ServiceConfigTests.cs @@ -0,0 +1,129 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using Grpc.Core; +using Grpc.Net.Client.Configuration; +using NUnit.Framework; + +namespace Grpc.Net.Client.Tests +{ + [TestFixture] + public class ServiceConfigTests + { + [Test] + public void ServiceConfig_CreateUnderlyingConfig() + { + // Arrange & Act + var serviceConfig = new ServiceConfig + { + MethodConfigs = + { + new MethodConfig + { + Names = { new MethodName() }, + RetryPolicy = new RetryPolicy + { + MaxAttempts = 5, + InitialBackoff = TimeSpan.FromSeconds(1), + RetryableStatusCodes = { StatusCode.Unavailable, StatusCode.Aborted } + } + } + } + }; + + // Assert + Assert.AreEqual(1, serviceConfig.MethodConfigs.Count); + Assert.AreEqual(1, serviceConfig.MethodConfigs[0].Names.Count); + Assert.AreEqual(5, serviceConfig.MethodConfigs[0].RetryPolicy!.MaxAttempts); + Assert.AreEqual(TimeSpan.FromSeconds(1), serviceConfig.MethodConfigs[0].RetryPolicy!.InitialBackoff); + Assert.AreEqual(StatusCode.Unavailable, serviceConfig.MethodConfigs[0].RetryPolicy!.RetryableStatusCodes[0]); + Assert.AreEqual(StatusCode.Aborted, serviceConfig.MethodConfigs[0].RetryPolicy!.RetryableStatusCodes[1]); + + var inner = serviceConfig.Inner; + var methodConfigs = (IList)inner["method_config"]; + var allServices = (IDictionary)methodConfigs[0]; + + Assert.AreEqual(5, (int)((IDictionary)allServices["retryPolicy"])["maxAttempts"]); + Assert.AreEqual("1s", (string)((IDictionary)allServices["retryPolicy"])["initialBackoff"]); + Assert.AreEqual("UNAVAILABLE", (string)((IList)((IDictionary)allServices["retryPolicy"])["retryableStatusCodes"])[0]); + Assert.AreEqual("ABORTED", (string)((IList)((IDictionary)allServices["retryPolicy"])["retryableStatusCodes"])[1]); + } + + [Test] + public void ServiceConfig_ReadUnderlyingConfig() + { + // Arrange + var inner = new Dictionary + { + ["method_config"] = new List + { + new Dictionary + { + ["name"] = new List { new Dictionary() }, + ["retryPolicy"] = new Dictionary + { + ["maxAttempts"] = 5, + ["initialBackoff"] = "1s", + ["retryableStatusCodes"] = new List { "UNAVAILABLE", "ABORTED" } + } + } + } + }; + + // Act + var serviceConfig = new ServiceConfig(inner); + + // Assert + Assert.AreEqual(1, serviceConfig.MethodConfigs.Count); + Assert.AreEqual(1, serviceConfig.MethodConfigs[0].Names.Count); + Assert.AreEqual(5, serviceConfig.MethodConfigs[0].RetryPolicy!.MaxAttempts); + Assert.AreEqual(TimeSpan.FromSeconds(1), serviceConfig.MethodConfigs[0].RetryPolicy!.InitialBackoff); + Assert.AreEqual(StatusCode.Unavailable, serviceConfig.MethodConfigs[0].RetryPolicy!.RetryableStatusCodes[0]); + Assert.AreEqual(StatusCode.Aborted, serviceConfig.MethodConfigs[0].RetryPolicy!.RetryableStatusCodes[1]); + } + + [Test] + public void RetryThrottlingPolicy_ReadUnderlyingConfig_Success() + { + // Arrange + var inner = new Dictionary + { + ["initialBackoff"] = "1.1s", + ["retryableStatusCodes"] = new List { "UNAVAILABLE", "Aborted", 1 } + }; + + // Act + var retryPolicy = new RetryPolicy(inner); + + // Assert + Assert.AreEqual(TimeSpan.FromSeconds(1.1), retryPolicy.InitialBackoff); + Assert.AreEqual(StatusCode.Unavailable, retryPolicy.RetryableStatusCodes[0]); + Assert.AreEqual(StatusCode.Aborted, retryPolicy.RetryableStatusCodes[1]); + Assert.AreEqual(StatusCode.Cancelled, retryPolicy.RetryableStatusCodes[2]); + } + + [Test] + public void MethodName_Default_ErrorOnChange() + { + // Arrange & Act & Assert + Assert.Throws(() => MethodName.Default.Method = "This will break"); + } + } +} diff --git a/test/Grpc.Net.Client.Web.Tests/Base64ResponseStreamTests.cs b/test/Grpc.Net.Client.Web.Tests/Base64ResponseStreamTests.cs index e8a56e658..0e87b91f1 100644 --- a/test/Grpc.Net.Client.Web.Tests/Base64ResponseStreamTests.cs +++ b/test/Grpc.Net.Client.Web.Tests/Base64ResponseStreamTests.cs @@ -48,7 +48,7 @@ public async Task ReadAsync_ReadLargeData_Success() var messageCount = 3; var streamContent = new List(); - for (int i = 0; i < messageCount; i++) + for (var i = 0; i < messageCount; i++) { streamContent.AddRange(messageContent); } @@ -56,7 +56,7 @@ public async Task ReadAsync_ReadLargeData_Success() var ms = new LimitedReadMemoryStream(streamContent.ToArray(), 3); var base64Stream = new Base64ResponseStream(ms); - for (int i = 0; i < messageCount; i++) + for (var i = 0; i < messageCount; i++) { // Assert 1 var resolvedHeaderData = await ReadContent(base64Stream, 5, CancellationToken.None); diff --git a/test/Shared/ClientTestHelpers.cs b/test/Shared/ClientTestHelpers.cs index 89f5fb182..81c904d18 100644 --- a/test/Shared/ClientTestHelpers.cs +++ b/test/Shared/ClientTestHelpers.cs @@ -17,6 +17,7 @@ #endregion using System; +using System.Collections.Generic; using System.IO; using System.Net; using System.Net.Http; @@ -26,6 +27,7 @@ using Google.Protobuf; using Greet; using Grpc.Core; +using Grpc.Net.Client.Configuration; using Grpc.Net.Compression; namespace Grpc.Tests.Shared @@ -35,7 +37,12 @@ internal static class ClientTestHelpers public static readonly Marshaller HelloRequestMarshaller = Marshallers.Create(r => r.ToByteArray(), data => HelloRequest.Parser.ParseFrom(data)); public static readonly Marshaller HelloReplyMarshaller = Marshallers.Create(r => r.ToByteArray(), data => HelloReply.Parser.ParseFrom(data)); - public static readonly Method ServiceMethod = new Method(MethodType.Unary, "ServiceName", "MethodName", HelloRequestMarshaller, HelloReplyMarshaller); + public static readonly Method ServiceMethod = GetServiceMethod(MethodType.Unary); + + public static Method GetServiceMethod(MethodType? methodType = null, Marshaller? requestMarshaller = null) + { + return new Method(methodType ?? MethodType.Unary, "ServiceName", "MethodName", requestMarshaller ?? HelloRequestMarshaller, HelloReplyMarshaller); + } public static TestHttpMessageHandler CreateTestMessageHandler(HelloReply reply) { diff --git a/test/Shared/ExceptionAssert.cs b/test/Shared/ExceptionAssert.cs index 0317c2908..0a0684e12 100644 --- a/test/Shared/ExceptionAssert.cs +++ b/test/Shared/ExceptionAssert.cs @@ -27,6 +27,11 @@ public static class ExceptionAssert public static async Task ThrowsAsync(Func action, params string[] possibleMessages) where TException : Exception { + if (action == null) + { + throw new ArgumentNullException(nameof(action)); + } + try { await action(); diff --git a/test/Shared/ServiceConfigHelpers.cs b/test/Shared/ServiceConfigHelpers.cs new file mode 100644 index 000000000..713b63c54 --- /dev/null +++ b/test/Shared/ServiceConfigHelpers.cs @@ -0,0 +1,108 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using Grpc.Core; +using Grpc.Net.Client.Configuration; + +namespace Grpc.Tests.Shared +{ + internal static class ServiceConfigHelpers + { + public static ServiceConfig CreateRetryServiceConfig( + int? maxAttempts = null, + TimeSpan? initialBackoff = null, + TimeSpan? maxBackoff = null, + double? backoffMultiplier = null, + IList? retryableStatusCodes = null, + RetryThrottlingPolicy? retryThrottling = null) + { + var retryPolicy = new RetryPolicy + { + MaxAttempts = maxAttempts ?? 5, + InitialBackoff = initialBackoff ?? TimeSpan.Zero, + MaxBackoff = maxBackoff ?? TimeSpan.Zero, + BackoffMultiplier = backoffMultiplier ?? 1 + }; + + if (retryableStatusCodes != null) + { + foreach (var statusCode in retryableStatusCodes) + { + retryPolicy.RetryableStatusCodes.Add(statusCode); + } + } + else + { + retryPolicy.RetryableStatusCodes.Add(StatusCode.Unavailable); + } + + return new ServiceConfig + { + MethodConfigs = + { + new MethodConfig + { + Names = { MethodName.Default }, + RetryPolicy = retryPolicy + } + }, + RetryThrottling = retryThrottling + }; + } + + public static ServiceConfig CreateHedgingServiceConfig( + int? maxAttempts = null, + TimeSpan? hedgingDelay = null, + IList? nonFatalStatusCodes = null, + RetryThrottlingPolicy? retryThrottling = null) + { + var hedgingPolicy = new HedgingPolicy + { + MaxAttempts = maxAttempts ?? 5, + HedgingDelay = hedgingDelay ?? TimeSpan.Zero + }; + + if (nonFatalStatusCodes != null) + { + foreach (var statusCode in nonFatalStatusCodes) + { + hedgingPolicy.NonFatalStatusCodes.Add(statusCode); + } + } + else + { + hedgingPolicy.NonFatalStatusCodes.Add(StatusCode.Unavailable); + } + + return new ServiceConfig + { + MethodConfigs = + { + new MethodConfig + { + Names = { MethodName.Default }, + HedgingPolicy = hedgingPolicy + } + }, + RetryThrottling = retryThrottling + }; + } + } +} diff --git a/test/Shared/TestHelpers.cs b/test/Shared/TestHelpers.cs index a80f6bcbb..f49d086d3 100644 --- a/test/Shared/TestHelpers.cs +++ b/test/Shared/TestHelpers.cs @@ -35,7 +35,7 @@ public static async Task AssertIsTrueRetryAsync(Func assert, string messag { const int Retrys = 10; - for (int i = 0; i < Retrys; i++) + for (var i = 0; i < Retrys; i++) { if (i > 0) { @@ -54,7 +54,7 @@ public static async Task AssertIsTrueRetryAsync(Func assert, string messag public static async Task RunParallel(int count, Func action) { var actionTasks = new Task[count]; - for (int i = 0; i < actionTasks.Length; i++) + for (var i = 0; i < actionTasks.Length; i++) { actionTasks[i] = action(i); } diff --git a/testassets/InteropTestsWebsite/TestServiceImpl.cs b/testassets/InteropTestsWebsite/TestServiceImpl.cs index 86123d4b2..dbe7bddf0 100644 --- a/testassets/InteropTestsWebsite/TestServiceImpl.cs +++ b/testassets/InteropTestsWebsite/TestServiceImpl.cs @@ -71,6 +71,7 @@ await requestStream.ForEachAsync(request => sum += request.Payload.Body.Length; return Task.CompletedTask; }); + return new StreamingInputCallResponse { AggregatedPayloadSize = sum }; }