From 3048c6d54aaa3f410b45af4c0ae3ad26c9f9b114 Mon Sep 17 00:00:00 2001 From: Zach Goldthorpe Date: Tue, 8 Jul 2025 17:52:34 -0500 Subject: [PATCH 1/2] Added patch to fold `lshr + zext + shl` patterns --- .../InstCombine/InstCombineShifts.cpp | 49 ++++++++++++- .../Analysis/ValueTracking/numsignbits-shl.ll | 6 +- .../Transforms/InstCombine/iX-ext-split.ll | 6 +- .../InstCombine/shifts-around-zext.ll | 69 +++++++++++++++++++ 4 files changed, 122 insertions(+), 8 deletions(-) create mode 100644 llvm/test/Transforms/InstCombine/shifts-around-zext.ll diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 550f095b26ba4a..b0b1301cd25803 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -978,6 +978,47 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) { return new ZExtInst(Overflow, Ty); } +/// If the operand of a zext-ed left shift \p V is a logically right-shifted +/// value, try to fold the opposing shifts. +static Instruction *foldShrThroughZExtedShl(Type *DestTy, Value *V, + unsigned ShlAmt, + InstCombinerImpl &IC, + const DataLayout &DL) { + auto *I = dyn_cast(V); + if (!I) + return nullptr; + + // Dig through operations until the first shift. + while (!I->isShift()) + if (!match(I, m_BinOp(m_OneUse(m_Instruction(I)), m_Constant()))) + return nullptr; + + // Fold only if the inner shift is a logical right-shift. + uint64_t InnerShrAmt; + if (!match(I, m_LShr(m_Value(), m_ConstantInt(InnerShrAmt)))) + return nullptr; + + if (InnerShrAmt >= ShlAmt) { + const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt; + if (!canEvaluateShifted(V, ReducedShrAmt, /*IsLeftShift=*/false, IC, + nullptr)) + return nullptr; + Value *NewInner = + getShiftedValue(V, ReducedShrAmt, /*isLeftShift=*/false, IC, DL); + return new ZExtInst(NewInner, DestTy); + } + + if (!canEvaluateShifted(V, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr)) + return nullptr; + + const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt; + Value *NewInner = + getShiftedValue(V, InnerShrAmt, /*isLeftShift=*/true, IC, DL); + Value *NewZExt = IC.Builder.CreateZExt(NewInner, DestTy); + return BinaryOperator::CreateShl(NewZExt, + ConstantInt::get(DestTy, ReducedShlAmt)); +} + // Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits. static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) { assert(I.isShift() && "Expected a shift as input"); @@ -1062,14 +1103,18 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { if (match(Op1, m_APInt(C))) { unsigned ShAmtC = C->getZExtValue(); - // shl (zext X), C --> zext (shl X, C) - // This is only valid if X would have zeros shifted out. Value *X; if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) { + // shl (zext X), C --> zext (shl X, C) + // This is only valid if X would have zeros shifted out. unsigned SrcWidth = X->getType()->getScalarSizeInBits(); if (ShAmtC < SrcWidth && MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), &I)) return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty); + + // Otherwise, try to cancel the outer shl with a lshr inside the zext. + if (Instruction *V = foldShrThroughZExtedShl(Ty, X, ShAmtC, *this, DL)) + return V; } // (X >> C) << C --> X & (-1 << C) diff --git a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll index 5224d75a157d5b..8330fd09090c89 100644 --- a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll +++ b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll @@ -101,9 +101,9 @@ define void @numsignbits_shl_zext_extended_bits_remains(i8 %x) { define void @numsignbits_shl_zext_all_bits_shifted_out(i8 %x) { ; CHECK-LABEL: define void @numsignbits_shl_zext_all_bits_shifted_out( ; CHECK-SAME: i8 [[X:%.*]]) { -; CHECK-NEXT: [[ASHR:%.*]] = lshr i8 [[X]], 5 -; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i8 [[ASHR]] to i16 -; CHECK-NEXT: [[NSB1:%.*]] = shl i16 [[ZEXT]], 14 +; CHECK-NEXT: [[ASHR:%.*]] = and i8 [[X]], 96 +; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i8 [[ASHR]] to i16 +; CHECK-NEXT: [[NSB1:%.*]] = shl nuw i16 [[TMP1]], 9 ; CHECK-NEXT: [[AND14:%.*]] = and i16 [[NSB1]], 16384 ; CHECK-NEXT: [[ADD14:%.*]] = add i16 [[AND14]], [[NSB1]] ; CHECK-NEXT: call void @escape(i16 [[ADD14]]) diff --git a/llvm/test/Transforms/InstCombine/iX-ext-split.ll b/llvm/test/Transforms/InstCombine/iX-ext-split.ll index fc804df0e4becb..b8e056725f1221 100644 --- a/llvm/test/Transforms/InstCombine/iX-ext-split.ll +++ b/llvm/test/Transforms/InstCombine/iX-ext-split.ll @@ -197,9 +197,9 @@ define i128 @i128_ext_split_neg4(i32 %x) { ; CHECK-NEXT: [[ENTRY:.*:]] ; CHECK-NEXT: [[LOWERSRC:%.*]] = sext i32 [[X]] to i64 ; CHECK-NEXT: [[LO:%.*]] = zext i64 [[LOWERSRC]] to i128 -; CHECK-NEXT: [[SIGN:%.*]] = lshr i32 [[X]], 31 -; CHECK-NEXT: [[WIDEN:%.*]] = zext nneg i32 [[SIGN]] to i128 -; CHECK-NEXT: [[HI:%.*]] = shl nuw nsw i128 [[WIDEN]], 64 +; CHECK-NEXT: [[SIGN:%.*]] = and i32 [[X]], -2147483648 +; CHECK-NEXT: [[TMP0:%.*]] = zext i32 [[SIGN]] to i128 +; CHECK-NEXT: [[HI:%.*]] = shl nuw nsw i128 [[TMP0]], 33 ; CHECK-NEXT: [[RES:%.*]] = or disjoint i128 [[HI]], [[LO]] ; CHECK-NEXT: ret i128 [[RES]] ; diff --git a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll new file mode 100644 index 00000000000000..517783fcbcb5c1 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll @@ -0,0 +1,69 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S -passes=instcombine %s | FileCheck %s + +define i64 @simple(i32 %x) { +; CHECK-LABEL: define i64 @simple( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[LSHR:%.*]] = and i32 [[X]], -256 +; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[LSHR]] to i64 +; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 24 +; CHECK-NEXT: ret i64 [[SHL]] +; + %lshr = lshr i32 %x, 8 + %zext = zext i32 %lshr to i64 + %shl = shl i64 %zext, 32 + ret i64 %shl +} + +;; u0xff0 = 4080 +define i64 @masked(i32 %x) { +; CHECK-LABEL: define i64 @masked( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[MASK:%.*]] = and i32 [[X]], 4080 +; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i32 [[MASK]] to i64 +; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 44 +; CHECK-NEXT: ret i64 [[SHL]] +; + %lshr = lshr i32 %x, 4 + %mask = and i32 %lshr, u0xff + %zext = zext i32 %mask to i64 + %shl = shl i64 %zext, 48 + ret i64 %shl +} + +define i64 @combine(i32 %lower, i32 %upper) { +; CHECK-LABEL: define i64 @combine( +; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) { +; CHECK-NEXT: [[BASE:%.*]] = zext i32 [[LOWER]] to i64 +; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[UPPER]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[TMP1]], 32 +; CHECK-NEXT: [[O_3:%.*]] = or disjoint i64 [[TMP2]], [[BASE]] +; CHECK-NEXT: ret i64 [[O_3]] +; + %base = zext i32 %lower to i64 + + %u.0 = and i32 %upper, u0xff + %z.0 = zext i32 %u.0 to i64 + %s.0 = shl i64 %z.0, 32 + %o.0 = or i64 %base, %s.0 + + %r.1 = lshr i32 %upper, 8 + %u.1 = and i32 %r.1, u0xff + %z.1 = zext i32 %u.1 to i64 + %s.1 = shl i64 %z.1, 40 + %o.1 = or i64 %o.0, %s.1 + + %r.2 = lshr i32 %upper, 16 + %u.2 = and i32 %r.2, u0xff + %z.2 = zext i32 %u.2 to i64 + %s.2 = shl i64 %z.2, 48 + %o.2 = or i64 %o.1, %s.2 + + %r.3 = lshr i32 %upper, 24 + %u.3 = and i32 %r.3, u0xff + %z.3 = zext i32 %u.3 to i64 + %s.3 = shl i64 %z.3, 56 + %o.3 = or i64 %o.2, %s.3 + + ret i64 %o.3 +} From 8ea66689bca1005b09f842500d7ece5ad4881386 Mon Sep 17 00:00:00 2001 From: Zach Goldthorpe Date: Fri, 11 Jul 2025 15:33:46 -0500 Subject: [PATCH 2/2] Incorporated straightforward reviewer feedback. --- .../InstCombine/InstCombineShifts.cpp | 48 +++++++++++-------- .../InstCombine/shifts-around-zext.ll | 22 +++++++-- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index b0b1301cd25803..edcd963e3a7db7 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -978,45 +978,53 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) { return new ZExtInst(Overflow, Ty); } -/// If the operand of a zext-ed left shift \p V is a logically right-shifted -/// value, try to fold the opposing shifts. -static Instruction *foldShrThroughZExtedShl(Type *DestTy, Value *V, +/// If the operand \p Op of a zext-ed left shift \p I is a logically +/// right-shifted value, try to fold the opposing shifts. +static Instruction *foldShrThroughZExtedShl(BinaryOperator &I, Value *Op, unsigned ShlAmt, InstCombinerImpl &IC, const DataLayout &DL) { - auto *I = dyn_cast(V); - if (!I) + Type *DestTy = I.getType(); + + auto *Inner = dyn_cast(Op); + if (!Inner) return nullptr; // Dig through operations until the first shift. - while (!I->isShift()) - if (!match(I, m_BinOp(m_OneUse(m_Instruction(I)), m_Constant()))) + while (!Inner->isShift()) + if (!match(Inner, m_BinOp(m_OneUse(m_Instruction(Inner)), m_Constant()))) return nullptr; // Fold only if the inner shift is a logical right-shift. - uint64_t InnerShrAmt; - if (!match(I, m_LShr(m_Value(), m_ConstantInt(InnerShrAmt)))) + const APInt *InnerShrConst; + if (!match(Inner, m_LShr(m_Value(), m_APInt(InnerShrConst)))) return nullptr; + const uint64_t InnerShrAmt = InnerShrConst->getZExtValue(); if (InnerShrAmt >= ShlAmt) { const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt; - if (!canEvaluateShifted(V, ReducedShrAmt, /*IsLeftShift=*/false, IC, + if (!canEvaluateShifted(Op, ReducedShrAmt, /*IsLeftShift=*/false, IC, nullptr)) return nullptr; - Value *NewInner = - getShiftedValue(V, ReducedShrAmt, /*isLeftShift=*/false, IC, DL); - return new ZExtInst(NewInner, DestTy); + Value *NewOp = + getShiftedValue(Op, ReducedShrAmt, /*isLeftShift=*/false, IC, DL); + return new ZExtInst(NewOp, DestTy); } - if (!canEvaluateShifted(V, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr)) + if (!canEvaluateShifted(Op, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr)) return nullptr; const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt; - Value *NewInner = - getShiftedValue(V, InnerShrAmt, /*isLeftShift=*/true, IC, DL); - Value *NewZExt = IC.Builder.CreateZExt(NewInner, DestTy); - return BinaryOperator::CreateShl(NewZExt, - ConstantInt::get(DestTy, ReducedShlAmt)); + Value *NewOp = getShiftedValue(Op, InnerShrAmt, /*isLeftShift=*/true, IC, DL); + Value *NewZExt = IC.Builder.CreateZExt(NewOp, DestTy); + NewZExt->takeName(I.getOperand(0)); + auto *NewShl = BinaryOperator::CreateShl( + NewZExt, ConstantInt::get(DestTy, ReducedShlAmt)); + + // New shl inherits all flags from the original shl instruction. + NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); + NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + return NewShl; } // Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits. @@ -1113,7 +1121,7 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty); // Otherwise, try to cancel the outer shl with a lshr inside the zext. - if (Instruction *V = foldShrThroughZExtedShl(Ty, X, ShAmtC, *this, DL)) + if (Instruction *V = foldShrThroughZExtedShl(I, X, ShAmtC, *this, DL)) return V; } diff --git a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll index 517783fcbcb5c1..818e7b0fc735c8 100644 --- a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll +++ b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll @@ -5,8 +5,8 @@ define i64 @simple(i32 %x) { ; CHECK-LABEL: define i64 @simple( ; CHECK-SAME: i32 [[X:%.*]]) { ; CHECK-NEXT: [[LSHR:%.*]] = and i32 [[X]], -256 -; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[LSHR]] to i64 -; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 24 +; CHECK-NEXT: [[ZEXT:%.*]] = zext i32 [[LSHR]] to i64 +; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 24 ; CHECK-NEXT: ret i64 [[SHL]] ; %lshr = lshr i32 %x, 8 @@ -20,8 +20,8 @@ define i64 @masked(i32 %x) { ; CHECK-LABEL: define i64 @masked( ; CHECK-SAME: i32 [[X:%.*]]) { ; CHECK-NEXT: [[MASK:%.*]] = and i32 [[X]], 4080 -; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i32 [[MASK]] to i64 -; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 44 +; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i32 [[MASK]] to i64 +; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 44 ; CHECK-NEXT: ret i64 [[SHL]] ; %lshr = lshr i32 %x, 4 @@ -67,3 +67,17 @@ define i64 @combine(i32 %lower, i32 %upper) { ret i64 %o.3 } + +define <2 x i64> @simple.vec(<2 x i32> %v) { +; CHECK-LABEL: define <2 x i64> @simple.vec( +; CHECK-SAME: <2 x i32> [[V:%.*]]) { +; CHECK-NEXT: [[LSHR:%.*]] = and <2 x i32> [[V]], splat (i32 -256) +; CHECK-NEXT: [[ZEXT:%.*]] = zext <2 x i32> [[LSHR]] to <2 x i64> +; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw <2 x i64> [[ZEXT]], splat (i64 24) +; CHECK-NEXT: ret <2 x i64> [[SHL]] +; + %lshr = lshr <2 x i32> %v, splat(i32 8) + %zext = zext <2 x i32> %lshr to <2 x i64> + %shl = shl <2 x i64> %zext, splat(i64 32) + ret <2 x i64> %shl +}