Skip to content

Commit

Permalink
[release/8.0] Fix support of FromKeyedServicesAttribute in ActivatorU…
Browse files Browse the repository at this point in the history
…tilities.CreateFactory (#92144)

* Fix support of FromKeyedServicesAttribute in ActivatorUtilities.CreateFactory

* Addressing comment and adding a test

---------

Co-authored-by: Benjamin Petit <bpetit@microsoft.com>
  • Loading branch information
github-actions[bot] and benjaminpetit committed Sep 16, 2023
1 parent 7b32406 commit c7252a3
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public static class ActivatorUtilities
#endif

private static readonly MethodInfo GetServiceInfo =
GetMethodInfo<Func<IServiceProvider, Type, Type, bool, object?>>((sp, t, r, c) => GetService(sp, t, r, c));
GetMethodInfo<Func<IServiceProvider, Type, Type, bool, object?, object?>>((sp, t, r, c, k) => GetService(sp, t, r, c, k));

/// <summary>
/// Instantiate a type with constructor arguments provided directly and/or from an <see cref="IServiceProvider"/>.
Expand Down Expand Up @@ -324,9 +324,9 @@ private static MethodInfo GetMethodInfo<T>(Expression<T> expr)
return mc.Method;
}

private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue)
private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue, object? key)
{
object? service = sp.GetService(type);
object? service = key == null ? sp.GetService(type) : GetKeyedService(sp, type, key);
if (service is null && !hasDefaultValue)
{
ThrowHelperUnableToResolveService(type, requiredBy);
Expand Down Expand Up @@ -361,10 +361,12 @@ private static BlockExpression BuildFactoryExpression(
}
else
{
var keyAttribute = (FromKeyedServicesAttribute?) Attribute.GetCustomAttribute(constructorParameter, typeof(FromKeyedServicesAttribute), inherit: false);
var parameterTypeExpression = new Expression[] { serviceProvider,
Expression.Constant(parameterType, typeof(Type)),
Expression.Constant(constructor.DeclaringType, typeof(Type)),
Expression.Constant(hasDefaultValue) };
Expression.Constant(hasDefaultValue),
Expression.Constant(keyAttribute?.Key) };
constructorArguments[i] = Expression.Call(GetServiceInfo, parameterTypeExpression);
}

Expand Down Expand Up @@ -435,10 +437,10 @@ private static ObjectFactory CreateFactoryReflection(
if (matchedArgCount == 0)
{
// All injected; use a fast path.
Type[] types = GetParameterTypes();
FactoryParameterContext[] parameters = GetFactoryParameterContext();
return useFixedValues ?
(serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, types, declaringType, serviceProvider) :
(serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, types, declaringType, serviceProvider);
(serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, parameters, declaringType, serviceProvider) :
(serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, parameters, declaringType, serviceProvider);
}

if (matchedArgCount == constructorParameters.Length)
Expand All @@ -456,16 +458,6 @@ ObjectFactory InvokeCanonical()
(serviceProvider, arguments) => ReflectionFactoryCanonicalFixed(invoker, parameters, declaringType, serviceProvider, arguments) :
(serviceProvider, arguments) => ReflectionFactoryCanonicalSpan(invoker, parameters, declaringType, serviceProvider, arguments);
}

Type[] GetParameterTypes()
{
Type[] types = new Type[constructorParameters.Length];
for (int i = 0; i < constructorParameters.Length; i++)
{
types[i] = constructorParameters[i].ParameterType;
}
return types;
}
#else
ParameterInfo[] constructorParameters = constructor.GetParameters();
if (constructorParameters.Length == 0)
Expand All @@ -484,8 +476,15 @@ FactoryParameterContext[] GetFactoryParameterContext()
for (int i = 0; i < constructorParameters.Length; i++)
{
ParameterInfo constructorParameter = constructorParameters[i];
FromKeyedServicesAttribute? attr = (FromKeyedServicesAttribute?)
Attribute.GetCustomAttribute(constructorParameter, typeof(FromKeyedServicesAttribute), inherit: false);
bool hasDefaultValue = ParameterDefaultValue.TryGetDefaultValue(constructorParameter, out object? defaultValue);
parameters[i] = new FactoryParameterContext(constructorParameter.ParameterType, hasDefaultValue, defaultValue, parameterMap[i] ?? -1);
parameters[i] = new FactoryParameterContext(
constructorParameter.ParameterType,
hasDefaultValue,
defaultValue,
parameterMap[i] ?? -1,
attr?.Key);
}

return parameters;
Expand All @@ -495,18 +494,20 @@ FactoryParameterContext[] GetFactoryParameterContext()

private readonly struct FactoryParameterContext
{
public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex)
public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex, object? serviceKey)
{
ParameterType = parameterType;
HasDefaultValue = hasDefaultValue;
DefaultValue = defaultValue;
ArgumentIndex = argumentIndex;
ServiceKey = serviceKey;
}

public Type ParameterType { get; }
public bool HasDefaultValue { get; }
public object? DefaultValue { get; }
public int ArgumentIndex { get; }
public object? ServiceKey { get; }
}

private static void FindApplicableConstructor(
Expand Down Expand Up @@ -825,57 +826,57 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments()
#if NET8_0_OR_GREATER // Use the faster ConstructorInvoker which also has alloc-free APIs when <= 4 parameters.
private static object ReflectionFactoryServiceOnlyFixed(
ConstructorInvoker invoker,
Type[] parameterTypes,
FactoryParameterContext[] parameters,
Type declaringType,
IServiceProvider serviceProvider)
{
Debug.Assert(parameterTypes.Length >= 1 && parameterTypes.Length <= FixedArgumentThreshold);
Debug.Assert(parameters.Length >= 1 && parameters.Length <= FixedArgumentThreshold);
Debug.Assert(FixedArgumentThreshold == 4);

if (serviceProvider is null)
ThrowHelperArgumentNullExceptionServiceProvider();

switch (parameterTypes.Length)
switch (parameters.Length)
{
case 1:
return invoker.Invoke(
GetService(serviceProvider, parameterTypes[0], declaringType, false));
GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey));

case 2:
return invoker.Invoke(
GetService(serviceProvider, parameterTypes[0], declaringType, false),
GetService(serviceProvider, parameterTypes[1], declaringType, false));
GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey),
GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey));

case 3:
return invoker.Invoke(
GetService(serviceProvider, parameterTypes[0], declaringType, false),
GetService(serviceProvider, parameterTypes[1], declaringType, false),
GetService(serviceProvider, parameterTypes[2], declaringType, false));
GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey),
GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey),
GetService(serviceProvider, parameters[2].ParameterType, declaringType, false, parameters[2].ServiceKey));

case 4:
return invoker.Invoke(
GetService(serviceProvider, parameterTypes[0], declaringType, false),
GetService(serviceProvider, parameterTypes[1], declaringType, false),
GetService(serviceProvider, parameterTypes[2], declaringType, false),
GetService(serviceProvider, parameterTypes[3], declaringType, false));
GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey),
GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey),
GetService(serviceProvider, parameters[2].ParameterType, declaringType, false, parameters[2].ServiceKey),
GetService(serviceProvider, parameters[3].ParameterType, declaringType, false, parameters[3].ServiceKey));
}

return null!;
}

private static object ReflectionFactoryServiceOnlySpan(
ConstructorInvoker invoker,
Type[] parameterTypes,
FactoryParameterContext[] parameters,
Type declaringType,
IServiceProvider serviceProvider)
{
if (serviceProvider is null)
ThrowHelperArgumentNullExceptionServiceProvider();

object?[] arguments = new object?[parameterTypes.Length];
for (int i = 0; i < parameterTypes.Length; i++)
object?[] arguments = new object?[parameters.Length];
for (int i = 0; i < parameters.Length; i++)
{
arguments[i] = GetService(serviceProvider, parameterTypes[i], declaringType, false);
arguments[i] = GetService(serviceProvider, parameters[i].ParameterType, declaringType, false, parameters[i].ServiceKey);
}

return invoker.Invoke(arguments.AsSpan());
Expand Down Expand Up @@ -907,7 +908,8 @@ private static object ReflectionFactoryCanonicalFixed(
serviceProvider,
parameter1.ParameterType,
declaringType,
parameter1.HasDefaultValue)) ?? parameter1.DefaultValue);
parameter1.HasDefaultValue,
parameter1.ServiceKey)) ?? parameter1.DefaultValue);
case 2:
{
ref FactoryParameterContext parameter2 = ref parameters[1];
Expand All @@ -920,15 +922,17 @@ private static object ReflectionFactoryCanonicalFixed(
serviceProvider,
parameter1.ParameterType,
declaringType,
parameter1.HasDefaultValue)) ?? parameter1.DefaultValue,
parameter1.HasDefaultValue,
parameter1.ServiceKey)) ?? parameter1.DefaultValue,
((parameter2.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter2.ArgumentIndex]
: GetService(
serviceProvider,
parameter2.ParameterType,
declaringType,
parameter2.HasDefaultValue)) ?? parameter2.DefaultValue);
parameter2.HasDefaultValue,
parameter2.ServiceKey)) ?? parameter2.DefaultValue);
}
case 3:
{
Expand All @@ -943,23 +947,26 @@ private static object ReflectionFactoryCanonicalFixed(
serviceProvider,
parameter1.ParameterType,
declaringType,
parameter1.HasDefaultValue)) ?? parameter1.DefaultValue,
parameter1.HasDefaultValue,
parameter1.ServiceKey)) ?? parameter1.DefaultValue,
((parameter2.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter2.ArgumentIndex]
: GetService(
serviceProvider,
parameter2.ParameterType,
declaringType,
parameter2.HasDefaultValue)) ?? parameter2.DefaultValue,
parameter2.HasDefaultValue,
parameter2.ServiceKey)) ?? parameter2.DefaultValue,
((parameter3.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter3.ArgumentIndex]
: GetService(
serviceProvider,
parameter3.ParameterType,
declaringType,
parameter3.HasDefaultValue)) ?? parameter3.DefaultValue);
parameter3.HasDefaultValue,
parameter3.ServiceKey)) ?? parameter3.DefaultValue);
}
case 4:
{
Expand All @@ -975,31 +982,35 @@ private static object ReflectionFactoryCanonicalFixed(
serviceProvider,
parameter1.ParameterType,
declaringType,
parameter1.HasDefaultValue)) ?? parameter1.DefaultValue,
parameter1.HasDefaultValue,
parameter1.ServiceKey)) ?? parameter1.DefaultValue,
((parameter2.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter2.ArgumentIndex]
: GetService(
serviceProvider,
parameter2.ParameterType,
declaringType,
parameter2.HasDefaultValue)) ?? parameter2.DefaultValue,
parameter2.HasDefaultValue,
parameter2.ServiceKey)) ?? parameter2.DefaultValue,
((parameter3.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter3.ArgumentIndex]
: GetService(
serviceProvider,
parameter3.ParameterType,
declaringType,
parameter3.HasDefaultValue)) ?? parameter3.DefaultValue,
parameter3.HasDefaultValue,
parameter3.ServiceKey)) ?? parameter3.DefaultValue,
((parameter4.ArgumentIndex != -1)
// Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
? arguments![parameter4.ArgumentIndex]
: GetService(
serviceProvider,
parameter4.ParameterType,
declaringType,
parameter4.HasDefaultValue)) ?? parameter4.DefaultValue);
parameter4.HasDefaultValue,
parameter4.ServiceKey)) ?? parameter4.DefaultValue);
}

}
Expand Down Expand Up @@ -1028,7 +1039,8 @@ private static object ReflectionFactoryCanonicalSpan(
serviceProvider,
parameter.ParameterType,
declaringType,
parameter.HasDefaultValue)) ?? parameter.DefaultValue;
parameter.HasDefaultValue,
parameter.ServiceKey)) ?? parameter.DefaultValue;
}

return invoker.Invoke(constructorArguments.AsSpan());
Expand Down Expand Up @@ -1078,7 +1090,8 @@ private static object ReflectionFactoryCanonical(
serviceProvider,
parameter.ParameterType,
declaringType,
parameter.HasDefaultValue)) ?? parameter.DefaultValue;
parameter.HasDefaultValue,
parameter.ServiceKey)) ?? parameter.DefaultValue;
}

return constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, constructorArguments, culture: null);
Expand All @@ -1099,5 +1112,17 @@ public static void ClearCache(Type[]? _)
}
}
#endif

private static object? GetKeyedService(IServiceProvider provider, Type type, object? serviceKey)
{
ThrowHelper.ThrowIfNull(provider);

if (provider is IKeyedServiceProvider keyedServiceProvider)
{
return keyedServiceProvider.GetKeyedService(type, serviceKey);
}

throw new InvalidOperationException(SR.KeyedServicesNotSupported);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -476,5 +476,41 @@ public ServiceProviderAccessor(IServiceProvider serviceProvider)

public IServiceProvider ServiceProvider { get; }
}

[Fact]
public void SimpleServiceKeyedResolution()
{
// Arrange
var services = new ServiceCollection();
services.AddKeyedTransient<ISimpleService, SimpleService>("simple");
services.AddKeyedTransient<ISimpleService, AnotherSimpleService>("another");
services.AddTransient<SimpleParentWithDynamicKeyedService>();
var provider = CreateServiceProvider(services);
var sut = provider.GetService<SimpleParentWithDynamicKeyedService>();

// Act
var result = sut!.GetService("simple");

// Assert
Assert.True(result.GetType() == typeof(SimpleService));
}

public class SimpleParentWithDynamicKeyedService
{
private readonly IServiceProvider _serviceProvider;

public SimpleParentWithDynamicKeyedService(IServiceProvider serviceProvider)
{
_serviceProvider = serviceProvider;
}

public ISimpleService GetService(string name) => _serviceProvider.GetKeyedService<ISimpleService>(name)!;
}

public interface ISimpleService { }

public class SimpleService : ISimpleService { }

public class AnotherSimpleService : ISimpleService { }
}
}
Loading

0 comments on commit c7252a3

Please sign in to comment.