From 113fae80f86ed3d638a90e8eff9a7bf1223c1b30 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Tue, 6 Aug 2024 08:02:45 +0530 Subject: [PATCH] [LLVMCPU] Tile root and fuse consumer producer pass (#17804) This is a draft patch to get helpful insights on consumer fusion. -- This patch tiles the root op and does producer-consumer fusion greedily. I want to iterate on this patch and get helpful inputs because consumer fusion modifies and replaces the consumers in place, unlike producer fusion, where you can get the producers from the tiledops. To upstream the `tileProducerAndFuseConsumerAPI,` we'll need both original and tiled ops. --- .../iree/compiler/Codegen/LLVMCPU/BUILD.bazel | 1 + .../compiler/Codegen/LLVMCPU/CMakeLists.txt | 1 + ...LLVMCPUTileRootAndFuseProducerConsumer.cpp | 212 ++++++++++++++++++ .../iree/compiler/Codegen/LLVMCPU/Passes.h | 4 + .../iree/compiler/Codegen/LLVMCPU/Passes.td | 11 + .../compiler/Codegen/LLVMCPU/test/BUILD.bazel | 1 + .../Codegen/LLVMCPU/test/CMakeLists.txt | 1 + .../tile-root-fuse-consumer-producer.mlir | 90 ++++++++ 8 files changed, 321 insertions(+) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp create mode 100644 compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index e3752f82eb01..35c5e81e4284 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -67,6 +67,7 @@ iree_compiler_cc_library( "LLVMCPUSynchronizeSymbolVisibility.cpp", "LLVMCPUTile.cpp", "LLVMCPUTileAndFuse.cpp", + "LLVMCPUTileRootAndFuseProducerConsumer.cpp", "LLVMCPUUnfuseFMAOps.cpp", "LLVMCPUVectorShapeCastLowering.cpp", "LLVMCPUVectorTransferLowering.cpp", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 29a7df00a663..a25c8973bf23 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -68,6 +68,7 @@ iree_cc_library( "LLVMCPUSynchronizeSymbolVisibility.cpp" "LLVMCPUTile.cpp" "LLVMCPUTileAndFuse.cpp" + "LLVMCPUTileRootAndFuseProducerConsumer.cpp" "LLVMCPUUnfuseFMAOps.cpp" "LLVMCPUVectorShapeCastLowering.cpp" "LLVMCPUVectorTransferLowering.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp new file mode 100644 index 000000000000..75d903fb7d02 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp @@ -0,0 +1,212 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h" +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "iree/compiler/Codegen/LLVMCPU/Utils.h" +#include "iree/compiler/Codegen/Utils/CPUUtils.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/IR/Iterators.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-tile-root-and-fuse-producers-consumers" + +namespace mlir::iree_compiler { + +namespace { + +/// Implementation of tile root and fuse producers and consumers greedily. +static LogicalResult tileRootAndFuseProducerConsumerUsingSCF( + RewriterBase &rewriter, TilingInterface root, + const scf::SCFTileAndFuseOptions &options) { + + // This transformation is only valid for ops that return values (i.e. not + // valid to use with operations that have memref operands). + if (!root->getNumResults()) { + return rewriter.notifyMatchFailure( + root, "invalid pattern for op with no results"); + } + + // 1. Tile root op and Fuse Producers. + FailureOr tiledResults = + scf::tileConsumerAndFuseProducersUsingSCF(rewriter, root, options); + + if (failed(tiledResults)) { + return rewriter.notifyMatchFailure( + root, "failed to tile root and fuse producers"); + } + + // 2. Replace the producers with the tiled verison. + SmallVector opsToReplace = {root}; + llvm::append_range(opsToReplace, tiledResults->fusedProducers); + for (Operation *toReplace : opsToReplace) { + for (OpResult res : toReplace->getResults()) + if (auto replacement = tiledResults->replacements.lookup(res)) { + rewriter.replaceAllUsesWith(res, replacement); + } + + if (toReplace->use_empty()) { + rewriter.eraseOp(toReplace); + } + } + + // 3. Typically, the consumers of the tiled operation are slices of the + // results of the tiled operation. These are expressed in IR using + // `tensor.insert_slice` operations, whose outputs are the operands of the + // untiled operation. Create a worklist of these `tensor.insert_siices` + // operations. If the consumers of the source of the `tensor.insert_slices` + // can be tiled such that the tiled value is generated in-place, that + // effectively tiles + fuses the operations. + auto addCandidateSlices = [](Operation *fusedOp, + std::queue &candidates) { + for (auto *userOp : fusedOp->getResults().getUsers()) { + if (auto sliceOp = llvm::dyn_cast(userOp)) { + candidates.push(sliceOp); + } + } + }; + + // Collect the candidate slices which can be potential consumers that can be + // fused. + std::queue candidates; + addCandidateSlices(tiledResults->tiledAndFusedOps.front(), candidates); + + while (!candidates.empty()) { + + // Traverse the slices in BFS fashion. + tensor::InsertSliceOp candidateSliceOp = candidates.front(); + candidates.pop(); + + FailureOr fusedResult = + mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp); + if (failed(fusedResult)) { + LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: " + << candidateSliceOp << "\n"); + continue; + } + + // Replace the original consumer operation with the tiled implementation. + rewriter.replaceOp(fusedResult->origConsumerOperand->getOwner(), + fusedResult->tiledOps.front()); + + // The result of the fused conumers might themselved be slices of + // values produced by operations that implement the `TilingInterface`. + // Add these operations to the worklist. + addCandidateSlices(fusedResult->tiledAndFusedConsumerOperand->getOwner(), + candidates); + } + return success(); +} + +static LogicalResult tileRootAndFuseProducerConsumer(IRRewriter &rewriter, + TilingInterface rootOp, + int64_t tilingLevel) { + + SmallVector tileSizes = + getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel, + rootOp); + int64_t numLoops = rootOp.getLoopIteratorTypes().size(); + if (tileSizes.size() > numLoops) + return failure(); + + scf::SCFTilingOptions tilingOptions; + tilingOptions.setTileSizes(tileSizes); + + scf::SCFTileAndFuseOptions tileAndFuseOptions; + tileAndFuseOptions.setTilingOptions(tilingOptions); + + return tileRootAndFuseProducerConsumerUsingSCF(rewriter, rootOp, + tileAndFuseOptions); +} + +/// This pass starts with the first TilingInterface operation that has +/// lowering_config attribute, tiles the op and fuses its consumers and +/// producers recursively. The `tilingLevel` must be specified. It picks the +/// `tilingLevel`-th list as tiling sizes from lowering_config. +struct LLVMCPUTileRootAndFuseProducerConsumer + : LLVMCPUTileRootAndFuseProducerConsumerBase< + LLVMCPUTileRootAndFuseProducerConsumer> { + LLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel = -1) { + this->tilingLevel.setValue(tilingLevel); + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; +}; + +void LLVMCPUTileRootAndFuseProducerConsumer::runOnOperation() { + MLIRContext *context = &getContext(); + auto funcOp = getOperation(); + + IRRewriter rewriter(funcOp); + + SmallVector computeOps = getComputeOps(funcOp); + FailureOr rootOp = getRootOperation(computeOps); + + if (failed(rootOp)) { + funcOp.emitError() << "not able to find the root operation\n"; + return signalPassFailure(); + } + + IREE::Codegen::LoweringConfigAttrInterface loweringConfig = + getLoweringConfig(rootOp.value()); + if (!loweringConfig) { + funcOp.emitError() << "not able to find the lowering config\n"; + return signalPassFailure(); + } + + if (!loweringConfig.hasTilingLevel(tilingLevel)) { + funcOp.emitError() + << "not able to find the lowering config with the tiling level " + << tilingLevel.getValue() << "\n"; + return signalPassFailure(); + } + + if (failed(tileRootAndFuseProducerConsumer( + rewriter, dyn_cast(rootOp.value()), + tilingLevel.getValue()))) { + funcOp.emitError() << "tiling of level " << tilingLevel.getValue() + << " failed\n"; + return signalPassFailure(); + } + + RewritePatternSet patterns = + linalg::getLinalgTilingCanonicalizationPatterns(context); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); + tensor::populateFoldTensorEmptyPatterns(patterns); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); + // Pull in tensor dialect canonicalization patterns to fold tensor.cast + // into producers when possible. + context->getLoadedDialect() + ->getCanonicalizationPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + LLVM_DEBUG(llvm::dbgs() << "----- cleanup failed -----\n"); + return signalPassFailure(); + } +} +} // namespace + +std::unique_ptr> +createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) { + return std::make_unique(tilingLevel); +} +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h index 213629276130..735e3561f036 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h @@ -68,6 +68,10 @@ createLLVMCPUSynchronizeSymbolVisibilityPass(); std::unique_ptr> createLLVMCPUTileAndFusePass(int64_t tilingLevel = -1); +// Pass to Tile the Root Op and Fuse Producer and Consumer. +std::unique_ptr> +createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel = -1); + std::unique_ptr> createLLVMCPU2DScalableTo1DScalablePass(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td index 1e9f9d93625e..4a84d1a36cba 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td @@ -142,6 +142,17 @@ def LLVMCPUTileAndFuse : ]; } +def LLVMCPUTileRootAndFuseProducerConsumer : + InterfacePass<"iree-llvmcpu-tile-root-and-fuse-producer-consumer", "mlir::FunctionOpInterface"> { + let summary = "Pass to tile root op and fuse with producer and consumer TilingInterface ops."; + let constructor = + "mlir::iree_compiler::createLLVMCPUTileRootAndFuseProducerConsumer()"; + let options = [ + Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1", + "Use default tiling level used to retrieve the configuration from lowering_config"> + ]; +} + def LLVMCPUVerifyVectorSizeLegality : InterfacePass<"iree-llvmcpu-verify-vector-size-legality", "mlir::FunctionOpInterface"> { let summary = diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel index cd9c4e923ec5..1cd905233264 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel @@ -57,6 +57,7 @@ iree_lit_test_suite( "split_reduction.mlir", "synchronize_symbol_visibility.mlir", "tile.mlir", + "tile-root-fuse-consumer-producer.mlir", "tile_and_fuse.mlir", "transform_dialect_bufferize.mlir", "unfused_fma.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt index de22f621fd83..46d22db238ab 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt @@ -51,6 +51,7 @@ iree_lit_test_suite( "select_x86_64_lowering_strategy.mlir" "split_reduction.mlir" "synchronize_symbol_visibility.mlir" + "tile-root-fuse-consumer-producer.mlir" "tile.mlir" "tile_and_fuse.mlir" "transform_dialect_bufferize.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir new file mode 100644 index 000000000000..9f7a7653d8d6 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir @@ -0,0 +1,90 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=0}), canonicalize)" --split-input-file %s | FileCheck %s + +#config1 = #iree_codegen.lowering_config +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> +func.func @mmt4d_bias_relu(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim = tensor.dim %arg0, %c0 : tensor + %dim_0 = tensor.dim %arg1, %c1 : tensor + %0 = tensor.empty(%dim, %dim_0) : tensor + %1 = tensor.empty(%dim, %dim_0) : tensor + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + %3 = linalg.mmt4d {lowering_config = #config1} ins(%arg0, %arg1 : tensor, tensor) outs(%2 : tensor) -> tensor + %4 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3, %arg2 : tensor, tensor) outs(%1 : tensor) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %5 = arith.addf %in, %in_1 : f32 + %6 = arith.maximumf %5, %cst : f32 + linalg.yield %6 : f32 + } -> tensor + return %4 : tensor +} +// CHECK: func.func @mmt4d_bias_relu( +// CHECK: scf.for +// CHECK-SAME: { +// CHECK: linalg.fill +// CHECK: linalg.mmt4d +// CHECK: linalg.generic +// CHECK: } + +// ----- + +#config2 = #iree_codegen.lowering_config +func.func @quantized_matmul() { + %c2995200 = arith.constant 2995200 : index + %c2994688 = arith.constant 2994688 : index + %c2994176 = arith.constant 2994176 : index + %c176128 = arith.constant 176128 : index + %c88064 = arith.constant 88064 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c2995200) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c2994688) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c2994176) flags(ReadOnly) : !flow.dispatch.tensor> + %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c176128) flags(ReadOnly) : !flow.dispatch.tensor> + %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c88064) flags(ReadOnly) : !flow.dispatch.tensor> + %5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %7 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x4x128x16x1xi8> + %8 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, 4, 16], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x4x16xf32> + %9 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [2, 4, 16], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x4x16xf32> + %10 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0, 0], sizes = [2, 688, 128, 16, 1], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x688x128x16x1xi8> + %11 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0], sizes = [2, 688, 16], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x688x16xf32> + %12 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0], sizes = [2, 688, 16], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x688x16xf32> + %13 = tensor.empty() : tensor<2x4x128x16x1xf32> + %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%7, %8, %9 : tensor<2x4x128x16x1xi8>, tensor<2x4x16xf32>, tensor<2x4x16xf32>) outs(%13 : tensor<2x4x128x16x1xf32>) { + ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32): + %21 = arith.extui %in : i8 to i32 + %22 = arith.uitofp %21 : i32 to f32 + %23 = arith.subf %22, %in_1 : f32 + %24 = arith.mulf %23, %in_0 : f32 + linalg.yield %24 : f32 + } -> tensor<2x4x128x16x1xf32> + %15 = tensor.empty() : tensor<2x688x128x16x1xf32> + %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%10, %11, %12 : tensor<2x688x128x16x1xi8>, tensor<2x688x16xf32>, tensor<2x688x16xf32>) outs(%15 : tensor<2x688x128x16x1xf32>) { + ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32): + %21 = arith.extui %in : i8 to i32 + %22 = arith.uitofp %21 : i32 to f32 + %23 = arith.subf %22, %in_1 : f32 + %24 = arith.mulf %23, %in_0 : f32 + linalg.yield %24 : f32 + } -> tensor<2x688x128x16x1xf32> + %17 = tensor.empty() : tensor<2x4x688x16x16xf32> + %18 = linalg.fill ins(%cst : f32) outs(%17 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32> + %19 = linalg.batch_mmt4d {lowering_config = #config2} ins(%14, %16 : tensor<2x4x128x16x1xf32>, tensor<2x688x128x16x1xf32>) outs(%18 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32> + %20 = tensor.empty() : tensor<2x11008x64xf32> + %unpack = tensor.unpack %19 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 16] into %20 : tensor<2x4x688x16x16xf32> -> tensor<2x11008x64xf32> + flow.dispatch.tensor.store %unpack, %6, offsets = [0, 0, 0], sizes = [2, 11008, 64], strides = [1, 1, 1] : tensor<2x11008x64xf32> -> !flow.dispatch.tensor> + return +} +// CHECK: func.func @quantized_matmul( +// CHECK: scf.for +// CHECK-SAME: { +// CHECK: linalg.generic +// CHECK: linalg.generic +// CHECK: linalg.fill +// CHECK: linalg.batch_mmt4d +// CHECK: tensor.unpack +// CHECK: }