diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp index 4d843354cef3..3e2df9002431 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp @@ -285,8 +285,6 @@ LogicalResult distributeVectorOps(Operation *root, // Run the analysis and determine the layouts. LLVM_DEBUG(llvm::dbgs() << "Running Layout Analysis\n"); VectorLayoutAnalysis analysis(root); - if (failed(options.setAnchorOps(analysis))) - return failure(); if (failed(analysis.run())) return failure(); LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Succeded\n"); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h index dd973337f3d0..2ef2eb39fcd7 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h @@ -106,9 +106,6 @@ class VectorLayoutOptions { virtual ~VectorLayoutOptions() = default; - /// Set the anchor ops in the analysis rooted on the root operation. - virtual LogicalResult setAnchorOps(VectorLayoutAnalysis &analysis) = 0; - bool verifyConversion() const { return fullConversion; } protected: diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index 93516c2e685b..47f12f006c51 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -1100,10 +1100,6 @@ class TestVectorLayoutOptions : public VectorLayoutOptions { public: TestVectorLayoutOptions(Operation *root) : VectorLayoutOptions(root, /*fullConversion=*/false) {} - - LogicalResult setAnchorOps(VectorLayoutAnalysis &analysis) override { - return success(); - } }; DiagnosedSilenceableFailure diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp index 6dceae765201..784838ff5a77 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp @@ -1023,17 +1023,6 @@ DistributionLayout *EnforceLayout::getLatticeElement(Value val) { /// VectorLayoutAnalysis /// ========================================================================== -LogicalResult VectorLayoutAnalysis::setAnchor(Value val, - VectorLayoutInterface layout) { - auto typedVal = dyn_cast>(val); - assert(typedVal && "expected value to be a vector type"); - if (layout.isValidLayout(typedVal).failed()) { - return failure(); - } - anchors[typedVal] = cast(layout); - return success(); -} - LogicalResult VectorLayoutAnalysis::run() { // The order of loading matters here, because propagateLayout does anchoring // initialization which needs the lattice to know both enforcement and diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.h b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.h index 5ede054347c0..0eb70b1a17e4 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.h +++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.h @@ -98,10 +98,6 @@ class VectorLayoutAnalysis { public: VectorLayoutAnalysis(Operation *root) : root(root) {} - /// Fix the layout for a specific value. Returns failure if the layout set is - /// invalid for the value. - LogicalResult setAnchor(Value val, VectorLayoutInterface layout); - /// Run the analysis. The analysis expects that the user has set some anchor /// points and is trying to infer the layout of other values. LogicalResult run(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp index 7ad8dcfd350c..25adf1f4491c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp @@ -67,22 +67,34 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions { subgroupSize); } - LogicalResult setAnchorOps(VectorLayoutAnalysis &analysis) override { + LogicalResult setAnchorOps(RewriterBase &rewriter) { MLIRContext *context = root->getContext(); - WalkResult walkResult = root->walk([&](Operation *op) { - LogicalResult setResult = - llvm::TypeSwitch(op) - .Case([&](vector::ContractionOp contract) { - return setContractionAnchor(context, analysis, contract); - }) - .Case([&](vector::TransferReadOp transfer) { - return setTransferReadAnchor(context, analysis, transfer); - }) - .Default([](Operation *) { return success(); }); - return failed(setResult) ? WalkResult::interrupt() - : WalkResult::advance(); + SmallVector reads; + SmallVector contracts; + + root->walk([&](Operation *op) { + llvm::TypeSwitch(op) + .Case([&](vector::TransferReadOp transfer) { + reads.push_back(transfer); + }) + .Case([&](vector::ContractionOp contract) { + contracts.push_back(contract); + }); }); - return failure(walkResult.wasInterrupted()); + + for (vector::TransferReadOp read : reads) { + if (failed(setTransferReadAnchor(context, rewriter, read))) { + return failure(); + } + } + + for (vector::ContractionOp contract : contracts) { + if (failed(setContractionAnchor(context, rewriter, contract))) { + return failure(); + } + } + + return success(); } RewritePatternSet &getPatterns() { return patterns; } @@ -92,7 +104,7 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions { // supported mma type from the cached list of mma types and populates the // necessary distribution pattern for those contractions. LogicalResult setContractionAnchor(MLIRContext *context, - VectorLayoutAnalysis &analysis, + RewriterBase &rewriter, vector::ContractionOp contract) { // TODO: Add SIMT fallback. if (!schedule) { @@ -105,19 +117,29 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions { } auto [aLayout, bLayout, cLayout] = *layouts; - if (analysis.setAnchor(contract.getLhs(), aLayout).failed()) { - return failure(); - } - if (analysis.setAnchor(contract.getRhs(), bLayout).failed()) { - return failure(); - } - if (analysis.setAnchor(contract.getAcc(), cLayout).failed()) { - return failure(); - } - if (analysis.setAnchor(contract.getResult(), cLayout).failed()) { - return failure(); - } + Location loc = contract.getLoc(); + + // Set layouts for lhs, rhs and acc. + rewriter.setInsertionPoint(contract); + Value layoutedLhs = rewriter.create( + loc, contract.getLhsType(), contract.getLhs(), aLayout); + Value layoutedRhs = rewriter.create( + loc, contract.getRhsType(), contract.getRhs(), bLayout); + Value layoutedAcc = rewriter.create( + loc, contract.getAccType(), contract.getAcc(), cLayout); + contract->setOperand(0, layoutedLhs); + contract->setOperand(1, layoutedRhs); + contract->setOperand(2, layoutedAcc); + + // Set layout for result. + rewriter.setInsertionPointAfter(contract); + auto toLayout = rewriter.create( + loc, contract.getResultType(), contract.getResult(), cLayout); + rewriter.replaceAllUsesExcept(contract, toLayout.getResult(), toLayout); + + // Set intrinsic kind. contract->setAttr("iree.amdgpu.mma", schedule.getIntrinsic()); + if (printLayout) { llvm::outs() << "contract A vector layout: " << aLayout << "\n"; llvm::outs() << "contract B vector layout: " << bLayout << "\n"; @@ -168,7 +190,7 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions { // // *_order = [0, 1]> LogicalResult setTransferReadAnchor(MLIRContext *context, - VectorLayoutAnalysis &analysis, + RewriterBase &rewriter, vector::TransferReadOp transfer) { // Get the forward slice of the transfer to approximate whether it will take @@ -332,9 +354,13 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions { auto layout = IREE::VectorExt::NestedLayoutAttr::get( context, subgroupCounts, batchSizes, outerSizes, threadCounts, elementSizes, subgroupStrides, threadStrides); - if (analysis.setAnchor(transfer.getResult(), layout).failed()) { - return failure(); - } + + Location loc = transfer.getLoc(); + rewriter.setInsertionPointAfter(transfer); + auto toLayout = rewriter.create( + loc, transfer.getResult().getType(), transfer.getResult(), layout); + rewriter.replaceAllUsesExcept(transfer, toLayout.getResult(), toLayout); + if (printLayout) { llvm::outs() << "transfer '" << transfer << "' vector layout: " << layout << "\n"; @@ -403,16 +429,18 @@ struct LLVMGPUVectorDistributePass AffineExpr linearId = x + workgroupSize[0] * y + workgroupSize[1] * workgroupSize[0] * z; - OpBuilder builder(func); - builder.setInsertionPointToStart(&func.getFunctionBody().front()); + IRRewriter rewriter(func); + rewriter.setInsertionPointToStart(&func.getFunctionBody().front()); SmallVector threadGrid = { - builder.createOrFold(func.getLoc(), gpu::Dimension::x), - builder.createOrFold(func.getLoc(), gpu::Dimension::y), - builder.createOrFold(func.getLoc(), - gpu::Dimension::z)}; + rewriter.createOrFold(func.getLoc(), + gpu::Dimension::x), + rewriter.createOrFold(func.getLoc(), + gpu::Dimension::y), + rewriter.createOrFold(func.getLoc(), + gpu::Dimension::z)}; Value linearThreadIdVal = affine::makeComposedAffineApply( - builder, func.getLoc(), linearId, threadGrid); + rewriter, func.getLoc(), linearId, threadGrid); std::optional subgroupSize = getSubgroupSize(func); if (!subgroupSize) { @@ -424,6 +452,13 @@ struct LLVMGPUVectorDistributePass ContractionVectorLayoutOptions options(func, workgroupSize, scheduleAttr, linearThreadIdVal, subgroupSize.value(), testLayout); + + // Set anchor layouts. + if (failed(options.setAnchorOps(rewriter))) { + func->emitError() << "failed to set anchors"; + return signalPassFailure(); + } + if (failed(distributeVectorOps(func, options.getPatterns(), options))) { func->emitOpError() << "failed to distribute"; return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index cc5f4a1d0851..f720f07d938a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -1474,10 +1474,6 @@ class TransformVectorLayoutOptions : public VectorLayoutOptions { public: TransformVectorLayoutOptions(Operation *root, bool fullConversion) : VectorLayoutOptions(root, fullConversion) {} - - LogicalResult setAnchorOps(VectorLayoutAnalysis &analysis) override { - return success(); - } }; DiagnosedSilenceableFailure