Skip to content

Commit

Permalink
[Codegen][GPU] Add pass to unroll to native mma widths (#18101)
Browse files Browse the repository at this point in the history
Left as TODO is to extend this pass to also unroll consumers/producers.
To eventually connect everything we should add a pattern that unrolls
iter args of `scf.for` loops.

Note that the unrolling pattern is already tested in its own test, so
simply updating the pipeline test is fine here. Once extending this pass
to do more unrolling a test can be added.
  • Loading branch information
qedawkins committed Aug 5, 2024
1 parent 09dc003 commit 4a1f619
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,64 +66,9 @@ void transform_dialect::ApplyLowerValueBarrierOp::populatePatterns(
// ApplyUnrollMultiMmaOp
//===---------------------------------------------------------------------===//

static bool isReductionIterator(Attribute attr) {
return cast<IREE::GPU::IteratorTypeAttr>(attr).getValue() ==
utils::IteratorType::reduction;
}
static bool isParallelIterator(Attribute attr) {
return cast<IREE::GPU::IteratorTypeAttr>(attr).getValue() ==
utils::IteratorType::parallel;
}

/// Pick an unrolling order that reuses the LHS register.
static std::optional<SmallVector<int64_t>>
gpuMultiMmaUnrollOrder(Operation *op) {
IREE::GPU::MultiMmaOp mmaOp = dyn_cast<IREE::GPU::MultiMmaOp>(op);
if (!mmaOp) {
return std::nullopt;
}
SmallVector<int64_t> order;
// First make reduction the outer dimensions.
for (auto [index, iter] : llvm::enumerate(mmaOp.getIteratorTypes())) {
if (isReductionIterator(iter)) {
order.push_back(index);
}
}

llvm::SmallDenseSet<int64_t> dims;
for (AffineExpr expr : mmaOp.getIndexingMapsArray()[0].getResults()) {
dims.insert(cast<AffineDimExpr>(expr).getPosition());
}
// Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
for (auto [index, iter] : llvm::enumerate(mmaOp.getIteratorTypes())) {
if (isParallelIterator(iter) && dims.count(index)) {
order.push_back(index);
}
}
// Then the remaining parallel loops.
for (auto [index, iter] : llvm::enumerate(mmaOp.getIteratorTypes())) {
if (isParallelIterator(iter) && !dims.count(index)) {
order.push_back(index);
}
}
return order;
}

static std::optional<SmallVector<int64_t>> getMultiMmaUnitShape(Operation *op) {
IREE::GPU::MultiMmaOp mmaOp = dyn_cast<IREE::GPU::MultiMmaOp>(op);
if (!mmaOp) {
return std::nullopt;
}
SmallVector<int64_t> targetOuterShape(mmaOp.getIteratorTypes().size(), 1);
return targetOuterShape;
}

void transform_dialect::ApplyUnrollMultiMmaOp::populatePatterns(
RewritePatternSet &patterns) {
GPU::populateIREEGPUVectorUnrollPatterns(
patterns, vector::UnrollVectorOptions()
.setNativeShapeFn(getMultiMmaUnitShape)
.setUnrollTraversalOrderFn(gpuMultiMmaUnrollOrder));
GPU::populateIREEGPUVectorUnrollPatterns(patterns);
}

//===---------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ iree_compiler_cc_library(
"PackToIntrinsics.cpp",
"Passes.cpp",
"Transforms.cpp",
"UnrollToIntrinsics.cpp",
"VectorizeIREEGPUOps.cpp",
],
hdrs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ iree_cc_library(
"PackToIntrinsics.cpp"
"Passes.cpp"
"Transforms.cpp"
"UnrollToIntrinsics.cpp"
"VectorizeIREEGPUOps.cpp"
DEPS
::PassesIncGen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def FuseAndHoistParallelLoopsPass :
];
}

def LowerIREEGPUOpsPass :
InterfacePass<"iree-gpu-lower-ops", "mlir::FunctionOpInterface"> {
let summary = "Post bufferization lowerings of iree_gpu ops before late lowerings";
let dependentDialects = [
"::mlir::gpu::GPUDialect",
];
}

def PackToIntrinsicsPass :
InterfacePass<"iree-gpu-pack-to-intrinsics", "mlir::FunctionOpInterface"> {
Expand All @@ -40,6 +47,15 @@ def PackToIntrinsicsPass :
];
}

def UnrollToIntrinsicsPass :
InterfacePass<"iree-gpu-unroll-to-intrinsics", "mlir::FunctionOpInterface"> {
let summary = "Unrolls iree_gpu.multi_mma ops to their inner vector size.";
let dependentDialects = [
"::mlir::arith::ArithDialect",
"::mlir::vector::VectorDialect",
];
}

def VectorizeIREEGPUOpsPass :
InterfacePass<"iree-gpu-vectorize-ops", "mlir::FunctionOpInterface"> {
let summary = "Vectorizes then lowers a few iree_gpu ops before vectorization.";
Expand All @@ -50,12 +66,4 @@ def VectorizeIREEGPUOpsPass :
];
}

def LowerIREEGPUOpsPass :
InterfacePass<"iree-gpu-lower-ops", "mlir::FunctionOpInterface"> {
let summary = "Post bufferization lowerings of iree_gpu ops before late lowerings";
let dependentDialects = [
"::mlir::gpu::GPUDialect",
];
}

#endif // IREE_CODEGEN_DIALECt_GPU_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,65 @@ void populateIREEGPUVectorUnrollPatterns(
patterns.add<UnrollMultiMmaPattern>(patterns.getContext(), options);
}

static bool isReductionIterator(Attribute attr) {
return cast<IREE::GPU::IteratorTypeAttr>(attr).getValue() ==
utils::IteratorType::reduction;
}
static bool isParallelIterator(Attribute attr) {
return cast<IREE::GPU::IteratorTypeAttr>(attr).getValue() ==
utils::IteratorType::parallel;
}

/// Pick an unrolling order that reuses the LHS register.
static std::optional<SmallVector<int64_t>>
gpuMultiMmaUnrollOrder(Operation *op) {
IREE::GPU::MultiMmaOp mmaOp = dyn_cast<IREE::GPU::MultiMmaOp>(op);
if (!mmaOp) {
return std::nullopt;
}
SmallVector<int64_t> order;
// First make reduction the outer dimensions.
for (auto [index, iter] : llvm::enumerate(mmaOp.getIteratorTypes())) {
if (isReductionIterator(iter)) {
order.push_back(index);
}
}

llvm::SmallDenseSet<int64_t> dimsInLhs;
for (AffineExpr expr : mmaOp.getIndexingMapsArray()[0].getResults()) {
dimsInLhs.insert(cast<AffineDimExpr>(expr).getPosition());
}
// Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
for (auto [index, iter] : llvm::enumerate(mmaOp.getIteratorTypes())) {
if (isParallelIterator(iter) && dimsInLhs.count(index)) {
order.push_back(index);
}
}
// Then the remaining parallel loops.
for (auto [index, iter] : llvm::enumerate(mmaOp.getIteratorTypes())) {
if (isParallelIterator(iter) && !dimsInLhs.count(index)) {
order.push_back(index);
}
}
return order;
}

static std::optional<SmallVector<int64_t>> getMultiMmaUnitShape(Operation *op) {
IREE::GPU::MultiMmaOp mmaOp = dyn_cast<IREE::GPU::MultiMmaOp>(op);
if (!mmaOp) {
return std::nullopt;
}
SmallVector<int64_t> targetOuterShape(mmaOp.getIteratorTypes().size(), 1);
return targetOuterShape;
}

void populateIREEGPUVectorUnrollPatterns(RewritePatternSet &patterns) {
populateIREEGPUVectorUnrollPatterns(
patterns, vector::UnrollVectorOptions()
.setNativeShapeFn(getMultiMmaUnitShape)
.setUnrollTraversalOrderFn(gpuMultiMmaUnrollOrder));
}

//===---------------------------------------------------------------------===//
// Resolving lane mapped forall ops
//===---------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ void populateIREEGPULowerShuffleTensorPatterns(RewritePatternSet &patterns);
void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns);
void populateIREEGPUVectorUnrollPatterns(
RewritePatternSet &patterns, const vector::UnrollVectorOptions &options);
// Version of unrolling with a preset configuration.
void populateIREEGPUVectorUnrollPatterns(RewritePatternSet &patterns);
void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns);

} // namespace mlir::iree_compiler::IREE::GPU
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::GPU {

#define GEN_PASS_DEF_UNROLLTOINTRINSICSPASS
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"

namespace {
struct UnrollToIntrinsicsPass final
: impl::UnrollToIntrinsicsPassBase<UnrollToIntrinsicsPass> {
void runOnOperation() override;
};
} // namespace

void UnrollToIntrinsicsPass::runOnOperation() {
MLIRContext *context = &getContext();

{
RewritePatternSet patterns(context);
GPU::populateIREEGPUVectorUnrollPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}

// Post unrolling unit dim folding patterns in preparation for later
// lowerings.
{
RewritePatternSet patterns(context);
GPU::populateIREEGPUDropUnitDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
}

} // namespace mlir::iree_compiler::IREE::GPU
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,11 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) {
// Vectorize copies that came out of vectorization.
funcPassManager.addPass(createVectorizeMemrefCopyPass());

// Step 7. Unroll operations to native intrinsic widths.
funcPassManager.addPass(IREE::GPU::createUnrollToIntrinsicsPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// Step 8. Remaining post-bufferization optimizations/lowerings.
funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass());
funcPassManager.addPass(createLoopInvariantCodeMotionPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ hal.executable public @main {
// CHECK: gpu.barrier
// CHECK: %[[LHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16>
// CHECK: gpu.barrier
// CHECK: %[[LHS_T:.+]] = vector.transpose %[[LHS_MM]], [0, 2, 1, 3] : vector<2x1x2x4xf16>
// CHECK: vector.transpose %[[LHS_MM]], [0, 2, 1, 3] : vector<2x1x2x4xf16>
// CHECK: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16>
// CHECK: vector.transfer_write %[[RHS_RD]]
// CHECK: gpu.barrier
// CHECK: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16>
// CHECK: gpu.barrier
// CHECK: %[[RHS_T:.+]] = vector.transpose %[[RHS_MM]], [0, 2, 1, 3] : vector<2x1x2x4xf16>
// CHECK: %[[MM:.+]] = iree_gpu.multi_mma %[[LHS_T]], %[[RHS_T]]
// CHECK: scf.yield %[[MM]]
// CHECK: vector.transpose %[[RHS_MM]], [0, 2, 1, 3] : vector<2x1x2x4xf16>
// CHECK-COUNT-4: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
// CHECK: scf.yield
// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 1, 3] : vector<2x2x4x1xf32> to vector<2x4x2x1xf32>
// CHECK: vector.transfer_write %[[LOOP_T]], %[[B2]]

Expand Down Expand Up @@ -186,32 +186,31 @@ hal.executable private @main {
}
}

// CHECK-LABEL: func @conv_igemm_im2col
// CHECK-DAG: %[[B0:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(0)
// CHECK-DAG: %[[B1:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(1)
// CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(2)
// CHECK-DAG: memref.alloc() : memref<1x64x32xf16, #gpu.address_space<workgroup>>
// CHECK-DAG: memref.alloc() : memref<32x64xf16, #gpu.address_space<workgroup>>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C720:.+]] = arith.constant 720 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x2x2x4x1xf32>)
// CHECK: gpu.barrier
// CHECK: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16>
// CHECK: vector.transfer_write %[[LHS_RD]]
// CHECK: gpu.barrier
// CHECK: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16>
// CHECK: %[[LHS_MM1:.+]] = vector.broadcast {{.*}} vector<2x1x2x4xf16> to vector<1x2x1x2x4xf16>
// CHECK: gpu.barrier
// CHECK: %[[LHS_T:.+]] = vector.transpose %[[LHS_MM1]], [0, 1, 3, 2, 4] : vector<1x2x1x2x4xf16> to vector<1x2x2x1x4xf16>
// CHECK: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16>
// CHECK: vector.transfer_write %[[RHS_RD]]
// CHECK: gpu.barrier
// CHECK: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x4x2x1xf16>
// CHECK: gpu.barrier
// CHECK: %[[RHS_T:.+]] = vector.transpose %[[RHS_MM]], [0, 2, 3, 1] : vector<2x4x2x1xf16> to vector<2x2x1x4xf16>
// CHECK: %[[MM:.+]] = iree_gpu.multi_mma %[[LHS_T]], %[[RHS_T]]
// CHECK: scf.yield %[[MM]]
// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 3, 2, 4] : vector<1x2x2x4x1xf32> to vector<1x2x4x2x1xf32>
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<2x4x2x1xf32> from vector<1x2x4x2x1xf32>
// CHECK: vector.transfer_write %[[EXTRACT]], %[[B2]]
// CHECK-LABEL: func @conv_igemm_im2col
// CHECK-DAG: %[[B0:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(0)
// CHECK-DAG: %[[B1:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(1)
// CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(2)
// CHECK-DAG: memref.alloc() : memref<1x64x32xf16, #gpu.address_space<workgroup>>
// CHECK-DAG: memref.alloc() : memref<32x64xf16, #gpu.address_space<workgroup>>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C720:.+]] = arith.constant 720 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x2x2x4x1xf32>)
// CHECK: gpu.barrier
// CHECK: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16>
// CHECK: vector.transfer_write %[[LHS_RD]]
// CHECK: gpu.barrier
// CHECK: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16>
// CHECK: %[[LHS_MM1:.+]] = vector.broadcast {{.*}} vector<2x1x2x4xf16> to vector<1x2x1x2x4xf16>
// CHECK: gpu.barrier
// CHECK: vector.transpose %[[LHS_MM1]], [0, 1, 3, 2, 4] : vector<1x2x1x2x4xf16> to vector<1x2x2x1x4xf16>
// CHECK: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16>
// CHECK: vector.transfer_write %[[RHS_RD]]
// CHECK: gpu.barrier
// CHECK: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x4x2x1xf16>
// CHECK: gpu.barrier
// CHECK: vector.transpose %[[RHS_MM]], [0, 2, 3, 1] : vector<2x4x2x1xf16> to vector<2x2x1x4xf16>
// CHECK-COUNT-4: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 3, 2, 4] : vector<1x2x2x4x1xf32> to vector<1x2x4x2x1xf32>
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<2x4x2x1xf32> from vector<1x2x4x2x1xf32>
// CHECK: vector.transfer_write %[[EXTRACT]], %[[B2]]

0 comments on commit 4a1f619

Please sign in to comment.