Skip to content
Merged
38 changes: 19 additions & 19 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,10 +612,10 @@ bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
SDValue Vector = N->getOperand(0);

// We only care about f16x2 as it's the only real vector type we
// We only care about 16x2 as it's the only real vector type we
// need to deal with.
MVT VT = Vector.getSimpleValueType();
if (!(VT == MVT::v2f16 || VT == MVT::v2bf16))
if (!Isv2x16VT(VT))
return false;
// Find and record all uses of this vector that extract element 0 or 1.
SmallVector<SDNode *, 4> E0, E1;
Expand Down Expand Up @@ -828,6 +828,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
return Opcode_i16;
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2i16:
return Opcode_i32;
case MVT::f32:
return Opcode_f32;
Expand Down Expand Up @@ -909,9 +910,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
// Vector Setting
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
if (SimpleVT.isVector()) {
assert((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
"Unexpected vector type");
// v2f16/v2bf16 is loaded using ld.b32
assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
// v2f16/v2bf16/v2i16 is loaded using ld.b32
fromTypeWidth = 32;
}

Expand Down Expand Up @@ -1061,10 +1061,10 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {

EVT EltVT = N->getValueType(0);

// v8f16 is a special case. PTX doesn't have ld.v8.f16
// instruction. Instead, we split the vector into v2f16 chunks and
// v8x16 is a special case. PTX doesn't have ld.v8.16
// instruction. Instead, we split the vector into v2x16 chunks and
// load them with ld.v4.b32.
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
if (Isv2x16VT(EltVT)) {
assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
EltVT = MVT::i32;
FromType = NVPTX::PTXLdStInstCode::Untyped;
Expand Down Expand Up @@ -1260,12 +1260,13 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
if (EltVT.isVector()) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
// vectors of f16 are loaded/stored as multiples of v2f16 elements.
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16)) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = N->getValueType(0);
NumElts /= 2;
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = N->getValueType(0);
NumElts /= 2;
}
}

Expand Down Expand Up @@ -1678,9 +1679,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
MVT ScalarVT = SimpleVT.getScalarType();
unsigned toTypeWidth = ScalarVT.getSizeInBits();
if (SimpleVT.isVector()) {
assert((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
"Unexpected vector type");
// v2f16 is stored using st.b32
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
// v2x16 is stored using st.b32
toTypeWidth = 32;
}

Expand Down Expand Up @@ -1844,10 +1844,10 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
return false;
}

// v8f16 is a special case. PTX doesn't have st.v8.f16
// instruction. Instead, we split the vector into v2f16 chunks and
// v8x16 is a special case. PTX doesn't have st.v8.x16
// instruction. Instead, we split the vector into v2x16 chunks and
// store them with st.v4.b32.
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
if (Isv2x16VT(EltVT)) {
assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
EltVT = MVT::i32;
ToType = NVPTX::PTXLdStInstCode::Untyped;
Expand Down
Loading