@@ -4237,28 +4237,35 @@ class StridedSliceBroadcast final
4237
4237
auto dstVecType = llvm::cast<VectorType>(op.getType ());
4238
4238
unsigned dstRank = dstVecType.getRank ();
4239
4239
unsigned rankDiff = dstRank - srcRank;
4240
- // Check if the most inner dimensions of the source of the broadcast are the
4241
- // same as the destination of the extract . If this is the case we can just
4242
- // use a broadcast as the original dimensions are untouched .
4243
- bool lowerDimMatch = true ;
4240
+ // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4241
+ // (n -> m with n > m) . If they are originally both broadcasted *and*
4242
+ // sliced, this can be simplified to just broadcasting .
4243
+ bool needsSlice = false ;
4244
4244
for (unsigned i = 0 ; i < srcRank; i++) {
4245
- if (srcVecType.getDimSize (i) != dstVecType.getDimSize (i + rankDiff)) {
4246
- lowerDimMatch = false ;
4245
+ if (srcVecType.getDimSize (i) != 1 &&
4246
+ srcVecType.getDimSize (i) != dstVecType.getDimSize (i + rankDiff)) {
4247
+ needsSlice = true ;
4247
4248
break ;
4248
4249
}
4249
4250
}
4250
4251
Value source = broadcast.getSource ();
4251
- // If the inner dimensions don't match, it means we need to extract from the
4252
- // source of the orignal broadcast and then broadcast the extracted value.
4253
- // We also need to handle degenerated cases where the source is effectively
4254
- // just a single scalar.
4255
- bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements () == 1 );
4256
- if (!lowerDimMatch && !isScalarSrc) {
4252
+ if (needsSlice) {
4253
+ SmallVector<int64_t > offsets =
4254
+ getI64SubArray (op.getOffsets (), /* dropFront=*/ rankDiff);
4255
+ SmallVector<int64_t > sizes =
4256
+ getI64SubArray (op.getSizes (), /* dropFront=*/ rankDiff);
4257
+ for (unsigned i = 0 ; i < srcRank; i++) {
4258
+ if (srcVecType.getDimSize (i) == 1 ) {
4259
+ // In case this dimension was broadcasted *and* sliced, the offset
4260
+ // and size need to be updated now that there is no broadcast before
4261
+ // the slice.
4262
+ offsets[i] = 0 ;
4263
+ sizes[i] = 1 ;
4264
+ }
4265
+ }
4257
4266
source = rewriter.create <ExtractStridedSliceOp>(
4258
- op->getLoc (), source,
4259
- getI64SubArray (op.getOffsets (), /* dropFront=*/ rankDiff),
4260
- getI64SubArray (op.getSizes (), /* dropFront=*/ rankDiff),
4261
- getI64SubArray (op.getStrides (), /* dropFront=*/ rankDiff));
4267
+ op->getLoc (), source, offsets, sizes,
4268
+ getI64SubArray (op.getStrides (), /* dropFront=*/ rankDiff));
4262
4269
}
4263
4270
rewriter.replaceOpWithNewOp <BroadcastOp>(op, op.getType (), source);
4264
4271
return success ();
0 commit comments