Skip to content

Commit

Permalink
[Codegen][GPU] Move conversion to multi_mma to PackToIntrinsics (#18141)
Browse files Browse the repository at this point in the history
Now that `iree_gpu.multi_mma` has a tiling interface implementation, the
conversion from linalg to it can happen before other levels of tiling.
This allows for reshaping the inner dimensions freely before reduction
tiling and then propagating the reshapes to nearby ops without needing
to hoist them out of tiling contructs.

Additionally this is closer to the required flow for data tiling where
we need to generate the `iree_gpu.multi_mma` op before any tiling.
  • Loading branch information
qedawkins committed Aug 8, 2024
1 parent 643f719 commit 7ab66ff
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,6 @@ struct DistributeMmaToLanesPass final
};
} // namespace

struct ConvertToMultiMma final : OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
auto loweringConfig =
getLoweringConfig<IREE::GPU::LoweringConfigAttr>(linalgOp);
if (!loweringConfig) {
return failure();
}
IREE::GPU::MmaInterfaceAttr kind = loweringConfig.getMmaKind();
if (!kind) {
return failure();
}
if (failed(convertContractionToMultiMma(rewriter, linalgOp, kind))) {
return failure();
}
return success();
}
};

LogicalResult fuseProducersGreedily(RewriterBase &rewriter,
scf::ForallOp laneForall) {

Expand Down Expand Up @@ -100,17 +80,7 @@ void DistributeMmaToLanesPass::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();

// Step 1. Convert configured linalg ops to multi_mma.
{
RewritePatternSet patterns(context);
patterns.add<ConvertToMultiMma>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
funcOp.emitError() << "failed to convert linalg to multi_mma";
return signalPassFailure();
}
}

// Step 2. Distribute multi_mma ops to lanes and greedily fuse producers.
// Distribute multi_mma ops to lanes and greedily fuse producers.
SmallVector<IREE::GPU::MultiMmaOp> mmaOps;
funcOp.walk([&](IREE::GPU::MultiMmaOp mmaOp) { mmaOps.push_back(mmaOp); });
IRRewriter rewriter(funcOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,31 @@ LogicalResult packToIntrinsic(linalg::LinalgOp linalgOp,
return success();
}

struct ConvertToMultiMma final : OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
auto loweringConfig =
getLoweringConfig<IREE::GPU::LoweringConfigAttr>(linalgOp);
if (!loweringConfig) {
return failure();
}
IREE::GPU::MmaInterfaceAttr kind = loweringConfig.getMmaKind();
if (!kind) {
return failure();
}
if (failed(convertContractionToMultiMma(rewriter, linalgOp, kind))) {
return failure();
}
return success();
}
};

void PackToIntrinsicsPass::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();

// Step 1. Pack candidate linalg ops to specified shapes.
IRRewriter rewriter(funcOp);
SmallVector<linalg::LinalgOp> packingCandidates;
funcOp->walk([&](linalg::LinalgOp linalgOp) {
Expand All @@ -95,7 +117,18 @@ void PackToIntrinsicsPass::runOnOperation() {
}
}

// Run layout propagation patterns to pull in adjacent un-configured ops.
// Step 2. Convert configured linalg ops to multi_mma.
{
RewritePatternSet patterns(context);
patterns.add<ConvertToMultiMma>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
funcOp.emitError() << "failed to convert linalg to multi_mma";
return signalPassFailure();
}
}

// Step 3. Run layout propagation patterns to pull in adjacent un-configured
// ops.
RewritePatternSet patterns(context);
linalg::ControlPropagationFn control = [](OpOperand *opOperand) -> bool {
Operation *producer = opOperand->get().getDefiningOp();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ include "mlir/Pass/PassBase.td"

def DistributeMmaToLanesPass :
InterfacePass<"iree-gpu-distribute-mma-to-lanes", "mlir::FunctionOpInterface"> {
let summary = "Converts and distributes linalg ops with mma kinds to lanes";
let summary = "Distributes iree_gpu.multi_mma ops to lanes";
let dependentDialects = [
"::mlir::arith::ArithDialect",
"::mlir::affine::AffineDialect",
"::mlir::scf::SCFDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
];
}

Expand Down Expand Up @@ -58,7 +57,7 @@ def LowerIREEGPUOpsPass :

def PackToIntrinsicsPass :
InterfacePass<"iree-gpu-pack-to-intrinsics", "mlir::FunctionOpInterface"> {
let summary = "Packs matmul like operations to specified intrinsic shapes";
let summary = "Packs matmul like operations and converts to iree_gpu.multi_mma";
let dependentDialects = [
"::mlir::tensor::TensorDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "llvm/ADT/ArrayRef.h"
Expand Down Expand Up @@ -442,10 +443,16 @@ convertContractionToMultiMma(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
accPerm = accInnerPerm;
}

IREE::Codegen::LoweringConfigAttrInterface maybeLoweringConfig =
getLoweringConfig(linalgOp);

auto newMmaOp = rewriter.replaceOpWithNewOp<IREE::GPU::MultiMmaOp>(
linalgOp, inputs[0], inputs[1], inputs[2],
ArrayRef<AffineMap>{outerLhsMap, outerRhsMap, outerAccMap}, iteratorTypes,
mmaKind, lhsPerm, rhsPerm, accPerm);
if (maybeLoweringConfig) {
setLoweringConfig(newMmaOp, maybeLoweringConfig);
}
return newMmaOp;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-distribute-mma-to-lanes, canonicalize, cse))' --split-input-file | FileCheck %s

#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
#contraction_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
module {
func.func @matmul_16x16x16(%arg0: tensor<8x2x16x16xf16>, %arg1: tensor<8x2x16x16xf16>, %arg2: tensor<2x2x16x16xf32>) -> tensor<2x2x16x16xf32> {
%empty = tensor.empty() : tensor<2x8x16x16xf16>
%lhs_transpose = linalg.transpose ins(%arg0: tensor<8x2x16x16xf16>) outs(%empty: tensor<2x8x16x16xf16>) permutation = [1, 0, 2, 3]
%mm = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
ins(%lhs_transpose, %arg1 : tensor<2x8x16x16xf16>, tensor<8x2x16x16xf16>)
outs(%arg2 : tensor<2x2x16x16xf32>)
attrs = {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}>} {
^bb0(%in: f16, %in_2: f16, %out: f32):
%4 = arith.extf %in : f16 to f32
%5 = arith.extf %in_2 : f16 to f32
%6 = arith.mulf %4, %5 : f32
%7 = arith.addf %out, %6 : f32
linalg.yield %7 : f32
} -> tensor<2x2x16x16xf32>
%mm = iree_gpu.multi_mma %lhs_transpose, %arg1, %arg2 {
indexing_maps = #contraction_accesses,
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
rhs_permutation = array<i64: 1, 0>
} : tensor<2x8x16x16xf16>, tensor<8x2x16x16xf16> into tensor<2x2x16x16xf32>
return %mm : tensor<2x2x16x16xf32>
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-pack-to-intrinsics, canonicalize, cse))' --split-input-file | FileCheck %s
// RUN: iree-opt %s --mlir-print-local-scope --pass-pipeline='builtin.module(func.func(iree-gpu-pack-to-intrinsics, canonicalize, cse))' --split-input-file | FileCheck %s

#config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>}>
module {
Expand All @@ -15,10 +15,15 @@ module {
// CHECK-DAG: %[[A_PACK:.+]] = tensor.pack %[[A]] inner_dims_pos = [0, 1] inner_tiles = [32, 8]
// CHECK-DAG: %[[B_PACK:.+]] = tensor.pack %[[B]] inner_dims_pos = [1, 0] inner_tiles = [32, 8]
// CHECK-DAG: %[[C_PACK:.+]] = tensor.pack %[[C]] inner_dims_pos = [0, 1] inner_tiles = [32, 32]
// CHECK: %[[PACKED_MM:.+]] = linalg.generic
// CHECK-SAME: ins(%[[A_PACK]], %[[B_PACK]] : tensor<2x8x32x8xf16>, tensor<8x2x32x8xf16>)
// CHECK-SAME: outs(%[[C_PACK]] : tensor<2x2x32x32xf32>)
// CHECK: iree_gpu.multi_mma %[[A_PACK]], %[[B_PACK]], %[[C_PACK]]
// CHECK-SAME: indexing_maps =
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-SAME: iterator_types = {{.*}}parallel{{.*}}parallel{{.*}}reduction
// CHECK-SAME: kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>}>
// CHECK-SAME: rhs_permutation = array<i64: 1, 0>

// -----

Expand All @@ -45,13 +50,11 @@ module {
}
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d3, d4, d5, d7)>
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3, d4, d6, d7)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5, d6)>

// CHECK-LABEL: func.func @matmul_16x16x16
// CHECK: %[[PACKED_MM:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: ins({{.*}} : tensor<?x?x?x16x16xf16>, tensor<?x?x?x?x16x16xf16>)
// CHECK-SAME: outs({{.*}} : tensor<?x?x?x16x16xf32>)
// CHECK: iree_gpu.multi_mma
// CHECK-SAME: indexing_maps =
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d2, d0, d3, d4)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}>
// CHECK-SAME: : tensor<?x?x?x16x16xf16>, tensor<?x?x?x?x16x16xf16> into tensor<?x?x?x16x16xf32>

0 comments on commit 7ab66ff

Please sign in to comment.