Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SplitVecOp_VP_REDUCE(SDNode *N, unsigned OpNo);
SDValue SplitVecOp_UnaryOp(SDNode *N);
SDValue SplitVecOp_TruncateHelper(SDNode *N);
SDValue SplitVecOp_VECTOR_COMPRESS(SDNode *N, unsigned OpNo);

SDValue SplitVecOp_BITCAST(SDNode *N);
SDValue SplitVecOp_INSERT_SUBVECTOR(SDNode *N, unsigned OpNo);
Expand Down
26 changes: 24 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2436,16 +2436,17 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS(SDNode *N, SDValue &Lo,
}

SDValue Passthru = N->getOperand(2);
if (!HasCustomLowering || !Passthru.isUndef()) {
if (!HasCustomLowering) {
SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL, LoVT, HiVT);
return;
}

// Try to VECTOR_COMPRESS smaller vectors and combine via a stack store+load.
SDValue Mask = N->getOperand(1);
SDValue LoMask, HiMask;
std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
std::tie(LoMask, HiMask) = SplitMask(N->getOperand(1));
std::tie(LoMask, HiMask) = SplitMask(Mask);

SDValue UndefPassthru = DAG.getUNDEF(LoVT);
Lo = DAG.getNode(ISD::VECTOR_COMPRESS, DL, LoVT, Lo, LoMask, UndefPassthru);
Expand All @@ -2469,6 +2470,10 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS(SDNode *N, SDValue &Lo,
MachinePointerInfo::getUnknownStack(MF));

SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
if (!Passthru.isUndef()) {
Compressed =
DAG.getNode(ISD::VSELECT, DL, VecVT, Mask, Compressed, Passthru);
}
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL);
}

Expand Down Expand Up @@ -3226,6 +3231,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VSELECT:
Res = SplitVecOp_VSELECT(N, OpNo);
break;
case ISD::VECTOR_COMPRESS:
Res = SplitVecOp_VECTOR_COMPRESS(N, OpNo);
break;
case ISD::STRICT_SINT_TO_FP:
case ISD::STRICT_UINT_TO_FP:
case ISD::SINT_TO_FP:
Expand Down Expand Up @@ -3372,6 +3380,20 @@ SDValue DAGTypeLegalizer::SplitVecOp_VSELECT(SDNode *N, unsigned OpNo) {
return DAG.getNode(ISD::CONCAT_VECTORS, DL, Src0VT, LoSelect, HiSelect);
}

SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_COMPRESS(SDNode *N, unsigned OpNo) {
// The only possibility for an illegal operand is the mask, since result type
// legalization would have handled this node already otherwise.
assert(OpNo == 1 && "Illegal operand must be mask");

// To split the mask, we need to split the result type too, so we can just
// reuse that logic here.
SDValue Lo, Hi;
SplitVecRes_VECTOR_COMPRESS(N, Lo, Hi);

EVT VecVT = N->getValueType(0);
return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VecVT, Lo, Hi);
}

SDValue DAGTypeLegalizer::SplitVecOp_VECREDUCE(SDNode *N, unsigned OpNo) {
EVT ResVT = N->getValueType(0);
SDValue Lo, Hi;
Expand Down
14 changes: 9 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11582,11 +11582,13 @@ SDValue TargetLowering::expandVECTOR_COMPRESS(SDNode *Node,
// ... if it is not a splat vector, we need to get the passthru value at
// position = popcount(mask) and re-load it from the stack before it is
// overwritten in the loop below.
EVT PopcountVT = ScalarVT.changeTypeToInteger();
SDValue Popcount = DAG.getNode(
ISD::TRUNCATE, DL, MaskVT.changeVectorElementType(MVT::i1), Mask);
Popcount = DAG.getNode(ISD::ZERO_EXTEND, DL,
MaskVT.changeVectorElementType(ScalarVT), Popcount);
Popcount = DAG.getNode(ISD::VECREDUCE_ADD, DL, ScalarVT, Popcount);
Popcount =
DAG.getNode(ISD::ZERO_EXTEND, DL,
MaskVT.changeVectorElementType(PopcountVT), Popcount);
Popcount = DAG.getNode(ISD::VECREDUCE_ADD, DL, PopcountVT, Popcount);
SDValue LastElmtPtr =
getVectorElementPointer(DAG, StackPtr, VecVT, Popcount);
LastWriteVal = DAG.getLoad(
Expand Down Expand Up @@ -11625,8 +11627,10 @@ SDValue TargetLowering::expandVECTOR_COMPRESS(SDNode *Node,

// Re-write the last ValI if all lanes were selected. Otherwise,
// overwrite the last write it with the passthru value.
LastWriteVal =
DAG.getSelect(DL, ScalarVT, AllLanesSelected, ValI, LastWriteVal);
SDNodeFlags Flags{};
Flags.setUnpredictable(true);
LastWriteVal = DAG.getSelect(DL, ScalarVT, AllLanesSelected, ValI,
LastWriteVal, Flags);
Chain = DAG.getStore(
Chain, DL, LastWriteVal, OutPtr,
MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
Expand Down
95 changes: 95 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2125,6 +2125,35 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64 })
setOperationAction(ISD::CTPOP, VT, Legal);
}

// We can try to convert vectors to different sizes to leverage legal
// `vpcompress` cases. So we mark these supported vector sizes as Custom and
// then specialize to Legal below.
for (MVT VT : {MVT::v8i32, MVT::v8f32, MVT::v4i32, MVT::v4f32, MVT::v4i64,
MVT::v4f64, MVT::v2i64, MVT::v2f64, MVT::v16i8, MVT::v8i16,
MVT::v16i16, MVT::v8i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// Legal vpcompress depends on various AVX512 extensions.
// Legal in AVX512F
for (MVT VT : {MVT::v16i32, MVT::v16f32, MVT::v8i64, MVT::v8f64})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Legal);

// Legal in AVX512F + AVX512VL
if (Subtarget.hasVLX())
for (MVT VT : {MVT::v8i32, MVT::v8f32, MVT::v4i32, MVT::v4f32, MVT::v4i64,
MVT::v4f64, MVT::v2i64, MVT::v2f64})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Legal);

// Legal in AVX512F + AVX512VBMI2
if (Subtarget.hasVBMI2())
for (MVT VT : {MVT::v32i16, MVT::v64i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Legal);

// Legal in AVX512F + AVX512VL + AVX512VBMI2
if (Subtarget.hasVBMI2() && Subtarget.hasVLX())
for (MVT VT : {MVT::v16i8, MVT::v8i16, MVT::v32i8, MVT::v16i16})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Legal);
}

// This block control legalization of v32i1/v64i1 which are available with
Expand Down Expand Up @@ -17755,6 +17784,71 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, const X86Subtarget &Subtarget,
llvm_unreachable("Unimplemented!");
}

// As legal vpcompress instructions depend on various AVX512 extensions, try to
// convert illegal vector sizes to legal ones to avoid expansion.
static SDValue lowerVECTOR_COMPRESS(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
assert(Subtarget.hasAVX512() &&
"Need AVX512 for custom VECTOR_COMPRESS lowering.");

SDLoc DL(Op);
SDValue Vec = Op.getOperand(0);
SDValue Mask = Op.getOperand(1);
SDValue Passthru = Op.getOperand(2);

EVT VecVT = Vec.getValueType();
EVT ElementVT = VecVT.getVectorElementType();
unsigned NumElements = VecVT.getVectorNumElements();
unsigned NumVecBits = VecVT.getFixedSizeInBits();
unsigned NumElementBits = ElementVT.getFixedSizeInBits();

// 128- and 256-bit vectors with <= 16 elements can be converted to and
// compressed as 512-bit vectors in AVX512F.
if (NumVecBits != 128 && NumVecBits != 256)
return SDValue();

if (NumElementBits == 32 || NumElementBits == 64) {
unsigned NumLargeElements = 512 / NumElementBits;
EVT LargeVecVT =
MVT::getVectorVT(ElementVT.getSimpleVT(), NumLargeElements);
EVT LargeMaskVT = MVT::getVectorVT(MVT::i1, NumLargeElements);

SDValue InsertPos = DAG.getConstant(0, DL, MVT::i64);
Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, LargeVecVT,
DAG.getUNDEF(LargeVecVT), Vec, InsertPos);
Mask = DAG.getNode(
ISD::INSERT_SUBVECTOR, DL, LargeMaskVT,
DAG.getSplatVector(LargeMaskVT, DL, DAG.getConstant(0, DL, MVT::i1)),
Mask, InsertPos);
Passthru = Passthru.isUndef()
? DAG.getUNDEF(LargeVecVT)
: DAG.getNode(ISD::INSERT_SUBVECTOR, DL, LargeVecVT,
DAG.getUNDEF(LargeVecVT), Passthru, InsertPos);

SDValue Compressed =
DAG.getNode(ISD::VECTOR_COMPRESS, DL, LargeVecVT, Vec, Mask, Passthru);
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, Compressed,
InsertPos);
}

if (VecVT == MVT::v8i16 || VecVT == MVT::v8i8 || VecVT == MVT::v16i8 ||
VecVT == MVT::v16i16) {
MVT LageElementVT = MVT::getIntegerVT(512 / NumElements);
EVT LargeVecVT = MVT::getVectorVT(LageElementVT, NumElements);

Vec = DAG.getNode(ISD::ANY_EXTEND, DL, LargeVecVT, Vec);
Passthru = Passthru.isUndef()
? DAG.getUNDEF(LargeVecVT)
: DAG.getNode(ISD::ANY_EXTEND, DL, LargeVecVT, Passthru);

SDValue Compressed =
DAG.getNode(ISD::VECTOR_COMPRESS, DL, LargeVecVT, Vec, Mask, Passthru);
return DAG.getNode(ISD::TRUNCATE, DL, VecVT, Compressed);
}

return SDValue();
}

/// Try to lower a VSELECT instruction to a vector shuffle.
static SDValue lowerVSELECTtoVectorShuffle(SDValue Op,
const X86Subtarget &Subtarget,
Expand Down Expand Up @@ -32374,6 +32468,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::BUILD_VECTOR: return LowerBUILD_VECTOR(Op, DAG);
case ISD::CONCAT_VECTORS: return LowerCONCAT_VECTORS(Op, Subtarget, DAG);
case ISD::VECTOR_SHUFFLE: return lowerVECTOR_SHUFFLE(Op, Subtarget, DAG);
case ISD::VECTOR_COMPRESS: return lowerVECTOR_COMPRESS(Op, Subtarget, DAG);
case ISD::VSELECT: return LowerVSELECT(Op, DAG);
case ISD::EXTRACT_VECTOR_ELT: return LowerEXTRACT_VECTOR_ELT(Op, DAG);
case ISD::INSERT_VECTOR_ELT: return LowerINSERT_VECTOR_ELT(Op, DAG);
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/X86/X86InstrAVX512.td
Original file line number Diff line number Diff line change
Expand Up @@ -10543,6 +10543,12 @@ multiclass compress_by_vec_width_lowering<X86VectorVTInfo _, string Name> {
def : Pat<(X86compress (_.VT _.RC:$src), _.ImmAllZerosV, _.KRCWM:$mask),
(!cast<Instruction>(Name#_.ZSuffix#rrkz)
_.KRCWM:$mask, _.RC:$src)>;
def : Pat<(_.VT (vector_compress _.RC:$src, _.KRCWM:$mask, undef)),
(!cast<Instruction>(Name#_.ZSuffix#rrkz)
_.KRCWM:$mask, _.RC:$src)>;
def : Pat<(_.VT (vector_compress _.RC:$src, _.KRCWM:$mask, _.RC:$passthru)),
(!cast<Instruction>(Name#_.ZSuffix#rrk)
_.RC:$passthru, _.KRCWM:$mask, _.RC:$src)>;
}

multiclass compress_by_elt_width<bits<8> opc, string OpcodeStr,
Expand Down
Loading