Skip to content

Commit

Permalink
Support for ExecuteDelete/Update
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Jan 5, 2023
1 parent c32304b commit 13bf019
Show file tree
Hide file tree
Showing 8 changed files with 853 additions and 743 deletions.
151 changes: 83 additions & 68 deletions src/EFCore.Design/Query/Internal/CSharpToLinqTranslator.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
Expand Down Expand Up @@ -427,13 +428,38 @@ public override Expression VisitInvocationExpression(InvocationExpressionSyntax
originalDefinition = originalDefinition.ReducedFrom;
}

// var originalDefParamTypes = originalDefinition.Parameters.Select(p => ResolveType(p.Type)).ToArray();
// var paramTypes = reducedMethodSymbol.Parameters.Select(p => ResolveType(p.Type)).ToArray();
// To accurately find the right open generic method definition based on the Roslyn symbol, we need to create a mapping between
// generic type parameter names (based on the Roslyn side) and .NET reflection Types representing those type parameters.
// This includes both type parameters immediately on the generic method, as well as type parameters from the method's
// containing type (and recursively, its containing types)
var typeTypeParameterMap = new Dictionary<string, Type>(Foo(methodSymbol.ContainingType));

// TODO: Populate with generic type parameters from the containing type (and nested containing types, methods?), not just method
// below
// TODO: We match Roslyn type parameters by name, not sure that's right
var genericParameterMap = new Dictionary<string, Type>();
IEnumerable<KeyValuePair<string, Type>> Foo(INamedTypeSymbol typeSymbol)
{
// TODO: We match Roslyn type parameters by name, not sure that's right; also for the method's generic type parameters

if (typeSymbol.ContainingType is INamedTypeSymbol containingTypeSymbol)
{
foreach (var kvp in Foo(containingTypeSymbol))
{
yield return kvp;
}
}

var type = ResolveType(typeSymbol);
var genericArguments = type.GetGenericArguments();

Check.DebugAssert(
genericArguments.Length == typeSymbol.TypeParameters.Length,
"genericArguments.Length == typeSymbol.TypeParameters.Length");

foreach (var (typeParamSymbol, typeParamType) in typeSymbol.TypeParameters.Zip(genericArguments))
{
// Check.DebugAssert(typeParamSymbol.Name == typeParamType.Name, "typeParamSymbol.Name == type.Name");

yield return new KeyValuePair<string, Type>(typeParamSymbol.Name, typeParamType);
}
}

var definitionMethodInfos = declaringType.GetMethods()
.Where(m =>
Expand All @@ -445,23 +471,24 @@ public override Expression VisitInvocationExpression(InvocationExpressionSyntax
&& m.GetParameters() is var candidateParams
&& candidateParams.Length == originalDefinition.Parameters.Length)
{
var methodTypeParameterMap = new Dictionary<string, Type>(typeTypeParameterMap);
// Prepare a dictionary that will be used to resolve generic type parameters (ITypeParameterSymbol) to the
// corresponding reflection Type. This is needed to correctly (and recursively) resolve the type of parameters
// below.
genericParameterMap.Clear();
foreach (var (symbol, type) in methodSymbol.TypeParameters.Zip(candidateGenericArguments))
{
if (symbol.Name != type.Name)
{
return false;
}
genericParameterMap[symbol.Name] = type;
methodTypeParameterMap[symbol.Name] = type;
}
for (var i = 0; i < candidateParams.Length; i++)
{
var translatedParamType = ResolveType(originalDefinition.Parameters[i].Type, genericParameterMap);
var translatedParamType = ResolveType(originalDefinition.Parameters[i].Type, methodTypeParameterMap);
if (translatedParamType != candidateParams[i].ParameterType)
{
return false;
Expand Down Expand Up @@ -931,45 +958,6 @@ private Type ResolveType(ITypeSymbol typeSymbol, Dictionary<string, Type>? gener
{
switch (typeSymbol)
{
case ITypeParameterSymbol typeParameterSymbol:
return genericParameterMap?.TryGetValue(typeParameterSymbol.Name, out var type) == true
? type
: throw new InvalidOperationException($"Unknown generic type parameter symbol {typeParameterSymbol}");

case INamedTypeSymbol { IsGenericType: true } genericTypeSymbol:
{
var genericTypeName =
genericTypeSymbol.OriginalDefinition.ToDisplayString(QualifiedTypeNameSymbolDisplayFormat)
+ '`' + genericTypeSymbol.Arity;

var definition = GetClrType(genericTypeName);
var typeArguments = genericTypeSymbol.TypeArguments.Select(a => ResolveType(a, genericParameterMap)).ToArray();
return definition.MakeGenericType(typeArguments);
}

// // Open generic type
// case INamedTypeSymbol { IsGenericType: true } genericTypeSymbol
// // TODO: Hacky... Detect open type, to avoid trying MakeGenericType on it
// when genericTypeSymbol.TypeArguments.Any(a => a is ITypeParameterSymbol):
// {
// var genericTypeName = genericTypeSymbol.ToDisplayString(QualifiedTypeNameSymbolDisplayFormat)
// + '`' + genericTypeSymbol.Arity;
//
// return GetClrType(genericTypeName);
// }
//
// // Closed generic type
// case INamedTypeSymbol { IsGenericType: true } genericTypeSymbol:
// {
// var genericTypeName =
// genericTypeSymbol.OriginalDefinition.ToDisplayString(QualifiedTypeNameSymbolDisplayFormat)
// + '`' + genericTypeSymbol.Arity;
//
// var definition = GetClrType(genericTypeName);
// var typeArguments = genericTypeSymbol.TypeArguments.Select(ResolveType).ToArray();
// return definition.MakeGenericType(typeArguments);
// }

case INamedTypeSymbol { IsAnonymousType: true } anonymousTypeSymbol:
_anonymousTypeDefinitions ??= LoadAnonymousTypes();
var properties = anonymousTypeSymbol.GetMembers().OfType<IPropertySymbol>().ToArray();
Expand All @@ -989,34 +977,61 @@ private Type ResolveType(ITypeSymbol typeSymbol, Dictionary<string, Type>? gener

// TODO: Cache closed anonymous types

return anonymousTypeGenericDefinition!.MakeGenericType(genericTypeArguments);
return anonymousTypeGenericDefinition.MakeGenericType(genericTypeArguments);

case INamedTypeSymbol namedTypeSymbol:
if (typeSymbol.ContainingType is null)
{
goto default;
}
case INamedTypeSymbol { IsDefinition: true } genericTypeSymbol:
return GetClrType(genericTypeSymbol);

var containingType = ResolveType(namedTypeSymbol.ContainingType);
case INamedTypeSymbol { IsGenericType: true } genericTypeSymbol:
{
var definition = GetClrType(genericTypeSymbol.OriginalDefinition);
var typeArguments = genericTypeSymbol.TypeArguments.Select(a => ResolveType(a, genericParameterMap)).ToArray();
return definition.MakeGenericType(typeArguments);
}

var nestedType =
containingType.GetNestedType(namedTypeSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat));
if (nestedType is null)
{
throw new InvalidOperationException(
$"Couldn't find nested type '{namedTypeSymbol.Name}' on containing type '{containingType.Name}'");
}
case ITypeParameterSymbol typeParameterSymbol:
return genericParameterMap?.TryGetValue(typeParameterSymbol.Name, out var type) == true
? type
: throw new InvalidOperationException($"Unknown generic type parameter symbol {typeParameterSymbol}");

return nestedType;
case INamedTypeSymbol namedTypeSymbol:
return GetClrType(namedTypeSymbol);

default:
return GetClrType(typeSymbol.ToDisplayString(QualifiedTypeNameSymbolDisplayFormat));
return GetClrTypeFromAssembly(
typeSymbol.ContainingAssembly,
typeSymbol.ToDisplayString(QualifiedTypeNameSymbolDisplayFormat));
}

Type GetClrType(INamedTypeSymbol symbol)
{
var name = symbol.ContainingType is null
? typeSymbol.ToDisplayString(QualifiedTypeNameSymbolDisplayFormat)
: typeSymbol.Name;

if (symbol.IsGenericType)
{
name += '`' + symbol.Arity.ToString();
}

if (symbol.ContainingType is not null)
{
var containingType = ResolveType(symbol.ContainingType);

return containingType.GetNestedType(name)
?? throw new InvalidOperationException(
$"Couldn't find nested type '{name}' on containing type '{containingType.Name}'");
}

return GetClrTypeFromAssembly(typeSymbol.ContainingAssembly, name);
}

Type GetClrType(string name)
=> typeSymbol.ContainingAssembly is null
? Type.GetType(name)!
: Type.GetType($"{name}, {typeSymbol.ContainingAssembly.Name}")!;
static Type GetClrTypeFromAssembly(IAssemblySymbol? assemblySymbol, string name)
=> (assemblySymbol is null
? Type.GetType(name)!
: Type.GetType($"{name}, {assemblySymbol.Name}"))
?? throw new InvalidOperationException(
$"Couldn't resolve CLR type '{name}' in assembly '{assemblySymbol?.Name}'");

Dictionary<string[], Type> LoadAnonymousTypes()
{
Expand Down
27 changes: 20 additions & 7 deletions src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,25 @@ public async Task GeneratePrecompiledQueries(string projectDir, DbContext contex
Console.Error.WriteLine("Loading project...");
using var workspace = MSBuildWorkspace.Create();

// var project = await workspace.OpenProjectAsync("/home/roji/projects/test/EFTest/EFTest.csproj")
var project = await workspace.OpenProjectAsync(projectDir, cancellationToken: cancellationToken)
.ConfigureAwait(false);

Console.WriteLine("Compiling project...");
var compilation = await project.GetCompilationAsync(cancellationToken)
.ConfigureAwait(false);

var errorDiagnostics = compilation.GetDiagnostics(cancellationToken).Where(d => d.Severity == DiagnosticSeverity.Error).ToArray();
if (errorDiagnostics.Any())
{
Console.Error.WriteLine("Compilation failed with errors:");
Console.Error.WriteLine();
foreach (var diagnostic in errorDiagnostics)
{
Console.WriteLine(diagnostic);
}
Environment.Exit(1);
}

Console.WriteLine($"Compiled assembly {compilation.Assembly.Name}");

// TODO: check reference to EF, bail early if not found?
Expand Down Expand Up @@ -181,6 +193,7 @@ public async Task GeneratePrecompiledQueries(string projectDir, DbContext contex
.Append("System.Collections.Generic")
.Append("Microsoft.EntityFrameworkCore")
.Append("Microsoft.EntityFrameworkCore.ChangeTracking.Internal")
.Append("Microsoft.EntityFrameworkCore.Diagnostics")
.Append("Microsoft.EntityFrameworkCore.Infrastructure")
.Append("Microsoft.EntityFrameworkCore.Infrastructure.Internal")
.Append("Microsoft.EntityFrameworkCore.Metadata")
Expand Down Expand Up @@ -425,13 +438,13 @@ SyntaxNode GenerateExecutorFactory(Expression queryExecutorExpression, HashSet<s
var sqlTreeVariable = "sqlTree" + (++sqlTreeCounter);

if (variableValue is NewExpression newRelationalCommandCacheExpression
&& newRelationalCommandCacheExpression.Arguments.FirstOrDefault(a => a.Type == typeof(SelectExpression)) is
ConstantExpression { Value: SelectExpression selectExpression })
&& newRelationalCommandCacheExpression.Arguments.FirstOrDefault(a => a.Type == typeof(Expression)) is
ConstantExpression { Value: Expression queryExpression })
{
// Render out the SQL tree, preceded by an ExpressionPrinter dump of it in a comment for easier debugging.
// Note that since the SQL tree is a graph (columns reference their SelectExpression's tables), rendering happens
// in multiple statements.
var sqlTreeBlock = _sqlTreeQuoter.Quote(selectExpression, sqlTreeVariable, variableNames);
var sqlTreeBlock = _sqlTreeQuoter.Quote(queryExpression, sqlTreeVariable, variableNames);
var sqlTreeSyntaxStatements =
((BlockSyntax)linqToCSharpTranslator.TranslateStatement(sqlTreeBlock, namespaces)).Statements
.ToArray();
Expand All @@ -440,7 +453,7 @@ SyntaxNode GenerateExecutorFactory(Expression queryExecutorExpression, HashSet<s
stringBuilder
.Clear()
.AppendLine("/*")
.AppendLine(sqlExpressionPrinter.PrintExpression(selectExpression))
.AppendLine(sqlExpressionPrinter.PrintExpression(queryExpression))
.AppendLine("*/")
.ToString()));

Expand All @@ -449,8 +462,8 @@ SyntaxNode GenerateExecutorFactory(Expression queryExecutorExpression, HashSet<s
// We've rendered the SQL tree, assigning it to variable "sqlTree". Update the RelationalCommandCache to point
// to it
variableValue = newRelationalCommandCacheExpression.Update(newRelationalCommandCacheExpression.Arguments
.Select(a => a.Type == typeof(SelectExpression)
? Expression.Parameter(typeof(SelectExpression), sqlTreeVariable)
.Select(a => a.Type == typeof(Expression)
? Expression.Parameter(typeof(Expression), sqlTreeVariable)
: a));
}
else
Expand Down
31 changes: 25 additions & 6 deletions src/EFCore.Design/Query/Internal/QueryLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class QueryLocator : CSharpSyntaxRewriter, IQueryLocator

#pragma warning disable CS8618 // Uninitialized non-nullable fields. We check _compilation to make sure LoadCompilation was invoked.
private ITypeSymbol _genericIQueryableSymbol, _nonGenericIQueryableSymbol, _dbSetSymbol;
private ITypeSymbol _efQueryableExtensionsSymbol, _enumerableSymbol, _queryableSymbol;
private ITypeSymbol _enumerableSymbol, _queryableSymbol, _efQueryableExtensionsSymbol, _efRelationalQueryableExtensionsSymbol;
private ITypeSymbol _cancellationTokenSymbol;
#pragma warning restore CS8618

Expand Down Expand Up @@ -51,10 +51,10 @@ public void LoadCompilation(Compilation compilation)
_nonGenericIQueryableSymbol = GetTypeSymbolOrThrow("System.Linq.IQueryable");
_dbSetSymbol = GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.DbSet`1");

_efQueryableExtensionsSymbol =
GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.EntityFrameworkQueryableExtensions");
_enumerableSymbol = GetTypeSymbolOrThrow("System.Linq.Enumerable");
_queryableSymbol = GetTypeSymbolOrThrow("System.Linq.Queryable");
_efQueryableExtensionsSymbol = GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.EntityFrameworkQueryableExtensions");
_efRelationalQueryableExtensionsSymbol = GetTypeSymbolOrThrow("Microsoft.EntityFrameworkCore.RelationalQueryableExtensions");
_cancellationTokenSymbol = GetTypeSymbolOrThrow("System.Threading.CancellationToken");

_syntaxTreesWithQueryCandidates.Clear();
Expand Down Expand Up @@ -187,23 +187,42 @@ public override SyntaxNode VisitInvocationExpression(InvocationExpressionSyntax
case nameof(EntityFrameworkQueryableExtensions.SingleAsync):
case nameof(EntityFrameworkQueryableExtensions.SingleOrDefaultAsync):
case nameof(EntityFrameworkQueryableExtensions.SumAsync):
{
return IsOnEfQueryableExtensions() && TryRewriteInvocationToSync(out var rewrittenSyncInvocation)
? CheckAndAddQuery(rewrittenSyncInvocation, async: true)
: invocation;
}

case nameof(RelationalQueryableExtensions.ExecuteDelete):
case nameof(RelationalQueryableExtensions.ExecuteUpdate):
return IsOnEfRelationalQueryableExtensions()
? CheckAndAddQuery(invocation, async: false)
: invocation;

case nameof(RelationalQueryableExtensions.ExecuteDeleteAsync):
case nameof(RelationalQueryableExtensions.ExecuteUpdateAsync):
{
return IsOnEfRelationalQueryableExtensions() && TryRewriteInvocationToSync(out var rewrittenSyncInvocation)
? CheckAndAddQuery(rewrittenSyncInvocation, async: true)
: invocation;
}

default:
return base.VisitInvocationExpression(invocation)!;
}

bool IsOnEfQueryableExtensions()
=> IsOnTypeSymbol(_efQueryableExtensionsSymbol);

bool IsOnEnumerable()
=> IsOnTypeSymbol(_enumerableSymbol);

bool IsOnQueryable()
=> IsOnTypeSymbol(_queryableSymbol);

bool IsOnEfQueryableExtensions()
=> IsOnTypeSymbol(_efQueryableExtensionsSymbol);

bool IsOnEfRelationalQueryableExtensions()
=> IsOnTypeSymbol(_efRelationalQueryableExtensionsSymbol);

bool IsOnTypeSymbol(ITypeSymbol typeSymbol)
{
if (GetSymbol(invocation) is not IMethodSymbol methodSymbol)
Expand Down
Loading

0 comments on commit 13bf019

Please sign in to comment.