Skip to content

Commit

Permalink
[LLVMCPU] Tile root and fuse consumer producer pass (#17804)
Browse files Browse the repository at this point in the history
This is a draft patch to get helpful insights on consumer fusion.

-- This patch tiles the root op and does producer-consumer fusion
greedily.

I want to iterate on this patch and get helpful inputs because consumer
fusion modifies and replaces the consumers in place, unlike producer
fusion, where you can get the producers from the tiledops.

To upstream the `tileProducerAndFuseConsumerAPI,` we'll need both
original and tiled ops.
  • Loading branch information
pashu123 committed Aug 6, 2024
1 parent 4883368 commit 113fae8
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ iree_compiler_cc_library(
"LLVMCPUSynchronizeSymbolVisibility.cpp",
"LLVMCPUTile.cpp",
"LLVMCPUTileAndFuse.cpp",
"LLVMCPUTileRootAndFuseProducerConsumer.cpp",
"LLVMCPUUnfuseFMAOps.cpp",
"LLVMCPUVectorShapeCastLowering.cpp",
"LLVMCPUVectorTransferLowering.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ iree_cc_library(
"LLVMCPUSynchronizeSymbolVisibility.cpp"
"LLVMCPUTile.cpp"
"LLVMCPUTileAndFuse.cpp"
"LLVMCPUTileRootAndFuseProducerConsumer.cpp"
"LLVMCPUUnfuseFMAOps.cpp"
"LLVMCPUVectorShapeCastLowering.cpp"
"LLVMCPUVectorTransferLowering.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/LLVMCPU/Utils.h"
#include "iree/compiler/Codegen/Utils/CPUUtils.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-llvmcpu-tile-root-and-fuse-producers-consumers"

namespace mlir::iree_compiler {

namespace {

/// Implementation of tile root and fuse producers and consumers greedily.
static LogicalResult tileRootAndFuseProducerConsumerUsingSCF(
RewriterBase &rewriter, TilingInterface root,
const scf::SCFTileAndFuseOptions &options) {

// This transformation is only valid for ops that return values (i.e. not
// valid to use with operations that have memref operands).
if (!root->getNumResults()) {
return rewriter.notifyMatchFailure(
root, "invalid pattern for op with no results");
}

// 1. Tile root op and Fuse Producers.
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, root, options);

if (failed(tiledResults)) {
return rewriter.notifyMatchFailure(
root, "failed to tile root and fuse producers");
}

// 2. Replace the producers with the tiled verison.
SmallVector<Operation *> opsToReplace = {root};
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
for (Operation *toReplace : opsToReplace) {
for (OpResult res : toReplace->getResults())
if (auto replacement = tiledResults->replacements.lookup(res)) {
rewriter.replaceAllUsesWith(res, replacement);
}

if (toReplace->use_empty()) {
rewriter.eraseOp(toReplace);
}
}

// 3. Typically, the consumers of the tiled operation are slices of the
// results of the tiled operation. These are expressed in IR using
// `tensor.insert_slice` operations, whose outputs are the operands of the
// untiled operation. Create a worklist of these `tensor.insert_siices`
// operations. If the consumers of the source of the `tensor.insert_slices`
// can be tiled such that the tiled value is generated in-place, that
// effectively tiles + fuses the operations.
auto addCandidateSlices = [](Operation *fusedOp,
std::queue<tensor::InsertSliceOp> &candidates) {
for (auto *userOp : fusedOp->getResults().getUsers()) {
if (auto sliceOp = llvm::dyn_cast<tensor::InsertSliceOp>(userOp)) {
candidates.push(sliceOp);
}
}
};

// Collect the candidate slices which can be potential consumers that can be
// fused.
std::queue<tensor::InsertSliceOp> candidates;
addCandidateSlices(tiledResults->tiledAndFusedOps.front(), candidates);

while (!candidates.empty()) {

// Traverse the slices in BFS fashion.
tensor::InsertSliceOp candidateSliceOp = candidates.front();
candidates.pop();

FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp);
if (failed(fusedResult)) {
LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: "
<< candidateSliceOp << "\n");
continue;
}

// Replace the original consumer operation with the tiled implementation.
rewriter.replaceOp(fusedResult->origConsumerOperand->getOwner(),
fusedResult->tiledOps.front());

// The result of the fused conumers might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
addCandidateSlices(fusedResult->tiledAndFusedConsumerOperand->getOwner(),
candidates);
}
return success();
}

static LogicalResult tileRootAndFuseProducerConsumer(IRRewriter &rewriter,
TilingInterface rootOp,
int64_t tilingLevel) {

SmallVector<OpFoldResult> tileSizes =
getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel,
rootOp);
int64_t numLoops = rootOp.getLoopIteratorTypes().size();
if (tileSizes.size() > numLoops)
return failure();

scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes);

scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.setTilingOptions(tilingOptions);

return tileRootAndFuseProducerConsumerUsingSCF(rewriter, rootOp,
tileAndFuseOptions);
}

/// This pass starts with the first TilingInterface operation that has
/// lowering_config attribute, tiles the op and fuses its consumers and
/// producers recursively. The `tilingLevel` must be specified. It picks the
/// `tilingLevel`-th list as tiling sizes from lowering_config.
struct LLVMCPUTileRootAndFuseProducerConsumer
: LLVMCPUTileRootAndFuseProducerConsumerBase<
LLVMCPUTileRootAndFuseProducerConsumer> {
LLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel = -1) {
this->tilingLevel.setValue(tilingLevel);
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, affine::AffineDialect,
linalg::LinalgDialect, scf::SCFDialect,
tensor::TensorDialect>();
}

void runOnOperation() override;
};

void LLVMCPUTileRootAndFuseProducerConsumer::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();

IRRewriter rewriter(funcOp);

SmallVector<Operation *> computeOps = getComputeOps(funcOp);
FailureOr<Operation *> rootOp = getRootOperation(computeOps);

if (failed(rootOp)) {
funcOp.emitError() << "not able to find the root operation\n";
return signalPassFailure();
}

IREE::Codegen::LoweringConfigAttrInterface loweringConfig =
getLoweringConfig(rootOp.value());
if (!loweringConfig) {
funcOp.emitError() << "not able to find the lowering config\n";
return signalPassFailure();
}

if (!loweringConfig.hasTilingLevel(tilingLevel)) {
funcOp.emitError()
<< "not able to find the lowering config with the tiling level "
<< tilingLevel.getValue() << "\n";
return signalPassFailure();
}

if (failed(tileRootAndFuseProducerConsumer(
rewriter, dyn_cast<TilingInterface>(rootOp.value()),
tilingLevel.getValue()))) {
funcOp.emitError() << "tiling of level " << tilingLevel.getValue()
<< " failed\n";
return signalPassFailure();
}

RewritePatternSet patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
tensor::populateFoldTensorEmptyPatterns(patterns);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
// Pull in tensor dialect canonicalization patterns to fold tensor.cast
// into producers when possible.
context->getLoadedDialect<tensor::TensorDialect>()
->getCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
LLVM_DEBUG(llvm::dbgs() << "----- cleanup failed -----\n");
return signalPassFailure();
}
}
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) {
return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(tilingLevel);
}
} // namespace mlir::iree_compiler
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ createLLVMCPUSynchronizeSymbolVisibilityPass();
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUTileAndFusePass(int64_t tilingLevel = -1);

// Pass to Tile the Root Op and Fuse Producer and Consumer.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel = -1);

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPU2DScalableTo1DScalablePass();

Expand Down
11 changes: 11 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,17 @@ def LLVMCPUTileAndFuse :
];
}

def LLVMCPUTileRootAndFuseProducerConsumer :
InterfacePass<"iree-llvmcpu-tile-root-and-fuse-producer-consumer", "mlir::FunctionOpInterface"> {
let summary = "Pass to tile root op and fuse with producer and consumer TilingInterface ops.";
let constructor =
"mlir::iree_compiler::createLLVMCPUTileRootAndFuseProducerConsumer()";
let options = [
Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1",
"Use default tiling level used to retrieve the configuration from lowering_config">
];
}

def LLVMCPUVerifyVectorSizeLegality :
InterfacePass<"iree-llvmcpu-verify-vector-size-legality", "mlir::FunctionOpInterface"> {
let summary =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ iree_lit_test_suite(
"split_reduction.mlir",
"synchronize_symbol_visibility.mlir",
"tile.mlir",
"tile-root-fuse-consumer-producer.mlir",
"tile_and_fuse.mlir",
"transform_dialect_bufferize.mlir",
"unfused_fma.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ iree_lit_test_suite(
"select_x86_64_lowering_strategy.mlir"
"split_reduction.mlir"
"synchronize_symbol_visibility.mlir"
"tile-root-fuse-consumer-producer.mlir"
"tile.mlir"
"tile_and_fuse.mlir"
"transform_dialect_bufferize.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=0}), canonicalize)" --split-input-file %s | FileCheck %s

#config1 = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
func.func @mmt4d_bias_relu(%arg0: tensor<?x?x16x1xf32>, %arg1: tensor<?x?x16x1xf32>, %arg2: tensor<?x16xf32>) -> tensor<?x?x16x16xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x16x1xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x16x1xf32>
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?x16x16xf32>
%1 = tensor.empty(%dim, %dim_0) : tensor<?x?x16x16xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32>
%3 = linalg.mmt4d {lowering_config = #config1} ins(%arg0, %arg1 : tensor<?x?x16x1xf32>, tensor<?x?x16x1xf32>) outs(%2 : tensor<?x?x16x16xf32>) -> tensor<?x?x16x16xf32>
%4 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3, %arg2 : tensor<?x?x16x16xf32>, tensor<?x16xf32>) outs(%1 : tensor<?x?x16x16xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%5 = arith.addf %in, %in_1 : f32
%6 = arith.maximumf %5, %cst : f32
linalg.yield %6 : f32
} -> tensor<?x?x16x16xf32>
return %4 : tensor<?x?x16x16xf32>
}
// CHECK: func.func @mmt4d_bias_relu(
// CHECK: scf.for
// CHECK-SAME: {
// CHECK: linalg.fill
// CHECK: linalg.mmt4d
// CHECK: linalg.generic
// CHECK: }

// -----

#config2 = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]>
func.func @quantized_matmul() {
%c2995200 = arith.constant 2995200 : index
%c2994688 = arith.constant 2994688 : index
%c2994176 = arith.constant 2994176 : index
%c176128 = arith.constant 176128 : index
%c88064 = arith.constant 88064 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c2995200) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4x128x16x1xi8>>
%1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c2994688) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>>
%2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c2994176) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>>
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c176128) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x688x128x16x1xi8>>
%4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c88064) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x688x16xf32>>
%5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x688x16xf32>>
%6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x11008x64xf32>>
%7 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x128x16x1xi8>> -> tensor<2x4x128x16x1xi8>
%8 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, 4, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>> -> tensor<2x4x16xf32>
%9 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [2, 4, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>> -> tensor<2x4x16xf32>
%10 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0, 0], sizes = [2, 688, 128, 16, 1], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x688x128x16x1xi8>> -> tensor<2x688x128x16x1xi8>
%11 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0], sizes = [2, 688, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x688x16xf32>> -> tensor<2x688x16xf32>
%12 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0], sizes = [2, 688, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x688x16xf32>> -> tensor<2x688x16xf32>
%13 = tensor.empty() : tensor<2x4x128x16x1xf32>
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%7, %8, %9 : tensor<2x4x128x16x1xi8>, tensor<2x4x16xf32>, tensor<2x4x16xf32>) outs(%13 : tensor<2x4x128x16x1xf32>) {
^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
%21 = arith.extui %in : i8 to i32
%22 = arith.uitofp %21 : i32 to f32
%23 = arith.subf %22, %in_1 : f32
%24 = arith.mulf %23, %in_0 : f32
linalg.yield %24 : f32
} -> tensor<2x4x128x16x1xf32>
%15 = tensor.empty() : tensor<2x688x128x16x1xf32>
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%10, %11, %12 : tensor<2x688x128x16x1xi8>, tensor<2x688x16xf32>, tensor<2x688x16xf32>) outs(%15 : tensor<2x688x128x16x1xf32>) {
^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
%21 = arith.extui %in : i8 to i32
%22 = arith.uitofp %21 : i32 to f32
%23 = arith.subf %22, %in_1 : f32
%24 = arith.mulf %23, %in_0 : f32
linalg.yield %24 : f32
} -> tensor<2x688x128x16x1xf32>
%17 = tensor.empty() : tensor<2x4x688x16x16xf32>
%18 = linalg.fill ins(%cst : f32) outs(%17 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32>
%19 = linalg.batch_mmt4d {lowering_config = #config2} ins(%14, %16 : tensor<2x4x128x16x1xf32>, tensor<2x688x128x16x1xf32>) outs(%18 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32>
%20 = tensor.empty() : tensor<2x11008x64xf32>
%unpack = tensor.unpack %19 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 16] into %20 : tensor<2x4x688x16x16xf32> -> tensor<2x11008x64xf32>
flow.dispatch.tensor.store %unpack, %6, offsets = [0, 0, 0], sizes = [2, 11008, 64], strides = [1, 1, 1] : tensor<2x11008x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x11008x64xf32>>
return
}
// CHECK: func.func @quantized_matmul(
// CHECK: scf.for
// CHECK-SAME: {
// CHECK: linalg.generic
// CHECK: linalg.generic
// CHECK: linalg.fill
// CHECK: linalg.batch_mmt4d
// CHECK: tensor.unpack
// CHECK: }

0 comments on commit 113fae8

Please sign in to comment.