From 5832181077857697769aa143b6a7aad0fc4ddbe6 Mon Sep 17 00:00:00 2001 From: Sima Nerush <2002ssn@gmail.com> Date: Fri, 28 Jul 2023 10:43:29 -0700 Subject: [PATCH 01/10] Sema --- include/swift/AST/Decl.h | 8 + include/swift/Sema/SyntacticElementTarget.h | 20 +- lib/AST/ASTVerifier.cpp | 22 ++ lib/AST/Decl.cpp | 5 + lib/AST/Stmt.cpp | 3 + lib/Sema/CSApply.cpp | 246 +++++++++++++------- lib/Sema/CSGen.cpp | 136 ++++++++--- test/stmt/pack_iteration.swift | 11 + 8 files changed, 332 insertions(+), 119 deletions(-) create mode 100644 test/stmt/pack_iteration.swift diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index f90c3912f80c4..465efd85535fd 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -5855,6 +5855,7 @@ enum class PropertyWrapperSynthesizedPropertyKind { class VarDecl : public AbstractStorageDecl { friend class NamingPatternRequest; NamedPattern *NamingPattern = nullptr; + GenericEnvironment *OpenedElementEnvironment = nullptr; public: enum class Introducer : uint8_t { @@ -5982,6 +5983,13 @@ class VarDecl : public AbstractStorageDecl { NamedPattern *getNamingPattern() const; void setNamingPattern(NamedPattern *Pat); + GenericEnvironment *getOpenedElementEnvironment() const { + return OpenedElementEnvironment; + } + void setOpenedElementEnvironment(GenericEnvironment *Env) { + OpenedElementEnvironment = Env; + } + /// If this is a VarDecl that does not belong to a CaseLabelItem's pattern, /// return this. Otherwise, this VarDecl must belong to a CaseStmt's /// CaseLabelItem. In that case, return the first case label item of the first diff --git a/include/swift/Sema/SyntacticElementTarget.h b/include/swift/Sema/SyntacticElementTarget.h index 2da4367f6edde..eafc13f824c9f 100644 --- a/include/swift/Sema/SyntacticElementTarget.h +++ b/include/swift/Sema/SyntacticElementTarget.h @@ -22,15 +22,16 @@ #include "swift/AST/Pattern.h" #include "swift/AST/Stmt.h" #include "swift/AST/TypeLoc.h" +#include "swift/Basic/TaggedUnion.h" #include "swift/Sema/ConstraintLocator.h" #include "swift/Sema/ContextualTypeInfo.h" namespace swift { namespace constraints { -/// Describes information about a for-each loop that needs to be tracked -/// within the constraint system. -struct ForEachStmtInfo { +/// Describes information specific to a sequence +/// in a for-each loop. +struct SequenceIterationInfo { /// The type of the sequence. Type sequenceType; @@ -47,6 +48,19 @@ struct ForEachStmtInfo { Expr *nextCall; }; +/// Describes information specific to a pack expansion expression +/// in a for-each loop. +struct PackIterationInfo { + /// The type of the pattern that matches the elements. + Type patternType; +}; + +/// Describes information about a for-each loop that needs to be tracked +/// within the constraint system. +struct ForEachStmtInfo : TaggedUnion { + using TaggedUnion::TaggedUnion; +}; + /// Describes the target to which a constraint system's solution can be /// applied. class SyntacticElementTarget { diff --git a/lib/AST/ASTVerifier.cpp b/lib/AST/ASTVerifier.cpp index 931a69b61197b..dad5742a6ea66 100644 --- a/lib/AST/ASTVerifier.cpp +++ b/lib/AST/ASTVerifier.cpp @@ -795,6 +795,13 @@ class Verifier : public ASTWalker { if (!shouldVerify(cast(S))) return false; + if (auto *expansion = + dyn_cast(S->getParsedSequence())) { + if (!shouldVerify(expansion)) { + return false; + } + } + if (!S->getElementExpr()) return true; @@ -804,6 +811,11 @@ class Verifier : public ASTWalker { } void cleanup(ForEachStmt *S) { + if (auto *expansion = + dyn_cast(S->getParsedSequence())) { + cleanup(expansion); + } + if (!S->getElementExpr()) return; @@ -2605,6 +2617,16 @@ class Verifier : public ASTWalker { abort(); } + // If we are performing pack iteration, variables have to carry the + // generic environment. Catching the missing environment here will prevent + // the code from being lowered. + if (var->getTypeInContext()->is()) { + Out << "VarDecl is missing a Generic Environment: "; + var->getInterfaceType().print(Out); + Out << "\n"; + abort(); + } + // The fact that this is *directly* be a reference storage type // cuts the code down quite a bit in getTypeOfReference. if (var->getAttrs().hasAttribute() != diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 2d14c3f42b6f9..8168baf9e4da6 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -7135,6 +7135,11 @@ VarDecl::VarDecl(DeclKind kind, bool isStatic, VarDecl::Introducer introducer, } Type VarDecl::getTypeInContext() const { + // If we are performing pack iteration, use the generic environment of the + // pack expansion expression to get the right context of a local variable. + if (auto *env = getOpenedElementEnvironment()) + return GenericEnvironment::mapTypeIntoContext(env, getInterfaceType()); + return getDeclContext()->mapTypeIntoContext(getInterfaceType()); } diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index a61483fb8d664..50b647a218f6f 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -446,6 +446,9 @@ void ForEachStmt::setPattern(Pattern *p) { } Expr *ForEachStmt::getTypeCheckedSequence() const { + if (auto *expansion = dyn_cast(getParsedSequence())) + return expansion; + return iteratorVar ? iteratorVar->getInit(/*index=*/0) : nullptr; } diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 64462ed173ddd..0260a8eeee7ab 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -9109,53 +9109,42 @@ applySolutionToInitialization(Solution &solution, SyntacticElementTarget target, return resultTarget; } -/// Apply the given solution to the for-each statement target. -/// -/// \returns the resulting initialization expression. -static llvm::Optional applySolutionToForEachStmt( - Solution &solution, SyntacticElementTarget target, +static llvm::Optional applySolutionToForEachStmt( + Solution &solution, ForEachStmt *stmt, SequenceIterationInfo info, + DeclContext *dc, llvm::function_ref< llvm::Optional(SyntacticElementTarget)> rewriteTarget) { - auto resultTarget = target; - auto &forEachStmtInfo = resultTarget.getForEachStmtInfo(); - auto *stmt = target.getAsForEachStmt(); + auto &cs = solution.getConstraintSystem(); + auto *parsedSequence = stmt->getParsedSequence(); bool isAsync = stmt->getAwaitLoc().isValid(); // Simplify the various types. - forEachStmtInfo.sequenceType = - solution.simplifyType(forEachStmtInfo.sequenceType); - forEachStmtInfo.elementType = - solution.simplifyType(forEachStmtInfo.elementType); - forEachStmtInfo.initType = - solution.simplifyType(forEachStmtInfo.initType); - - auto &cs = solution.getConstraintSystem(); - auto *dc = target.getDeclContext(); + info.sequenceType = solution.simplifyType(info.sequenceType); + info.elementType = solution.simplifyType(info.elementType); + info.initType = solution.simplifyType(info.initType); - // First, let's apply the solution to the sequence expression. - { - auto *makeIteratorVar = forEachStmtInfo.makeIteratorVar; - - auto makeIteratorTarget = *cs.getTargetFor({makeIteratorVar, /*index=*/0}); + // First, let's apply the solution to the expression. + auto *makeIteratorVar = info.makeIteratorVar; - auto rewrittenTarget = rewriteTarget(makeIteratorTarget); - if (!rewrittenTarget) - return llvm::None; + auto makeIteratorTarget = *cs.getTargetFor({makeIteratorVar, /*index=*/0}); - // Set type-checked initializer and mark it as such. - { - makeIteratorVar->setInit(/*index=*/0, rewrittenTarget->getAsExpr()); - makeIteratorVar->setInitializerChecked(/*index=*/0); - } + auto rewrittenTarget = rewriteTarget(makeIteratorTarget); + if (!rewrittenTarget) + return llvm::None; - stmt->setIteratorVar(makeIteratorVar); + // Set type-checked initializer and mark it as such. + { + makeIteratorVar->setInit(/*index=*/0, rewrittenTarget->getAsExpr()); + makeIteratorVar->setInitializerChecked(/*index=*/0); } + stmt->setIteratorVar(makeIteratorVar); + // Now, `$iterator.next()` call. { - auto nextTarget = *cs.getTargetFor(forEachStmtInfo.nextCall); + auto nextTarget = *cs.getTargetFor(info.nextCall); auto rewrittenTarget = rewriteTarget(nextTarget); if (!rewrittenTarget) @@ -9189,7 +9178,7 @@ static llvm::Optional applySolutionToForEachStmt( ShouldStop = true; auto nextRefType = - S.getResolvedType(call->getFn())->castTo(); + S.getResolvedType(call->getFn())->castTo(); // If the inferred witness is throwing, we need to wrap the call // into `try` expression. @@ -9212,26 +9201,43 @@ static llvm::Optional applySolutionToForEachStmt( stmt->setNextCall(nextCall); } - // Coerce the pattern to the element type. - { - TypeResolutionOptions options(TypeResolverContext::ForEachStmt); - options |= TypeResolutionFlags::OverrideType; - - auto tryRewritePattern = [&](Pattern *EP, Type ty) { - return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget); - }; - - // Apply the solution to the pattern as well. - auto contextualPattern = target.getContextualPattern(); - auto coercedPattern = TypeChecker::coercePatternToType( - contextualPattern, forEachStmtInfo.initType, options, - tryRewritePattern); - if (!coercedPattern) + // Convert that llvm::Optional value to the type of the pattern. + auto optPatternType = OptionalType::get(info.initType); + Type nextResultType = OptionalType::get(info.elementType); + if (!optPatternType->isEqual(nextResultType)) { + ASTContext &ctx = cs.getASTContext(); + OpaqueValueExpr *elementExpr = new (ctx) OpaqueValueExpr( + stmt->getInLoc(), nextResultType->getOptionalObjectType(), + /*isPlaceholder=*/true); + Expr *convertElementExpr = elementExpr; + if (TypeChecker::typeCheckExpression(convertElementExpr, dc, + /*contextualInfo=*/ + {info.initType, CTP_CoerceOperand}) + .isNull()) { return llvm::None; + } + elementExpr->setIsPlaceholder(false); + stmt->setElementExpr(elementExpr); + stmt->setConvertElementExpr(convertElementExpr); + } - stmt->setPattern(coercedPattern); - resultTarget.setPattern(coercedPattern); + // Get the conformance of the sequence type to the Sequence protocol. + auto sequenceProto = TypeChecker::getProtocol( + cs.getASTContext(), stmt->getForLoc(), + stmt->getAwaitLoc().isValid() ? KnownProtocolKind::AsyncSequence + : KnownProtocolKind::Sequence); + + auto type = info.sequenceType->getRValueType(); + if (type->isExistentialType()) { + auto *contextualLoc = solution.getConstraintLocator( + parsedSequence, LocatorPathElt::ContextualType(CTP_ForEachSequence)); + type = Type(solution.OpenedExistentialTypes[contextualLoc]); } + auto sequenceConformance = TypeChecker::conformsToProtocol( + type, sequenceProto, dc->getParentModule()); + assert(!sequenceConformance.isInvalid() && + "Couldn't find sequence conformance"); + stmt->setSequenceConformance(sequenceConformance); // Apply the solution to the filtering condition, if there is one. if (auto *whereExpr = stmt->getWhere()) { @@ -9244,44 +9250,118 @@ static llvm::Optional applySolutionToForEachStmt( stmt->setWhere(rewrittenTarget->getAsExpr()); } - // Convert that llvm::Optional value to the type of the pattern. - auto optPatternType = OptionalType::get(forEachStmtInfo.initType); - Type nextResultType = OptionalType::get(forEachStmtInfo.elementType); - if (!optPatternType->isEqual(nextResultType)) { - ASTContext &ctx = cs.getASTContext(); - OpaqueValueExpr *elementExpr = new (ctx) OpaqueValueExpr( - stmt->getInLoc(), nextResultType->getOptionalObjectType(), - /*isPlaceholder=*/true); - Expr *convertElementExpr = elementExpr; - if (TypeChecker::typeCheckExpression( - convertElementExpr, dc, - /*contextualInfo=*/{forEachStmtInfo.initType, CTP_CoerceOperand}) - .isNull()) { + return info; +} + +static llvm::Optional applySolutionToForEachStmt( + Solution &solution, ForEachStmt *stmt, PackIterationInfo info, + llvm::function_ref< + llvm::Optional(SyntacticElementTarget)> + rewriteTarget) { + + // A special walker to record opened element environment for var decls in a + // for-each loop. + class Walker : public ASTWalker { + GenericEnvironment *Environment; + + public: + Walker(GenericEnvironment *Environment) { this->Environment = Environment; } + + PreWalkResult walkToStmtPre(Stmt *S) override { + if (isa(S)) { + return Action::SkipChildren(S); + } + return Action::Continue(S); + } + + PreWalkAction walkToDeclPre(Decl *D) override { + if (auto *decl = dyn_cast(D)) { + decl->setOpenedElementEnvironment(Environment); + } + if (isa(D)) { + return Action::SkipChildren(); + } + if (isa(D)) { + return Action::SkipChildren(); + } + return Action::Continue(); + } + }; + + auto &cs = solution.getConstraintSystem(); + auto *sequenceExpr = stmt->getParsedSequence(); + PackExpansionExpr *expansion = cast(sequenceExpr); + + // First, let's apply the solution to the pack expansion. + auto makeExpansionTarget = *cs.getTargetFor(expansion); + auto rewrittenTarget = rewriteTarget(makeExpansionTarget); + if (!rewrittenTarget) + return llvm::None; + + // Simplify the pattern type of the pack expansion. + info.patternType = solution.simplifyType(info.patternType); + + // Record the opened element environment for the VarDecls inside the loop + Walker forEachWalker(expansion->getGenericEnvironment()); + stmt->getPattern()->walk(forEachWalker); + stmt->getBody()->walk(forEachWalker); + + return info; +} + +/// Apply the given solution to the for-each statement target. +/// +/// \returns the resulting initialization expression. +static llvm::Optional applySolutionToForEachStmt( + Solution &solution, SyntacticElementTarget target, + llvm::function_ref< + llvm::Optional(SyntacticElementTarget)> + rewriteTarget) { + auto resultTarget = target; + auto &forEachStmtInfo = resultTarget.getForEachStmtInfo(); + auto *stmt = target.getAsForEachStmt(); + + Type rewrittenPatternType; + + if (auto *info = forEachStmtInfo.dyn_cast()) { + auto resultInfo = applySolutionToForEachStmt( + solution, stmt, *info, target.getDeclContext(), rewriteTarget); + if (!resultInfo) { return llvm::None; } - elementExpr->setIsPlaceholder(false); - stmt->setElementExpr(elementExpr); - stmt->setConvertElementExpr(convertElementExpr); + + forEachStmtInfo = *resultInfo; + rewrittenPatternType = resultInfo->initType; + } else { + auto resultInfo = applySolutionToForEachStmt( + solution, stmt, forEachStmtInfo.get(), + rewriteTarget); + if (!resultInfo) { + return llvm::None; + } + + forEachStmtInfo = *resultInfo; + rewrittenPatternType = resultInfo->patternType; } - // Get the conformance of the sequence type to the Sequence protocol. + // Coerce the pattern to the element type. { - auto sequenceProto = TypeChecker::getProtocol( - cs.getASTContext(), stmt->getForLoc(), - stmt->getAwaitLoc().isValid() ? KnownProtocolKind::AsyncSequence - : KnownProtocolKind::Sequence); - - auto type = forEachStmtInfo.sequenceType->getRValueType(); - if (type->isExistentialType()) { - auto *contextualLoc = solution.getConstraintLocator( - parsedSequence, LocatorPathElt::ContextualType(CTP_ForEachSequence)); - type = Type(solution.OpenedExistentialTypes[contextualLoc]); - } - auto sequenceConformance = TypeChecker::conformsToProtocol( - type, sequenceProto, dc->getParentModule()); - assert(!sequenceConformance.isInvalid() && - "Couldn't find sequence conformance"); - stmt->setSequenceConformance(sequenceConformance); + TypeResolutionOptions options(TypeResolverContext::ForEachStmt); + options |= TypeResolutionFlags::OverrideType; + + auto tryRewritePattern = [&](Pattern *EP, Type ty) { + return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget); + }; + + // Apply the solution to the pattern as well. + auto contextualPattern = target.getContextualPattern(); + auto coercedPattern = TypeChecker::coercePatternToType( + contextualPattern, rewrittenPatternType, options, tryRewritePattern); + if (!coercedPattern) + return llvm::None; + + stmt->setPattern(coercedPattern); + resultTarget.setPattern(coercedPattern); } return resultTarget; diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 19ecfc5f81d90..3da7d149a5868 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -4438,17 +4438,51 @@ static bool generateInitPatternConstraints(ConstraintSystem &cs, return false; } -static llvm::Optional -generateForEachStmtConstraints(ConstraintSystem &cs, - SyntacticElementTarget target) { +/// Generate constraints for a for-in statement preamble, expecting a +/// `PackExpansionExpr`. +static llvm::Optional +generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc, + PackExpansionExpr *expansion, Type patternType) { + auto packIterationInfo = PackIterationInfo(); + auto elementLocator = cs.getConstraintLocator( + expansion, ConstraintLocator::SequenceElementType); + + { + SyntacticElementTarget target(expansion, dc, CTP_Unused, + /*contextualType=*/Type(), + /*isDiscarded=*/false); + + if (cs.generateConstraints(target)) + return llvm::None; + + cs.setTargetFor(expansion, target); + } + + auto elementType = cs.getType(expansion->getPatternExpr()); + + cs.addConstraint(ConstraintKind::Conversion, elementType, patternType, + elementLocator); + + packIterationInfo.patternType = patternType; + return packIterationInfo; +} + +/// Generate constraints for a for-in statement preamble, expecting an +/// expression that conforms to `Swift.Sequence`. +static llvm::Optional +generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc, + ForEachStmt *stmt, Pattern *typeCheckedPattern, + bool shouldBindPatternVarsOneWay, + bool ignoreForEachWhereClause) { ASTContext &ctx = cs.getASTContext(); - auto forEachStmtInfo = target.getForEachStmtInfo(); - ForEachStmt *stmt = target.getAsForEachStmt(); bool isAsync = stmt->getAwaitLoc().isValid(); auto *sequenceExpr = stmt->getParsedSequence(); - auto *dc = target.getDeclContext(); auto contextualLocator = cs.getConstraintLocator( sequenceExpr, LocatorPathElt::ContextualType(CTP_ForEachSequence)); + auto elementLocator = cs.getConstraintLocator( + sequenceExpr, ConstraintLocator::SequenceElementType); + + auto sequenceIterationInfo = SequenceIterationInfo(); // The expression type must conform to the Sequence protocol. auto sequenceProto = TypeChecker::getProtocol( @@ -4481,7 +4515,6 @@ generateForEachStmtConstraints(ConstraintSystem &cs, auto *makeIteratorCall = CallExpr::createImplicitEmpty(ctx, makeIteratorRef); - Pattern *pattern = NamedPattern::createImplicit(ctx, makeIteratorVar); auto *PB = PatternBindingDecl::createImplicit( ctx, StaticSpellingKind::None, pattern, makeIteratorCall, dc); @@ -4497,8 +4530,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs, if (cs.generateConstraints(makeIteratorTarget)) return llvm::None; - forEachStmtInfo.makeIteratorVar = PB; - + sequenceIterationInfo.makeIteratorVar = PB; // Type of sequence expression has to conform to Sequence protocol. // // Note that the following emulates having `$generator` separately @@ -4517,7 +4549,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs, sequenceProto->getDeclaredInterfaceType(), contextualLocator); - forEachStmtInfo.sequenceType = cs.getType(sequenceExpr); + sequenceIterationInfo.sequenceType = cs.getType(sequenceExpr); } cs.setTargetFor({PB, /*index=*/0}, makeIteratorTarget); @@ -4561,27 +4593,14 @@ generateForEachStmtConstraints(ConstraintSystem &cs, if (cs.generateConstraints(nextTarget, FreeTypeVariableBinding::Disallow)) return llvm::None; - forEachStmtInfo.nextCall = nextTarget.getAsExpr(); - cs.setTargetFor(forEachStmtInfo.nextCall, nextTarget); + sequenceIterationInfo.nextCall = nextTarget.getAsExpr(); + cs.setTargetFor(sequenceIterationInfo.nextCall, nextTarget); } - Pattern *pattern = TypeChecker::resolvePattern(stmt->getPattern(), dc, - /*isStmtCondition*/ false); - if (!pattern) - return llvm::None; - - auto contextualPattern = ContextualPattern::forRawPattern(pattern, dc); - Type patternType = TypeChecker::typeCheckPattern(contextualPattern); - if (patternType->hasError()) { - return llvm::None; - } - - // Collect constraints from the element pattern. - auto elementLocator = cs.getConstraintLocator( - sequenceExpr, ConstraintLocator::SequenceElementType); + // Generate constraints for the pattern Type initType = - cs.generateConstraints(pattern, elementLocator, - target.shouldBindPatternVarsOneWay(), nullptr, 0); + cs.generateConstraints(typeCheckedPattern, elementLocator, + shouldBindPatternVarsOneWay, nullptr, 0); if (!initType) return llvm::None; @@ -4592,7 +4611,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs, auto elementType = cs.createTypeVariable(elementTypeLoc, /*flags=*/0); { - auto nextType = cs.getType(forEachStmtInfo.nextCall); + auto nextType = cs.getType(sequenceIterationInfo.nextCall); cs.addConstraint(ConstraintKind::OptionalObject, nextType, elementType, elementTypeLoc); cs.addConstraint(ConstraintKind::Conversion, elementType, initType, @@ -4601,7 +4620,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs, // Generate constraints for the "where" expression, if there is one. auto *whereExpr = stmt->getWhere(); - if (whereExpr && !target.ignoreForEachWhereClause()) { + if (whereExpr && !ignoreForEachWhereClause) { Type boolType = dc->getASTContext().getBoolType(); if (!boolType) return llvm::None; @@ -4618,10 +4637,61 @@ generateForEachStmtConstraints(ConstraintSystem &cs, } // Populate all of the information for a for-each loop. - forEachStmtInfo.elementType = elementType; - forEachStmtInfo.initType = initType; + sequenceIterationInfo.elementType = elementType; + sequenceIterationInfo.initType = initType; + + return sequenceIterationInfo; +} + +static llvm::Optional +generateForEachStmtConstraints(ConstraintSystem &cs, + SyntacticElementTarget target) { + ForEachStmt *stmt = target.getAsForEachStmt(); + auto *sequenceExpr = stmt->getParsedSequence(); + auto *dc = target.getDeclContext(); + + auto elementLocator = cs.getConstraintLocator( + sequenceExpr, ConstraintLocator::SequenceElementType); + + Pattern *pattern = TypeChecker::resolvePattern(stmt->getPattern(), dc, + /*isStmtCondition*/ false); + if (!pattern) + return llvm::None; target.setPattern(pattern); - target.getForEachStmtInfo() = forEachStmtInfo; + + auto contextualPattern = ContextualPattern::forRawPattern(pattern, dc); + + if (TypeChecker::typeCheckPattern(contextualPattern)->hasError()) { + return llvm::None; + } + + if (isa(sequenceExpr)) { + auto *expansion = cast(sequenceExpr); + + // Generate constraints for the pattern + Type patternType = cs.generateConstraints( + pattern, elementLocator, target.shouldBindPatternVarsOneWay(), nullptr, + 0); + if (!patternType) + return llvm::None; + + auto packIterationInfo = + generateForEachStmtConstraints(cs, dc, expansion, patternType); + if (!packIterationInfo) { + return llvm::None; + } + + target.getForEachStmtInfo() = *packIterationInfo; + } else { + auto sequenceIterationInfo = generateForEachStmtConstraints( + cs, dc, stmt, pattern, target.shouldBindPatternVarsOneWay(), + target.ignoreForEachWhereClause()); + if (!sequenceIterationInfo) { + return llvm::None; + } + + target.getForEachStmtInfo() = *sequenceIterationInfo; + } return target; } diff --git a/test/stmt/pack_iteration.swift b/test/stmt/pack_iteration.swift new file mode 100644 index 0000000000000..7591b1abb2cad --- /dev/null +++ b/test/stmt/pack_iteration.swift @@ -0,0 +1,11 @@ +// RUN: %target-typecheck-verify-swift + +// Test the AST Verifier assertion for the VarDecl getting its Generic +// Environment. +func variadic(ts: repeat each T) { + for t in repeat each ts { + func inner() {} + let y = t + // expected-warning@-1{{initialization of immutable value 'y' was never used; consider replacing with assignment to '_' or removing it}} + } +} From 42a06668dff59272887cef382e1d9908165113f0 Mon Sep 17 00:00:00 2001 From: Sima Nerush <2002ssn@gmail.com> Date: Sat, 2 Dec 2023 15:26:54 -0800 Subject: [PATCH 02/10] CSGen: Rename `sequenceExpr` --- lib/Sema/CSGen.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 3da7d149a5868..48a9aec8c8e25 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -4647,11 +4647,11 @@ static llvm::Optional generateForEachStmtConstraints(ConstraintSystem &cs, SyntacticElementTarget target) { ForEachStmt *stmt = target.getAsForEachStmt(); - auto *sequenceExpr = stmt->getParsedSequence(); + auto *forEachExpr = stmt->getParsedSequence(); auto *dc = target.getDeclContext(); auto elementLocator = cs.getConstraintLocator( - sequenceExpr, ConstraintLocator::SequenceElementType); + forEachExpr, ConstraintLocator::SequenceElementType); Pattern *pattern = TypeChecker::resolvePattern(stmt->getPattern(), dc, /*isStmtCondition*/ false); @@ -4665,8 +4665,8 @@ generateForEachStmtConstraints(ConstraintSystem &cs, return llvm::None; } - if (isa(sequenceExpr)) { - auto *expansion = cast(sequenceExpr); + if (isa(forEachExpr)) { + auto *expansion = cast(forEachExpr); // Generate constraints for the pattern Type patternType = cs.generateConstraints( From 6ab831043d98e19fffba86273566788692d2d338 Mon Sep 17 00:00:00 2001 From: Sima Nerush <2002ssn@gmail.com> Date: Tue, 10 Oct 2023 11:42:34 -0700 Subject: [PATCH 03/10] SIL --- lib/SILGen/SILGenStmt.cpp | 32 ++++++++++-- test/SILGen/pack_iteration.swift | 85 ++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 3 deletions(-) create mode 100644 test/SILGen/pack_iteration.swift diff --git a/lib/SILGen/SILGenStmt.cpp b/lib/SILGen/SILGenStmt.cpp index 24cbec3c29015..bb6948fed028f 100644 --- a/lib/SILGen/SILGenStmt.cpp +++ b/lib/SILGen/SILGenStmt.cpp @@ -1248,11 +1248,37 @@ void StmtEmitter::visitRepeatWhileStmt(RepeatWhileStmt *S) { void StmtEmitter::visitForEachStmt(ForEachStmt *S) { // Emit the 'iterator' variable that we'll be using for iteration. LexicalScope OuterForScope(SGF, CleanupLocation(S)); - { - SGF.emitPatternBinding(S->getIteratorVar(), - /*index=*/0, /*debuginfo*/ true); + + if (auto *expansion = + dyn_cast(S->getTypeCheckedSequence())) { + auto formalPackType = dyn_cast( + PackType::get(SGF.getASTContext(), expansion->getType()) + ->getCanonicalType()); + + // Create a new basic block and jump into it. + JumpDest loopDest = createJumpDest(S->getBody()); + SGF.B.emitBlock(loopDest.getBlock(), S); + + SGF.emitDynamicPackLoop( + SILLocation(expansion), formalPackType, 0, + expansion->getGenericEnvironment(), + [&](SILValue indexWithinComponent, SILValue packExpansionIndex, + SILValue packIndex) { + Scope innerForScope(SGF.Cleanups, CleanupLocation(S->getBody())); + auto letValueInit = + SGF.emitPatternBindingInitialization(S->getPattern(), loopDest); + + SGF.emitExprInto(expansion->getPatternExpr(), letValueInit.get()); + visit(S->getBody()); + return; + }); + + return; } + SGF.emitPatternBinding(S->getIteratorVar(), + /*index=*/0, /*debuginfo*/ true); + // If we ever reach an unreachable point, stop emitting statements. // This will need revision if we ever add goto. if (!SGF.B.hasValidInsertionPoint()) return; diff --git a/test/SILGen/pack_iteration.swift b/test/SILGen/pack_iteration.swift new file mode 100644 index 0000000000000..7f7c44938e725 --- /dev/null +++ b/test/SILGen/pack_iteration.swift @@ -0,0 +1,85 @@ + +// RUN: %target-swift-emit-silgen -module-name pack_iteration %s | %FileCheck %s + +////////////////// +// Declarations // +////////////////// +@_silgen_name("loopBodyEnd") +func loopBodyEnd() -> () + +@_silgen_name("funcEnd") +func funcEnd() -> () + +enum E { + case one(T) + case two +} + +////////////// +// Tests // +/////////// + +// CHECK-LABEL: sil hidden [ossa] @$s14pack_iteration14iterateTrivial4overyxxQp_tRvzlF : $@convention(thin) (@pack_guaranteed Pack{repeat each Element}) -> () { +// CHECK: bb0([[PACK:%.*]] : $*Pack{repeat each Element}): +// CHECK: [[IDX1:%.*]] = integer_literal $Builtin.Word, 0 +// CHECK: [[IDX2:%.*]] = integer_literal $Builtin.Word, 1 +// CHECK: [[PACK_LENGTH:%.*]] = pack_length $Pack{repeat each Element} +// CHECK: br [[LOOP_DEST:bb[0-9]+]]([[IDX1]] : $Builtin.Word) +// +// CHECK: [[LOOP_DEST]]([[IDX3:%.*]] : $Builtin.Word): +// CHECK: [[COND:%.*]] = builtin "cmp_eq_Word"([[IDX3]] : $Builtin.Word, [[PACK_LENGTH]] : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br [[COND]], [[NONE_BB:bb[0-9]+]], [[SOME_BB:bb[0-9]+]] +// +// CHECK: [[SOME_BB]]: +// CHECK: [[DYN_PACK_IDX:%.*]] = dynamic_pack_index [[IDX3]] of $Pack{repeat each Element} +// CHECK: open_pack_element [[DYN_PACK_IDX]] of at , shape $each Element, uuid "[[UUID:.*]]" +// CHECK: [[STACK:%.*]] = alloc_stack [lexical] $@pack_element("[[UUID]]") each Element, let, name "element" +// CHECK: [[PACK_ELT_GET:%.*]] = pack_element_get [[DYN_PACK_IDX]] of [[PACK]] : $*Pack{repeat each Element} as $*@pack_element("[[UUID]]") each Element +// CHECK: copy_addr [[PACK_ELT_GET]] to [init] [[STACK]] : $*@pack_element("[[UUID]]") each Element +// CHECK: [[LOOP_END_FUNC:%.*]] = function_ref @loopBodyEnd : $@convention(thin) () -> () +// CHECK: apply [[LOOP_END_FUNC]]() : $@convention(thin) () -> () +// CHECK: destroy_addr [[STACK]] : $*@pack_element("[[UUID]]") each Element +// CHECK: dealloc_stack [[STACK]] : $*@pack_element("[[UUID]]") each Element +// CHECK: [[IDX4:%.*]] = builtin "add_Word"([[IDX3]] : $Builtin.Word, [[IDX2]] : $Builtin.Word) : $Builtin.Word +// CHECK: br [[LOOP_DEST]]([[IDX4]] : $Builtin.Word) +// +// CHECK: [[NONE_BB]]: +// CHECK: [[FUNC_END_FUNC:%.*]] = function_ref @funcEnd : $@convention(thin) () -> () +// CHECK: apply [[FUNC_END_FUNC]]() : $@convention(thin) () -> () +// CHECK: } // end sil function '$s14pack_iteration14iterateTrivial4overyxxQp_tRvzlF' +func iterateTrivial(over elements: repeat each Element) { + for element in repeat each elements { + loopBodyEnd() + } + funcEnd() +} + +// TODO: Write this test +func equalTuples(lhs: (repeat each Element), rhs: (repeat each Element)) -> Bool { + +// %12 = dynamic_pack_index %9 of $Pack{repeat (each Element, each Element)} // users: %19, %17, %14, %13 +// %13 = open_pack_element %12 of at , shape $each Element, uuid "E53D635E-3D89-11EE-82A2-7AABAFDC7DCA" // users: %19, %17, %14 +// %14 = tuple_pack_element_addr %12 of %4 : $*(repeat (each Element, each Element)) as $*(@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element, @pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element) // users: %16, %15 +// %15 = tuple_element_addr %14 : $*(@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element, @pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element), 0 // user: %18 +// %16 = tuple_element_addr %14 : $*(@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element, @pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element), 1 // user: %20 +// %17 = pack_element_get %12 of %0 : $*Pack{repeat each Element} as $*@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element // user: %18 +// copy_addr %17 to [init] %15 : $*@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element // id: %18 +// %19 = pack_element_get %12 of %1 : $*Pack{repeat each Element} as $*@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element // user: %20 +// copy_addr %19 to [init] %16 : $*@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element // id: %20 +// %21 = builtin "add_Word"(%9 : $Builtin.Word, %6 : $Builtin.Word) : $Builtin.Word // user: %22 +// br bb1(%21 : $Builtin.Word) + + for (left, right) in repeat (each lhs, each rhs) { + guard left == right else { return false } + } + return true +} + +// TODO: Write this test +func iteratePatternMatch(over element: repeat E) { + for case .one(let value) in repeat each element { + print(value) + } +} + + From af56beb60eb25835c141bb9f8c67785d280aaab9 Mon Sep 17 00:00:00 2001 From: Sima Nerush <2002ssn@gmail.com> Date: Sun, 3 Sep 2023 20:34:46 -0700 Subject: [PATCH 04/10] Add `for case` support and tests --- lib/SILGen/SILGenFunction.h | 32 +++++------ lib/SILGen/SILGenPack.cpp | 43 ++++++++------ lib/SILGen/SILGenStmt.cpp | 5 +- test/SILGen/pack_iteration.swift | 96 +++++++++++++++++++++++++------- 4 files changed, 119 insertions(+), 57 deletions(-) diff --git a/lib/SILGen/SILGenFunction.h b/lib/SILGen/SILGenFunction.h index 35b57c6587bbe..5ee1649356e84 100644 --- a/lib/SILGen/SILGenFunction.h +++ b/lib/SILGen/SILGenFunction.h @@ -2724,26 +2724,24 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction /// /// This function will be called within a cleanups scope and with /// InnermostPackExpansion set up properly for the context. - void emitDynamicPackLoop(SILLocation loc, - CanPackType formalPackType, - unsigned componentIndex, - SILValue startingAfterIndexWithinComponent, - SILValue limitWithinComponent, - GenericEnvironment *openedElementEnv, - bool reverse, - llvm::function_ref emitBody); + void emitDynamicPackLoop( + SILLocation loc, CanPackType formalPackType, unsigned componentIndex, + SILValue startingAfterIndexWithinComponent, SILValue limitWithinComponent, + GenericEnvironment *openedElementEnv, bool reverse, + llvm::function_ref + emitBody, + SILBasicBlock *loopLatch = nullptr); /// A convenience version of dynamic pack loop that visits an entire /// pack expansion component in forward order. - void emitDynamicPackLoop(SILLocation loc, - CanPackType formalPackType, - unsigned componentIndex, - GenericEnvironment *openedElementEnv, - llvm::function_ref emitBody); + void emitDynamicPackLoop( + SILLocation loc, CanPackType formalPackType, unsigned componentIndex, + GenericEnvironment *openedElementEnv, + llvm::function_ref + emitBody, + SILBasicBlock *loopLatch = nullptr); /// Emit a transform on each element of a pack-expansion component /// of a pack, write the result into a pack-expansion component of diff --git a/lib/SILGen/SILGenPack.cpp b/lib/SILGen/SILGenPack.cpp index 5bcdd172bd6cd..de9efac5facb2 100644 --- a/lib/SILGen/SILGenPack.cpp +++ b/lib/SILGen/SILGenPack.cpp @@ -662,28 +662,27 @@ void SILGenFunction::projectTupleElementsToPack(SILLocation loc, }); } -void SILGenFunction::emitDynamicPackLoop(SILLocation loc, - CanPackType formalPackType, - unsigned componentIndex, - GenericEnvironment *openedElementEnv, - llvm::function_ref emitBody) { +void SILGenFunction::emitDynamicPackLoop( + SILLocation loc, CanPackType formalPackType, unsigned componentIndex, + GenericEnvironment *openedElementEnv, + llvm::function_ref + emitBody, + SILBasicBlock *loopLatch) { return emitDynamicPackLoop(loc, formalPackType, componentIndex, /*startAfter*/ SILValue(), /*limit*/ SILValue(), - openedElementEnv, /*reverse*/false, emitBody); + openedElementEnv, /*reverse*/ false, emitBody, + loopLatch); } -void SILGenFunction::emitDynamicPackLoop(SILLocation loc, - CanPackType formalPackType, - unsigned componentIndex, - SILValue startingAfterIndexInComponent, - SILValue limitWithinComponent, - GenericEnvironment *openedElementEnv, - bool reverse, - llvm::function_ref emitBody) { +void SILGenFunction::emitDynamicPackLoop( + SILLocation loc, CanPackType formalPackType, unsigned componentIndex, + SILValue startingAfterIndexInComponent, SILValue limitWithinComponent, + GenericEnvironment *openedElementEnv, bool reverse, + llvm::function_ref + emitBody, + SILBasicBlock *loopLatch) { assert(isa(formalPackType.getElementType(componentIndex))); assert((!startingAfterIndexInComponent || !reverse) && "cannot reverse with a starting index"); @@ -764,6 +763,7 @@ void SILGenFunction::emitDynamicPackLoop(SILLocation loc, // the incoming index - 1 if reverse) SILValue curIndex = incomingIndex; if (reverse) { + assert(!loopLatch && "Only forward iteration supported with loop latch"); curIndex = B.createBuiltinBinaryFunction(loc, "sub", wordTy, wordTy, { incomingIndex, one }); } @@ -791,6 +791,13 @@ void SILGenFunction::emitDynamicPackLoop(SILLocation loc, { FullExpr scope(Cleanups, CleanupLocation(loc)); emitBody(curIndex, packExpansionIndex, packIndex); + if (loopLatch) { + B.createBranch(loc, loopLatch); + } + } + + if (loopLatch) { + B.emitBlock(loopLatch); } // The index to pass to the loop condition block (the current index + 1 diff --git a/lib/SILGen/SILGenStmt.cpp b/lib/SILGen/SILGenStmt.cpp index bb6948fed028f..f9820e13a4adb 100644 --- a/lib/SILGen/SILGenStmt.cpp +++ b/lib/SILGen/SILGenStmt.cpp @@ -1255,9 +1255,7 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) { PackType::get(SGF.getASTContext(), expansion->getType()) ->getCanonicalType()); - // Create a new basic block and jump into it. JumpDest loopDest = createJumpDest(S->getBody()); - SGF.B.emitBlock(loopDest.getBlock(), S); SGF.emitDynamicPackLoop( SILLocation(expansion), formalPackType, 0, @@ -1271,7 +1269,8 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) { SGF.emitExprInto(expansion->getPatternExpr(), letValueInit.get()); visit(S->getBody()); return; - }); + }, + loopDest.getBlock()); return; } diff --git a/test/SILGen/pack_iteration.swift b/test/SILGen/pack_iteration.swift index 7f7c44938e725..734fd908e9a24 100644 --- a/test/SILGen/pack_iteration.swift +++ b/test/SILGen/pack_iteration.swift @@ -30,6 +30,10 @@ enum E { // CHECK: [[COND:%.*]] = builtin "cmp_eq_Word"([[IDX3]] : $Builtin.Word, [[PACK_LENGTH]] : $Builtin.Word) : $Builtin.Int1 // CHECK: cond_br [[COND]], [[NONE_BB:bb[0-9]+]], [[SOME_BB:bb[0-9]+]] // +// CHECK: [[NONE_BB]]: +// CHECK: [[FUNC_END_FUNC:%.*]] = function_ref @funcEnd : $@convention(thin) () -> () +// CHECK: apply [[FUNC_END_FUNC]]() : $@convention(thin) () -> () +// // CHECK: [[SOME_BB]]: // CHECK: [[DYN_PACK_IDX:%.*]] = dynamic_pack_index [[IDX3]] of $Pack{repeat each Element} // CHECK: open_pack_element [[DYN_PACK_IDX]] of at , shape $each Element, uuid "[[UUID:.*]]" @@ -43,9 +47,6 @@ enum E { // CHECK: [[IDX4:%.*]] = builtin "add_Word"([[IDX3]] : $Builtin.Word, [[IDX2]] : $Builtin.Word) : $Builtin.Word // CHECK: br [[LOOP_DEST]]([[IDX4]] : $Builtin.Word) // -// CHECK: [[NONE_BB]]: -// CHECK: [[FUNC_END_FUNC:%.*]] = function_ref @funcEnd : $@convention(thin) () -> () -// CHECK: apply [[FUNC_END_FUNC]]() : $@convention(thin) () -> () // CHECK: } // end sil function '$s14pack_iteration14iterateTrivial4overyxxQp_tRvzlF' func iterateTrivial(over elements: repeat each Element) { for element in repeat each elements { @@ -54,32 +55,89 @@ func iterateTrivial(over elements: repeat each Element) { funcEnd() } -// TODO: Write this test +// CHECK-LABEL: sil hidden [ossa] @$s14pack_iteration11equalTuples3lhs3rhsSbxxQp_t_xxQp_ttRvzSQRzlF : $@convention(thin) (@pack_guaranteed Pack{repeat each Element}, @pack_guaranteed Pack{repeat each Element}) -> Bool { +// CHECK: bb6: +// CHECK: [[STACK1:%.*]] = alloc_stack $(repeat each Element) +// CHECK: [[STACK2:%.*]] = alloc_stack $(repeat each Element) +// CHECK: [[IDX1:%.*]] = integer_literal $Builtin.Word, 0 +// CHECK: [[IDX2:%.*]] = integer_literal $Builtin.Word, 1 +// CHECK: [[PACK_LENGTH:%.*]] = pack_length $Pack{repeat each Element} +// CHECK: br [[LOOP_DEST:bb[0-9]+]]([[IDX1]] : $Builtin.Word) +// +// CHECK: [[LOOP_DEST]]([[IDX3:%.*]] : $Builtin.Word): +// CHECK: [[COND:%.*]] = builtin "cmp_eq_Word"([[IDX3]] : $Builtin.Word, [[PACK_LENGTH]] : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br [[COND]], [[NONE_BB:bb[0-9]+]], [[SOME_BB:bb[0-9]+]] +// +// CHECK: [[SOME_BB]]: +// CHECK: [[DYN_PACK_IDX:%.*]] = dynamic_pack_index [[IDX3]] of $Pack{repeat (each Element, each Element)} +// CHECK: [[OPEN_PACK_ELT:%.*]] = open_pack_element [[DYN_PACK_IDX]] of at , shape $each Element, uuid "[[UUID:.*]]" +// CHECK: [[STACK_LEFT:%.*]] = alloc_stack [lexical] $@pack_element("[[UUID]]") each Element, let, name "left" +// CHECK: [[STACK_RIGHT:%.*]] = alloc_stack [lexical] $@pack_element("[[UUID]]") each Element, let, name "right" +// CHECK: tuple_pack_element_addr [[DYN_PACK_IDX]] of [[STACK1]] : $*(repeat each Element) as $*@pack_element("[[UUID]]") each Element +// CHECK: tuple_pack_element_addr [[DYN_PACK_IDX]] of [[STACK2]] : $*(repeat each Element) as $*@pack_element("[[UUID]]") each Element +// CHECK: [[METATYPE:%.*]] = metatype $@thick (@pack_element("[[UUID]]") each Element).Type +// CHECK: [[WITNESS_METHOD:%.*]] = witness_method $@pack_element("[[UUID]]") each Element, #Equatable."==" : (Self.Type) -> (Self, Self) -> Bool, [[OPEN_PACK_ELT]] : $Builtin.SILToken : $@convention(witness_method: Equatable) <τ_0_0 where τ_0_0 : Equatable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool +// CHECK: apply [[WITNESS_METHOD]]<@pack_element("[[UUID]]") each Element>([[STACK_LEFT]], [[STACK_RIGHT]], [[METATYPE]]) : $@convention(witness_method: Equatable) <τ_0_0 where τ_0_0 : Equatable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> Bool +// +// CHECK: } // end sil function '$s14pack_iteration11equalTuples3lhs3rhsSbxxQp_t_xxQp_ttRvzSQRzlF' func equalTuples(lhs: (repeat each Element), rhs: (repeat each Element)) -> Bool { - -// %12 = dynamic_pack_index %9 of $Pack{repeat (each Element, each Element)} // users: %19, %17, %14, %13 -// %13 = open_pack_element %12 of at , shape $each Element, uuid "E53D635E-3D89-11EE-82A2-7AABAFDC7DCA" // users: %19, %17, %14 -// %14 = tuple_pack_element_addr %12 of %4 : $*(repeat (each Element, each Element)) as $*(@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element, @pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element) // users: %16, %15 -// %15 = tuple_element_addr %14 : $*(@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element, @pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element), 0 // user: %18 -// %16 = tuple_element_addr %14 : $*(@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element, @pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element), 1 // user: %20 -// %17 = pack_element_get %12 of %0 : $*Pack{repeat each Element} as $*@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element // user: %18 -// copy_addr %17 to [init] %15 : $*@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element // id: %18 -// %19 = pack_element_get %12 of %1 : $*Pack{repeat each Element} as $*@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element // user: %20 -// copy_addr %19 to [init] %16 : $*@pack_element("E53D635E-3D89-11EE-82A2-7AABAFDC7DCA") each Element // id: %20 -// %21 = builtin "add_Word"(%9 : $Builtin.Word, %6 : $Builtin.Word) : $Builtin.Word // user: %22 -// br bb1(%21 : $Builtin.Word) - + for (left, right) in repeat (each lhs, each rhs) { guard left == right else { return false } } + return true } -// TODO: Write this test +// CHECK-LABEL: sil hidden [ossa] @$s14pack_iteration19iteratePatternMatch4overyAA1EOyxGxQp_tRvzlF : $@convention(thin) (@pack_guaranteed Pack{repeat E}) -> () { +// CHECK: bb0([[PACK:%.*]] : $*Pack{repeat E}): +// CHECK: [[IDX1:%.*]] = integer_literal $Builtin.Word, 0 +// CHECK: [[IDX2:%.*]] = integer_literal $Builtin.Word, 1 +// CHECK: [[PACK_LENGTH:%.*]] = pack_length $Pack{repeat each Element} +// CHECK: br [[LOOP_DEST:bb[0-9]+]]([[IDX1]] : $Builtin.Word) +// +// CHECK: [[LOOP_DEST]]([[IDX3:%.*]] : $Builtin.Word): +// CHECK: [[COND:%.*]] = builtin "cmp_eq_Word"([[IDX3]] : $Builtin.Word, [[PACK_LENGTH]] : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br [[COND]], [[NONE_BB:bb[0-9]+]], [[SOME_BB:bb[0-9]+]] +// +// CHECK: [[SOME_BB]]: +// CHECK: [[DYN_PACK_IDX:%.*]] = dynamic_pack_index [[IDX3]] of $Pack{repeat E} +// CHECK: open_pack_element [[DYN_PACK_IDX]] of at , shape $each Element, uuid "[[UUID:.*]]" +// CHECK: [[STACK:%.*]] = alloc_stack [lexical] $@pack_element("[[UUID]]") each Element, let, name "value" +// CHECK: [[PACK_ELT_GET:%.*]] = pack_element_get [[DYN_PACK_IDX]] of [[PACK]] : $*Pack{repeat E} as $*E<@pack_element("[[UUID]]") each Element> +// CHECK: [[ENUM_STACK:%.*]] = alloc_stack $E<@pack_element("[[UUID]]") each Element> +// CHECK: copy_addr [[PACK_ELT_GET]] to [init] [[ENUM_STACK]] : $*E<@pack_element("[[UUID]]") each Element> +// CHECK: switch_enum_addr [[ENUM_STACK]] : $*E<@pack_element("[[UUID]]") each Element>, case #E.one!enumelt: [[ENUM_MATCH_BB:bb[0-9]+]], case #E.two!enumelt: [[CONTINUE_BB:bb[0-9]+]] +// +// CHECK: [[CONTINUE_BB]]: +// CHECK: destroy_addr [[ENUM_STACK]] +// CHECK: dealloc_stack [[ENUM_STACK]] +// CHECK: dealloc_stack [[STACK]] +// CHECK: br [[LATCH_BB:bb[0-9]+]] +// +// CHECK: [[ENUM_MATCH_BB]]: +// CHECK: [[ENUM_DATA_ADDR:%.*]] = unchecked_take_enum_data_addr %13 : $*E<@pack_element("[[UUID]]") each Element>, #E.one!enumelt +// CHECK: copy_addr [take] [[ENUM_DATA_ADDR]] to [init] [[STACK]] +// CHECK: [[LOOP_END_FUNC:%.*]] = function_ref @loopBodyEnd : $@convention(thin) () -> () +// CHECK: apply [[LOOP_END_FUNC]]() : $@convention(thin) () -> () +// CHECK: dealloc_stack [[ENUM_STACK]] +// CHECK: destroy_addr [[STACK]] +// CHECK: dealloc_stack [[STACK]] +// CHECK: br [[LATCH:bb[0-9]+]] +// +// CHECK: [[NONE_BB]]: +// CHECK: [[FUNC_END_FUNC:%.*]] = function_ref @funcEnd : $@convention(thin) () -> () +// CHECK: apply [[FUNC_END_FUNC]]() : $@convention(thin) () -> () +// +// CHECK: [[LATCH_BB]]: +// CHECK: [[ADD_WORD:%.*]] = builtin "add_Word"([[IDX3]] : $Builtin.Word, [[IDX2]] : $Builtin.Word) : $Builtin.Word +// CHECK: br [[LOOP_DEST]]([[ADD_WORD]] : $Builtin.Word) +// CHECK: } // end sil function '$s14pack_iteration19iteratePatternMatch4overyAA1EOyxGxQp_tRvzlF' func iteratePatternMatch(over element: repeat E) { for case .one(let value) in repeat each element { - print(value) + loopBodyEnd() } + funcEnd() } From be212badbd35c79f2dbff98659bbdb54f52dcaa9 Mon Sep 17 00:00:00 2001 From: Sima Nerush <2002ssn@gmail.com> Date: Tue, 5 Sep 2023 10:00:30 -0700 Subject: [PATCH 05/10] Add `break`/`continue` support --- lib/SILGen/SILGenStmt.cpp | 9 ++++ test/SILGen/pack_iteration.swift | 91 ++++++++++++++++++++++++++++++-- 2 files changed, 97 insertions(+), 3 deletions(-) diff --git a/lib/SILGen/SILGenStmt.cpp b/lib/SILGen/SILGenStmt.cpp index f9820e13a4adb..f1a0b1f190ee6 100644 --- a/lib/SILGen/SILGenStmt.cpp +++ b/lib/SILGen/SILGenStmt.cpp @@ -1257,6 +1257,9 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) { JumpDest loopDest = createJumpDest(S->getBody()); + // Set the destinations for 'break' and 'continue'. + JumpDest endDest = createJumpDest(S->getBody()); + SGF.emitDynamicPackLoop( SILLocation(expansion), formalPackType, 0, expansion->getGenericEnvironment(), @@ -1267,11 +1270,17 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) { SGF.emitPatternBindingInitialization(S->getPattern(), loopDest); SGF.emitExprInto(expansion->getPatternExpr(), letValueInit.get()); + + SGF.BreakContinueDestStack.push_back({S, endDest, loopDest}); visit(S->getBody()); + SGF.BreakContinueDestStack.pop_back(); + return; }, loopDest.getBlock()); + emitOrDeleteBlock(SGF, endDest, S); + return; } diff --git a/test/SILGen/pack_iteration.swift b/test/SILGen/pack_iteration.swift index 734fd908e9a24..78d73462804ea 100644 --- a/test/SILGen/pack_iteration.swift +++ b/test/SILGen/pack_iteration.swift @@ -10,6 +10,15 @@ func loopBodyEnd() -> () @_silgen_name("funcEnd") func funcEnd() -> () +@_silgen_name("condition") +func condition() -> Bool + +@_silgen_name("loopContinueEnd") +func loopContinueEnd() -> () + +@_silgen_name("loopBreakEnd") +func loopBreakEnd() -> () + enum E { case one(T) case two @@ -37,7 +46,7 @@ enum E { // CHECK: [[SOME_BB]]: // CHECK: [[DYN_PACK_IDX:%.*]] = dynamic_pack_index [[IDX3]] of $Pack{repeat each Element} // CHECK: open_pack_element [[DYN_PACK_IDX]] of at , shape $each Element, uuid "[[UUID:.*]]" -// CHECK: [[STACK:%.*]] = alloc_stack [lexical] $@pack_element("[[UUID]]") each Element, let, name "element" +// CHECK: [[STACK:%.*]] = alloc_stack [lexical] $@pack_element("[[UUID]]") each Element, let, name "el" // CHECK: [[PACK_ELT_GET:%.*]] = pack_element_get [[DYN_PACK_IDX]] of [[PACK]] : $*Pack{repeat each Element} as $*@pack_element("[[UUID]]") each Element // CHECK: copy_addr [[PACK_ELT_GET]] to [init] [[STACK]] : $*@pack_element("[[UUID]]") each Element // CHECK: [[LOOP_END_FUNC:%.*]] = function_ref @loopBodyEnd : $@convention(thin) () -> () @@ -48,8 +57,8 @@ enum E { // CHECK: br [[LOOP_DEST]]([[IDX4]] : $Builtin.Word) // // CHECK: } // end sil function '$s14pack_iteration14iterateTrivial4overyxxQp_tRvzlF' -func iterateTrivial(over elements: repeat each Element) { - for element in repeat each elements { +func iterateTrivial(over element: repeat each Element) { + for el in repeat each element { loopBodyEnd() } funcEnd() @@ -140,4 +149,80 @@ func iteratePatternMatch(over element: repeat E) { funcEnd() } +// CHECK-LABEL: sil hidden [ossa] @$s14pack_iteration20iterateContinueBreak4overyxxQp_tRvzlF : $@convention(thin) (@pack_guaranteed Pack{repeat each Element}) -> () { +// CHECK: bb0([[PACK:%.*]] : $*Pack{repeat each Element}): +// CHECK: [[IDX1:%.*]] = integer_literal $Builtin.Word, 0 +// CHECK: [[IDX2:%.*]] = integer_literal $Builtin.Word, 1 +// CHECK: [[PACK_LENGTH:%.*]] = pack_length $Pack{repeat each Element} +// CHECK: br [[LOOP_DEST:bb[0-9]+]]([[IDX1]] : $Builtin.Word) +// +// CHECK: [[LOOP_DEST]]([[IDX3:%.*]] : $Builtin.Word): +// CHECK: [[COND:%.*]] = builtin "cmp_eq_Word"([[IDX3]] : $Builtin.Word, [[PACK_LENGTH]] : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br [[COND]], [[NONE_BB:bb[0-9]+]], [[SOME_BB:bb[0-9]+]] +// +// CHECK: [[SOME_BB]]: +// CHECK: [[DYN_PACK_IDX:%.*]] = dynamic_pack_index [[IDX3]] of $Pack{repeat each Element} +// CHECK: open_pack_element [[DYN_PACK_IDX]] of at , shape $each Element, uuid "[[UUID:.*]]" +// CHECK: [[STACK:%.*]] = alloc_stack [lexical] $@pack_element("[[UUID]]") each Element, let, name "el" +// CHECK: [[PACK_ELT_GET:%.*]] = pack_element_get [[DYN_PACK_IDX]] of [[PACK]] : $*Pack{repeat each Element} as $*@pack_element("[[UUID]]") each Element +// CHECK: copy_addr [[PACK_ELT_GET]] to [init] [[STACK]] : $*@pack_element("[[UUID]]") each Element +// CHECK: [[COND_FUNC:%.*]] = function_ref @condition : $@convention(thin) () -> Bool +// CHECK: [[BOOL:%.*]] = apply [[COND_FUNC]]() : $@convention(thin) () -> Bool +// CHECK: [[IF:%.*]] = struct_extract [[BOOL]] : $Bool, #Bool._value +// CHECK: cond_br [[IF]], [[LOOP_BREAK:bb[0-9]+]], [[LOOP_CONDITION:bb[0-9]+]] +// +// CHECK: [[LOOP_BREAK]]: +// CHECK: [[LOOP_BREAK_FUNC:%.*]] = function_ref @loopBreakEnd : $@convention(thin) () -> () +// CHECK: apply [[LOOP_BREAK_FUNC]]() : $@convention(thin) () -> () +// CHECK: destroy_addr [[STACK]] : $*@pack_element("[[UUID]]") each Element +// CHECK: dealloc_stack [[STACK]] : $*@pack_element("[[UUID]]") each Element +// br [[FUNC_END:bb[0-9]+]] +// +// CHECK: [[LOOP_CONDITION]]: +// CHECK: [[LOOP_CONDITION_FUNC:%.*]] = function_ref @condition : $@convention(thin) () -> Bool +// CHECK: [[BOOL:%.*]] = apply [[LOOP_CONDITION_FUNC]]() : $@convention(thin) () -> Bool +// CHECK: [[IF:%.*]] = struct_extract [[BOOL]] : $Bool, #Bool._value +// CHECK: cond_br [[IF]], [[LOOP_CONTINUE:bb[0-9]+]], [[LOOP_BODY_END:bb[0-9]+]] +// +// CHECK: [[LOOP_CONTINUE]]: +// CHECK: [[LOOP_CONTINUE_FUNC:%.*]] = function_ref @loopContinueEnd : $@convention(thin) () -> () +// CHECK: apply [[LOOP_CONTINUE_FUNC]]() : $@convention(thin) () -> () +// CHECK: destroy_addr [[STACK]] : $*@pack_element("[[UUID]]") each Element +// CHECK: dealloc_stack [[STACK]] : $*@pack_element("[[UUID]]") each Element +// CHECK: br [[LATCH:bb[0-9]+]] +// +// CHECK: [[LOOP_BODY_END]]: +// CHECK: [[LOOP_BODY_END_FUNC:%.*]] = function_ref @loopBodyEnd : $@convention(thin) () -> () +// CHECK: apply [[LOOP_BODY_END_FUNC]]() : $@convention(thin) () -> () +// CHECK: destroy_addr [[STACK]] : $*@pack_element("[[UUID]]") each Element +// CHECK: dealloc_stack [[STACK]] : $*@pack_element("[[UUID]]") each Element +// CHECK: br [[LATCH]] +// +// CHECK: [[NONE_BB]]: +// CHECK: br [[FUNC_END_BB:bb[0-9]+]] +// +// CHECK: [[FUNC_END_BB]] +// CHECK: [[FUNC_END_FUNC:%.*]] = function_ref @funcEnd : $@convention(thin) () -> () +// CHECK: apply [[FUNC_END_FUNC]]() : $@convention(thin) () -> () +// +// CHECK: [[LATCH]]: +// CHECK: [[ADD_WORD:%.*]] = builtin "add_Word"([[IDX3]] : $Builtin.Word, [[IDX2]] : $Builtin.Word) : $Builtin.Word +// CHECK: br [[LOOP_DEST]]([[ADD_WORD]] : $Builtin.Word) +// CHECK: } // end sil function '$s14pack_iteration20iterateContinueBreak4overyxxQp_tRvzlF' +func iterateContinueBreak(over element: repeat each Element) { + for el in repeat each element { + if (condition()) { + loopBreakEnd() + break + } + + if (condition()) { + loopContinueEnd() + continue + } + loopBodyEnd() + } + + funcEnd() +} From f276bd20798339092c6c5e4aff13d29a40b47b37 Mon Sep 17 00:00:00 2001 From: Sima Nerush <2002ssn@gmail.com> Date: Tue, 10 Oct 2023 11:58:24 -0700 Subject: [PATCH 06/10] Support pack iteration in closures --- lib/AST/ASTWalker.cpp | 8 ++--- test/SILGen/pack_iteration.swift | 59 ++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/lib/AST/ASTWalker.cpp b/lib/AST/ASTWalker.cpp index 3888b8a6755e1..08ac1f3eab6e7 100644 --- a/lib/AST/ASTWalker.cpp +++ b/lib/AST/ASTWalker.cpp @@ -1905,11 +1905,9 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) { // // If for-in is already type-checked, the type-checked version // of the sequence is going to be visited as part of `iteratorVar`. - if (S->getTypeCheckedSequence()) { - if (auto IteratorVar = S->getIteratorVar()) { - if (doIt(IteratorVar)) - return nullptr; - } + if (auto IteratorVar = S->getIteratorVar()) { + if (doIt(IteratorVar)) + return nullptr; if (auto NextCall = S->getNextCall()) { if ((NextCall = doIt(NextCall))) diff --git a/test/SILGen/pack_iteration.swift b/test/SILGen/pack_iteration.swift index 78d73462804ea..dd3a32844a5c8 100644 --- a/test/SILGen/pack_iteration.swift +++ b/test/SILGen/pack_iteration.swift @@ -226,3 +226,62 @@ func iterateContinueBreak(over element: repeat each Element) { funcEnd() } +// CHECK-LABEL: sil hidden [ossa] @$s14pack_iteration14iterateClosure4overyxxQp_tRvzlF : $@convention(thin) (@pack_guaranteed Pack{repeat each Element}) -> () { +// +// CHECK-LABEL: sil private [ossa] @$s14pack_iteration14iterateClosure4overyxxQp_tRvzlFyycfU_ : $@convention(thin) (@in_guaranteed (repeat each Element)) -> () { +// +// CHECK: bb0([[PACK:%.*]] : @closureCapture $*(repeat each Element)): +// CHECK: [[ALLOC_PACK:%.*]] = alloc_pack $Pack{repeat each Element} +// CHECK: [[IDX1:%.*]] = integer_literal $Builtin.Word, 0 +// CHECK: [[IDX2:%.*]] = integer_literal $Builtin.Word, 1 +// CHECK: [[PACK_LENGTH:%.*]] = pack_length $Pack{repeat each Element} +// CHECK: br [[LOOP_DEST:bb[0-9]+]]([[IDX1]] : $Builtin.Word) +// +// CHECK: [[LOOP_DEST]]([[IDX3:%.*]] : $Builtin.Word): +// CHECK: [[COND:%.*]] = builtin "cmp_eq_Word"([[IDX3]] : $Builtin.Word, [[PACK_LENGTH]] : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br [[COND]], [[SETUP_BB:bb[0-9]+]], [[ITER_BB:bb[0-9]+]] +// +// CHECK: [[ITER_BB]]: +// CHECK: [[PACK_IDX:%.*]] = dynamic_pack_index [[IDX3]] of $Pack{repeat each Element} +// CHECK: [[OPEN_ELT:%.*]] = open_pack_element [[PACK_IDX]] of at , shape $each Element, uuid "[[UUID:.*]]" +// CHECK: [[TUPLE_ADDR:%.*]] = tuple_pack_element_addr [[PACK_IDX]] of [[PACK]] : $*(repeat each Element) as $*@pack_element("[[UUID]]") each Element +// CHECK: pack_element_set [[TUPLE_ADDR]] : $*@pack_element("[[UUID]]") each Element into [[PACK_IDX]] of [[ALLOC_PACK]] : $*Pack{repeat each Element} +// CHECK: [[ADD_WORD:%.*]] = builtin "add_Word"([[IDX3]] : $Builtin.Word, [[IDX2]] : $Builtin.Word) : $Builtin.Word +// CHECK: br [[LOOP_DEST]]([[ADD_WORD]] : $Builtin.Word) +// +// CHECK: [[SETUP_BB]] +// CHECK: debug_value [[ALLOC_PACK]] : $*Pack{repeat each Element}, let, name "element", argno 1, expr op_deref +// CHECK: [[SETUP_0:%.*]] = integer_literal $Builtin.Word, 0 +// CHECK: [[SETUP_1:%.*]] = integer_literal $Builtin.Word, 1 +// CHECK: [[SETUP_LENGTH:%.*]] = pack_length $Pack{repeat each Element} +// CHECK: br [[LOOP_DEST:bb[0-9]+]]([[SETUP_0]] : $Builtin.Word) +// +// CHECK: [[LOOP_DEST]]([[IDX4:%.*]] : $Builtin.Word): +// CHECK: [[COND:%.*]] = builtin "cmp_eq_Word"([[IDX4]] : $Builtin.Word, [[SETUP_LENGTH]] : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br [[COND]], [[NONE_BB:bb[0-9]+]], [[SOME_BB:bb[0-9]+]] +// +// CHECK: [[SOME_BB]]: +// CHECK: [[DYN_PACK_IDX:%.*]] = dynamic_pack_index [[IDX4]] of $Pack{repeat each Element} +// CHECK: open_pack_element [[DYN_PACK_IDX]] of at , shape $each Element, uuid "[[UUID:.*]]" +// CHECK: [[STACK:%.*]] = alloc_stack [lexical] $@pack_element("[[UUID]]") each Element, let, name "el" +// CHECK: [[PACK_ELT_GET:%.*]] = pack_element_get [[DYN_PACK_IDX]] of [[ALLOC_PACK]] : $*Pack{repeat each Element} as $*@pack_element("[[UUID]]") each Element +// CHECK: copy_addr [[PACK_ELT_GET]] to [init] [[STACK]] : $*@pack_element("[[UUID]]") each Element +// CHECK: [[LOOP_END_FUNC:%.*]] = function_ref @loopBodyEnd : $@convention(thin) () -> () +// CHECK: apply [[LOOP_END_FUNC]]() : $@convention(thin) () -> () +// CHECK: destroy_addr [[STACK]] : $*@pack_element("[[UUID]]") each Element +// CHECK: dealloc_stack [[STACK]] : $*@pack_element("[[UUID]]") each Element +// CHECK: br [[LATCH_BB:bb[0-9]+]] +// +// CHECK: [[NONE_BB]]: +// CHECK: dealloc_pack [[ALLOC_PACK]] : $*Pack{repeat each Element} +// +// CHECK: [[LATCH_BB]]: +// CHECK: [[ADD_WORD:%.*]] = builtin "add_Word"([[IDX4]] : $Builtin.Word, [[SETUP_1]] : $Builtin.Word) : $Builtin.Word +// CHECK: br [[LOOP_DEST]]([[ADD_WORD]] : $Builtin.Word) +func iterateClosure(over element: repeat each Element) { + let _ = { () -> Void in + for el in repeat each element { + loopBodyEnd() + } + } +} From b0af0f762d35a557e0bf63d4e993d5121e071442 Mon Sep 17 00:00:00 2001 From: Sima Nerush <2002ssn@gmail.com> Date: Tue, 10 Oct 2023 11:15:15 -0700 Subject: [PATCH 07/10] Update diagnostics --- include/swift/AST/DiagnosticsSema.def | 4 ++-- test/Constraints/pack-expansion-expressions.swift | 6 +++--- test/Constraints/variadic_generic_functions.swift | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index e45f317063cc4..587c5629b90f2 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -5834,8 +5834,8 @@ ERROR(expansion_not_allowed,none, "pack expansion %0 can only appear in a function parameter list, " "tuple element, or generic argument of a variadic type", (Type)) ERROR(expansion_expr_not_allowed,none, - "value pack expansion can only appear inside a function argument list " - "or tuple element", ()) + "value pack expansion can only appear inside a function argument list, " + "tuple element, or as the expression of a for-in loop", ()) ERROR(invalid_expansion_argument,none, "cannot pass value pack expansion to non-pack parameter of type %0", (Type)) diff --git a/test/Constraints/pack-expansion-expressions.swift b/test/Constraints/pack-expansion-expressions.swift index 907e8231ee48c..ea6e52b3c33e7 100644 --- a/test/Constraints/pack-expansion-expressions.swift +++ b/test/Constraints/pack-expansion-expressions.swift @@ -290,10 +290,10 @@ func concrete(_: Int) {} func invalidRepeat(t: repeat each T) { _ = repeat each t - // expected-error@-1 {{value pack expansion can only appear inside a function argument list or tuple element}} + // expected-error@-1 {{value pack expansion can only appear inside a function argument list, tuple element, or as the expression of a for-in loop}} let _: Int = repeat each t - // expected-error@-1 {{value pack expansion can only appear inside a function argument list or tuple element}} + // expected-error@-1 {{value pack expansion can only appear inside a function argument list, tuple element, or as the expression of a for-in loop}} identity(identity(repeat each t)) // expected-error@-1 {{cannot pass value pack expansion to non-pack parameter of type 'T'}} @@ -302,7 +302,7 @@ func invalidRepeat(t: repeat each T) { // expected-error@-1 {{cannot pass value pack expansion to non-pack parameter of type 'Int'}} _ = [repeat each t] - // expected-error@-1 {{value pack expansion can only appear inside a function argument list or tuple element}} + // expected-error@-1 {{value pack expansion can only appear inside a function argument list, tuple element, or as the expression of a for-in loop}} } // Make sure that single parameter initializers are handled correctly because diff --git a/test/Constraints/variadic_generic_functions.swift b/test/Constraints/variadic_generic_functions.swift index 24542565d0b7e..c9217803fe9b2 100644 --- a/test/Constraints/variadic_generic_functions.swift +++ b/test/Constraints/variadic_generic_functions.swift @@ -74,7 +74,7 @@ func contextualTyping() { do { func foo(_: repeat each T = bar().element) {} // expected-note {{in call to function 'foo'}} // expected-error@-1 {{variadic parameter cannot have a default value}} - // expected-error@-2 {{value pack expansion can only appear inside a function argument list or tuple element}} + // expected-error@-2 {{value pack expansion can only appear inside a function argument list, tuple element, or as the expression of a for-in loop}} // expected-error@-3 {{generic parameter 'each T' could not be inferred}} func bar() -> (repeat each T) {} From a687032925ecb84b330b3d2a6a83dc7dc83fa562 Mon Sep 17 00:00:00 2001 From: Sima Nerush <2002ssn@gmail.com> Date: Wed, 20 Sep 2023 12:58:54 -0700 Subject: [PATCH 08/10] Add `repeat` code completion --- lib/IDE/CodeCompletion.cpp | 2 ++ test/IDE/complete_loop.swift | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/lib/IDE/CodeCompletion.cpp b/lib/IDE/CodeCompletion.cpp index 0a78898ea20b2..2143701ad5e5d 100644 --- a/lib/IDE/CodeCompletion.cpp +++ b/lib/IDE/CodeCompletion.cpp @@ -1062,6 +1062,8 @@ void CodeCompletionCallbacksImpl::addKeywords(CodeCompletionResultSink &Sink, addSuperKeyword(Sink, CurDeclContext); addExprKeywords(Sink, CurDeclContext); addAnyTypeKeyword(Sink, CurDeclContext->getASTContext().TheAnyType); + if (Kind == CompletionKind::ForEachSequence) + addKeyword(Sink, "repeat", CodeCompletionKeywordKind::kw_repeat); break; case CompletionKind::CallArg: diff --git a/test/IDE/complete_loop.swift b/test/IDE/complete_loop.swift index 805c29aa28f02..5ec4b45325b49 100644 --- a/test/IDE/complete_loop.swift +++ b/test/IDE/complete_loop.swift @@ -7,6 +7,7 @@ // RUN: %target-swift-ide-test -code-completion -source-filename %s -code-completion-token=LOOP_5 | %FileCheck %s -check-prefix=LOOP_5 // RUN: %target-swift-ide-test -code-completion -source-filename %s -code-completion-token=LOOP_6 | %FileCheck %s -check-prefix=LOOP_6 // RUN: %target-swift-ide-test -code-completion -source-filename %s -code-completion-token=LOOP_7 | %FileCheck %s -check-prefix=LOOP_6 +// RUN: %target-swift-ide-test -code-completion -source-filename %s -code-completion-token=LOOP_8 | %FileCheck %s -check-prefix=LOOP_8 class Gen { func IntGen() -> Int { return 0 } @@ -80,3 +81,9 @@ do { } // LOOP_6: Begin completions, 1 items // LOOP_6-CHECK-NEXT: Keyword[in]/None: in; name=in + +// Pack Iteration +do { + for t in #^LOOP_8^# {} +} +// LOOP_8-DAG: Keyword[repeat]/None: repeat; name=repeat From 23485990e562d2ce0c80e57793e193c83309d402 Mon Sep 17 00:00:00 2001 From: Sima Nerush <2002ssn@gmail.com> Date: Mon, 27 Nov 2023 20:19:59 -0800 Subject: [PATCH 09/10] Diagnose that `where` clause is not supported --- include/swift/AST/DiagnosticsSema.def | 7 +++++++ include/swift/Sema/CSFix.h | 23 +++++++++++++++++++++++ lib/Sema/CSDiagnostics.cpp | 5 +++++ lib/Sema/CSDiagnostics.h | 10 ++++++++++ lib/Sema/CSFix.cpp | 12 ++++++++++++ lib/Sema/CSGen.cpp | 5 +++++ lib/Sema/CSSimplify.cpp | 1 + test/stmt/foreach.swift | 15 +++++++++++++++ 8 files changed, 78 insertions(+) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 587c5629b90f2..887bee9a7841c 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -7715,5 +7715,12 @@ ERROR(referencebindings_binding_must_be_to_lvalue,none, ERROR(result_depends_on_no_result,none, "Incorrect use of %0 with no result", (StringRef)) +//------------------------------------------------------------------------------ +// MARK: Pack Iteration Diagnostics +//------------------------------------------------------------------------------ + +ERROR(pack_iteration_where_clause_not_supported, none, + "'where' clause in pack iteration is not supported", ()) + #define UNDEFINE_DIAGNOSTIC_MACROS #include "DefineDiagnosticMacros.h" diff --git a/include/swift/Sema/CSFix.h b/include/swift/Sema/CSFix.h index 4bd301e649f8c..5f3deeb6a6049 100644 --- a/include/swift/Sema/CSFix.h +++ b/include/swift/Sema/CSFix.h @@ -440,6 +440,9 @@ enum class FixKind : uint8_t { /// Allow pack expansion expressions in a context that does not support them. AllowInvalidPackExpansion, + /// Ignore `where` clause in a for-in loop with a pack expansion expression. + IgnoreWhereClauseInPackIteration, + /// Allow a pack expansion parameter of N elements to be matched /// with a single tuple literal argument of the same arity. DestructureTupleToMatchPackExpansionParameter, @@ -2223,6 +2226,26 @@ class AllowInvalidPackExpansion final : public ConstraintFix { } }; +class IgnoreWhereClauseInPackIteration final : public ConstraintFix { + IgnoreWhereClauseInPackIteration(ConstraintSystem &cs, + ConstraintLocator *locator) + : ConstraintFix(cs, FixKind::IgnoreWhereClauseInPackIteration, locator) {} + +public: + std::string getName() const override { + return "ignore where clause in pack iteration"; + } + + bool diagnose(const Solution &solution, bool asNote = false) const override; + + static IgnoreWhereClauseInPackIteration *create(ConstraintSystem &cs, + ConstraintLocator *locator); + + static bool classof(const ConstraintFix *fix) { + return fix->getKind() == FixKind::IgnoreWhereClauseInPackIteration; + } +}; + class CollectionElementContextualMismatch final : public ContextualMismatch, private llvm::TrailingObjectsgetWhere()) { + cs.recordFix(IgnoreWhereClauseInPackIteration::create( + cs, cs.getConstraintLocator(whereClause))); + } + auto packIterationInfo = generateForEachStmtConstraints(cs, dc, expansion, patternType); if (!packIterationInfo) { diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index f8e22a617f56e..ce956faac1124 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -14895,6 +14895,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyFixConstraint( case FixKind::AllowInvalidPackElement: case FixKind::AllowInvalidPackReference: case FixKind::AllowInvalidPackExpansion: + case FixKind::IgnoreWhereClauseInPackIteration: case FixKind::MacroMissingPound: case FixKind::AllowGlobalActorMismatch: case FixKind::AllowAssociatedValueMismatch: diff --git a/test/stmt/foreach.swift b/test/stmt/foreach.swift index 233116f46b08f..77caccc3d57f1 100644 --- a/test/stmt/foreach.swift +++ b/test/stmt/foreach.swift @@ -316,3 +316,18 @@ do { } } } + +// SE-0408 +do { + func variadic(ts: repeat each T) { + for t in repeat each ts where !ts.isEmpty {} + // expected-error@-1 {{'where' clause in pack iteration is not supported}} + + func test(_: () -> Void) {} + + test { + for t in repeat each ts where !ts.isEmpty {} + // expected-error@-1 {{'where' clause in pack iteration is not supported}} + } + } +} From 48cb3309bd5ddeb01a4e92250f1107085ffd0dd3 Mon Sep 17 00:00:00 2001 From: Holly Borla Date: Thu, 31 Aug 2023 19:42:49 -0700 Subject: [PATCH 10/10] [Constraint System] Fix the shape class and context substitiutions for opened element generic environments containing same-element requirements. --- include/swift/Sema/ConstraintSystem.h | 5 +++++ lib/AST/GenericSignature.cpp | 10 ++++++++- lib/Sema/CSApply.cpp | 5 +++++ lib/Sema/CSGen.cpp | 11 +--------- lib/Sema/CSSimplify.cpp | 22 ++++++++++++++----- .../pack-expansion-expressions.swift | 10 +++++++++ 6 files changed, 47 insertions(+), 16 deletions(-) diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index b6ddbf3eb8db5..6a037654493dd 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -3759,6 +3759,11 @@ class ConstraintSystem { RememberChoice_t rememberChoice, ConstraintLocatorBuilder locator, ConstraintFix *compatFix = nullptr); + + /// Add a materialize constraint for a pack expansion. + TypeVariableType * + addMaterializePackExpansionConstraint(Type patternType, + ConstraintLocatorBuilder locator); /// Add a disjunction constraint. void diff --git a/lib/AST/GenericSignature.cpp b/lib/AST/GenericSignature.cpp index 1fa30884d5c25..3d22e4ceff37a 100644 --- a/lib/AST/GenericSignature.cpp +++ b/lib/AST/GenericSignature.cpp @@ -105,8 +105,16 @@ void GenericSignatureImpl::forEachParam( for (auto req : getRequirements()) { GenericTypeParamType *gp; + bool isCanonical = false; switch (req.getKind()) { case RequirementKind::SameType: { + if (req.getSecondType()->isParameterPack() != + req.getFirstType()->isParameterPack()) { + // This is a same-element requirement, which does not make + // type parameters non-canonical. + isCanonical = true; + } + if (auto secondGP = req.getSecondType()->getAs()) { // If two generic parameters are same-typed, then the right-hand one // is non-canonical. @@ -136,7 +144,7 @@ void GenericSignatureImpl::forEachParam( } unsigned index = GenericParamKey(gp).findIndexIn(genericParams); - genericParamsAreCanonical[index] = false; + genericParamsAreCanonical[index] = isCanonical; } // Call the callback with each parameter and the result of the above analysis. diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 0260a8eeee7ab..981b848fdbcdf 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -3843,6 +3843,11 @@ namespace { auto *locator = cs.getConstraintLocator(expr); auto *environment = cs.getPackElementEnvironment(locator, expansionTy->getCountType()->getCanonicalType()); + + // Assert that we have an opened element environment, otherwise we'll get + // an ASTVerifier crash when pack archetypes or element archetypes appear + // inside the pack expansion expression. + assert(environment); expr->setGenericEnvironment(environment); return expr; diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index b0dff78b4a548..69ecb52951d34 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -3198,18 +3198,9 @@ namespace { auto expansionType = CS.getType(packEnvironment)->castTo(); CS.addConstraint(ConstraintKind::ShapeOf, expansionType->getCountType(), - elementType, + packType, CS.getConstraintLocator(packEnvironment, ConstraintLocator::PackShape)); - auto *elementShape = CS.createTypeVariable( - CS.getConstraintLocator(expr, ConstraintLocator::PackShape), - TVO_CanBindToPack); - CS.addConstraint( - ConstraintKind::ShapeOf, elementShape, elementType, - CS.getConstraintLocator(expr, ConstraintLocator::PackShape)); - CS.addConstraint( - ConstraintKind::Equal, elementShape, expansionType->getCountType(), - CS.getConstraintLocator(expr, ConstraintLocator::PackShape)); } else { CS.recordFix(AllowInvalidPackReference::create( CS, packType, CS.getConstraintLocator(expr->getPackRefExpr()))); diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index ce956faac1124..5dc2de76357c0 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -9313,11 +9313,7 @@ ConstraintSystem::simplifyPackElementOfConstraint(Type first, Type second, } if (isSingleUnlabeledPackExpansionTuple(patternType)) { - auto *packVar = - createTypeVariable(getConstraintLocator(locator), TVO_CanBindToPack); - addConstraint(ConstraintKind::MaterializePackExpansion, patternType, - packVar, - getConstraintLocator(locator, {ConstraintLocator::Member})); + auto *packVar = addMaterializePackExpansionConstraint(patternType, locator); addConstraint(ConstraintKind::PackElementOf, elementType, packVar, locator); return SolutionKind::Solved; } @@ -13440,6 +13436,12 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyShapeOfConstraint( return SolutionKind::Solved; } + if (isSingleUnlabeledPackExpansionTuple(packTy)) { + auto *packVar = addMaterializePackExpansionConstraint(packTy, locator); + addConstraint(ConstraintKind::ShapeOf, shapeTy, packVar, locator); + return SolutionKind::Solved; + } + // Map element archetypes to the pack context to check for equality. if (packTy->hasElementArchetype()) { auto *packEnv = DC->getGenericEnvironmentOfContext(); @@ -15623,6 +15625,16 @@ void ConstraintSystem::addExplicitConversionConstraint( addDisjunctionConstraint(constraints, locator, rememberChoice); } +TypeVariableType *ConstraintSystem::addMaterializePackExpansionConstraint( + Type patternType, ConstraintLocatorBuilder locator) { + assert(isSingleUnlabeledPackExpansionTuple(patternType)); + TypeVariableType *packVar = + createTypeVariable(getConstraintLocator(locator), TVO_CanBindToPack); + addConstraint(ConstraintKind::MaterializePackExpansion, patternType, packVar, + getConstraintLocator(locator, {ConstraintLocator::Member})); + return packVar; +} + ConstraintSystem::SolutionKind ConstraintSystem::simplifyConstraint(const Constraint &constraint) { auto matchKind = constraint.getKind(); diff --git a/test/Constraints/pack-expansion-expressions.swift b/test/Constraints/pack-expansion-expressions.swift index ea6e52b3c33e7..fdc6e6a60ba6b 100644 --- a/test/Constraints/pack-expansion-expressions.swift +++ b/test/Constraints/pack-expansion-expressions.swift @@ -734,3 +734,13 @@ do { } } } + +// Pack Iteration +do { + func test(_ t: repeat each T) { + func nested() -> (repeat (Int, each T)) {} + for (x, y) in repeat each nested() {} + // expected-warning@-1 {{immutable value 'x' was never used; consider replacing with '_' or removing it}} + // expected-warning@-2 {{immutable value 'y' was never used; consider replacing with '_' or removing it}} + } +}