Skip to content

Commit

Permalink
[LinalgExt] Fix retire linalg_ext.reverse by linalg.index and tensor.…
Browse files Browse the repository at this point in the history
…extract.(iree-org#16060)
  • Loading branch information
harrisonGPU committed Jan 22, 2024
1 parent fe43684 commit 4f9e270
Showing 1 changed file with 23 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -431,39 +431,41 @@ struct ReverseOpConversion final
LogicalResult
matchAndRewrite(mlir::stablehlo::ReverseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getOperands()[0];
auto inputType = dyn_cast<RankedTensorType>(adaptor.getOperands()[0].getType());
if (!inputType)
auto operandType = dyn_cast<RankedTensorType>(adaptor.getOperands()[0].getType());
if (!operandType)
return failure();
auto shape = inputType.getShape();
auto rank = inputType.getRank();
auto reverseAxes = op.getDimensions();

Location loc = op.getLoc();
auto operand = adaptor.getOperands()[0];
auto operandShape = operandType.getShape();
auto operandRank = operandType.getRank();
auto reverseAxes = op.getDimensions();
SmallVector<OpFoldResult> mixedSizes =
tensor::getMixedSizes(rewriter, loc, adaptor.getOperands()[0]);
tensor::getMixedSizes(rewriter, loc, operand);
Value output =
rewriter.create<tensor::EmptyOp>(loc, mixedSizes, inputType.getElementType());

SmallVector<Value, 4> lowerBounds(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
SmallVector<Value, 4> upperBounds;
for (auto dimSize : shape) {
rewriter.create<tensor::EmptyOp>(loc, mixedSizes, operandType.getElementType());
SmallVector<Value> lowerBounds(operandRank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
SmallVector<Value> upperBounds;
for (auto dimSize : operandShape) {
upperBounds.push_back(rewriter.create<arith::ConstantIndexOp>(loc, dimSize));
}
SmallVector<Value> steps(operandRank, rewriter.create<arith::ConstantIndexOp>(loc, 1));

SmallVector<Value, 4> steps(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(rank)};
SmallVector<AffineMap> affineMaps = {
rewriter.getMultiDimIdentityMap(operandRank)};

rewriter.create<linalg::GenericOp>(
loc, ArrayRef<Type>{inputType}, ValueRange{input}, ValueRange{output},
affineMaps, getNParallelLoopsAttrs(rank),
loc, operandType,
operand,
/*outputs=*/output,
affineMaps,
getNParallelLoopsAttrs(operandRank),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
SmallVector<Value, 4> indices;
for (unsigned int i = 0; i < rank; ++i) {
SmallVector<Value> indices;
for (unsigned int i = 0; i < operandRank; ++i) {
Value currentIndex = nestedBuilder.create<linalg::IndexOp>(nestedLoc, i);

if (llvm::is_contained(reverseAxes, i)) {
Value dimSize = nestedBuilder.create<arith::ConstantIndexOp>(nestedLoc, shape[i]);
Value dimSize = nestedBuilder.create<arith::ConstantIndexOp>(nestedLoc, operandShape[i]);
Value one = nestedBuilder.create<arith::ConstantIndexOp>(nestedLoc, 1);
currentIndex = nestedBuilder.create<arith::SubIOp>(nestedLoc, nestedBuilder.create<arith::SubIOp>(nestedLoc, dimSize, one), currentIndex);
}
Expand Down

0 comments on commit 4f9e270

Please sign in to comment.