Skip to content

Commit

Permalink
[AArch64] Add lowering for @llvm.experimental.vector.compress (llvm…
Browse files Browse the repository at this point in the history
…#101015)

This is a follow-up to llvm#92289 that adds custom lowering of the new
`@llvm.experimental.vector.compress` intrinsic on AArch64 with SVE
instructions.

Some vectors have a `compact` instruction that they can be lowered to.
  • Loading branch information
lawben authored and bwendling committed Aug 15, 2024
1 parent 49e58d0 commit 1b7a3f7
Show file tree
Hide file tree
Showing 4 changed files with 453 additions and 6 deletions.
65 changes: 59 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2412,11 +2412,64 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS(SDNode *N, SDValue &Lo,
SDValue &Hi) {
// This is not "trivial", as there is a dependency between the two subvectors.
// Depending on the number of 1s in the mask, the elements from the Hi vector
// need to be moved to the Lo vector. So we just perform this as one "big"
// operation and then extract the Lo and Hi vectors from that. This gets rid
// of VECTOR_COMPRESS and all other operands can be legalized later.
SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, SDLoc(N));
// need to be moved to the Lo vector. Passthru values make this even harder.
// We try to use VECTOR_COMPRESS if the target has custom lowering with
// smaller types and passthru is undef, as it is most likely faster than the
// fully expand path. Otherwise, just do the full expansion as one "big"
// operation and then extract the Lo and Hi vectors from that. This gets
// rid of VECTOR_COMPRESS and all other operands can be legalized later.
SDLoc DL(N);
EVT VecVT = N->getValueType(0);

auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VecVT);
bool HasCustomLowering = false;
EVT CheckVT = LoVT;
while (CheckVT.getVectorMinNumElements() > 1) {
// TLI.isOperationLegalOrCustom requires a legal type, but we could have a
// custom lowering for illegal types. So we do the checks separately.
if (TLI.isOperationLegal(ISD::VECTOR_COMPRESS, CheckVT) ||
TLI.isOperationCustom(ISD::VECTOR_COMPRESS, CheckVT)) {
HasCustomLowering = true;
break;
}
CheckVT = CheckVT.getHalfNumVectorElementsVT(*DAG.getContext());
}

SDValue Passthru = N->getOperand(2);
if (!HasCustomLowering || !Passthru.isUndef()) {
SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL, LoVT, HiVT);
return;
}

// Try to VECTOR_COMPRESS smaller vectors and combine via a stack store+load.
SDValue LoMask, HiMask;
std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
std::tie(LoMask, HiMask) = SplitMask(N->getOperand(1));

SDValue UndefPassthru = DAG.getUNDEF(LoVT);
Lo = DAG.getNode(ISD::VECTOR_COMPRESS, DL, LoVT, Lo, LoMask, UndefPassthru);
Hi = DAG.getNode(ISD::VECTOR_COMPRESS, DL, HiVT, Hi, HiMask, UndefPassthru);

SDValue StackPtr = DAG.CreateStackTemporary(
VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
MachineFunction &MF = DAG.getMachineFunction();
MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());

// We store LoVec and then insert HiVec starting at offset=|1s| in LoMask.
SDValue WideMask =
DAG.getNode(ISD::ZERO_EXTEND, DL, LoMask.getValueType(), LoMask);
SDValue Offset = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, WideMask);
Offset = TLI.getVectorElementPointer(DAG, StackPtr, VecVT, Offset);

SDValue Chain = DAG.getEntryNode();
Chain = DAG.getStore(Chain, DL, Lo, StackPtr, PtrInfo);
Chain = DAG.getStore(Chain, DL, Hi, Offset,
MachinePointerInfo::getUnknownStack(MF));

SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL);
}

void DAGTypeLegalizer::SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
Expand Down Expand Up @@ -5790,7 +5843,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_VECTOR_COMPRESS(SDNode *N) {
TLI.getTypeToTransformTo(*DAG.getContext(), Vec.getValueType());
EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
Mask.getValueType().getVectorElementType(),
WideVecVT.getVectorNumElements());
WideVecVT.getVectorElementCount());

SDValue WideVec = ModifyToType(Vec, WideVecVT);
SDValue WideMask = ModifyToType(Mask, WideMaskVT, /*FillWithZeroes=*/true);
Expand Down
116 changes: 116 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1775,6 +1775,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::v2f32, MVT::v4f32, MVT::v2f64})
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);

// We can lower types that have <vscale x {2|4}> elements to compact.
for (auto VT :
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
for (auto VT : {MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64, MVT::v2f32,
MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// Histcnt is SVE2 only
if (Subtarget->hasSVE2()) {
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
Expand Down Expand Up @@ -6619,6 +6631,104 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
return DAG.getMergeValues({Ext, Chain}, DL);
}

SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
SDValue Vec = Op.getOperand(0);
SDValue Mask = Op.getOperand(1);
SDValue Passthru = Op.getOperand(2);
EVT VecVT = Vec.getValueType();
EVT MaskVT = Mask.getValueType();
EVT ElmtVT = VecVT.getVectorElementType();
const bool IsFixedLength = VecVT.isFixedLengthVector();
const bool HasPassthru = !Passthru.isUndef();
unsigned MinElmts = VecVT.getVectorElementCount().getKnownMinValue();
EVT FixedVecVT = MVT::getVectorVT(ElmtVT.getSimpleVT(), MinElmts);

assert(VecVT.isVector() && "Input to VECTOR_COMPRESS must be vector.");

if (!Subtarget->isSVEAvailable())
return SDValue();

if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
return SDValue();

// Only <vscale x {4|2} x {i32|i64}> supported for compact.
if (MinElmts != 2 && MinElmts != 4)
return SDValue();

// We can use the SVE register containing the NEON vector in its lowest bits.
if (IsFixedLength) {
EVT ScalableVecVT =
MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), MinElmts);
EVT ScalableMaskVT = MVT::getScalableVectorVT(
MaskVT.getVectorElementType().getSimpleVT(), MinElmts);

Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
DAG.getUNDEF(ScalableVecVT), Vec,
DAG.getConstant(0, DL, MVT::i64));
Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableMaskVT,
DAG.getUNDEF(ScalableMaskVT), Mask,
DAG.getConstant(0, DL, MVT::i64));
Mask = DAG.getNode(ISD::TRUNCATE, DL,
ScalableMaskVT.changeVectorElementType(MVT::i1), Mask);
Passthru = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
DAG.getUNDEF(ScalableVecVT), Passthru,
DAG.getConstant(0, DL, MVT::i64));

VecVT = Vec.getValueType();
MaskVT = Mask.getValueType();
}

// Get legal type for compact instruction
EVT ContainerVT = getSVEContainerType(VecVT);
EVT CastVT = VecVT.changeVectorElementTypeToInteger();

// Convert to i32 or i64 for smaller types, as these are the only supported
// sizes for compact.
if (ContainerVT != VecVT) {
Vec = DAG.getBitcast(CastVT, Vec);
Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
}

SDValue Compressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);

// compact fills with 0s, so if our passthru is all 0s, do nothing here.
if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
SDValue Offset = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Mask, Mask);

SDValue IndexMask = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MaskVT,
DAG.getConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64),
DAG.getConstant(0, DL, MVT::i64), Offset);

Compressed =
DAG.getNode(ISD::VSELECT, DL, VecVT, IndexMask, Compressed, Passthru);
}

// Extracting from a legal SVE type before truncating produces better code.
if (IsFixedLength) {
Compressed = DAG.getNode(
ISD::EXTRACT_SUBVECTOR, DL,
FixedVecVT.changeVectorElementType(ContainerVT.getVectorElementType()),
Compressed, DAG.getConstant(0, DL, MVT::i64));
CastVT = FixedVecVT.changeVectorElementTypeToInteger();
VecVT = FixedVecVT;
}

// If we changed the element type before, we need to convert it back.
if (ContainerVT != VecVT) {
Compressed = DAG.getNode(ISD::TRUNCATE, DL, CastVT, Compressed);
Compressed = DAG.getBitcast(VecVT, Compressed);
}

return Compressed;
}

// Generate SUBS and CSEL for integer abs.
SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
MVT VT = Op.getSimpleValueType();
Expand Down Expand Up @@ -6999,6 +7109,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerDYNAMIC_STACKALLOC(Op, DAG);
case ISD::VSCALE:
return LowerVSCALE(Op, DAG);
case ISD::VECTOR_COMPRESS:
return LowerVECTOR_COMPRESS(Op, DAG);
case ISD::ANY_EXTEND:
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
Expand Down Expand Up @@ -26563,6 +26675,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
case ISD::VECREDUCE_UMIN:
Results.push_back(LowerVECREDUCE(SDValue(N, 0), DAG));
return;
case ISD::VECTOR_COMPRESS:
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
case ISD::ADD:
case ISD::FADD:
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,8 @@ class AArch64TargetLowering : public TargetLowering {

SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerVECTOR_COMPRESS(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINTRINSIC_VOID(SDValue Op, SelectionDAG &DAG) const;
Expand Down
Loading

0 comments on commit 1b7a3f7

Please sign in to comment.