Skip to content

Commit 5d36708

Browse files
authored
[MLIR][Vector] Fix bug in ExtractStrideSlicesOp canonicalization (#147591)
The pattern would produce an invalid slice when some dimensions were both sliced and broadcast.
1 parent db2eb4d commit 5d36708

File tree

2 files changed

+38
-16
lines changed

2 files changed

+38
-16
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4237,28 +4237,35 @@ class StridedSliceBroadcast final
42374237
auto dstVecType = llvm::cast<VectorType>(op.getType());
42384238
unsigned dstRank = dstVecType.getRank();
42394239
unsigned rankDiff = dstRank - srcRank;
4240-
// Check if the most inner dimensions of the source of the broadcast are the
4241-
// same as the destination of the extract. If this is the case we can just
4242-
// use a broadcast as the original dimensions are untouched.
4243-
bool lowerDimMatch = true;
4240+
// Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4241+
// (n -> m with n > m). If they are originally both broadcasted *and*
4242+
// sliced, this can be simplified to just broadcasting.
4243+
bool needsSlice = false;
42444244
for (unsigned i = 0; i < srcRank; i++) {
4245-
if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4246-
lowerDimMatch = false;
4245+
if (srcVecType.getDimSize(i) != 1 &&
4246+
srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4247+
needsSlice = true;
42474248
break;
42484249
}
42494250
}
42504251
Value source = broadcast.getSource();
4251-
// If the inner dimensions don't match, it means we need to extract from the
4252-
// source of the orignal broadcast and then broadcast the extracted value.
4253-
// We also need to handle degenerated cases where the source is effectively
4254-
// just a single scalar.
4255-
bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
4256-
if (!lowerDimMatch && !isScalarSrc) {
4252+
if (needsSlice) {
4253+
SmallVector<int64_t> offsets =
4254+
getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4255+
SmallVector<int64_t> sizes =
4256+
getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4257+
for (unsigned i = 0; i < srcRank; i++) {
4258+
if (srcVecType.getDimSize(i) == 1) {
4259+
// In case this dimension was broadcasted *and* sliced, the offset
4260+
// and size need to be updated now that there is no broadcast before
4261+
// the slice.
4262+
offsets[i] = 0;
4263+
sizes[i] = 1;
4264+
}
4265+
}
42574266
source = rewriter.create<ExtractStridedSliceOp>(
4258-
op->getLoc(), source,
4259-
getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
4260-
getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
4261-
getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
4267+
op->getLoc(), source, offsets, sizes,
4268+
getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
42624269
}
42634270
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
42644271
return success();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,21 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
13791379

13801380
// -----
13811381

1382+
// Check the case where the same dimension is both broadcasted and sliced
1383+
// CHECK-LABEL: func @extract_strided_broadcast5
1384+
// CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
1385+
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
1386+
// CHECK: return %[[V]]
1387+
func.func @extract_strided_broadcast5(%arg0: vector<2x1xf32>) -> vector<2x4xf32> {
1388+
%0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x8xf32>
1389+
%1 = vector.extract_strided_slice %0
1390+
{offsets = [0, 4], sizes = [2, 4], strides = [1, 1]}
1391+
: vector<2x8xf32> to vector<2x4xf32>
1392+
return %1 : vector<2x4xf32>
1393+
}
1394+
1395+
// -----
1396+
13821397
// CHECK-LABEL: consecutive_shape_cast
13831398
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
13841399
// CHECK-NEXT: return %[[C]] : vector<4x4xf16>

0 commit comments

Comments
 (0)