From fad333e5b0caaa88b4641758a12682a26d7a2278 Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Tue, 27 Jul 2021 13:09:17 +1200 Subject: [PATCH] Fix using gRPC reflection with services that have a base type --- .../GrpcReflectionServiceExtensions.cs | 59 ++++++++++++- .../Proto/greet.proto | 5 ++ .../ReflectionGrpcServiceActivatorTests.cs | 82 +++++++++++++++++-- 3 files changed, 136 insertions(+), 10 deletions(-) diff --git a/src/Grpc.AspNetCore.Server.Reflection/GrpcReflectionServiceExtensions.cs b/src/Grpc.AspNetCore.Server.Reflection/GrpcReflectionServiceExtensions.cs index 54d95cc66..9bbd3df64 100644 --- a/src/Grpc.AspNetCore.Server.Reflection/GrpcReflectionServiceExtensions.cs +++ b/src/Grpc.AspNetCore.Server.Reflection/GrpcReflectionServiceExtensions.cs @@ -21,6 +21,7 @@ using System.Linq; using System.Reflection; using Grpc.AspNetCore.Server; +using Grpc.Core; using Grpc.Reflection; using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -63,10 +64,7 @@ public static IServiceCollection AddGrpcReflection(this IServiceCollection servi foreach (var serviceType in serviceTypes) { - var baseType = GetServiceBaseType(serviceType); - var definitionType = baseType?.DeclaringType; - - var descriptorPropertyInfo = definitionType?.GetProperty("Descriptor", BindingFlags.Public | BindingFlags.Static); + var descriptorPropertyInfo = GetDescriptorProperty(serviceType); if (descriptorPropertyInfo != null) { if (descriptorPropertyInfo.GetValue(null) is Google.Protobuf.Reflection.ServiceDescriptor serviceDescriptor) @@ -85,6 +83,59 @@ public static IServiceCollection AddGrpcReflection(this IServiceCollection servi return services; } + private static PropertyInfo? GetDescriptorProperty(Type serviceType) + { + // Prefer finding the descriptor property using attribute on the generated service + var descriptorPropertyInfo = GetDescriptorPropertyUsingAttribute(serviceType); + + if (descriptorPropertyInfo == null) + { + // Fallback to searching for descriptor property using known type hierarchy that Grpc.Tools generates + descriptorPropertyInfo = GetDescriptorPropertyFallback(serviceType); + } + + return descriptorPropertyInfo; + } + + private static PropertyInfo? GetDescriptorPropertyUsingAttribute(Type serviceType) + { + Type? currentServiceType = serviceType; + BindServiceMethodAttribute? bindServiceMethod; + do + { + // Search through base types for bind service attribute. + bindServiceMethod = currentServiceType.GetCustomAttribute(); + if (bindServiceMethod != null) + { + // Descriptor property will be public and static and return ServiceDescriptor. + return bindServiceMethod.BindType.GetProperty( + "Descriptor", + BindingFlags.Public | BindingFlags.Static, + binder: null, + typeof(Google.Protobuf.Reflection.ServiceDescriptor), + Type.EmptyTypes, + Array.Empty()); + } + } while ((currentServiceType = currentServiceType.BaseType) != null); + + return null; + } + + private static PropertyInfo? GetDescriptorPropertyFallback(Type serviceType) + { + // Search for the generated service base class + var baseType = GetServiceBaseType(serviceType); + var definitionType = baseType?.DeclaringType; + + return definitionType?.GetProperty( + "Descriptor", + BindingFlags.Public | BindingFlags.Static, + binder: null, + typeof(Google.Protobuf.Reflection.ServiceDescriptor), + Type.EmptyTypes, + Array.Empty()); + } + private static Type? GetServiceBaseType(Type serviceImplementation) { // TService is an implementation of the gRPC service. It ultimately derives from Foo.TServiceBase base class. diff --git a/test/Grpc.AspNetCore.Server.Tests/Proto/greet.proto b/test/Grpc.AspNetCore.Server.Tests/Proto/greet.proto index 7945286d3..d0d90ea47 100644 --- a/test/Grpc.AspNetCore.Server.Tests/Proto/greet.proto +++ b/test/Grpc.AspNetCore.Server.Tests/Proto/greet.proto @@ -26,6 +26,11 @@ service SecondGreeter { rpc SayHellos (HelloRequest) returns (stream HelloReply); } +service ThirdGreeterWithBaseType { + rpc SayHello (HelloRequest) returns (HelloReply); + rpc SayHellos (HelloRequest) returns (stream HelloReply); +} + message HelloRequest { string name = 1; } diff --git a/test/Grpc.AspNetCore.Server.Tests/Reflection/ReflectionGrpcServiceActivatorTests.cs b/test/Grpc.AspNetCore.Server.Tests/Reflection/ReflectionGrpcServiceActivatorTests.cs index a7d0f2a2d..127323417 100644 --- a/test/Grpc.AspNetCore.Server.Tests/Reflection/ReflectionGrpcServiceActivatorTests.cs +++ b/test/Grpc.AspNetCore.Server.Tests/Reflection/ReflectionGrpcServiceActivatorTests.cs @@ -41,6 +41,56 @@ public class ReflectionGrpcServiceActivatorTests { [Test] public async Task Create_ConfiguredGrpcEndpoint_EndpointReturnedFromReflectionService() + { + // Arrange and act + TestServerStreamWriter writer = await ConfigureReflectionServerAndCallAsync(builder => + { + builder.MapGrpcService(); + }); + + // Assert + Assert.AreEqual(1, writer.Responses.Count); + Assert.AreEqual(1, writer.Responses[0].ListServicesResponse.Service.Count); + + var serviceResponse = writer.Responses[0].ListServicesResponse.Service[0]; + Assert.AreEqual("greet.Greeter", serviceResponse.Name); + } + + [Test] + public async Task Create_ConfiguredGrpcEndpointWithMultipleInheritenceLevel_EndpointReturnedFromReflectionService() + { + // Arrange and act + TestServerStreamWriter writer = await ConfigureReflectionServerAndCallAsync(builder => + { + builder.MapGrpcService(); + }); + + // Assert + Assert.AreEqual(1, writer.Responses.Count); + Assert.AreEqual(1, writer.Responses[0].ListServicesResponse.Service.Count); + + var serviceResponse = writer.Responses[0].ListServicesResponse.Service[0]; + Assert.AreEqual("greet.Greeter", serviceResponse.Name); + } + + [Test] + public async Task Create_ConfiguredGrpcEndpointWithBaseType_EndpointReturnedFromReflectionService() + { + // Arrange and act + TestServerStreamWriter writer = await ConfigureReflectionServerAndCallAsync(builder => + { + builder.MapGrpcService(); + }); + + // Assert + Assert.AreEqual(1, writer.Responses.Count); + Assert.AreEqual(1, writer.Responses[0].ListServicesResponse.Service.Count); + + var serviceResponse = writer.Responses[0].ListServicesResponse.Service[0]; + Assert.AreEqual("greet.ThirdGreeterWithBaseType", serviceResponse.Name); + } + + private static async Task> ConfigureReflectionServerAndCallAsync(Action action) { // Arrange var endpointRouteBuilder = new TestEndpointRouteBuilder(); @@ -56,7 +106,8 @@ public async Task Create_ConfiguredGrpcEndpoint_EndpointReturnedFromReflectionSe var serviceProvider = services.BuildServiceProvider(validateScopes: true); endpointRouteBuilder.ServiceProvider = serviceProvider; - endpointRouteBuilder.MapGrpcService(); + + action(endpointRouteBuilder); // Act var service = serviceProvider.GetRequiredService(); @@ -73,12 +124,16 @@ public async Task Create_ConfiguredGrpcEndpoint_EndpointReturnedFromReflectionSe await service.ServerReflectionInfo(reader, writer, context); - // Assert - Assert.AreEqual(1, writer.Responses.Count); - Assert.AreEqual(1, writer.Responses[0].ListServicesResponse.Service.Count); + return writer; + } + + private class InheritGreeterService : GreeterService + { + } + + private class GreeterServiceWithBaseType : ThirdGreeterWithBaseType.ThirdGreeterWithBaseTypeBase + { - var serviceResponse = writer.Responses[0].ListServicesResponse.Service[0]; - Assert.AreEqual("greet.Greeter", serviceResponse.Name); } private class GreeterService : Greeter.GreeterBase @@ -124,3 +179,18 @@ public IApplicationBuilder CreateApplicationBuilder() } } } + +namespace Greet +{ + public class ThirdGreeterBaseType + { + + } + + public static partial class ThirdGreeterWithBaseType + { + public partial class ThirdGreeterWithBaseTypeBase : ThirdGreeterBaseType + { + } + } +}