Skip to content

Commit

Permalink
[X86] Teach combineBitcastvxi1 to prefer movmsk on avx512 in more cases
Browse files Browse the repository at this point in the history
If the input to the bitcast is a sign bit test, it makes sense to
directly use vpmovmskb or vmovmskps/pd. This removes the need to
copy the sign bits to a k-register and then to a GPR.

Fixes PR46200.

Differential Revision: https://reviews.llvm.org/D81327
  • Loading branch information
topperc committed Jun 13, 2020
1 parent 6b4b660 commit cb5072d
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 143 deletions.
63 changes: 58 additions & 5 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37668,14 +37668,26 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
// movmskb even with avx512. This will be better than truncating to vXi1 and
// using a kmov. This can especially help KNL if the input is a v16i8/v32i8
// vpcmpeqb/vpcmpgtb.
bool IsTruncated = Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse() &&
(Src.getOperand(0).getValueType() == MVT::v16i8 ||
Src.getOperand(0).getValueType() == MVT::v32i8 ||
Src.getOperand(0).getValueType() == MVT::v64i8);
bool PreferMovMsk = Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse() &&
(Src.getOperand(0).getValueType() == MVT::v16i8 ||
Src.getOperand(0).getValueType() == MVT::v32i8 ||
Src.getOperand(0).getValueType() == MVT::v64i8);

// Prefer movmsk for AVX512 for (bitcast (setlt X, 0)) which can be handled
// directly with vpmovmskb/vmovmskps/vmovmskpd.
if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse() &&
cast<CondCodeSDNode>(Src.getOperand(2))->get() == ISD::SETLT &&
ISD::isBuildVectorAllZeros(Src.getOperand(1).getNode())) {
EVT CmpVT = Src.getOperand(0).getValueType();
EVT EltVT = CmpVT.getVectorElementType();
if (CmpVT.getSizeInBits() <= 256 &&
(EltVT == MVT::i8 || EltVT == MVT::i32 || EltVT == MVT::i64))
PreferMovMsk = true;
}

// With AVX512 vxi1 types are legal and we prefer using k-regs.
// MOVMSK is supported in SSE2 or later.
if (!Subtarget.hasSSE2() || (Subtarget.hasAVX512() && !IsTruncated))
if (!Subtarget.hasSSE2() || (Subtarget.hasAVX512() && !PreferMovMsk))
return SDValue();

// There are MOVMSK flavors for types v16i8, v32i8, v4f32, v8f32, v4f64 and
Expand Down Expand Up @@ -38169,6 +38181,47 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
return DAG.getConstant(0, SDLoc(N0), VT);
}

// Look for MOVMSK that is maybe truncated and then bitcasted to vXi1.
// Turn it into a sign bit compare that produces a k-register. This avoids
// a trip through a GPR.
if (Subtarget.hasAVX512() && SrcVT.isScalarInteger() &&
VT.isVector() && VT.getVectorElementType() == MVT::i1 &&
isPowerOf2_32(VT.getVectorNumElements())) {
unsigned NumElts = VT.getVectorNumElements();
SDValue Src = N0;

// Peek through truncate.
if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse())
Src = N0.getOperand(0);

if (Src.getOpcode() == X86ISD::MOVMSK && Src.hasOneUse()) {
SDValue MovmskIn = Src.getOperand(0);
MVT MovmskVT = MovmskIn.getSimpleValueType();
unsigned MovMskElts = MovmskVT.getVectorNumElements();

// We allow extra bits of the movmsk to be used since they are known zero.
// We can't convert a VPMOVMSKB without avx512bw.
if (MovMskElts <= NumElts &&
(Subtarget.hasBWI() || MovmskVT.getVectorElementType() != MVT::i8)) {
EVT IntVT = EVT(MovmskVT).changeVectorElementTypeToInteger();
MovmskIn = DAG.getBitcast(IntVT, MovmskIn);
SDLoc dl(N);
MVT CmpVT = MVT::getVectorVT(MVT::i1, MovMskElts);
SDValue Cmp = DAG.getSetCC(dl, CmpVT, MovmskIn,
DAG.getConstant(0, dl, IntVT), ISD::SETLT);
if (EVT(CmpVT) == VT)
return Cmp;

// Pad with zeroes up to original VT to replace the zeroes that were
// being used from the MOVMSK.
unsigned NumConcats = NumElts / MovMskElts;
SmallVector<SDValue, 4> Ops(NumConcats, DAG.getConstant(0, dl, CmpVT));
Ops[0] = Cmp;
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Ops);
}
}
}

// Try to remove bitcasts from input and output of mask arithmetic to
// remove GPR<->K-register crossings.
if (SDValue V = combineCastedMaskArithmetic(N, DAG, DCI, Subtarget))
Expand Down
6 changes: 2 additions & 4 deletions llvm/test/CodeGen/X86/avx512bwvl-intrinsics-upgrade.ll
Original file line number Diff line number Diff line change
Expand Up @@ -6334,8 +6334,7 @@ declare i16 @llvm.x86.avx512.cvtb2mask.128(<16 x i8>)
define i16@test_int_x86_avx512_cvtb2mask_128(<16 x i8> %x0) {
; CHECK-LABEL: test_int_x86_avx512_cvtb2mask_128:
; CHECK: # %bb.0:
; CHECK-NEXT: vpmovb2m %xmm0, %k0 # encoding: [0x62,0xf2,0x7e,0x08,0x29,0xc0]
; CHECK-NEXT: kmovd %k0, %eax # encoding: [0xc5,0xfb,0x93,0xc0]
; CHECK-NEXT: vpmovmskb %xmm0, %eax # encoding: [0xc5,0xf9,0xd7,0xc0]
; CHECK-NEXT: # kill: def $ax killed $ax killed $eax
; CHECK-NEXT: ret{{[l|q]}} # encoding: [0xc3]
%res = call i16 @llvm.x86.avx512.cvtb2mask.128(<16 x i8> %x0)
Expand All @@ -6347,8 +6346,7 @@ declare i32 @llvm.x86.avx512.cvtb2mask.256(<32 x i8>)
define i32@test_int_x86_avx512_cvtb2mask_256(<32 x i8> %x0) {
; CHECK-LABEL: test_int_x86_avx512_cvtb2mask_256:
; CHECK: # %bb.0:
; CHECK-NEXT: vpmovb2m %ymm0, %k0 # encoding: [0x62,0xf2,0x7e,0x28,0x29,0xc0]
; CHECK-NEXT: kmovd %k0, %eax # encoding: [0xc5,0xfb,0x93,0xc0]
; CHECK-NEXT: vpmovmskb %ymm0, %eax # encoding: [0xc5,0xfd,0xd7,0xc0]
; CHECK-NEXT: vzeroupper # encoding: [0xc5,0xf8,0x77]
; CHECK-NEXT: ret{{[l|q]}} # encoding: [0xc3]
%res = call i32 @llvm.x86.avx512.cvtb2mask.256(<32 x i8> %x0)
Expand Down
3 changes: 1 addition & 2 deletions llvm/test/CodeGen/X86/avx512dqvl-intrinsics-upgrade.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2764,8 +2764,7 @@ declare i8 @llvm.x86.avx512.cvtd2mask.256(<8 x i32>)
define i8@test_int_x86_avx512_cvtd2mask_256(<8 x i32> %x0) {
; CHECK-LABEL: test_int_x86_avx512_cvtd2mask_256:
; CHECK: # %bb.0:
; CHECK-NEXT: vpmovd2m %ymm0, %k0 # encoding: [0x62,0xf2,0x7e,0x28,0x39,0xc0]
; CHECK-NEXT: kmovw %k0, %eax # encoding: [0xc5,0xf8,0x93,0xc0]
; CHECK-NEXT: vmovmskps %ymm0, %eax # encoding: [0xc5,0xfc,0x50,0xc0]
; CHECK-NEXT: # kill: def $al killed $al killed $eax
; CHECK-NEXT: vzeroupper # encoding: [0xc5,0xf8,0x77]
; CHECK-NEXT: ret{{[l|q]}} # encoding: [0xc3]
Expand Down
23 changes: 6 additions & 17 deletions llvm/test/CodeGen/X86/bitcast-setcc-256.ll
Original file line number Diff line number Diff line change
Expand Up @@ -420,23 +420,12 @@ define void @bitcast_8i32_store(i8* %p, <8 x i32> %a0) {
; AVX12-NEXT: vzeroupper
; AVX12-NEXT: retq
;
; AVX512F-LABEL: bitcast_8i32_store:
; AVX512F: # %bb.0:
; AVX512F-NEXT: vpxor %xmm1, %xmm1, %xmm1
; AVX512F-NEXT: vpcmpgtd %ymm0, %ymm1, %k0
; AVX512F-NEXT: kmovw %k0, %eax
; AVX512F-NEXT: movb %al, (%rdi)
; AVX512F-NEXT: vzeroupper
; AVX512F-NEXT: retq
;
; AVX512BW-LABEL: bitcast_8i32_store:
; AVX512BW: # %bb.0:
; AVX512BW-NEXT: vpxor %xmm1, %xmm1, %xmm1
; AVX512BW-NEXT: vpcmpgtd %ymm0, %ymm1, %k0
; AVX512BW-NEXT: kmovd %k0, %eax
; AVX512BW-NEXT: movb %al, (%rdi)
; AVX512BW-NEXT: vzeroupper
; AVX512BW-NEXT: retq
; AVX512-LABEL: bitcast_8i32_store:
; AVX512: # %bb.0:
; AVX512-NEXT: vmovmskps %ymm0, %eax
; AVX512-NEXT: movb %al, (%rdi)
; AVX512-NEXT: vzeroupper
; AVX512-NEXT: retq
%a1 = icmp slt <8 x i32> %a0, zeroinitializer
%a2 = bitcast <8 x i1> %a1 to i8
store i8 %a2, i8* %p
Expand Down
34 changes: 9 additions & 25 deletions llvm/test/CodeGen/X86/bitcast-vector-bool.ll
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ define i2 @bitcast_v4i32_to_v2i2(<4 x i32> %a0) nounwind {
;
; AVX512-LABEL: bitcast_v4i32_to_v2i2:
; AVX512: # %bb.0:
; AVX512-NEXT: vpxor %xmm1, %xmm1, %xmm1
; AVX512-NEXT: vpcmpgtd %xmm0, %xmm1, %k0
; AVX512-NEXT: kmovd %k0, %eax
; AVX512-NEXT: vmovmskps %xmm0, %eax
; AVX512-NEXT: movl %eax, %ecx
; AVX512-NEXT: shrb $2, %cl
; AVX512-NEXT: andb $3, %al
Expand Down Expand Up @@ -146,11 +144,9 @@ define i8 @bitcast_v16i8_to_v2i8(<16 x i8> %a0) nounwind {
;
; AVX512-LABEL: bitcast_v16i8_to_v2i8:
; AVX512: # %bb.0:
; AVX512-NEXT: vpmovb2m %xmm0, %k0
; AVX512-NEXT: kmovw %k0, -{{[0-9]+}}(%rsp)
; AVX512-NEXT: vmovdqa -{{[0-9]+}}(%rsp), %xmm0
; AVX512-NEXT: vmovd %xmm0, %ecx
; AVX512-NEXT: vpextrb $1, %xmm0, %eax
; AVX512-NEXT: vpmovmskb %xmm0, %ecx
; AVX512-NEXT: movl %ecx, %eax
; AVX512-NEXT: shrl $8, %eax
; AVX512-NEXT: addb %cl, %al
; AVX512-NEXT: # kill: def $al killed $al killed $eax
; AVX512-NEXT: retq
Expand Down Expand Up @@ -191,9 +187,7 @@ define i2 @bitcast_v4i64_to_v2i2(<4 x i64> %a0) nounwind {
;
; AVX512-LABEL: bitcast_v4i64_to_v2i2:
; AVX512: # %bb.0:
; AVX512-NEXT: vpxor %xmm1, %xmm1, %xmm1
; AVX512-NEXT: vpcmpgtq %ymm0, %ymm1, %k0
; AVX512-NEXT: kmovd %k0, %eax
; AVX512-NEXT: vmovmskpd %ymm0, %eax
; AVX512-NEXT: movl %eax, %ecx
; AVX512-NEXT: shrb $2, %cl
; AVX512-NEXT: andb $3, %al
Expand Down Expand Up @@ -235,9 +229,7 @@ define i4 @bitcast_v8i32_to_v2i4(<8 x i32> %a0) nounwind {
;
; AVX512-LABEL: bitcast_v8i32_to_v2i4:
; AVX512: # %bb.0:
; AVX512-NEXT: vpxor %xmm1, %xmm1, %xmm1
; AVX512-NEXT: vpcmpgtd %ymm0, %ymm1, %k0
; AVX512-NEXT: kmovd %k0, %eax
; AVX512-NEXT: vmovmskps %ymm0, %eax
; AVX512-NEXT: movl %eax, %ecx
; AVX512-NEXT: shrb $4, %cl
; AVX512-NEXT: andb $15, %al
Expand Down Expand Up @@ -338,19 +330,11 @@ define i16 @bitcast_v32i8_to_v2i16(<32 x i8> %a0) nounwind {
;
; AVX512-LABEL: bitcast_v32i8_to_v2i16:
; AVX512: # %bb.0:
; AVX512-NEXT: pushq %rbp
; AVX512-NEXT: movq %rsp, %rbp
; AVX512-NEXT: andq $-32, %rsp
; AVX512-NEXT: subq $32, %rsp
; AVX512-NEXT: vpmovb2m %ymm0, %k0
; AVX512-NEXT: kmovd %k0, (%rsp)
; AVX512-NEXT: vmovdqa (%rsp), %xmm0
; AVX512-NEXT: vmovd %xmm0, %ecx
; AVX512-NEXT: vpextrw $1, %xmm0, %eax
; AVX512-NEXT: vpmovmskb %ymm0, %ecx
; AVX512-NEXT: movl %ecx, %eax
; AVX512-NEXT: shrl $16, %eax
; AVX512-NEXT: addl %ecx, %eax
; AVX512-NEXT: # kill: def $ax killed $ax killed $eax
; AVX512-NEXT: movq %rbp, %rsp
; AVX512-NEXT: popq %rbp
; AVX512-NEXT: vzeroupper
; AVX512-NEXT: retq
%1 = icmp slt <32 x i8> %a0, zeroinitializer
Expand Down
Loading

0 comments on commit cb5072d

Please sign in to comment.