diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp index bf3724bdd466..4c6b5514889f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp @@ -66,64 +66,9 @@ void transform_dialect::ApplyLowerValueBarrierOp::populatePatterns( // ApplyUnrollMultiMmaOp //===---------------------------------------------------------------------===// -static bool isReductionIterator(Attribute attr) { - return cast(attr).getValue() == - utils::IteratorType::reduction; -} -static bool isParallelIterator(Attribute attr) { - return cast(attr).getValue() == - utils::IteratorType::parallel; -} - -/// Pick an unrolling order that reuses the LHS register. -static std::optional> -gpuMultiMmaUnrollOrder(Operation *op) { - IREE::GPU::MultiMmaOp mmaOp = dyn_cast(op); - if (!mmaOp) { - return std::nullopt; - } - SmallVector order; - // First make reduction the outer dimensions. - for (auto [index, iter] : llvm::enumerate(mmaOp.getIteratorTypes())) { - if (isReductionIterator(iter)) { - order.push_back(index); - } - } - - llvm::SmallDenseSet dims; - for (AffineExpr expr : mmaOp.getIndexingMapsArray()[0].getResults()) { - dims.insert(cast(expr).getPosition()); - } - // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. - for (auto [index, iter] : llvm::enumerate(mmaOp.getIteratorTypes())) { - if (isParallelIterator(iter) && dims.count(index)) { - order.push_back(index); - } - } - // Then the remaining parallel loops. - for (auto [index, iter] : llvm::enumerate(mmaOp.getIteratorTypes())) { - if (isParallelIterator(iter) && !dims.count(index)) { - order.push_back(index); - } - } - return order; -} - -static std::optional> getMultiMmaUnitShape(Operation *op) { - IREE::GPU::MultiMmaOp mmaOp = dyn_cast(op); - if (!mmaOp) { - return std::nullopt; - } - SmallVector targetOuterShape(mmaOp.getIteratorTypes().size(), 1); - return targetOuterShape; -} - void transform_dialect::ApplyUnrollMultiMmaOp::populatePatterns( RewritePatternSet &patterns) { - GPU::populateIREEGPUVectorUnrollPatterns( - patterns, vector::UnrollVectorOptions() - .setNativeShapeFn(getMultiMmaUnitShape) - .setUnrollTraversalOrderFn(gpuMultiMmaUnrollOrder)); + GPU::populateIREEGPUVectorUnrollPatterns(patterns); } //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel index 9a4574deb7db..00af941c2721 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel @@ -57,6 +57,7 @@ iree_compiler_cc_library( "PackToIntrinsics.cpp", "Passes.cpp", "Transforms.cpp", + "UnrollToIntrinsics.cpp", "VectorizeIREEGPUOps.cpp", ], hdrs = [ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt index 1f0a472dfd89..7d49b1df7e72 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt @@ -51,6 +51,7 @@ iree_cc_library( "PackToIntrinsics.cpp" "Passes.cpp" "Transforms.cpp" + "UnrollToIntrinsics.cpp" "VectorizeIREEGPUOps.cpp" DEPS ::PassesIncGen diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td index 5bfb0d4f5e0f..9f64563ef0ad 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td @@ -30,6 +30,13 @@ def FuseAndHoistParallelLoopsPass : ]; } +def LowerIREEGPUOpsPass : + InterfacePass<"iree-gpu-lower-ops", "mlir::FunctionOpInterface"> { + let summary = "Post bufferization lowerings of iree_gpu ops before late lowerings"; + let dependentDialects = [ + "::mlir::gpu::GPUDialect", + ]; +} def PackToIntrinsicsPass : InterfacePass<"iree-gpu-pack-to-intrinsics", "mlir::FunctionOpInterface"> { @@ -40,6 +47,15 @@ def PackToIntrinsicsPass : ]; } +def UnrollToIntrinsicsPass : + InterfacePass<"iree-gpu-unroll-to-intrinsics", "mlir::FunctionOpInterface"> { + let summary = "Unrolls iree_gpu.multi_mma ops to their inner vector size."; + let dependentDialects = [ + "::mlir::arith::ArithDialect", + "::mlir::vector::VectorDialect", + ]; +} + def VectorizeIREEGPUOpsPass : InterfacePass<"iree-gpu-vectorize-ops", "mlir::FunctionOpInterface"> { let summary = "Vectorizes then lowers a few iree_gpu ops before vectorization."; @@ -50,12 +66,4 @@ def VectorizeIREEGPUOpsPass : ]; } -def LowerIREEGPUOpsPass : - InterfacePass<"iree-gpu-lower-ops", "mlir::FunctionOpInterface"> { - let summary = "Post bufferization lowerings of iree_gpu ops before late lowerings"; - let dependentDialects = [ - "::mlir::gpu::GPUDialect", - ]; -} - #endif // IREE_CODEGEN_DIALECt_GPU_TRANSFORMS_PASSES diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index c2c5c03f5437..389ffc3a76e0 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -796,6 +796,65 @@ void populateIREEGPUVectorUnrollPatterns( patterns.add(patterns.getContext(), options); } +static bool isReductionIterator(Attribute attr) { + return cast(attr).getValue() == + utils::IteratorType::reduction; +} +static bool isParallelIterator(Attribute attr) { + return cast(attr).getValue() == + utils::IteratorType::parallel; +} + +/// Pick an unrolling order that reuses the LHS register. +static std::optional> +gpuMultiMmaUnrollOrder(Operation *op) { + IREE::GPU::MultiMmaOp mmaOp = dyn_cast(op); + if (!mmaOp) { + return std::nullopt; + } + SmallVector order; + // First make reduction the outer dimensions. + for (auto [index, iter] : llvm::enumerate(mmaOp.getIteratorTypes())) { + if (isReductionIterator(iter)) { + order.push_back(index); + } + } + + llvm::SmallDenseSet dimsInLhs; + for (AffineExpr expr : mmaOp.getIndexingMapsArray()[0].getResults()) { + dimsInLhs.insert(cast(expr).getPosition()); + } + // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. + for (auto [index, iter] : llvm::enumerate(mmaOp.getIteratorTypes())) { + if (isParallelIterator(iter) && dimsInLhs.count(index)) { + order.push_back(index); + } + } + // Then the remaining parallel loops. + for (auto [index, iter] : llvm::enumerate(mmaOp.getIteratorTypes())) { + if (isParallelIterator(iter) && !dimsInLhs.count(index)) { + order.push_back(index); + } + } + return order; +} + +static std::optional> getMultiMmaUnitShape(Operation *op) { + IREE::GPU::MultiMmaOp mmaOp = dyn_cast(op); + if (!mmaOp) { + return std::nullopt; + } + SmallVector targetOuterShape(mmaOp.getIteratorTypes().size(), 1); + return targetOuterShape; +} + +void populateIREEGPUVectorUnrollPatterns(RewritePatternSet &patterns) { + populateIREEGPUVectorUnrollPatterns( + patterns, vector::UnrollVectorOptions() + .setNativeShapeFn(getMultiMmaUnitShape) + .setUnrollTraversalOrderFn(gpuMultiMmaUnrollOrder)); +} + //===---------------------------------------------------------------------===// // Resolving lane mapped forall ops //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h index 3f502156e3b8..ca597cf234cf 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h @@ -63,6 +63,8 @@ void populateIREEGPULowerShuffleTensorPatterns(RewritePatternSet &patterns); void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns); void populateIREEGPUVectorUnrollPatterns( RewritePatternSet &patterns, const vector::UnrollVectorOptions &options); +// Version of unrolling with a preset configuration. +void populateIREEGPUVectorUnrollPatterns(RewritePatternSet &patterns); void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns); } // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/UnrollToIntrinsics.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/UnrollToIntrinsics.cpp new file mode 100644 index 000000000000..8ff72f90c5e2 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/UnrollToIntrinsics.cpp @@ -0,0 +1,58 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler::IREE::GPU { + +#define GEN_PASS_DEF_UNROLLTOINTRINSICSPASS +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc" + +namespace { +struct UnrollToIntrinsicsPass final + : impl::UnrollToIntrinsicsPassBase { + void runOnOperation() override; +}; +} // namespace + +void UnrollToIntrinsicsPass::runOnOperation() { + MLIRContext *context = &getContext(); + + { + RewritePatternSet patterns(context); + GPU::populateIREEGPUVectorUnrollPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + + // Post unrolling unit dim folding patterns in preparation for later + // lowerings. + { + RewritePatternSet patterns(context); + GPU::populateIREEGPUDropUnitDimsPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +} + +} // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 47d137e3d344..957813aa8ce9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -369,6 +369,11 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) { // Vectorize copies that came out of vectorization. funcPassManager.addPass(createVectorizeMemrefCopyPass()); + // Step 7. Unroll operations to native intrinsic widths. + funcPassManager.addPass(IREE::GPU::createUnrollToIntrinsicsPass()); + funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createCSEPass()); + // Step 8. Remaining post-bufferization optimizations/lowerings. funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass()); funcPassManager.addPass(createLoopInvariantCodeMotionPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index b09a6495fa3f..f999765af78d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -115,15 +115,15 @@ hal.executable public @main { // CHECK: gpu.barrier // CHECK: %[[LHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16> // CHECK: gpu.barrier -// CHECK: %[[LHS_T:.+]] = vector.transpose %[[LHS_MM]], [0, 2, 1, 3] : vector<2x1x2x4xf16> +// CHECK: vector.transpose %[[LHS_MM]], [0, 2, 1, 3] : vector<2x1x2x4xf16> // CHECK: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> // CHECK: vector.transfer_write %[[RHS_RD]] // CHECK: gpu.barrier // CHECK: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16> // CHECK: gpu.barrier -// CHECK: %[[RHS_T:.+]] = vector.transpose %[[RHS_MM]], [0, 2, 1, 3] : vector<2x1x2x4xf16> -// CHECK: %[[MM:.+]] = iree_gpu.multi_mma %[[LHS_T]], %[[RHS_T]] -// CHECK: scf.yield %[[MM]] +// CHECK: vector.transpose %[[RHS_MM]], [0, 2, 1, 3] : vector<2x1x2x4xf16> +// CHECK-COUNT-4: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32 +// CHECK: scf.yield // CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 1, 3] : vector<2x2x4x1xf32> to vector<2x4x2x1xf32> // CHECK: vector.transfer_write %[[LOOP_T]], %[[B2]] @@ -186,32 +186,31 @@ hal.executable private @main { } } -// CHECK-LABEL: func @conv_igemm_im2col -// CHECK-DAG: %[[B0:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(0) -// CHECK-DAG: %[[B1:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(1) -// CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(2) -// CHECK-DAG: memref.alloc() : memref<1x64x32xf16, #gpu.address_space> -// CHECK-DAG: memref.alloc() : memref<32x64xf16, #gpu.address_space> -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C720:.+]] = arith.constant 720 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x2x2x4x1xf32>) -// CHECK: gpu.barrier -// CHECK: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> -// CHECK: vector.transfer_write %[[LHS_RD]] -// CHECK: gpu.barrier -// CHECK: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16> -// CHECK: %[[LHS_MM1:.+]] = vector.broadcast {{.*}} vector<2x1x2x4xf16> to vector<1x2x1x2x4xf16> -// CHECK: gpu.barrier -// CHECK: %[[LHS_T:.+]] = vector.transpose %[[LHS_MM1]], [0, 1, 3, 2, 4] : vector<1x2x1x2x4xf16> to vector<1x2x2x1x4xf16> -// CHECK: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> -// CHECK: vector.transfer_write %[[RHS_RD]] -// CHECK: gpu.barrier -// CHECK: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x4x2x1xf16> -// CHECK: gpu.barrier -// CHECK: %[[RHS_T:.+]] = vector.transpose %[[RHS_MM]], [0, 2, 3, 1] : vector<2x4x2x1xf16> to vector<2x2x1x4xf16> -// CHECK: %[[MM:.+]] = iree_gpu.multi_mma %[[LHS_T]], %[[RHS_T]] -// CHECK: scf.yield %[[MM]] -// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 3, 2, 4] : vector<1x2x2x4x1xf32> to vector<1x2x4x2x1xf32> -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<2x4x2x1xf32> from vector<1x2x4x2x1xf32> -// CHECK: vector.transfer_write %[[EXTRACT]], %[[B2]] +// CHECK-LABEL: func @conv_igemm_im2col +// CHECK-DAG: %[[B0:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(0) +// CHECK-DAG: %[[B1:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(1) +// CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(2) +// CHECK-DAG: memref.alloc() : memref<1x64x32xf16, #gpu.address_space> +// CHECK-DAG: memref.alloc() : memref<32x64xf16, #gpu.address_space> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C720:.+]] = arith.constant 720 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x2x2x4x1xf32>) +// CHECK: gpu.barrier +// CHECK: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> +// CHECK: vector.transfer_write %[[LHS_RD]] +// CHECK: gpu.barrier +// CHECK: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16> +// CHECK: %[[LHS_MM1:.+]] = vector.broadcast {{.*}} vector<2x1x2x4xf16> to vector<1x2x1x2x4xf16> +// CHECK: gpu.barrier +// CHECK: vector.transpose %[[LHS_MM1]], [0, 1, 3, 2, 4] : vector<1x2x1x2x4xf16> to vector<1x2x2x1x4xf16> +// CHECK: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> +// CHECK: vector.transfer_write %[[RHS_RD]] +// CHECK: gpu.barrier +// CHECK: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x4x2x1xf16> +// CHECK: gpu.barrier +// CHECK: vector.transpose %[[RHS_MM]], [0, 2, 3, 1] : vector<2x4x2x1xf16> to vector<2x2x1x4xf16> +// CHECK-COUNT-4: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32 +// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 3, 2, 4] : vector<1x2x2x4x1xf32> to vector<1x2x4x2x1xf32> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<2x4x2x1xf32> from vector<1x2x4x2x1xf32> +// CHECK: vector.transfer_write %[[EXTRACT]], %[[B2]]