diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp index a20f872c7b8ff..cc7fa2bd2b57c 100644 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp @@ -7,6 +7,7 @@ // Implements IREE-specific logic for lowering StableHLO/CHLO dialects to // LinalgExt dialect. +#include #include #include #include @@ -427,7 +428,6 @@ struct FftOpConversion final : OpConversionPattern { struct ReverseOpConversion final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(mlir::stablehlo::ReverseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -435,14 +435,45 @@ struct ReverseOpConversion final if (!ty) return failure(); + Value input = op.getOperand(); + auto inputTy = cast(input.getType()); + auto resultTy = cast(op.getType()); + ArrayRef dims = op.getDimensions(); Location loc = op.getLoc(); - SmallVector mixedSizes = - tensor::getMixedSizes(rewriter, loc, adaptor.getOperands()[0]); - Value emptyTensor = - rewriter.create(loc, mixedSizes, ty.getElementType()); - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), adaptor.getOperands(), - emptyTensor, rewriter.getI64TensorAttr(op.getDimensions())); + int64_t inputTyRank = inputTy.getRank(); + + // First fill the output buffer with the init value. + SmallVector inputMixedSizes = + tensor::getMixedSizes(rewriter, loc, input); + auto emptyTensor = rewriter.create( + loc, inputMixedSizes, inputTy.getElementType()); + SmallVector affineMaps = { + rewriter.getMultiDimIdentityMap(resultTy.getRank())}; + + rewriter.replaceOpWithNewOp( + op, resultTy, ArrayRef({}), ValueRange{emptyTensor}, affineMaps, + getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + llvm::SmallVector indices; + for (unsigned int i = 0; i < inputTyRank; i++) { + Value index = + rewriter.create(nestedLoc, i).getResult(); + if (std::find(dims.begin(), dims.end(), i) != dims.end()) { + auto one = rewriter.create(nestedLoc, 1); + Value axisDimSize = rewriter.create(loc, input, i); + auto sizeMinusOne = + rewriter.create(nestedLoc, axisDimSize, one); + index = rewriter.create(nestedLoc, sizeMinusOne, + index); + } + indices.push_back(index); + } + + auto extract = nestedBuilder.create( + nestedLoc, input, indices); + nestedBuilder.create(op.getLoc(), + extract.getResult()); + }); return success(); } }; diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir index 917f2f8b83b50..09b2bd4d87bd3 100644 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir +++ b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir @@ -495,12 +495,17 @@ func.func @reverse_dim1(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { return %0 : tensor<3x5xi32> } // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xi32> -// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>) -// CHECK-SAME: ins(%[[IN]] : tensor<3x5xi32>) -// CHECK-SAME: outs(%[[INIT]] : tensor<3x5xi32>) : tensor<3x5xi32> -// CHECK: return %[[REV]] - +// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<3x5xi32>) { +// CHECK: %[[SAME_DIM:.+]] = linalg.index 0 : index +// CHECK: %[[REV_DIM:.+]] = linalg.index 1 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C1_0:.+]] = arith.constant 1 : index +// CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C1_0]] : tensor<3x5xi32> +// CHECK: %[[DIMSUB1:.+]] = arith.subi %[[DIM]], %[[C1]] : index +// CHECK: %[[REV_IDX:.+]] = arith.subi %[[DIMSUB1]], %[[REV_DIM]] : index +// CHECK: %[[EXTRACTED:.+]] = tensor.extract %arg0[%[[SAME_DIM]], %[[REV_IDX]]] : tensor<3x5xi32> +// CHECK: linalg.yield %[[EXTRACTED]] : i32 +// CHECK: return %[[GEN]] // ----- func.func @reverse_unsigned(%arg0: tensor<3x5xui32>) -> tensor<3x5xui32> { @@ -512,13 +517,18 @@ func.func @reverse_unsigned(%arg0: tensor<3x5xui32>) -> tensor<3x5xui32> { // CHECK-LABEL: func.func @reverse_unsigned // CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] // CHECK: %[[BITCAST:.+]] = builtin.unrealized_conversion_cast %[[IN]] : tensor<3x5xui32> to tensor<3x5xi32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xi32> -// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>) -// CHECK-SAME: ins(%[[BITCAST]] : tensor<3x5xi32>) -// CHECK-SAME: outs(%[[INIT]] : tensor<3x5xi32>) : tensor<3x5xi32> -// CHECK: %[[BITCAST:.+]] = builtin.unrealized_conversion_cast %[[REV]] : tensor<3x5xi32> to tensor<3x5xui32> -// CHECK: return %[[BITCAST]] +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xui32> +// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<3x5xui32>) +// CHECK: %[[SAME_DIM:.+]] = linalg.index 0 : index +// CHECK: %[[REV_DIM:.+]] = linalg.index 1 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C1_0:.+]] = arith.constant 1 : index +// CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C1_0]] : tensor<3x5xui32> +// CHECK: %[[DIMSUB1:.+]] = arith.subi %[[DIM]], %[[C1]] : index +// CHECK: %[[REV_IDX:.+]] = arith.subi %[[DIMSUB1]], %[[REV_DIM]] : index +// CHECK: %[[EXTRACTED:.+]] = tensor.extract %arg0[%[[SAME_DIM]], %[[REV_IDX]]] : tensor<3x5xui32> +// CHECK: linalg.yield %[[EXTRACTED]] : ui32 +// CHECK: return %[[GEN]] // ----- @@ -530,16 +540,32 @@ func.func @reverse_multi_dim(%arg0: tensor) -> tensor { } : (tensor) -> tensor return %0 : tensor } -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[IN]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[IN]], %[[C1]] -// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor -// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>) -// CHECK-SAME: ins(%[[IN]] : tensor) -// CHECK-SAME: outs(%[[INIT]] : tensor) : tensor -// CHECK: return %[[REV]] +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[D:.+]] = tensor.dim %[[IN]], %[[C0]] : tensor +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[D0:.+]] = tensor.dim %[[IN]], %[[C1]] : tensor +// CHECK: %[[INIT:.+]] = tensor.empty(%[[D]], %[[D0]]) : tensor +// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor) { + +// First reverse dimension +// CHECK: %[[IDX0:.+]] = linalg.index 0 : index +// CHECK: %[[C1_1:.+]] = arith.constant 1 : index +// CHECK: %[[C0_2:.+]] = arith.constant 0 : index +// CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[C0_2]] : tensor +// CHECK: %[[DIM0SUB1:.+]] = arith.subi %[[DIM0]], %[[C1_1]] : index +// CHECK: %[[REV_IDX0:.+]] = arith.subi %[[DIM0SUB1]], %[[IDX0]] : index + +// Second reverse dimension +// CHECK: %[[IDX1:.+]] = linalg.index 1 : index +// CHECK: %[[C1_4:.+]] = arith.constant 1 : index +// CHECK: %[[C1_5:.+]] = arith.constant 1 : index +// CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1_5]] : tensor +// CHECK: %[[DIM1SUB1:.+]] = arith.subi %[[DIM1]], %[[C1_4]] : index +// CHECK: %[[REV_IDX1:.+]] = arith.subi %[[DIM1SUB1]], %[[IDX1]] : index + +// CHECK: %[[EXTRACTED:.+]] = tensor.extract %arg0[%[[REV_IDX0]], %[[REV_IDX1]]] : tensor +// CHECK: linalg.yield %[[EXTRACTED]] : i32 +// CHECK: return %[[GEN]] // ----- diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir index 68c0873fd81f2..c075e09767f7c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir @@ -526,46 +526,6 @@ func.func @reduce_window_max_4x6xf32() { // ----- -func.func @linalg_ext_reverse_dim0() { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %workgroup_id_x = hal.interface.workgroup.id[0] : index - %workgroup_count_x = hal.interface.workgroup.count[0] : index - %workgroup_id_y = hal.interface.workgroup.id[1] : index - %workgroup_count_y = hal.interface.workgroup.count[1] : index - %2 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y] - %3 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y] - scf.for %arg0 = %2 to %c2 step %3 { - %4 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x] - %5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x] - scf.for %arg1 = %4 to %c3 step %5 { - %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [2, 3], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<2x3xf32> - %7 = tensor.empty() : tensor<2x3xf32> - %8 = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%6 : tensor<2x3xf32>) outs(%7 : tensor<2x3xf32>) : tensor<2x3xf32> - %9 = affine.apply affine_map<()[s0] -> (-s0)>()[%arg0] - flow.dispatch.tensor.store %8, %1, offsets = [%9, %arg1], sizes = [2, 3], strides = [%c1, %c1] : tensor<2x3xf32> -> !flow.dispatch.tensor> - } - } - return -} -// CHECK: func.func @linalg_ext_reverse_dim0() -// CHECK-DAG: %[[IN:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) -// CHECK-DAG: %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) -// CHECK: scf.for %[[IV0:.+]] = -// CHECK: scf.for %[[IV1:.+]] = -// CHECK-DAG: %[[IN_TILE:.+]] = flow.dispatch.tensor.load %[[IN]] -// CHECK-DAG: %[[OUT_TILE:.+]] = flow.dispatch.tensor.load %[[OUT]] -// CHECK: %[[REV_TILE:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: ins(%[[IN_TILE]] : tensor<2x3xf32>) -// CHECK-SAME: outs(%[[OUT_TILE]] : tensor<2x3xf32>) -// CHECK: flow.dispatch.tensor.store %[[REV_TILE]], %[[OUT]] - -// ----- - func.func @sort1D() { %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir index 911a44df4f7f5..f360aad2a3a42 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir @@ -2170,30 +2170,6 @@ func.func @rank_reducing_no_op_subview() { // ----- -// CHECK-LABEL: func.func @reverse_dim( -// CHECK-DAG: %[[alloc:.*]] = memref.alloc() -// CHECK-DAG: %[[cst:.*]] = bufferization.to_memref -// CHECK: iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) -// CHECK-SAME: ins(%[[cst]] : -// CHECK-SAME: outs(%[[alloc]] : -// CHECK: %[[load:.*]] = memref.load %[[alloc]] -// CHECK: return %[[load]] -func.func @reverse_dim(%pos: index) -> f32 { - %input = arith.constant dense<[[1.0, 2.0, 3.0], - [4.0, 5.0, 6.0]]> : tensor<2x3xf32> - - %init = bufferization.alloc_tensor() : tensor<2x3xf32> - %0 = iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%input : tensor<2x3xf32>) - outs(%init : tensor<2x3xf32>) : tensor<2x3xf32> - - %1 = tensor.extract %0[%pos, %pos] : tensor<2x3xf32> - return %1 : f32 -} - -// ----- - // CHECK-LABEL: func.func @fft_tensor( // CHECK: memref.alloc // CHECK: memref.alloc diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp index 0422f16f27a8c..8e52db076525b 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp @@ -350,9 +350,9 @@ struct LinalgExtOpInterface bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - // TODO: Revisit this for Scatter/ReverseOp. We can then get rid of + // TODO: Revisit this for ScatterOp. We can then get rid of // `bufferizesToMemoryRead` completely. - return !isa(op); + return !isa(op); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -630,8 +630,6 @@ void registerBufferizationInterfaces(DialectRegistry ®istry) { LinalgExtOpInterface>(*ctx); IREE::LinalgExt::UnPackOp::attachInterface< LinalgExtOpInterface>(*ctx); - IREE::LinalgExt::ReverseOp::attachInterface< - LinalgExtOpInterface>(*ctx); IREE::LinalgExt::ScanOp::attachInterface< LinalgExtOpInterface>(*ctx); IREE::LinalgExt::ScatterOp::attachInterface< diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp index cd289a50c47e5..b695e1e82197a 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp +++ b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp @@ -241,8 +241,6 @@ void registerPartitionableLoopsInterfaceModels(DialectRegistry ®istry) { OuterParallelAsPartitionableLoops>(*ctx); IREE::LinalgExt::SortOp::attachInterface< AllParallelAsPartitionableLoops>(*ctx); - IREE::LinalgExt::ReverseOp::attachInterface< - OuterParallelAsPartitionableLoops>(*ctx); IREE::LinalgExt::TopkOp::attachInterface< AllParallelAsPartitionableLoops>(*ctx); IREE::LinalgExt::WinogradInputTransformOp::attachInterface< diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir index 6621451169ee7..4cca860d76479 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir @@ -46,41 +46,6 @@ util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> { // CHECK: %[[GEN2:.+]] = linalg.generic // CHECK-SAME: ins(%[[INPUT:.+]] : tensor<8192x16x8x128xf32>) - - -// ----- - - -#map = affine_map<(d0, d1) -> (d0, d1)> -util.func public @linalgext_reverse_fusion() -> tensor<10x10xi32> { - %0 = tensor.empty() : tensor<10x10xi64> - %1 = tensor.empty() : tensor<10x10xi32> - %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x10xi64>) outs(%1 : tensor<10x10xi32>) { - ^bb0(%in: i64, %out: i32): - %7 = arith.trunci %in : i64 to i32 - linalg.yield %7 : i32 - } -> tensor<10x10xi32> - %3 = tensor.empty() : tensor<10x10xi32> - %4 = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%2 : tensor<10x10xi32>) outs(%3 : tensor<10x10xi32>) : tensor<10x10xi32> - - // dont fuse with with reverse's consumer - %5 = tensor.empty() : tensor<10x10xi32> - %6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<10x10xi32>) outs(%5 : tensor<10x10xi32>) { - ^bb0(%in: i32, %out: i32): - %7 = arith.addi %in, %out : i32 - linalg.yield %7 : i32 - } -> tensor<10x10xi32> - util.return %6 : tensor<10x10xi32> -} - -// CHECK: util.func public @linalgext_reverse_fusion -// CHECK: flow.dispatch.workgroups -// CHECK: %[[SHRUNK:.+]] = linalg.generic -// CHECK: %[[REVERSED:.+]] = iree_linalg_ext.reverse -// CHECK: ins(%[[SHRUNK]] : tensor<10x10xi32>) -// CHECK: flow.dispatch.workgroups -// CHECK: %[[GEN:.+]] = linalg.generic - // ----- #map = affine_map<(d0, d1) -> (d0, d1)> diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index ce40d002b8ee3..4df93b39e535b 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -439,71 +439,6 @@ ScanOp::reifyResultShapes(OpBuilder &b, .reifyResultShapes(b, reifiedReturnShapes); } -//===----------------------------------------------------------------------===// -// ReverseOp -//===----------------------------------------------------------------------===// - -LogicalResult ReverseOp::verify() { - Operation *op = getOperation(); - if (getNumDpsInputs() != 1) { - return op->emitOpError("expected exactly one input"); - } - if (getNumDpsInits() != 1) { - return op->emitOpError("expected exactly one output"); - } - auto inputType = cast(getInput().getType()); - auto outputType = cast(getOutput().getType()); - if (inputType.getElementType() != outputType.getElementType()) { - return op->emitOpError( - "expected input/output element types to be identical"); - } - ArrayRef inputShapes = inputType.getShape(); - ArrayRef outputShapes = outputType.getShape(); - if (inputShapes.size() != outputShapes.size()) { - return op->emitOpError("expexted input/output to have identical ranks"); - } - if (llvm::any_of(llvm::zip_equal(inputShapes, outputShapes), - [](std::tuple s) { - return !ShapedType::isDynamic(std::get<0>(s)) && - !ShapedType::isDynamic(std::get<1>(s)) && - std::get<0>(s) != std::get<1>(s); - })) { - return op->emitOpError("incompatible input/output shapes"); - } - - int64_t rank = getOperandRank(); - llvm::SmallSetVector s; - for (auto dim : getDimensionsArray()) { - if (dim < 0 || dim >= rank) { - return op->emitOpError("all the dimensions must be within [0, ") - << rank << ")"; - } - if (s.contains(dim)) { - return op->emitOpError("expected dimensions numbers are all unique"); - } - s.insert(dim); - } - - return success(); -} - -LogicalResult -ReverseOp::reifyResultShapes(OpBuilder &b, - ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(b, reifiedReturnShapes); -} - -SmallVector ReverseOp::getIndexingMapsForOperands() { - Builder builder(getContext()); - return {builder.getMultiDimIdentityMap(getOperandRank()), - /*output=*/AffineMap(nullptr)}; -} - -SmallVector ReverseOp::getIndexingMapsForResults() { - return {AffineMap(nullptr)}; -} - //===----------------------------------------------------------------------===// // TopkOp //===----------------------------------------------------------------------===// @@ -1583,7 +1518,6 @@ Im2colOp::reifyResultShapes(OpBuilder &b, DEFINE_OP_GET_EFFECTS(ScatterOp) DEFINE_OP_GET_EFFECTS(SortOp) DEFINE_OP_GET_EFFECTS(FftOp) -DEFINE_OP_GET_EFFECTS(ReverseOp) DEFINE_OP_GET_EFFECTS(ScanOp) DEFINE_OP_GET_EFFECTS(TopkOp) DEFINE_OP_GET_EFFECTS(PackOp) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index 993ab3bbb0d6c..fe8693ab3a476 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -369,67 +369,6 @@ def IREELinalgExt_ScanOp : IREELinalgExt_Op<"scan", }]; } -def IREELinalgExt_ReverseOp : IREELinalgExt_Op<"reverse", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods< - TilingInterface, - ["generateScalarImplementation", - "getIterationDomain", - "getLoopIteratorTypes", - "getResultTilePosition", - "getTiledImplementation"]>, - DeclareOpInterfaceMethods]> { - let summary = "Reverse operator"; - let description = [{ - A temporary solution for lowering reverse ops into IREE, allowing IREE to - tile and distribute them. - } - }]; - - let arguments = (ins Variadic:$inputs, - Variadic:$outputs, - I64ElementsAttr:$dimensions - ); - let results = (outs Variadic:$results); - let assemblyFormat = [{ - attr-dict `dimensions` `(` $dimensions `)` - (`ins` `(` $inputs^ `:` type($inputs) `)`)? - (`outs` `(` $outputs^ `:` type($outputs) `)`)? - (`:` type($results)^)? - }]; - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - Value getInput() { - return getDpsInputOperand(0)->get(); - } - Value getOutput() { - return getDpsInitOperand(0)->get(); - } - ShapedType getOperandType() { - return cast(getInput().getType()); - } - int64_t getOperandRank() { - return getOperandType().getRank(); - } - ArrayRef getOprerandShape() { - return getOperandType().getShape(); - } - SmallVector getDimensionsArray() { - SmallVector ret; - for (const APInt& elem : getDimensions()) { - ret.push_back(elem.getLimitedValue()); - } - return ret; - } - - // Method to implement for specifying output range for - // DestinationStyleOpInterface - MutableOperandRange getDpsInitsMutable() { - return getOutputsMutable(); - } - }]; -} - def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/canonicalize.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/canonicalize.mlir index a2f89e8e6b4e6..87ee42901a239 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/canonicalize.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/canonicalize.mlir @@ -1,27 +1,5 @@ // RUN: iree-opt --canonicalize --split-input-file %s | FileCheck %s -func.func @tensor_cast(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { - %init = tensor.empty() : tensor<3x5xi32> - - %casted_arg0 = tensor.cast %arg0 : tensor<3x5xi32> to tensor - %casted_init = tensor.cast %init : tensor<3x5xi32> to tensor - - %0 = iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%casted_arg0 : tensor) - outs(%casted_init : tensor) : tensor - - %1 = tensor.cast %0 : tensor to tensor<3x5xi32> - - return %1: tensor<3x5xi32> -} -// CHECK-LABEL: func.func @tensor_cast( -// CHECK: iree_linalg_ext.reverse -// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<3x5xi32>) -// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<3x5xi32>) - -// ----- - func.func @pack_canonicalize(%arg0 : tensor, %arg1 : tensor<1x2x3x3xi32>) -> tensor<1x?x3x3xi32> { %c0_i32 = arith.constant 0 : i32 diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir index 3363f1bf0cb47..1d0280bc75eee 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir @@ -338,42 +338,6 @@ func.func @scatter_original_rank_mismatch( // ----- -func.func @reverse_diff_element_type(%arg0: tensor<3x5xi32>) -> tensor<3x5xf32> { - %init = tensor.empty() : tensor<3x5xf32> - // expected-error @+1 {{expected input/output element types to be identical}} - %0 = iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%arg0 : tensor<3x5xi32>) - outs(%init : tensor<3x5xf32>) : tensor<3x5xf32> - return %0 : tensor<3x5xf32> -} - -// ----- - -func.func @reverse_diff_shape(%arg0: tensor<3x5xi32>) -> tensor<3x6xi32> { - %init = tensor.empty() : tensor<3x6xi32> - // expected-error @+1 {{incompatible input/output shapes}} - %0 = iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%arg0 : tensor<3x5xi32>) - outs(%init : tensor<3x6xi32>) : tensor<3x6xi32> - return %0 : tensor<3x6xi32> -} - -// ----- - -func.func @reverse_dup_dims(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { - %init = tensor.empty() : tensor<3x5xi32> - // expected-error @+1 {{expected dimensions numbers are all unique}} - %0 = iree_linalg_ext.reverse - dimensions(dense<[0, 0]> : tensor<2xi64>) - ins(%arg0 : tensor<3x5xi32>) - outs(%init : tensor<3x5xi32>) : tensor<3x5xi32> - return %0 : tensor<3x5xi32> -} - -// ----- - func.func @topk_invalid(%input_values: tensor<2x10xf32>, %input_indices: tensor<2x10xi32>, %out_values : tensor<2x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<2x3xf32>, tensor<2x3xi32>) { // expected-error@+1 {{expected one or two input operands}} %0:2 = iree_linalg_ext.topk diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir index 88a1f2d522671..eddb2c7558d42 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir @@ -484,111 +484,6 @@ func.func @fft_tensor_coef_stage_5(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf // ----- -func.func @reverse_tensor(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { - %init = tensor.empty() : tensor<3x5xi32> - %0 = iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%arg0 : tensor<3x5xi32>) - outs(%init : tensor<3x5xi32>) : tensor<3x5xi32> - return %0 : tensor<3x5xi32> -} -// CHECK-LABEL: func.func @reverse_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32> -// CHECK: %[[INIT:.+]] = tensor.empty() -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[INIT]] - -// ----- - -func.func @reverse_memref(%arg0: memref<3x5xi32>, %arg1: memref<3x5xi32>) { - iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%arg0 : memref<3x5xi32>) - outs(%arg1 : memref<3x5xi32>) - return -} -// CHECK-LABEL: func.func @reverse_memref -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x5xi32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x5xi32> -// CHECK: iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[ARG1]] - -// ----- - -func.func @reverse_dynamic_tensor(%arg0: tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = tensor.dim %arg0, %c0 : tensor - %d1 = tensor.dim %arg0, %c1 : tensor - %init = tensor.empty(%d0, %d1) : tensor - %0 = iree_linalg_ext.reverse - dimensions(dense<1> : tensor<1xi64>) - ins(%arg0 : tensor) - outs(%init : tensor) : tensor - return %0 : tensor -} -// CHECK-LABEL: func.func @reverse_dynamic_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[INIT]] - -// ----- - -func.func @reverse_static_dynamic_tensor(%arg0: tensor<3x5xi32>) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = tensor.dim %arg0, %c0 : tensor<3x5xi32> - %d1 = tensor.dim %arg0, %c1 : tensor<3x5xi32> - %init = tensor.empty(%d0, %d1) : tensor - %0 = iree_linalg_ext.reverse - dimensions(dense<1> : tensor<1xi64>) - ins(%arg0 : tensor<3x5xi32>) - outs(%init : tensor) : tensor - return %0 : tensor -} -// CHECK-LABEL: func.func @reverse_static_dynamic_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32> -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[INIT]] - -// ----- - -func.func @reverse_multi_dims(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { - %init = tensor.empty() : tensor<3x5xi32> - %0 = iree_linalg_ext.reverse - dimensions(dense<[0, 1]> : tensor<2xi64>) - ins(%arg0 : tensor<3x5xi32>) - outs(%init : tensor<3x5xi32>) : tensor<3x5xi32> - return %0 : tensor<3x5xi32> -} -// CHECK-LABEL: func.func @reverse_multi_dims -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32> -// CHECK: %[[INIT:.+]] = tensor.empty() -// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[INIT]] - -// ----- - func.func @topk_tensor(%input_values: tensor<20x10x8x4xf32>, %input_indices: tensor<20x10x8x4xi32>) -> (tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>) { %out_values = tensor.empty() : tensor<20x10x3x4xf32> %out_indices = tensor.empty() : tensor<20x10x3x4xi32> diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp index 231f694c07917..2565df8d64007 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp @@ -729,101 +729,6 @@ LogicalResult ScanOp::getResultTilePosition( return failure(); } -//===----------------------------------------------------------------------===// -// ReverseOp -//===----------------------------------------------------------------------===// - -SmallVector ReverseOp::getLoopIteratorTypes() { - SmallVector iteratorTypes(getOperandRank(), - utils::IteratorType::parallel); - return iteratorTypes; -} - -SmallVector ReverseOp::getIterationDomain(OpBuilder &builder) { - Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - SmallVector ranges; - for (auto dim : llvm::seq(0, getOperandRank())) { - Value ub = getDimValue(builder, loc, getInput(), dim); - ranges.emplace_back(Range{zero, ub, one}); - } - return ranges; -} - -LogicalResult ReverseOp::generateScalarImplementation(OpBuilder &b, - Location loc, - ValueRange ivs) { - SmallVector mirrorIndices(ivs.begin(), ivs.end()); - for (auto dim : getDimensionsArray()) { - auto size = getDimValue(b, loc, getInput(), dim); - size = b.create(loc, size, - b.create(loc, 1)); - mirrorIndices[dim] = b.create(loc, size, mirrorIndices[dim]); - } - Value val = b.create(loc, getInput(), ivs); - b.create(loc, val, getOutput(), mirrorIndices); - return success(); -} - -FailureOr -ReverseOp::getTiledImplementation(OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes) { - int64_t rank = getOperandRank(); - SmallVector strides(rank, builder.getI64IntegerAttr(1)); - Location loc = getLoc(); - SmallVector mirrorOffsets, mirrorSizes; - if (failed(getResultTilePosition(builder, 0, offsets, sizes, mirrorOffsets, - mirrorSizes))) { - return {}; - } - - SmallVector tiledOperands; - tiledOperands.emplace_back( - getSlice(builder, loc, getInput(), offsets, sizes, strides)); - - SmallVector resultTypes; - if (hasPureTensorSemantics()) { - tiledOperands.emplace_back( - getSlice(builder, loc, getOutput(), mirrorOffsets, sizes, strides)); - resultTypes.push_back(tiledOperands[1].getType()); - } else { - tiledOperands.emplace_back( - getSlice(builder, loc, getOutput(), mirrorOffsets, sizes, strides)); - } - - Operation *tiledRevOp = - mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - - return TilingResult{{tiledRevOp}, - SmallVector(tiledRevOp->getResults())}; -} - -LogicalResult ReverseOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - AffineExpr sym0, sym1, sym2; - bindSymbols(builder.getContext(), sym0, sym1, sym2); - AffineMap map = - AffineMap::get(/*dimCount=*/0, /*symbolCount=*/3, {sym0 - sym1 - sym2}); - resultOffsets.assign(offsets.begin(), offsets.end()); - Location loc = getLoc(); - for (auto dim : getDimensionsArray()) { - Value size = getDimValue(builder, loc, getInput(), dim); - Value offset = - getValueOrCreateConstantIndexOp(builder, loc, resultOffsets[dim]); - Value tileSize = getValueOrCreateConstantIndexOp(builder, loc, sizes[dim]); - resultOffsets[dim] = builder - .create( - loc, map, ValueRange{size, offset, tileSize}) - .getResult(); - } - resultSizes.assign(sizes.begin(), sizes.end()); - return success(); -} - //===----------------------------------------------------------------------===// // TopkOp //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir index 3623426bbebc3..f136ab5113423 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir @@ -507,28 +507,6 @@ func.func @fft_2D_coef_buf(%real: memref, %imag: memref, // ----- -func.func @reverse_dim_0(%arg0: memref, %arg1: memref) { - iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%arg0 : memref) - outs(%arg1 : memref) - return -} -// CHECK-LABEL: func.func @reverse_dim_0 -// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = memref.dim %arg0, %c0 : memref -// CHECK-DAG: %[[D1:.+]] = memref.dim %arg0, %c1 : memref -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]] -// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C1]] -// CHECK: %[[T0:.+]] = memref.dim %[[IN]], %[[C0]] -// CHECK: %[[T1:.+]] = arith.subi %[[T0]], %[[C1]] : index -// CHECK: %[[T2:.+]] = arith.subi %[[T1]], %[[I]] : index -// CHECK: %[[V0:.+]] = memref.load %[[IN]][%[[I]], %[[J]]] -// CHECK: memref.store %[[V0]], %[[OUT]][%[[T2]], %[[J]]] : memref - func.func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) { %c0 = memref.alloc() : memref iree_linalg_ext.scan dimension(0) inclusive(true) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir index 42e619326eddf..67a189b7d24b2 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir @@ -494,95 +494,6 @@ module attributes { transform.with_named_sequence } { // ----- -func.func @reverse_memref(%arg0: memref, %arg1: memref) { - iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%arg0: memref) - outs(%arg1: memref) - return -} -module attributes { transform.with_named_sequence } { - transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["iree_linalg_ext.reverse"]} in %module_op : (!transform.any_op) -> !transform.any_op - %1, %loops = transform.structured.tile_using_for %0 tile_sizes [10] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10) -// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)> -// CHECK: func.func @reverse_memref( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]] : memref -// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]] { -// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])[%[[D0]]] -// CHECK-DAG: %[[IDX:.+]] = affine.apply #[[MAP2]]()[%[[D0]], %[[I]], %[[SIZE]]] -// CHECK-DAG: %[[SUB_IN:.+]] = memref.subview %[[ARG0]][%[[I]]] [%[[SIZE]]] [1] -// CHECK-DAG: %[[SUB_OUT:.+]] = memref.subview %[[ARG1]][%[[IDX]]] [%[[SIZE]]] [1] -// CHECK: iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>) -// CHECK-SAME: ins(%[[SUB_IN]] -// CHECK-SAME: outs(%[[SUB_OUT]] - -// ----- - -func.func @reverse_tensor_multi_dim(%arg0: tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = tensor.dim %arg0, %c0 : tensor - %d1 = tensor.dim %arg0, %c1 : tensor - %init = tensor.empty(%d0, %d1) : tensor - %0 = iree_linalg_ext.reverse - dimensions(dense<[0, 1]> : tensor<2xi64>) - ins(%arg0: tensor) - outs(%init: tensor) : tensor - return %0 : tensor -} -module attributes { transform.with_named_sequence } { - transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["iree_linalg_ext.reverse"]} in %module_op : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - transform.yield - } -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)> -// CHECK: func.func @reverse_tensor_multi_dim( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor -// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]] -// CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT]]) -> (tensor) { -// CHECK: %[[RES2:.+]] = scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C20]] -// CHECK-SAME: iter_args(%[[INIT3:.+]] = %[[INIT2]]) -> (tensor) { -// CHECK-DAG: %[[SIZE_I:.+]] = affine.min #[[MAP0]](%[[I]])[%[[D0]]] -// CHECK-DAG: %[[SIZE_J:.+]] = affine.min #[[MAP1]](%[[J]])[%[[D1]]] -// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP2]]()[%[[D0]], %[[I]], %[[SIZE_I]]] -// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP2]]()[%[[D1]], %[[J]], %[[SIZE_J]]] -// CHECK: %[[SUB_IN:.+]] = tensor.extract_slice -// CHECK-SAME: %[[ARG0]][%[[I]], %[[J]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1] -// CHECK: %[[SUB_INIT:.+]] = tensor.extract_slice -// CHECK-SAME: %[[INIT3]][%[[IDX0]], %[[IDX1]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1] -// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>) -// CHECK-SAME: ins(%[[SUB_IN]] -// CHECK-SAME: outs(%[[SUB_INIT]] -// CHECK: %[[RES3:.+]] = tensor.insert_slice %[[REV]] into -// CHECK-SAME: %[[INIT3]][%[[IDX0]], %[[IDX1]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1] -// CHECK: scf.yield %[[RES3]] -// CHECK: scf.yield %[[RES2]] -// CHECK: return %[[RES]] - -// ----- - func.func @scan_1d(%0: tensor<128xi32>) -> tensor<128xi32> { %c0 = tensor.empty() : tensor %1 = tensor.empty() : tensor<128xi32> diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp index 0a9a1b542a621..681b3369d165d 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp @@ -313,8 +313,6 @@ void registerUtilExternalModels(DialectRegistry ®istry) { LinalgOpTiedOpInterface>(*context); IREE::LinalgExt::ScanOp::attachInterface< LinalgOpTiedOpInterface>(*context); - IREE::LinalgExt::ReverseOp::attachInterface< - LinalgOpTiedOpInterface>(*context); IREE::LinalgExt::TopkOp::attachInterface< LinalgOpTiedOpInterface>(*context); IREE::LinalgExt::WinogradInputTransformOp::attachInterface< diff --git a/tests/e2e/linalg_ext_ops/BUILD.bazel b/tests/e2e/linalg_ext_ops/BUILD.bazel index a2210f32b8489..468bd99fa672a 100644 --- a/tests/e2e/linalg_ext_ops/BUILD.bazel +++ b/tests/e2e/linalg_ext_ops/BUILD.bazel @@ -16,7 +16,6 @@ ALL_SRCS = enforce_glob( # keep sorted [ "attention.mlir", - "reverse.mlir", "scan.mlir", "scatter.mlir", "sort.mlir", @@ -42,7 +41,6 @@ iree_check_single_backend_test_suite( VMVX_SRCS = enforce_glob( # keep sorted [ - "reverse.mlir", "scan.mlir", "scatter.mlir", "sort.mlir", @@ -66,7 +64,6 @@ iree_check_single_backend_test_suite( LLVM_GPU_SRCS = enforce_glob( # keep sorted [ - "reverse.mlir", "scan.mlir", "scatter.mlir", "sort.mlir", @@ -107,7 +104,6 @@ iree_check_single_backend_test_suite( srcs = enforce_glob( # keep sorted [ - "reverse.mlir", "scan.mlir", "scatter.mlir", "sort.mlir", @@ -138,7 +134,6 @@ iree_check_single_backend_test_suite( include = ["*.mlir"], exclude = [ "attention.mlir", - "reverse.mlir", #TODO(#12415): disabled due to miscompilation on Pixel 6. "top-k.mlir", ], ), diff --git a/tests/e2e/linalg_ext_ops/CMakeLists.txt b/tests/e2e/linalg_ext_ops/CMakeLists.txt index bc4f2ae73082f..b9208c39c642b 100644 --- a/tests/e2e/linalg_ext_ops/CMakeLists.txt +++ b/tests/e2e/linalg_ext_ops/CMakeLists.txt @@ -15,7 +15,6 @@ iree_check_single_backend_test_suite( check_llvm-cpu_local-task SRCS "attention.mlir" - "reverse.mlir" "scan.mlir" "scatter.mlir" "sort.mlir" @@ -34,7 +33,6 @@ iree_check_single_backend_test_suite( NAME check_vmvx_local-task SRCS - "reverse.mlir" "scan.mlir" "scatter.mlir" "sort.mlir" @@ -51,7 +49,6 @@ iree_check_single_backend_test_suite( NAME check_cuda SRCS - "reverse.mlir" "scan.mlir" "scatter.mlir" "sort.mlir" @@ -75,7 +72,6 @@ iree_check_single_backend_test_suite( NAME check_rocm_hip SRCS - "reverse.mlir" "scan.mlir" "scatter.mlir" "sort.mlir" @@ -92,7 +88,6 @@ iree_check_single_backend_test_suite( NAME check_metal-spirv_vulkan SRCS - "reverse.mlir" "scan.mlir" "scatter.mlir" "sort.mlir" diff --git a/tests/e2e/linalg_ext_ops/reverse.mlir b/tests/e2e/linalg_ext_ops/reverse.mlir deleted file mode 100644 index db1610be97ea3..0000000000000 --- a/tests/e2e/linalg_ext_ops/reverse.mlir +++ /dev/null @@ -1,53 +0,0 @@ -func.func @reverse_dim0() { - %input = util.unfoldable_constant dense<[[1.0, 2.0, 3.0], - [4.0, 5.0, 6.0]]> : tensor<2x3xf32> - - %init = tensor.empty() : tensor<2x3xf32> - %0 = iree_linalg_ext.reverse - dimensions(dense<0> : tensor<1xi64>) - ins(%input : tensor<2x3xf32>) - outs(%init : tensor<2x3xf32>) : tensor<2x3xf32> - - check.expect_almost_eq_const( - %0, - dense<[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]> : tensor<2x3xf32> - ) : tensor<2x3xf32> - - return -} - -func.func @reverse_dim1() { - %input = util.unfoldable_constant dense<[[1, 2, 3], - [4, 5, 6]]> : tensor<2x3xi32> - - %init = tensor.empty() : tensor<2x3xi32> - %0 = iree_linalg_ext.reverse - dimensions(dense<1> : tensor<1xi64>) - ins(%input : tensor<2x3xi32>) - outs(%init : tensor<2x3xi32>) : tensor<2x3xi32> - - check.expect_eq_const( - %0, - dense<[[3, 2, 1], [6, 5, 4]]> : tensor<2x3xi32> - ) : tensor<2x3xi32> - - return -} - -func.func @reverse_multi_dims() { - %input = util.unfoldable_constant dense<[[1, 2, 3], - [4, 5, 6]]> : tensor<2x3xi32> - - %init = tensor.empty() : tensor<2x3xi32> - %0 = iree_linalg_ext.reverse - dimensions(dense<[0, 1]> : tensor<2xi64>) - ins(%input : tensor<2x3xi32>) - outs(%init : tensor<2x3xi32>) : tensor<2x3xi32> - - check.expect_eq_const( - %0, - dense<[[6, 5, 4], [3, 2, 1]]> : tensor<2x3xi32> - ) : tensor<2x3xi32> - - return -}