Skip to content

[WASM] Constant fold SIMD wasm intrinsics: any/alltrue #148074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,8 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::arm_mve_vctp32:
case Intrinsic::arm_mve_vctp64:
case Intrinsic::aarch64_sve_convert_from_svbool:
case Intrinsic::wasm_alltrue:
case Intrinsic::wasm_anytrue:
// WebAssembly float semantics are always known
case Intrinsic::wasm_trunc_signed:
case Intrinsic::wasm_trunc_unsigned:
Expand Down Expand Up @@ -2832,7 +2834,8 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,

// Support ConstantVector in case we have an Undef in the top.
if (isa<ConstantVector>(Operands[0]) ||
isa<ConstantDataVector>(Operands[0])) {
isa<ConstantDataVector>(Operands[0]) ||
isa<ConstantAggregateZero>(Operands[0])) {
auto *Op = cast<Constant>(Operands[0]);
switch (IntrinsicID) {
default: break;
Expand All @@ -2856,6 +2859,20 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
/*roundTowardZero=*/true, Ty,
/*IsSigned*/true);
break;

case Intrinsic::wasm_anytrue:
return Op->isZeroValue() ? ConstantInt::get(Ty, 0)
: ConstantInt::get(Ty, 1);

case Intrinsic::wasm_alltrue:
// Check each element individually
unsigned E = cast<FixedVectorType>(Op->getType())->getNumElements();
for (unsigned I = 0; I != E; ++I)
if (Constant *Elt = Op->getAggregateElement(I))
if (Elt->isZeroValue())
return ConstantInt::get(Ty, 0);

return ConstantInt::get(Ty, 1);
}
}

Expand Down
127 changes: 127 additions & 0 deletions llvm/test/CodeGen/WebAssembly/const_fold_simd_intrinsics.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5

; RUN: opt -passes=instcombine -S < %s | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test/CodeGen directory is for tests that run llc, because this is actually in the middle end and running opt this should be in llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/

I think we also only need to run instsimplify for this since we're not creating any new instructions:

Suggested change
; RUN: opt -passes=instcombine -S < %s | FileCheck %s
; RUN: opt -passes=instsimplify -S < %s | FileCheck %s


; Test that intrinsics wasm call are constant folded

; all_non_zero: a splat that is all non_zero
; not_all_non_zero: a splat that is all one, except for 0 in the first location

; all_zero: a splat that is all zero
; not_all_zero: a splat that is all zero, except for a non-zero in the first location

target triple = "wasm32-unknown-unknown"

define void @all_true_splat_not_all_non_zero(ptr %ptr) {
; CHECK-LABEL: define void @all_true_splat_not_all_non_zero(
; CHECK-SAME: ptr [[PTR:%.*]]) {
; CHECK-NEXT: store volatile i32 0, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 0, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 0, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 0, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 0, ptr [[PTR]], align 4
; CHECK-NEXT: ret void
;
%a = call i32 @llvm.wasm.alltrue(<16 x i8> <i8 0, i8 1, i8 2, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>)
store volatile i32 %a, ptr %ptr

%b = call i32 @llvm.wasm.alltrue(<8 x i16> <i16 0, i16 1, i16 2, i16 1, i16 1, i16 1, i16 1, i16 1>)
store volatile i32 %b, ptr %ptr

%c = call i32 @llvm.wasm.alltrue(<4 x i32> <i32 0, i32 1, i32 1, i32 1>)
store volatile i32 %c, ptr %ptr

%d = call i32 @llvm.wasm.alltrue(<2 x i64> <i64 0, i64 42>)
store volatile i32 %d, ptr %ptr

%e = call i32 @llvm.wasm.alltrue(<4 x i64> <i64 0, i64 1, i64 1, i64 1>)
store volatile i32 %e, ptr %ptr

ret void
}

define void @all_true_splat_all_non_zero(ptr %ptr) {
; CHECK-LABEL: define void @all_true_splat_all_non_zero(
; CHECK-SAME: ptr [[PTR:%.*]]) {
; CHECK-NEXT: store volatile i32 1, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 1, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 1, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 1, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 1, ptr [[PTR]], align 4
; CHECK-NEXT: ret void
;
%a = call i32 @llvm.wasm.alltrue(<16 x i8> <i8 1, i8 3, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>)
store volatile i32 %a, ptr %ptr

%b = call i32 @llvm.wasm.alltrue(<8 x i16> <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>)
store volatile i32 %b, ptr %ptr

%c = call i32 @llvm.wasm.alltrue(<4 x i32> <i32 1, i32 1, i32 1, i32 1>)
store volatile i32 %c, ptr %ptr

%d = call i32 @llvm.wasm.alltrue(<2 x i64> <i64 2, i64 2>)
store volatile i32 %d, ptr %ptr

%e = call i32 @llvm.wasm.alltrue(<4 x i64> <i64 1, i64 2, i64 1, i64 1>)
store volatile i32 %e, ptr %ptr

ret void
}


define void @any_true_splat_all_zero(ptr %ptr) {
; CHECK-LABEL: define void @any_true_splat_all_zero(
; CHECK-SAME: ptr [[PTR:%.*]]) {
; CHECK-NEXT: store volatile i32 0, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 0, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 0, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 0, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 0, ptr [[PTR]], align 4
; CHECK-NEXT: ret void
;
%a = call i32 @llvm.wasm.anytrue(<16 x i8> <i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0>)
store volatile i32 %a, ptr %ptr

%b = call i32 @llvm.wasm.anytrue(<8 x i16> <i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>)
store volatile i32 %b, ptr %ptr

%c = call i32 @llvm.wasm.anytrue(<4 x i32> <i32 0, i32 0, i32 0, i32 0>)
store volatile i32 %c, ptr %ptr

%d = call i32 @llvm.wasm.anytrue(<2 x i64> <i64 0, i64 0>)
store volatile i32 %d, ptr %ptr

%e = call i32 @llvm.wasm.anytrue(<4 x i64> <i64 0, i64 0, i64 0, i64 0>)
store volatile i32 %e, ptr %ptr

ret void
}


define void @any_true_splat_not_all_zero(ptr %ptr) {
; CHECK-LABEL: define void @any_true_splat_not_all_zero(
; CHECK-SAME: ptr [[PTR:%.*]]) {
; CHECK-NEXT: store volatile i32 1, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 1, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 1, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 1, ptr [[PTR]], align 4
; CHECK-NEXT: store volatile i32 1, ptr [[PTR]], align 4
; CHECK-NEXT: ret void
;
%a = call i32 @llvm.wasm.anytrue(<16 x i8> <i8 1, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0>)
store volatile i32 %a, ptr %ptr

%b = call i32 @llvm.wasm.anytrue(<8 x i16> <i16 3, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>)
store volatile i32 %b, ptr %ptr

%c = call i32 @llvm.wasm.anytrue(<4 x i32> <i32 1, i32 0, i32 0, i32 0>)
store volatile i32 %c, ptr %ptr

%d = call i32 @llvm.wasm.anytrue(<2 x i64> <i64 -1, i64 0>)
store volatile i32 %d, ptr %ptr

%e = call i32 @llvm.wasm.anytrue(<4 x i64> <i64 2, i64 0, i64 0, i64 0>)
store volatile i32 %e, ptr %ptr

ret void
}