From 8f0909c3d794248ad82eafb5edcd7c3b57782476 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Thu, 1 Aug 2024 17:39:15 +0100 Subject: [PATCH] [VectorDistribution] Split layout configuration and distribution (#18065) This patch splits the LLVMGPUVectorDistribution pass into two separate passes, one that sets the layouts and one that distributes. This improves the debugging experience and the failing IR can be checked for the anchors. --- .../iree/compiler/Codegen/LLVMGPU/BUILD.bazel | 1 + .../compiler/Codegen/LLVMGPU/CMakeLists.txt | 1 + .../LLVMGPU/LLVMGPUConfigureVectorLayouts.cpp | 369 ++++++++++++++++++ .../LLVMGPU/LLVMGPUVectorDistribute.cpp | 362 +---------------- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 1 + .../iree/compiler/Codegen/LLVMGPU/Passes.h | 4 + .../iree/compiler/Codegen/LLVMGPU/Passes.td | 11 +- .../compiler/Codegen/LLVMGPU/test/BUILD.bazel | 2 +- .../Codegen/LLVMGPU/test/CMakeLists.txt | 2 +- ...yout.mlir => configure_vector_layout.mlir} | 323 +++++++-------- .../test/vector_distribute_conversion.mlir | 2 +- 11 files changed, 540 insertions(+), 538 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureVectorLayouts.cpp rename compiler/src/iree/compiler/Codegen/LLVMGPU/test/{vector_distribute_layout.mlir => configure_vector_layout.mlir} (58%) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index 70cd714262c3..784a945cc0cf 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -95,6 +95,7 @@ iree_compiler_cc_library( "KernelConfig.cpp", "LLVMGPUCastAddressSpaceFunction.cpp", "LLVMGPUCastTypeToFitMMA.cpp", + "LLVMGPUConfigureVectorLayouts.cpp", "LLVMGPULowerExecutableTarget.cpp", "LLVMGPUPackSharedMemoryAlloc.cpp", "LLVMGPUPrefetching.cpp", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index 361a8a90624b..a552962f0fee 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -80,6 +80,7 @@ iree_cc_library( "KernelConfig.cpp" "LLVMGPUCastAddressSpaceFunction.cpp" "LLVMGPUCastTypeToFitMMA.cpp" + "LLVMGPUConfigureVectorLayouts.cpp" "LLVMGPULowerExecutableTarget.cpp" "LLVMGPUPackSharedMemoryAlloc.cpp" "LLVMGPUPrefetching.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureVectorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureVectorLayouts.cpp new file mode 100644 index 000000000000..e576102c9db7 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureVectorLayouts.cpp @@ -0,0 +1,369 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMGPU/Passes.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" + +#define DEBUG_TYPE "iree-llvmgpu-configure-vector-layouts" + +namespace mlir::iree_compiler { + +namespace { + +// Sets an anchoring layout for the given contraction op. Looks for a +// supported mma type from the cached list of mma types and populates the +// necessary distribution pattern for those contractions. +LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule, + RewriterBase &rewriter, + vector::ContractionOp contract) { + // TODO: Add SIMT fallback. + if (!schedule) { + return contract->emitError("missing mma schedule for contraction"); + } + + auto layouts = schedule.getContractionLayout(contract); + if (failed(layouts)) { + return contract->emitError("cannot get concrete layout for contraction"); + } + + auto [aLayout, bLayout, cLayout] = *layouts; + Location loc = contract.getLoc(); + + // Set layouts for lhs, rhs and acc. + rewriter.setInsertionPoint(contract); + Value layoutedLhs = rewriter.create( + loc, contract.getLhsType(), contract.getLhs(), aLayout); + Value layoutedRhs = rewriter.create( + loc, contract.getRhsType(), contract.getRhs(), bLayout); + Value layoutedAcc = rewriter.create( + loc, contract.getAccType(), contract.getAcc(), cLayout); + contract->setOperand(0, layoutedLhs); + contract->setOperand(1, layoutedRhs); + contract->setOperand(2, layoutedAcc); + + // Set layout for result. + rewriter.setInsertionPointAfter(contract); + auto toLayout = rewriter.create( + loc, contract.getResultType(), contract.getResult(), cLayout); + rewriter.replaceAllUsesExcept(contract, toLayout.getResult(), toLayout); + + // Set intrinsic kind. + contract->setAttr("iree.amdgpu.mma", schedule.getIntrinsic()); + + LLVM_DEBUG({ + llvm::dbgs() << "chosen a layout: " << aLayout << "\n"; + llvm::dbgs() << "chosen b layout: " << bLayout << "\n"; + llvm::dbgs() << "chosen c layout: " << cLayout << "\n"; + llvm::dbgs() << "anchor set on contract: " << contract << "\n"; + }); + + return success(); +} + +// Sets a layout anchor for reads from global memory. +// The layout this generates is approximately the following: +// +// #layout = #iree_vector_ext.nested_layout< +// subgroups_per_workgroup = [1, ..., 1] +// batches_per_subgroup = [] +// outers_per_batch = [1, ..., 1] +// threads_per_outer = [] +// elements_per_thread = [1, ..., 128/element_bitwidth, ..., 1] +// innermost_memref_dimension ^^^^^^ +// +// (All orders are the same) +// *_order = [, ]> +// +// So for the following transfer_read with 64 threads: +// vector.transfer_read ... : memref<16x256xf16>, vector<16x32xf16> +// +// We use the following layout: +// #layout = #iree_vector_ext.nested_layout< +// subgroups_per_workgroup = [1, 1] +// batches_per_subgroup = [1, 1] +// outers_per_batch = [1, 1] +// threads_per_outer = [16, 4] +// elements_per_thread = [1, 8] +// +// *_order = [0, 1]> +LogicalResult setTransferReadAnchor(ArrayRef workgroupSize, + RewriterBase &rewriter, + vector::TransferReadOp transfer) { + MLIRContext *context = rewriter.getContext(); + + // Get the forward slice of the transfer to approximate whether it will take + // the layout of a contraction instead. Transfer_read ops used directly by a + // contraction (i.e. without a copy to shared memory in between) should take + // the layout of the contraction op. This is common for cases where the + // initial values of the accumulator in a linalg.matmul is read from memory + // instead of just being a zerofill. + ForwardSliceOptions forwardOptions; + forwardOptions.filter = [&](Operation *op) -> bool { + return llvm::any_of(op->getResultTypes(), llvm::IsaPred); + }; + BackwardSliceOptions backwardOptions; + backwardOptions.filter = [&](Operation *op) -> bool { + return llvm::any_of(op->getOperandTypes(), llvm::IsaPred); + }; + SetVector slice = + getSlice(transfer, backwardOptions, forwardOptions); + + if (llvm::any_of(slice, llvm::IsaPred)) { + return success(); + } + + // Shared memory loads are expected to take the layout of the contraction. + auto sourceMemRefType = dyn_cast(transfer.getSource().getType()); + if (!sourceMemRefType || hasSharedMemoryAddressSpace(sourceMemRefType)) { + return success(); + } + + // Take on layout of broadcast. + if (transfer->hasOneUse() && + dyn_cast(*transfer->getUsers().begin())) { + return success(); + } + + // TODO: Support masking. + if (transfer.getMask()) { + transfer->emitOpError( + "Anchoring on transfer_read with masks is not yet implemented."); + return failure(); + } + + int64_t bitWidth = IREE::Util::getTypeBitWidth( + getElementTypeOrSelf(transfer.getVectorType())); + if (!llvm::isPowerOf2_64(bitWidth) || bitWidth > 128) { + transfer->emitOpError( + "Anchoring on transfer_read with element type of bitwidth " + + std::to_string(bitWidth) + " is not yet implemented"); + return failure(); + } + int64_t numElementsPerThread = 128 / bitWidth; + int64_t flatNumElements = + ShapedType::getNumElements(transfer.getVectorType().getShape()); + int64_t flatNumThreads = ShapedType::getNumElements(workgroupSize); + if (flatNumElements % flatNumThreads != 0) { + transfer->emitOpError() + << "Anchoring on transfer_read with unsupported number of elements " + "(not divisible by workgroup size)" + << ", number of elements: " << flatNumElements + << ", workgroup size: " << flatNumThreads; + return failure(); + } + numElementsPerThread = + std::min(numElementsPerThread, flatNumElements / flatNumThreads); + + AffineMap transferMap = transfer.getPermutationMap(); + if (transferMap.getNumDims() == 0) { + transfer->emitOpError("Anchoring on transfer_read with zero-rank " + "permutation map is not supported."); + return failure(); + } + + // Select the innermost dim of the memref as the contiguous dim to load + // from. + int64_t transferRank = transfer.getVectorType().getRank(); + std::optional maybeDim = transferMap.getResultPosition( + getAffineDimExpr(transferMap.getNumDims() - 1, context)); + int64_t distXDim = maybeDim ? *maybeDim : transferRank - 1; + + ArrayRef vectorShape = transfer.getVectorType().getShape(); + + // Limit the maximum inner vector read width to the innermost contiguous + // dimension. We could try to be clever and extend this to adjacent + // dimensions in cases where the innermost read vector dimension is small, + // but that requires comparing memref strides and is uncommon. For now + // prioritize warp contiguity over 128-bit read granularity. + numElementsPerThread = std::min(numElementsPerThread, vectorShape[distXDim]); + + llvm::SetVector vectorDimDistributionOrder; + // Get the order in which to distribute vector dimensions to threads, going + // from innermost to outermost memref dimension. It's important to note + // that this heuristic only applies to matrix multiplication cases where + // we are promoting the operands of a contraction to shared memory and we + // have no producers fused with the matmul. In general there is no universal + // way to set an anchoring layout for reads without doing an analysis of how + // the read values are used. + for (int i = transferMap.getNumDims() - 1; i >= 0; --i) { + std::optional maybeDim = + transferMap.getResultPosition(getAffineDimExpr(i, context)); + if (maybeDim) { + vectorDimDistributionOrder.insert(*maybeDim); + } + } + // Add all remaining (broadcasted) dimensions + for (auto dim : llvm::seq(static_cast(0), transferRank)) { + if (!vectorDimDistributionOrder.contains(dim)) + vectorDimDistributionOrder.insert(dim); + } + + int64_t residualThreads = flatNumThreads; + int64_t residualElements = numElementsPerThread; + + SmallVector order(vectorDimDistributionOrder.rbegin(), + vectorDimDistributionOrder.rend()); + + // Distribute all threads in the workgroup to the "threads" dimension, + // meaning subgroup counts is unit here, even though the read is being + // distributed to multiple subgroups. This is in an attempt to do a + // workgroup contiguous load. + SmallVector subgroupCounts(transferRank, 1); + SmallVector batchSizes(transferRank, 1); + SmallVector outerSizes(transferRank, 1); + SmallVector threadCounts(transferRank, 1); + SmallVector elementSizes(transferRank, 1); + + SmallVector subgroupStrides(transferRank, 1); + SmallVector threadStrides(transferRank, 1); + + int64_t currStrides = 1; + for (auto dim : llvm::reverse(order)) { + int64_t vectorSize = vectorShape[dim]; + // Set the element count for the innermost vector dimension. + if (residualElements != 1) { + elementSizes[dim] = residualElements; + vectorSize /= residualElements; + residualElements = 1; + } + + assert((residualThreads % vectorSize == 0 || + vectorSize % residualThreads == 0) && + "dividing threads to incompatible vector"); + if (residualThreads <= vectorSize) { + vectorSize /= residualThreads; + threadCounts[dim] = residualThreads; + threadStrides[dim] = currStrides; + currStrides *= residualThreads; + residualThreads = 1; + } else { + residualThreads /= vectorSize; + threadCounts[dim] = vectorSize; + threadStrides[dim] = currStrides; + currStrides *= vectorSize; + vectorSize = 1; + } + + batchSizes[dim] = vectorSize; + } + + auto layout = IREE::VectorExt::NestedLayoutAttr::get( + context, subgroupCounts, batchSizes, outerSizes, threadCounts, + elementSizes, subgroupStrides, threadStrides); + + Location loc = transfer.getLoc(); + rewriter.setInsertionPointAfter(transfer); + auto toLayout = rewriter.create( + loc, transfer.getResult().getType(), transfer.getResult(), layout); + rewriter.replaceAllUsesExcept(transfer, toLayout.getResult(), toLayout); + + return success(); +} + +struct LLVMGPUConfigureVectorLayoutsPass + : public LLVMGPUConfigureVectorLayoutsBase< + LLVMGPUConfigureVectorLayoutsPass> { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + auto func = getOperation(); + + std::array workgroupSize; + if (func->hasAttr("workgroup_size")) { + auto tmpSizes = + llvm::cast(func->getAttr("workgroup_size")).getValue(); + for (auto [i, size] : llvm::enumerate(tmpSizes)) { + workgroupSize[i] = llvm::cast(size).getInt(); + } + } else { + std::optional> maybeWorkgroupSize = + getWorkgroupSize(func); + if (!maybeWorkgroupSize) { + func->emitOpError() + << "unable to query workgroup_size information from entry point"; + return signalPassFailure(); + } + for (auto [index, value] : llvm::enumerate(maybeWorkgroupSize.value())) { + workgroupSize[index] = value; + } + for (auto index : llvm::seq(maybeWorkgroupSize->size(), 3)) { + workgroupSize[index] = 1; + } + } + + llvm::StringLiteral scheduleAttrName = + IREE::GPU::MMAScheduleAttr::getMnemonic(); + auto scheduleAttr = + func->getAttrOfType(scheduleAttrName); + if (!scheduleAttr) { + DictionaryAttr configDict = getTranslationInfo(func).getConfiguration(); + scheduleAttr = dyn_cast_or_null( + configDict.get(scheduleAttrName)); + } + + // Vector layout option setter aimed at contractions. Currently this only + // sets anchors for two types of operations; vector.contract and + // vector.transfer_read from non-shared memory. The assumption in this case + // is that all IR input to this pass has a leaf rooted on a transfer_read or + // includes a contraction in the program slice, meaning all operations + // should receive layouts. Layout setting for other problems like reductions + // is TODO. + SmallVector reads; + SmallVector contracts; + + func->walk([&](Operation *op) { + llvm::TypeSwitch(op) + .Case([&](vector::TransferReadOp transfer) { + reads.push_back(transfer); + }) + .Case([&](vector::ContractionOp contract) { + contracts.push_back(contract); + }); + }); + + IRRewriter rewriter(func); + + for (vector::TransferReadOp read : reads) { + if (failed(setTransferReadAnchor(workgroupSize, rewriter, read))) { + return signalPassFailure(); + } + } + + for (vector::ContractionOp contract : contracts) { + if (failed(setContractionAnchor(scheduleAttr, rewriter, contract))) { + return signalPassFailure(); + } + } + } +}; +} // namespace + +std::unique_ptr> +createLLVMGPUConfigureVectorLayouts() { + return std::make_unique(); +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp index 25adf1f4491c..d9d513819bac 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp @@ -4,26 +4,15 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include - #include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h" #include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h" #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h" -#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" -#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/LLVMGPU/PassDetail.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" -#include "iree/compiler/Codegen/Utils/GPUUtils.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/MathExtras.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" @@ -33,347 +22,25 @@ #define DEBUG_TYPE "iree-llvmgpu-vector-distribute" -using LayoutDimension = mlir::iree_compiler::IREE::VectorExt::LayoutDimension; -using LayoutDimensionAttr = - mlir::iree_compiler::IREE::VectorExt::LayoutDimensionAttr; -using VectorLayoutInterface = - mlir::iree_compiler::IREE::VectorExt::VectorLayoutInterface; -using PerDimLayoutAttr = mlir::iree_compiler::IREE::VectorExt::PerDimLayoutAttr; -using LayoutAttr = mlir::iree_compiler::IREE::VectorExt::LayoutAttr; - namespace mlir::iree_compiler { namespace { -// Vector layout option setter aimed at contractions. Currently this only sets -// anchors for two types of operations; vector.contract and vector.transfer_read -// from non-shared memory. The assumption in this case is that all IR input to -// this pass has a leaf rooted on a transfer_read or includes a contraction in -// the program slice, meaning all operations should receive layouts. Layout -// setting for other problems like reductions is TODO. class ContractionVectorLayoutOptions : public VectorLayoutOptions { public: - ContractionVectorLayoutOptions(Operation *root, - ArrayRef workgroupSize, - IREE::GPU::MMAScheduleAttr schedule, - Value laneId, int64_t subgroupSize, - bool printLayout) - : VectorLayoutOptions(root, /*fullConversion=*/!printLayout), - workgroupSize(workgroupSize), schedule(schedule), - printLayout(printLayout), patterns(root->getContext()) { + ContractionVectorLayoutOptions(Operation *root, Value laneId, + int64_t subgroupSize) + : VectorLayoutOptions(root), patterns(root->getContext()) { populateGPUDistributionPatterns(patterns); populateGPUDistributionLayoutAttrPatterns(laneId, patterns); populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, subgroupSize); - } - - LogicalResult setAnchorOps(RewriterBase &rewriter) { - MLIRContext *context = root->getContext(); - SmallVector reads; - SmallVector contracts; - - root->walk([&](Operation *op) { - llvm::TypeSwitch(op) - .Case([&](vector::TransferReadOp transfer) { - reads.push_back(transfer); - }) - .Case([&](vector::ContractionOp contract) { - contracts.push_back(contract); - }); - }); - - for (vector::TransferReadOp read : reads) { - if (failed(setTransferReadAnchor(context, rewriter, read))) { - return failure(); - } - } - - for (vector::ContractionOp contract : contracts) { - if (failed(setContractionAnchor(context, rewriter, contract))) { - return failure(); - } - } - - return success(); + populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns); } RewritePatternSet &getPatterns() { return patterns; } private: - // Sets an anchoring layout for the given contraction op. Looks for a - // supported mma type from the cached list of mma types and populates the - // necessary distribution pattern for those contractions. - LogicalResult setContractionAnchor(MLIRContext *context, - RewriterBase &rewriter, - vector::ContractionOp contract) { - // TODO: Add SIMT fallback. - if (!schedule) { - return contract->emitError("missing mma schedule for contraction"); - } - - auto layouts = schedule.getContractionLayout(contract); - if (failed(layouts)) { - return contract->emitError("cannot get concrete layout for contraction"); - } - - auto [aLayout, bLayout, cLayout] = *layouts; - Location loc = contract.getLoc(); - - // Set layouts for lhs, rhs and acc. - rewriter.setInsertionPoint(contract); - Value layoutedLhs = rewriter.create( - loc, contract.getLhsType(), contract.getLhs(), aLayout); - Value layoutedRhs = rewriter.create( - loc, contract.getRhsType(), contract.getRhs(), bLayout); - Value layoutedAcc = rewriter.create( - loc, contract.getAccType(), contract.getAcc(), cLayout); - contract->setOperand(0, layoutedLhs); - contract->setOperand(1, layoutedRhs); - contract->setOperand(2, layoutedAcc); - - // Set layout for result. - rewriter.setInsertionPointAfter(contract); - auto toLayout = rewriter.create( - loc, contract.getResultType(), contract.getResult(), cLayout); - rewriter.replaceAllUsesExcept(contract, toLayout.getResult(), toLayout); - - // Set intrinsic kind. - contract->setAttr("iree.amdgpu.mma", schedule.getIntrinsic()); - - if (printLayout) { - llvm::outs() << "contract A vector layout: " << aLayout << "\n"; - llvm::outs() << "contract B vector layout: " << bLayout << "\n"; - llvm::outs() << "contract C vector layout: " << cLayout << "\n"; - } - LLVM_DEBUG({ - llvm::dbgs() << "chosen a layout: " << aLayout << "\n"; - llvm::dbgs() << "chosen b layout: " << bLayout << "\n"; - llvm::dbgs() << "chosen c layout: " << cLayout << "\n"; - llvm::dbgs() << "anchor set on contract: " << contract << "\n"; - }); - - if (isa(schedule.getIntrinsic())) { - if (!populatedMma) { - populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns); - populatedMma = true; - } - } else { - llvm_unreachable("Unsupported mma type"); - } - return success(); - } - - // Sets a layout anchor for reads from global memory. - // The layout this generates is approximately the following: - // - // #layout = #iree_vector_ext.nested_layout< - // subgroups_per_workgroup = [1, ..., 1] - // batches_per_subgroup = [] - // outers_per_batch = [1, ..., 1] - // threads_per_outer = [] - // elements_per_thread = [1, ..., 128/element_bitwidth, ..., 1] - // innermost_memref_dimension ^^^^^^ - // - // (All orders are the same) - // *_order = [, ]> - // - // So for the following transfer_read with 64 threads: - // vector.transfer_read ... : memref<16x256xf16>, vector<16x32xf16> - // - // We use the following layout: - // #layout = #iree_vector_ext.nested_layout< - // subgroups_per_workgroup = [1, 1] - // batches_per_subgroup = [1, 1] - // outers_per_batch = [1, 1] - // threads_per_outer = [16, 4] - // elements_per_thread = [1, 8] - // - // *_order = [0, 1]> - LogicalResult setTransferReadAnchor(MLIRContext *context, - RewriterBase &rewriter, - vector::TransferReadOp transfer) { - - // Get the forward slice of the transfer to approximate whether it will take - // the layout of a contraction instead. Transfer_read ops used directly by a - // contraction (i.e. without a copy to shared memory in between) should take - // the layout of the contraction op. This is common for cases where the - // initial values of the accumulator in a linalg.matmul is read from memory - // instead of just being a zerofill. - ForwardSliceOptions forwardOptions; - forwardOptions.filter = [&](Operation *op) -> bool { - return llvm::any_of(op->getResultTypes(), llvm::IsaPred); - }; - BackwardSliceOptions backwardOptions; - backwardOptions.filter = [&](Operation *op) -> bool { - return llvm::any_of(op->getOperandTypes(), llvm::IsaPred); - }; - SetVector slice = - getSlice(transfer, backwardOptions, forwardOptions); - - if (llvm::any_of(slice, llvm::IsaPred)) { - return success(); - } - - // Shared memory loads are expected to take the layout of the contraction. - auto sourceMemRefType = - dyn_cast(transfer.getSource().getType()); - if (!sourceMemRefType || hasSharedMemoryAddressSpace(sourceMemRefType)) { - return success(); - } - - // Take on layout of broadcast. - if (transfer->hasOneUse() && - dyn_cast(*transfer->getUsers().begin())) { - return success(); - } - - // TODO: Support masking. - if (transfer.getMask()) { - transfer->emitOpError( - "Anchoring on transfer_read with masks is not yet implemented."); - return failure(); - } - - int64_t bitWidth = IREE::Util::getTypeBitWidth( - getElementTypeOrSelf(transfer.getVectorType())); - if (!llvm::isPowerOf2_64(bitWidth) || bitWidth > 128) { - transfer->emitOpError( - "Anchoring on transfer_read with element type of bitwidth " + - std::to_string(bitWidth) + " is not yet implemented"); - return failure(); - } - int64_t numElementsPerThread = 128 / bitWidth; - int64_t flatNumElements = - ShapedType::getNumElements(transfer.getVectorType().getShape()); - int64_t flatNumThreads = ShapedType::getNumElements(workgroupSize); - if (flatNumElements % flatNumThreads != 0) { - transfer->emitOpError() - << "Anchoring on transfer_read with unsupported number of elements " - "(not divisible by workgroup size)" - << ", number of elements: " << flatNumElements - << ", workgroup size: " << flatNumThreads; - return failure(); - } - numElementsPerThread = - std::min(numElementsPerThread, flatNumElements / flatNumThreads); - - AffineMap transferMap = transfer.getPermutationMap(); - if (transferMap.getNumDims() == 0) { - transfer->emitOpError("Anchoring on transfer_read with zero-rank " - "permutation map is not supported."); - return failure(); - } - - // Select the innermost dim of the memref as the contiguous dim to load - // from. - int64_t transferRank = transfer.getVectorType().getRank(); - std::optional maybeDim = transferMap.getResultPosition( - getAffineDimExpr(transferMap.getNumDims() - 1, context)); - int64_t distXDim = maybeDim ? *maybeDim : transferRank - 1; - - ArrayRef vectorShape = transfer.getVectorType().getShape(); - - // Limit the maximum inner vector read width to the innermost contiguous - // dimension. We could try to be clever and extend this to adjacent - // dimensions in cases where the innermost read vector dimension is small, - // but that requires comparing memref strides and is uncommon. For now - // prioritize warp contiguity over 128-bit read granularity. - numElementsPerThread = - std::min(numElementsPerThread, vectorShape[distXDim]); - - llvm::SetVector vectorDimDistributionOrder; - // Get the order in which to distribute vector dimensions to threads, going - // from innermost to outermost memref dimension. It's important to note - // that this heuristic only applies to matrix multiplication cases where - // we are promoting the operands of a contraction to shared memory and we - // have no producers fused with the matmul. In general there is no universal - // way to set an anchoring layout for reads without doing an analysis of how - // the read values are used. - for (int i = transferMap.getNumDims() - 1; i >= 0; --i) { - std::optional maybeDim = - transferMap.getResultPosition(getAffineDimExpr(i, context)); - if (maybeDim) { - vectorDimDistributionOrder.insert(*maybeDim); - } - } - // Add all remaining (broadcasted) dimensions - for (auto dim : llvm::seq(static_cast(0), transferRank)) { - if (!vectorDimDistributionOrder.contains(dim)) - vectorDimDistributionOrder.insert(dim); - } - - int64_t residualThreads = flatNumThreads; - int64_t residualElements = numElementsPerThread; - - SmallVector order(vectorDimDistributionOrder.rbegin(), - vectorDimDistributionOrder.rend()); - - // Distribute all threads in the workgroup to the "threads" dimension, - // meaning subgroup counts is unit here, even though the read is being - // distributed to multiple subgroups. This is in an attempt to do a - // workgroup contiguous load. - SmallVector subgroupCounts(transferRank, 1); - SmallVector batchSizes(transferRank, 1); - SmallVector outerSizes(transferRank, 1); - SmallVector threadCounts(transferRank, 1); - SmallVector elementSizes(transferRank, 1); - - SmallVector subgroupStrides(transferRank, 1); - SmallVector threadStrides(transferRank, 1); - - int64_t currStrides = 1; - for (auto dim : llvm::reverse(order)) { - int64_t vectorSize = vectorShape[dim]; - // Set the element count for the innermost vector dimension. - if (residualElements != 1) { - elementSizes[dim] = residualElements; - vectorSize /= residualElements; - residualElements = 1; - } - - assert((residualThreads % vectorSize == 0 || - vectorSize % residualThreads == 0) && - "dividing threads to incompatible vector"); - if (residualThreads <= vectorSize) { - vectorSize /= residualThreads; - threadCounts[dim] = residualThreads; - threadStrides[dim] = currStrides; - currStrides *= residualThreads; - residualThreads = 1; - } else { - residualThreads /= vectorSize; - threadCounts[dim] = vectorSize; - threadStrides[dim] = currStrides; - currStrides *= vectorSize; - vectorSize = 1; - } - - batchSizes[dim] = vectorSize; - } - - auto layout = IREE::VectorExt::NestedLayoutAttr::get( - context, subgroupCounts, batchSizes, outerSizes, threadCounts, - elementSizes, subgroupStrides, threadStrides); - - Location loc = transfer.getLoc(); - rewriter.setInsertionPointAfter(transfer); - auto toLayout = rewriter.create( - loc, transfer.getResult().getType(), transfer.getResult(), layout); - rewriter.replaceAllUsesExcept(transfer, toLayout.getResult(), toLayout); - - if (printLayout) { - llvm::outs() << "transfer '" << transfer << "' vector layout: " << layout - << "\n"; - } - return success(); - } - - SmallVector workgroupSize; - IREE::GPU::MMAScheduleAttr schedule; - // Whether to print the chosen layout for testing purposes - bool printLayout; - - bool populatedMma = false; RewritePatternSet patterns; }; @@ -413,16 +80,6 @@ struct LLVMGPUVectorDistributePass } } - llvm::StringLiteral scheduleAttrName = - IREE::GPU::MMAScheduleAttr::getMnemonic(); - auto scheduleAttr = - func->getAttrOfType(scheduleAttrName); - if (!scheduleAttr) { - DictionaryAttr configDict = getTranslationInfo(func).getConfiguration(); - scheduleAttr = dyn_cast_or_null( - configDict.get(scheduleAttrName)); - } - AffineExpr x, y, z; bindSymbols(func.getContext(), x, y, z); // Construct the expression for linearizing the thread indices. @@ -449,15 +106,8 @@ struct LLVMGPUVectorDistributePass return signalPassFailure(); } - ContractionVectorLayoutOptions options(func, workgroupSize, scheduleAttr, - linearThreadIdVal, - subgroupSize.value(), testLayout); - - // Set anchor layouts. - if (failed(options.setAnchorOps(rewriter))) { - func->emitError() << "failed to set anchors"; - return signalPassFailure(); - } + ContractionVectorLayoutOptions options(func, linearThreadIdVal, + subgroupSize.value()); if (failed(distributeVectorOps(func, options.getPatterns(), options))) { func->emitOpError() << "failed to distribute"; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 8918fb2cc748..f2a975965c4a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -789,6 +789,7 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createAMDGPUPrepareForChainedMatmulPass()); // Vector SIMD -> Vector SIMT + funcPassManager.addPass(createLLVMGPUConfigureVectorLayouts()); funcPassManager.addPass(createLLVMGPUVectorDistribute()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h index 488705fc38e1..fb8427502278 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h @@ -152,6 +152,10 @@ std::unique_ptr> createLLVMGPUPromoteMatmulToFitMMAPass( LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ParallelDims); +// Pass to set layouts for vector distribution. +std::unique_ptr> +createLLVMGPUConfigureVectorLayouts(); + enum class GPUTensorCoreType { WMMA = 0, MMA_SYNC = 1, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td index b4176a534286..8ea7da497989 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td @@ -60,6 +60,12 @@ def LLVMGPUCastTypeToFitMMA : InterfacePass<"iree-llvmgpu-cast-type-to-fit-mma", let constructor = "mlir::iree_compiler::createLLVMGPUCastTypeToFitMMAPass()"; } +def LLVMGPUConfigureVectorLayouts : + InterfacePass<"iree-llvmgpu-configure-vector-layouts", "mlir::FunctionOpInterface"> { + let summary = "Pass to set layouts for vector distribution"; + let constructor = "mlir::iree_compiler::createLLVMGPUConfigureVectorLayouts()"; +} + def LLVMGPULowerExecutableTarget : InterfacePass<"iree-llvmgpu-lower-executable-target", "mlir::FunctionOpInterface"> { let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline"; @@ -125,11 +131,6 @@ def LLVMGPUVectorDistribute : InterfacePass<"iree-llvmgpu-vector-distribute", "mlir::FunctionOpInterface"> { let summary = "Pass to distribute vectorized functions."; let constructor = "mlir::iree_compiler::createLLVMGPUVectorDistribute()"; - let options = [ - Option<"testLayout", "test-layout", "bool", /*default=*/"false", - "Annotate vector ops with deduced layouts without real conversion " - "for testing purposes"> - ]; } def LLVMGPUVectorLowering : diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index 1757b5ce48f4..057ff1c7e2ed 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -74,7 +74,7 @@ iree_lit_test_suite( "transpose_pipeline_test.mlir", "ukernel_pipeline_transform.mlir", "vector_distribute_conversion.mlir", - "vector_distribute_layout.mlir", + "configure_vector_layout.mlir", "vector_lowering.mlir", "vector_to_gpu.mlir", "winograd_pipeline_test.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt index 2ff84aa75ea2..5f603cfbf14e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt @@ -23,6 +23,7 @@ iree_lit_test_suite( "cast_type_to_fit_mma.mlir" "config_matvec.mlir" "config_winograd.mlir" + "configure_vector_layout.mlir" "conv_pipeline_test_cuda.mlir" "conv_pipeline_test_rocm.mlir" "convert_to_nvvm.mlir" @@ -70,7 +71,6 @@ iree_lit_test_suite( "transpose_pipeline_test.mlir" "ukernel_pipeline_transform.mlir" "vector_distribute_conversion.mlir" - "vector_distribute_layout.mlir" "vector_lowering.mlir" "vector_to_gpu.mlir" "winograd_pipeline_test.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_vector_layout.mlir similarity index 58% rename from compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir rename to compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_vector_layout.mlir index a2519abd0046..e67c011cdba3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_vector_layout.mlir @@ -1,11 +1,23 @@ -// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-llvmgpu-vector-distribute{test-layout}, canonicalize, cse))' %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-llvmgpu-configure-vector-layouts, canonicalize, cse))' %s | FileCheck %s #translation = #iree_codegen.translation_info, subgroup_m_count = 1, subgroup_n_count = 1>}> +// Since CHECK-SAME doesnt work with CHECK-DAG, we cannot have prettier tests. + +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout + +// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mm func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf32>) -> vector<96x64xf32> attributes { translation_info = #translation } { + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]] + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]] + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]] + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %0 = vector.contract { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} @@ -13,16 +25,6 @@ func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf return %0 : vector<96x64xf32> } -// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 32]> -// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [32, 1]> -// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [32, 1]> - // ----- #translation = #iree_codegen.translation_info, %rhs: vector<16x64xf subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 1>}> +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout + +// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mmt func.func @mfma_matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<96x64xf32>) -> vector<96x64xf32> attributes { translation_info = #translation } { + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]] + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]] + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]] + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %0 = vector.contract { indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, d2) -> (m, n)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} @@ -38,16 +50,6 @@ func.func @mfma_matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16x return %0 : vector<96x64xf32> } -// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 32]> -// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 32]> -// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [32, 1]> - // ----- #translation = #iree_codegen.translation_info, %rhs: vector<64x16x subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 1>}> +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout + +// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mmtt func.func @mfma_matmul_96x64x16_mmtt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<64x96xf32>) -> vector<64x96xf32> attributes { translation_info = #translation } { + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]] + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]] + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]] + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %0 = vector.contract { indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, k) -> (n, m)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} @@ -63,16 +75,6 @@ func.func @mfma_matmul_96x64x16_mmtt(%lhs: vector<96x16xf16>, %rhs: vector<64x16 return %0 : vector<64x96xf32> } -// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 32] -// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 32] -// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 3], outers_per_batch = [1, 4], threads_per_outer = [32, 2], elements_per_thread = [1, 4], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 32] - // ----- #translation = #iree_codegen.translation_info, %rhs: vector<64x16 subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 1>}> +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout, %rhs: vector<16x64xf16>, %init: vector<192x64xf32>) -> vector<192x64xf32> attributes { translation_info = #translation } { + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]] + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]] + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]] + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %0 = vector.contract { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} @@ -88,10 +100,6 @@ func.func @matmul_192x64x16_mmt_multisubgroup(%lhs: vector<192x16xf16>, %rhs: ve return %0 : vector<192x64xf32> } -// CHECK: contract A vector layout: #iree_vector_ext.nested_layout, %rhs: ve subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 1, subgroup_n_count = 1>}> +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout + +// CHECK-LABEL: func.func @matmul_16x16x256_read func.func @matmul_16x16x256_read(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, %rhs: memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type>) @@ -113,12 +125,21 @@ func.func @matmul_16x16x256_read(%lhs: memref<16x256xf16, strided<[256, 1], offs %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %cst_1) -> (vector<16x16xf32>) { %6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<16x32xf16> %7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true]} : memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<32x16xf16> + // CHECK: %[[READ0:.+]] = vector.transfer_read + // CHECK: to_layout %[[READ0]] to #[[$NESTED]] + // CHECK: %[[READ1:.+]] = vector.transfer_read + // CHECK: to_layout %[[READ1]] to #[[$NESTED1]] vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space> gpu.barrier vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space> gpu.barrier %8 = vector.transfer_read %alloc_0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space>, vector<16x32xf16> %9 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space>, vector<32x16xf16> + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %10 = vector.contract { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} @@ -131,23 +152,6 @@ func.func @matmul_16x16x256_read(%lhs: memref<16x256xf16, strided<[256, 1], offs return } -// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}>, vector<16x32xf16>' vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 8], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [4, 1]> -// CHECK: transfer '{{.+}} memref<256x16xf16{{.+}}>, vector<32x16xf16>' vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 8], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [2, 1]> - -// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 4], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 16]> -// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [16, 1]> -// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [16, 1]> - // ----- #translation = #iree_codegen.translation_info, subgroup_m_count = 1, subgroup_n_count = 1>}> +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout +// CHECK-LABEL: func.func @matmul_16x16x256_read_permute func.func @matmul_16x16x256_read_permute(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, %rhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type>) @@ -169,15 +176,25 @@ func.func @matmul_16x16x256_read_permute(%lhs: memref<16x256xf16, strided<[256, %c0 = arith.constant 0 : index %init_acc = vector.transfer_read %out[%c0, %c0], %cst_f32 {in_bounds = [true, true]} : memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<16x16xf32> + // CHECK: scf.for %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %init_acc) -> (vector<16x16xf32>) { %6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<16x32xf16> %7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<32x16xf16> + // CHECK: %[[READ0:.+]] = vector.transfer_read + // CHECK: to_layout %[[READ0]] to #[[$NESTED]] + // CHECK: %[[READ1:.+]] = vector.transfer_read + // CHECK: to_layout %[[READ1]] to #[[$NESTED1]] vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space> gpu.barrier vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space> gpu.barrier %8 = vector.transfer_read %alloc_0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space>, vector<16x32xf16> %9 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space>, vector<32x16xf16> + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %10 = vector.contract { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} @@ -190,24 +207,6 @@ func.func @matmul_16x16x256_read_permute(%lhs: memref<16x256xf16, strided<[256, return } -// CHECK-NOT: transfer '{{.+}} memref<16x16xf16{{.+}}>, vector<16x16xf16>' vector layout -// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}>, vector<16x32xf16>' vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 8], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [4, 1]> -// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}storage_buffer>>, vector<32x16xf16>' vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [8, 1], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 4]> - -// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 4], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 16]> -// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [16, 1]> -// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [16, 1]> - // ----- #translation = #iree_codegen.translation_info, subgroup_m_count = 1, subgroup_n_count = 1>}> +// We don't really care what layout we assign here, just that the only anchor +// we set is on the contraction. + +// CHECK-LABEL: func.func @matmul_16x16x256_fused func.func @matmul_16x16x256_fused(%lhs: memref<16x32xf16>, %rhs: memref<32x16xf16>, %bias: memref<16x16xf32>, @@ -228,6 +231,18 @@ func.func @matmul_16x16x256_fused(%lhs: memref<16x32xf16>, %acc = vector.transfer_read %out[%c0, %c0], %cst_f32 {in_bounds = [true, true]} : memref<16x16xf32>, vector<16x16xf32> %8 = vector.transfer_read %lhs[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16>, vector<16x32xf16> %9 = vector.transfer_read %rhs[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16>, vector<32x16xf16> + // CHECK-DAG: %[[READA:.+]] = vector.transfer_read + // CHECK-DAG: %[[READB:.+]] = vector.transfer_read + // CHECK-DAG: %[[READC:.+]] = vector.transfer_read + // CHECK-NOT: to_layout %[[READA]] + // CHECK-NOT: to_layout %[[READB]] + // CHECK-NOT: to_layout %[[READC]] + + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %10 = vector.contract { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} @@ -238,16 +253,6 @@ func.func @matmul_16x16x256_fused(%lhs: memref<16x32xf16>, return } -// We don't really care what layout we assign here, just that the only anchor -// we set is on the contraction. -// CHECK-NOT: transfer {{.*}} vector layout -// CHECK: contract A vector layout -// CHECK-NOT: transfer {{.*}} vector layout -// CHECK: contract B vector layout -// CHECK-NOT: transfer {{.*}} vector layout -// CHECK: contract C vector layout -// CHECK-NOT: transfer {{.*}} vector layout - // ----- #translation = #iree_codegen.translation_info, subgroup_size = 32, {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 1>}> +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout + +// CHECK-LABEL: func.func @wmma_matmul_48x32x32_mm func.func @wmma_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf32>) -> vector<48x32xf32> attributes { translation_info = #translation } { + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]] + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]] + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]] + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %0 = vector.contract { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} @@ -263,16 +278,6 @@ func.func @wmma_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf return %0 : vector<48x32xf32> } -// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 0]> -// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [1, 16], elements_per_thread = [16, 1], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [0, 1]> -// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [16, 1]> - // ----- #translation = #iree_codegen.translation_info, %rhs: vector<32x32xf subgroup_size = 32, {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 1>}> +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout + +// CHECK-LABEL: func.func @wmma_matmul_48x32x32_mmt func.func @wmma_matmul_48x32x32_mmt(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf32>) -> vector<48x32xf32> attributes { translation_info = #translation } { + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]] + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]] + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]] + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %0 = vector.contract { indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, d2) -> (m, n)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} @@ -288,16 +303,6 @@ func.func @wmma_matmul_48x32x32_mmt(%lhs: vector<48x32xf16>, %rhs: vector<32x32x return %0 : vector<48x32xf32> } -// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 0]> -// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [1, 0]> -// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1], -// CHECK-SAME: subgroup_strides = [0, 0], thread_strides = [16, 1]> - // ----- #translation = #iree_codegen.translation_info, %rhs: vector<32x32x subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 1>}> +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout +// CHECK-LABEL: func.func @matmul_192x64x16_mmt_multi_m func.func @matmul_192x64x16_mmt_multi_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes { translation_info = #translation } { + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]] + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]] + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]] + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %0 = vector.contract { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind} @@ -314,31 +328,6 @@ func.func @matmul_192x64x16_mmt_multi_m(%lhs: vector<2x64x16xf16>, %rhs: vector< return %0 : vector<2x64x64xf32> } -// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [2, 1, 1], -// CHECK-SAME: batches_per_subgroup = [1, 4, 1], -// CHECK-SAME: outers_per_batch = [1, 1, 1], -// CHECK-SAME: threads_per_outer = [1, 16, 4], -// CHECK-SAME: elements_per_thread = [1, 1, 4], -// CHECK-SAME: subgroup_strides = [1, 0, 0], -// CHECK-SAME: thread_strides = [0, 1, 16]> -// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], -// CHECK-SAME: batches_per_subgroup = [1, 4], -// CHECK-SAME: outers_per_batch = [1, 1], -// CHECK-SAME: threads_per_outer = [4, 16], -// CHECK-SAME: elements_per_thread = [4, 1], -// CHECK-SAME: subgroup_strides = [0, 0], -// CHECK-SAME: thread_strides = [16, 1]> -// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [2, 1, 1], -// CHECK-SAME: batches_per_subgroup = [1, 4, 4], -// CHECK-SAME: outers_per_batch = [1, 1, 1], -// CHECK-SAME: threads_per_outer = [1, 4, 16], -// CHECK-SAME: elements_per_thread = [1, 4, 1], -// CHECK-SAME: subgroup_strides = [1, 0, 0], -// CHECK-SAME: thread_strides = [0, 16, 1]> - // ----- #translation = #iree_codegen.translation_info, %rhs: vector< subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 4, subgroup_n_count = 1>}> +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout + +// CHECK-LABEL: func.func @matmul_192x64x16_mmt_multi_split_m func.func @matmul_192x64x16_mmt_multi_split_m(%lhs: vector<2x64x16xf16>, %rhs: vector<16x64xf16>, %init: vector<2x64x64xf32>) -> vector<2x64x64xf32> attributes { translation_info = #translation } { + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]] + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]] + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]] + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %0 = vector.contract { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind} @@ -354,15 +353,6 @@ func.func @matmul_192x64x16_mmt_multi_split_m(%lhs: vector<2x64x16xf16>, %rhs: v return %0 : vector<2x64x64xf32> } -// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [2, 2, 1], -// CHECK-SAME: batches_per_subgroup = [1, 2, 1], -// CHECK-SAME: subgroup_strides = [2, 1, 0], -// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [2, 2, 1], -// CHECK-SAME: batches_per_subgroup = [1, 2, 4], -// CHECK-SAME: subgroup_strides = [2, 1, 0], - // ----- #translation = #iree_codegen.translation_info, %rhs: v subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule< intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2>, workgroup_size = [128, 2, 1]}> +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout, %rhs: vector<2x16x64xf16>, %init: vector<4x2x64x64xf32>) -> vector<4x2x64x64xf32> attributes { translation_info = #translation } { + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]] + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]] + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]] + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %0 = vector.contract { indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], kind = #vector.kind} @@ -378,19 +378,6 @@ func.func @matmul_192x64x16_mmt_multi_m_and_n(%lhs: vector<4x64x16xf16>, %rhs: v return %0 : vector<4x2x64x64xf32> } -// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [2, 1, 1], -// CHECK-SAME: batches_per_subgroup = [2, 4, 1], -// CHECK-SAME: subgroup_strides = [2, 0, 0], -// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [2, 1, 1], -// CHECK-SAME: batches_per_subgroup = [1, 1, 4], -// CHECK-SAME: subgroup_strides = [1, 0, 0], -// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [2, 2, 1, 1], -// CHECK-SAME: batches_per_subgroup = [2, 1, 4, 4], -// CHECK-SAME: subgroup_strides = [2, 1, 0, 0], - // ----- #translation = #iree_codegen.translation_info, %rhs: v subgroup_size = 32, {mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 1, subgroup_n_count = 4>}> +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout + +// CHECK-LABEL: func.func @dequant_anchors_on_quant_only func.func @dequant_anchors_on_quant_only(%quant: memref<128x128xi4, strided<[4096, 1], offset: ?>, #hal.descriptor_type>, %scale: memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type>, %zp: memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type>) @@ -410,6 +400,8 @@ func.func @dequant_anchors_on_quant_only(%quant: memref<128x128xi4, strided<[409 %c0_i4 = arith.constant 0 : i4 %c0 = arith.constant 0 : index %0 = vector.transfer_read %quant[%c0, %c0], %c0_i4 {in_bounds = [true, true]} : memref<128x128xi4, strided<[4096, 1], offset: ?>, #hal.descriptor_type>, vector<128x128xi4> + // CHECK: %[[READ:.+]] = vector.transfer_read + // CHECK: to_layout %[[READ]] to #[[$NESTED]] %1 = vector.transfer_read %scale[%c0], %cst {in_bounds = [true]} : memref<128xf16, strided<[32], offset: ?>, #hal.descriptor_type>, vector<128xf16> %2 = vector.broadcast %1 : vector<128xf16> to vector<128x128xf16> %3 = vector.transpose %2, [1, 0] : vector<128x128xf16> to vector<128x128xf16> @@ -423,12 +415,13 @@ func.func @dequant_anchors_on_quant_only(%quant: memref<128x128xi4, strided<[409 vector.transfer_write %10, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<128x128xf16>, memref<128x128xf16, #gpu.address_space> return } -// CHECK: transfer '{{.+}} memref<128x128xi4{{.+}}>, vector<128x128xi4>' vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 1], outers_per_batch = [1, 1], threads_per_outer = [32, 4], elements_per_thread = [1, 32], subgroup_strides = [0, 0], thread_strides = [4, 1]> -// CHECK-NOT: transfer '{{.+}} memref<128xf16{{.+}}>, vector<128xf16>' vector layout // ----- +// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout +// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout + #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> @@ -436,7 +429,13 @@ func.func @dequant_anchors_on_quant_only(%quant: memref<128x128xi4, strided<[409 workgroup_size = [128, 2, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 2, subgroup_n_count = 2>}> +// CHECK-LABEL: func.func @batch_matmul_unit_batch func.func @batch_matmul_unit_batch(%arg0: vector<1x64x64xf16>, %arg1: vector<1x64x128xf16>, %arg2: vector<1x64x128xf32>) -> vector<1x64x128xf32> attributes {translation_info = #translation} { + // CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]] + // CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]] + // CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]] + // CHECK: vector.contract + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] %0 = vector.contract { indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], @@ -444,27 +443,3 @@ func.func @batch_matmul_unit_batch(%arg0: vector<1x64x64xf16>, %arg1: vector<1x6 %arg0, %arg1, %arg2 : vector<1x64x64xf16>, vector<1x64x128xf16> into vector<1x64x128xf32> return %0 : vector<1x64x128xf32> } -// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 2, 1], -// CHECK-SAME: batches_per_subgroup = [1, 2, 4], -// CHECK-SAME: outers_per_batch = [1, 1, 1] -// CHECK-SAME: threads_per_outer = [1, 16, 4] -// CHECK-SAME: elements_per_thread = [1, 1, 4] -// CHECK-SAME: subgroup_strides = [0, 2, 0], -// CHECK-SAME: thread_strides = [0, 1, 16]> -// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 1, 2] -// CHECK-SAME: batches_per_subgroup = [1, 4, 4] -// CHECK-SAME: outers_per_batch = [1, 1, 1] -// CHECK-SAME: threads_per_outer = [1, 4, 16] -// CHECK-SAME: elements_per_thread = [1, 4, 1] -// CHECK-SAME: subgroup_strides = [0, 0, 1], -// CHECK-SAME: thread_strides = [0, 16, 1]> -// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< -// CHECK-SAME: subgroups_per_workgroup = [1, 2, 2] -// CHECK-SAME: batches_per_subgroup = [1, 2, 4] -// CHECK-SAME: outers_per_batch = [1, 1, 1] -// CHECK-SAME: threads_per_outer = [1, 4, 16] -// CHECK-SAME: elements_per_thread = [1, 4, 1] -// CHECK-SAME: subgroup_strides = [0, 2, 1], -// CHECK-SAME: thread_strides = [0, 16, 1]> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir index 0818851c5a50..9c9eab06e665 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-llvmgpu-vector-distribute, canonicalize, cse))' -split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-llvmgpu-configure-vector-layouts, iree-llvmgpu-vector-distribute, canonicalize, cse))' -split-input-file %s | FileCheck %s #translation = #iree_codegen.translation_info