From dba1477f34b5ffeeea1549893a2e85131aa555b3 Mon Sep 17 00:00:00 2001 From: Bruno Cardoso Lopes Date: Tue, 22 Dec 2020 17:11:15 -0800 Subject: [PATCH] [PatternMatching] Fix pattern guards for pat id elements in structured bindings Decouple binding creation from running sema for the structured bindings by acting early on the pattern list. Since we now create the variables before parsing the condition, their name is scope, available to be used from the guard condition. --- clang/include/clang/Sema/Sema.h | 5 ++++- clang/lib/Parse/ParseStmt.cpp | 9 ++++++-- clang/lib/Sema/SemaStmt.cpp | 28 +++++++++++++++++------- clang/test/AST/ast-dump-inspect-stmt.cpp | 23 +++++++++++++------ clang/test/SemaCXX/inspect.cpp | 1 + 5 files changed, 48 insertions(+), 18 deletions(-) diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index ed7daeb8398847..e5fc7d8c253d2f 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -4248,7 +4248,10 @@ class Sema final { StmtResult ActOnStructuredBindingPattern( SourceLocation ColonLoc, SourceLocation LLoc, SourceLocation RLoc, SmallVectorImpl &PatList, Stmt *SubStmt, - Expr *Guard, bool ExcludedFromTypeDeduction); + Expr *Guard, Stmt *DecompStmt, bool ExcludedFromTypeDeduction); + StmtResult + ActOnPatternList(SmallVectorImpl &PatList, + SourceLocation LLoc); ExprResult CheckPatternConstantExpr(Expr *MatchExpr, SourceLocation MatchExprLoc); diff --git a/clang/lib/Parse/ParseStmt.cpp b/clang/lib/Parse/ParseStmt.cpp index e1c2aaa299cb8e..0c2c22111b8ae1 100644 --- a/clang/lib/Parse/ParseStmt.cpp +++ b/clang/lib/Parse/ParseStmt.cpp @@ -875,7 +875,12 @@ StmtResult Parser::ParseStructuralBindingPattern(ParsedStmtContext StmtCtx) { SourceLocation IfLoc; ParseScope PatternScope(this, Scope::PatternScope | Scope::DeclScope, true); - + StmtResult DecompDS; + if (ValidPatList) { + DecompDS = Actions.ActOnPatternList(PatList, LSquare); + if (DecompDS.isInvalid()) + ValidPatList = false; + } // FIXME: retrieve constexpr information from InspectExpr if (Tok.is(tok::kw_if)) if (!ParsePatternGuard(Cond, IfLoc, false /*IsConstexprIf*/)) @@ -908,7 +913,7 @@ StmtResult Parser::ParseStructuralBindingPattern(ParsedStmtContext StmtCtx) { if (ValidPatList) Res = Actions.ActOnStructuredBindingPattern( ArrowLoc, LSquare, RSquare, PatList, nullptr, Cond.get().second, - ExclaimLoc.isValid()); + DecompDS.get(), ExclaimLoc.isValid()); // Parse the statement // diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp index 95437fb7164ec4..1870a806542783 100644 --- a/clang/lib/Sema/SemaStmt.cpp +++ b/clang/lib/Sema/SemaStmt.cpp @@ -681,10 +681,9 @@ StmtResult Sema::ActOnExpressionPattern(SourceLocation MatchExprLoc, return EPS; } -StmtResult Sema::ActOnStructuredBindingPattern( - SourceLocation ColonLoc, SourceLocation LLoc, SourceLocation RLoc, - SmallVectorImpl &PatList, Stmt *SubStmt, - Expr *Guard, bool ExcludedFromTypeDeduction) { +StmtResult +Sema::ActOnPatternList(SmallVectorImpl &PatList, + SourceLocation LLoc) { if (PatList.empty()) { Diag(LLoc, diag::err_empty_stbind_pattern); return StmtError(); @@ -748,7 +747,8 @@ StmtResult Sema::ActOnStructuredBindingPattern( // Deduce the type of the inspect condition. QualType DeducedType = deduceVarTypeFromInitializer( - /*VarDecl*/ DecompCond, DeclarationName(), DeductType, TSI, SourceRange(LLoc), + /*VarDecl*/ DecompCond, DeclarationName(), DeductType, TSI, + SourceRange(LLoc), /*IsDirectInit*/ false, MatchSource); if (DeducedType.isNull()) // deduceVarTypeFromInitializer already emits diags return StmtError(); @@ -777,6 +777,18 @@ StmtResult Sema::ActOnStructuredBindingPattern( DecompCond->getBeginLoc(), DecompCond->getEndLoc()); if (DecompDS.isInvalid()) return StmtError(); + return DecompDS; +} + +StmtResult Sema::ActOnStructuredBindingPattern( + SourceLocation ColonLoc, SourceLocation LLoc, SourceLocation RLoc, + SmallVectorImpl &PatList, Stmt *SubStmt, + Expr *Guard, Stmt *DecompStmt, bool ExcludedFromTypeDeduction) { + auto *DS = static_cast(DecompStmt); + auto *DecompCond = cast(DS->getSingleDecl()); + if (getCurFunction()->InspectStack.empty()) + return StmtError(); + InspectExpr *Inspect = getCurFunction()->InspectStack.back().getPointer(); // Now that we got all bindings populated with the proper type, for each // element in the pattern list try to ==/match() with the equivalent element @@ -792,7 +804,7 @@ StmtResult Sema::ActOnStructuredBindingPattern( case ParsedPatEltAction::Match: { ExprResult M = ActOnMatchBinOp(NewBindings[I]->getBinding(), - cast(PatList[I].Elt), MatchSourceLoc); + cast(PatList[I].Elt), PatList[I].Loc); if (M.isInvalid()) continue; if (!PatCond) { @@ -817,8 +829,8 @@ StmtResult Sema::ActOnStructuredBindingPattern( } auto *SBP = StructuredBindingPatternStmt::Create( - Context, LLoc, ColonLoc, LLoc, RLoc, DecompDS.get(), SubStmt, Guard, - PatCond, ExcludedFromTypeDeduction); + Context, LLoc, ColonLoc, LLoc, RLoc, DecompStmt, SubStmt, Guard, PatCond, + ExcludedFromTypeDeduction); Inspect->addPattern(SBP); return SBP; diff --git a/clang/test/AST/ast-dump-inspect-stmt.cpp b/clang/test/AST/ast-dump-inspect-stmt.cpp index 85c520a4a9d25c..841f67dd724383 100644 --- a/clang/test/AST/ast-dump-inspect-stmt.cpp +++ b/clang/test/AST/ast-dump-inspect-stmt.cpp @@ -162,19 +162,28 @@ void TestInspect(int a, int b) { }; insn_type insn; inspect(insn) { - [o, i] => { o++; }; + [o, i] if (o+i < 12) => { o++; }; }; // CHECK: InspectExpr // CHECK: StructuredBindingPatternStmt - // CHECK: |-CompoundStmt {{.*}} - // CHECK: | `-UnaryOperator {{.*}} 'unsigned int' postfix '++' - // CHECK: | `-DeclRefExpr {{.*}} 'unsigned int' lvalue bitfield Binding {{.*}} 'o' 'unsigned int' - // CHECK: `-DeclStmt + // CHECK: |-CompoundStmt {{.*}} + // CHECK: | `-UnaryOperator {{.*}} 'unsigned int' postfix '++' + // CHECK: | `-DeclRefExpr {{.*}} 'unsigned int' lvalue bitfield Binding {{.*}} 'o' 'unsigned int' + // CHECK: |-DeclStmt // CHECK: `-DecompositionDecl {{.*}} used 'insn_type &' cinit // CHECK: |-BindingDecl {{.*}} col:6 referenced o 'unsigned int' // CHECK: | `-MemberExpr {{.*}} 'unsigned int' lvalue bitfield .opc - // CHECK: `-BindingDecl {{.*}} col:9 i 'unsigned int' - // CHECK: `-MemberExpr {{.*}} 'unsigned int' lvalue bitfield .imm + // CHECK: `-BindingDecl {{.*}} col:9 referenced i 'unsigned int' + // CHECK: `-MemberExpr {{.*}} 'unsigned int' lvalue bitfield .imm + // CHECK: `-BinaryOperator {{.*}} 'bool' '<' + // CHECK: |-BinaryOperator {{.*}} 'int' '+' + // CHECK: | |-ImplicitCastExpr {{.*}} 'int' + // CHECK: | | `-ImplicitCastExpr {{.*}} 'unsigned int' + // CHECK: | | `-DeclRefExpr {{.*}} 'unsigned int' lvalue bitfield Binding {{.*}} 'o' 'unsigned int' + // CHECK: | `-ImplicitCastExpr {{.*}} 'int' + // CHECK: | `-ImplicitCastExpr {{.*}} 'unsigned int' + // CHECK: | `-DeclRefExpr {{.*}} 'unsigned int' lvalue bitfield Binding {{.*}} 'i' 'unsigned int' + // CHECK: `-IntegerLiteral {{.*}} 'int' 12 } using size_t = decltype(sizeof(0)); diff --git a/clang/test/SemaCXX/inspect.cpp b/clang/test/SemaCXX/inspect.cpp index 320d1ef1d91ac1..c2fbd84e8c57d9 100644 --- a/clang/test/SemaCXX/inspect.cpp +++ b/clang/test/SemaCXX/inspect.cpp @@ -199,6 +199,7 @@ void stbind0(int x) { int array[2] = {2,1}; inspect (array) { [1,2] =>; + [id0, id1] if (id0+id1 < 10) =>; }; using FourUInts = unsigned __attribute__((__vector_size__(16)));