Skip to content

Commit 70adaf7

Browse files
- Return a concat_vector created from the results of whilelo_x2 from performActiveLaneMaskCombine
- Add tests for the 4 extracts case which will use ptest & reinterpret_cast - Remove changes to canRemovePTestInstr
1 parent 64b6d61 commit 70adaf7

3 files changed

Lines changed: 42 additions & 40 deletions

File tree

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18774,7 +18774,7 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
1877418774
static SDValue
1877518775
performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1877618776
const AArch64Subtarget *ST) {
18777-
if (DCI.isBeforeLegalize())
18777+
if (DCI.isBeforeLegalize() && !!DCI.isBeforeLegalizeOps())
1877818778
return SDValue();
1877918779

1878018780
if (SDValue While = optimizeIncrementingWhile(N, DCI.DAG, /*IsSigned=*/false,
@@ -18793,7 +18793,7 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1879318793
});
1879418794

1879518795
auto MaskEC = N->getValueType(0).getVectorElementCount();
18796-
if (!MaskEC.isKnownMultipleOf(NumExts))
18796+
if (NumExts == 0 || !MaskEC.isKnownMultipleOf(NumExts))
1879718797
return SDValue();
1879818798

1879918799
ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
@@ -18802,12 +18802,8 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1880218802

1880318803
SmallVector<SDNode *> Extracts(NumExts, nullptr);
1880418804
for (SDNode *Use : N->users()) {
18805-
if (Use->getOpcode() == AArch64ISD::PTEST_FIRST ||
18806-
Use->getOpcode() == AArch64ISD::REINTERPRET_CAST)
18807-
continue;
18808-
1880918805
if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
18810-
return SDValue();
18806+
continue;
1881118807

1881218808
// Ensure the extract type is correct (e.g. if NumExts is 4 and
1881318809
// the mask return type is nxv8i1, each extract should be nxv2i1.
@@ -18842,11 +18838,10 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1884218838
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
1884318839
DCI.CombineTo(Extracts[0], R.getValue(0));
1884418840
DCI.CombineTo(Extracts[1], R.getValue(1));
18841+
SmallVector<SDValue> Results = {R.getValue(0), R.getValue(1)};
1884518842

18846-
if (NumExts == 2) {
18847-
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
18848-
return SDValue(SDValue(N, 0));
18849-
}
18843+
if (NumExts == 2)
18844+
return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Results);
1885018845

1885118846
auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
1885218847
for (unsigned I = 2; I < NumExts; I += 2) {
@@ -18855,10 +18850,11 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1885518850
R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
1885618851
DCI.CombineTo(Extracts[I], R.getValue(0));
1885718852
DCI.CombineTo(Extracts[I + 1], R.getValue(1));
18853+
Results.push_back(R.getValue(0));
18854+
Results.push_back(R.getValue(1));
1885818855
}
1885918856

18860-
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
18861-
return SDValue(N, 0);
18857+
return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Results);
1886218858
}
1886318859

1886418860
// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,19 +1495,13 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
14951495
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
14961496
return PredOpcode;
14971497

1498-
if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) {
1499-
auto PTestOp = MRI->getUniqueVRegDef(PTest->getOperand(1).getReg());
1500-
if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && PTestOp->isCopy() &&
1501-
PTestOp->getOperand(1).getSubReg() == AArch64::psub0)
1502-
return PredOpcode;
1503-
1504-
// For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
1505-
// redundant since WHILE performs an implicit PTEST with an all active
1506-
// mask.
1507-
if (getElementSizeForOpcode(MaskOpcode) ==
1508-
getElementSizeForOpcode(PredOpcode))
1509-
return PredOpcode;
1510-
}
1498+
// For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
1499+
// redundant since WHILE performs an implicit PTEST with an all active
1500+
// mask.
1501+
if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
1502+
getElementSizeForOpcode(MaskOpcode) ==
1503+
getElementSizeForOpcode(PredOpcode))
1504+
return PredOpcode;
15111505

15121506
return {};
15131507
}

llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,9 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
327327
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
328328
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
329329
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
330+
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.b
331+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.b, p0.b, p1.b
332+
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
330333
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
331334
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
332335
; CHECK-SVE2p1-SME2-NEXT: b use
@@ -365,6 +368,9 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
365368
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
366369
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
367370
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
371+
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.h
372+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.h, p0.h, p1.h
373+
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
368374
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
369375
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
370376
; CHECK-SVE2p1-SME2-NEXT: b use
@@ -403,15 +409,18 @@ define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
403409
;
404410
; CHECK-SVE2p1-SME2-LABEL: test_4x4bit_mask_with_extracts_and_ptest:
405411
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
406-
; CHECK-SVE2p1-SME2-NEXT: whilelo p0.b, x0, x1
412+
; CHECK-SVE2p1-SME2-NEXT: cnth x8
413+
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
414+
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
415+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
416+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
417+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p0.h, p1.h
418+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p2.h, p3.h
419+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p4.b, p5.b
420+
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.b
421+
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
407422
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2
408423
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
409-
; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p0.b
410-
; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p0.b
411-
; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
412-
; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
413-
; CHECK-SVE2p1-SME2-NEXT: punpklo p2.h, p3.b
414-
; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p3.b
415424
; CHECK-SVE2p1-SME2-NEXT: b use
416425
; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end
417426
; CHECK-SVE2p1-SME2-NEXT: ret
@@ -450,15 +459,18 @@ define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
450459
;
451460
; CHECK-SVE2p1-SME2-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts:
452461
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
453-
; CHECK-SVE2p1-SME2-NEXT: whilelo p0.h, x0, x1
462+
; CHECK-SVE2p1-SME2-NEXT: cntw x8
463+
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
464+
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
465+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
466+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
467+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p0.s, p1.s
468+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p2.s, p3.s
469+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p4.h, p5.h
470+
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.h
471+
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
454472
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2
455473
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
456-
; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p0.b
457-
; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p0.b
458-
; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
459-
; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
460-
; CHECK-SVE2p1-SME2-NEXT: punpklo p2.h, p3.b
461-
; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p3.b
462474
; CHECK-SVE2p1-SME2-NEXT: b use
463475
; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end
464476
; CHECK-SVE2p1-SME2-NEXT: ret

0 commit comments

Comments
 (0)