diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 0dce0077bf158..db302d7e52684 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -4177,7 +4177,9 @@ static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q, /// a check for a lossy truncation. /// Folds: /// icmp SrcPred (x & Mask), x to icmp DstPred x, Mask +/// icmp SrcPred (x & ~Mask), ~Mask to icmp DstPred x, ~Mask /// icmp eq/ne (x & ~Mask), 0 to icmp DstPred x, Mask +/// icmp eq/ne (~x | Mask), -1 to icmp DstPred x, Mask /// Where Mask is some pattern that produces all-ones in low bits: /// (-1 >> y) /// ((-1 << y) >> y) <- non-canonical, has extra uses @@ -4189,82 +4191,126 @@ static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q, static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0, Value *Op1, const SimplifyQuery &Q, InstCombiner &IC) { - Value *X, *M; - bool NeedsNot = false; - - auto CheckMask = [&](Value *V, bool Not) { - if (ICmpInst::isSigned(Pred) && !match(V, m_ImmConstant())) - return false; - return isMaskOrZero(V, Not, Q); - }; - - if (match(Op0, m_c_And(m_Specific(Op1), m_Value(M))) && - CheckMask(M, /*Not*/ false)) { - X = Op1; - } else if (match(Op1, m_Zero()) && ICmpInst::isEquality(Pred) && - match(Op0, m_OneUse(m_And(m_Value(X), m_Value(M))))) { - NeedsNot = true; - if (IC.isFreeToInvert(X, X->hasOneUse()) && CheckMask(X, /*Not*/ true)) - std::swap(X, M); - else if (!IC.isFreeToInvert(M, M->hasOneUse()) || - !CheckMask(M, /*Not*/ true)) - return nullptr; - } else { - return nullptr; - } ICmpInst::Predicate DstPred; switch (Pred) { case ICmpInst::Predicate::ICMP_EQ: - // x & (-1 >> y) == x -> x u<= (-1 >> y) + // x & Mask == x + // x & ~Mask == 0 + // ~x | Mask == -1 + // -> x u<= Mask + // x & ~Mask == ~Mask + // -> ~Mask u<= x DstPred = ICmpInst::Predicate::ICMP_ULE; break; case ICmpInst::Predicate::ICMP_NE: - // x & (-1 >> y) != x -> x u> (-1 >> y) + // x & Mask != x + // x & ~Mask != 0 + // ~x | Mask != -1 + // -> x u> Mask + // x & ~Mask != ~Mask + // -> ~Mask u> x DstPred = ICmpInst::Predicate::ICMP_UGT; break; case ICmpInst::Predicate::ICMP_ULT: - // x & (-1 >> y) u< x -> x u> (-1 >> y) - // x u> x & (-1 >> y) -> x u> (-1 >> y) + // x & Mask u< x + // -> x u> Mask + // x & ~Mask u< ~Mask + // -> ~Mask u> x DstPred = ICmpInst::Predicate::ICMP_UGT; break; case ICmpInst::Predicate::ICMP_UGE: - // x & (-1 >> y) u>= x -> x u<= (-1 >> y) - // x u<= x & (-1 >> y) -> x u<= (-1 >> y) + // x & Mask u>= x + // -> x u<= Mask + // x & ~Mask u>= ~Mask + // -> ~Mask u<= x DstPred = ICmpInst::Predicate::ICMP_ULE; break; case ICmpInst::Predicate::ICMP_SLT: - // x & (-1 >> y) s< x -> x s> (-1 >> y) - // x s> x & (-1 >> y) -> x s> (-1 >> y) - if (!match(M, m_Constant())) // Can not do this fold with non-constant. - return nullptr; - if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. - return nullptr; + // x & Mask s< x [iff Mask s>= 0] + // -> x s> Mask + // x & ~Mask s< ~Mask [iff ~Mask != 0] + // -> ~Mask s> x DstPred = ICmpInst::Predicate::ICMP_SGT; break; case ICmpInst::Predicate::ICMP_SGE: - // x & (-1 >> y) s>= x -> x s<= (-1 >> y) - // x s<= x & (-1 >> y) -> x s<= (-1 >> y) - if (!match(M, m_Constant())) // Can not do this fold with non-constant. - return nullptr; - if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. - return nullptr; + // x & Mask s>= x [iff Mask s>= 0] + // -> x s<= Mask + // x & ~Mask s>= ~Mask [iff ~Mask != 0] + // -> ~Mask s<= x DstPred = ICmpInst::Predicate::ICMP_SLE; break; - case ICmpInst::Predicate::ICMP_SGT: - case ICmpInst::Predicate::ICMP_SLE: - return nullptr; - case ICmpInst::Predicate::ICMP_UGT: - case ICmpInst::Predicate::ICMP_ULE: - llvm_unreachable("Instsimplify took care of commut. variant"); - break; default: - llvm_unreachable("All possible folds are handled."); + // We don't support sgt,sle + // ult/ugt are simplified to true/false respectively. + return nullptr; } - // The mask value may be a vector constant that has undefined elements. But it - // may not be safe to propagate those undefs into the new compare, so replace - // those elements by copying an existing, defined, and safe scalar constant. + Value *X, *M; + // Put search code in lambda for early positive returns. + auto IsLowBitMask = [&]() { + if (match(Op0, m_c_And(m_Specific(Op1), m_Value(M)))) { + X = Op1; + // Look for: x & Mask pred x + if (isMaskOrZero(M, /*Not=*/false, Q)) { + return !ICmpInst::isSigned(Pred) || + (match(M, m_NonNegative()) || isKnownNonNegative(M, Q)); + } + + // Look for: x & ~Mask pred ~Mask + if (isMaskOrZero(X, /*Not=*/true, Q)) { + return !ICmpInst::isSigned(Pred) || + isKnownNonZero(X, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + } + return false; + } + if (ICmpInst::isEquality(Pred) && match(Op1, m_AllOnes()) && + match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(M))))) { + + auto Check = [&]() { + // Look for: ~x | Mask == -1 + if (isMaskOrZero(M, /*Not=*/false, Q)) { + if (Value *NotX = + IC.getFreelyInverted(X, X->hasOneUse(), &IC.Builder)) { + X = NotX; + return true; + } + } + return false; + }; + if (Check()) + return true; + std::swap(X, M); + return Check(); + } + if (ICmpInst::isEquality(Pred) && match(Op1, m_Zero()) && + match(Op0, m_OneUse(m_And(m_Value(X), m_Value(M))))) { + auto Check = [&]() { + // Look for: x & ~Mask == 0 + if (isMaskOrZero(M, /*Not=*/true, Q)) { + if (Value *NotM = + IC.getFreelyInverted(M, M->hasOneUse(), &IC.Builder)) { + M = NotM; + return true; + } + } + return false; + }; + if (Check()) + return true; + std::swap(X, M); + return Check(); + } + return false; + }; + + if (!IsLowBitMask()) + return nullptr; + + // The mask value may be a vector constant that has undefined elements. But + // it may not be safe to propagate those undefs into the new compare, so + // replace those elements by copying an existing, defined, and safe scalar + // constant. Type *OpTy = M->getType(); auto *VecC = dyn_cast(M); auto *OpVTy = dyn_cast(OpTy); @@ -4280,8 +4326,6 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0, M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant); } - if (NeedsNot) - M = IC.Builder.CreateNot(M); return IC.Builder.CreateICmp(DstPred, X, M); } diff --git a/llvm/test/Transforms/InstCombine/icmp-and-lowbit-mask.ll b/llvm/test/Transforms/InstCombine/icmp-and-lowbit-mask.ll index 410b6c29187b2..5de3e89d7027a 100644 --- a/llvm/test/Transforms/InstCombine/icmp-and-lowbit-mask.ll +++ b/llvm/test/Transforms/InstCombine/icmp-and-lowbit-mask.ll @@ -470,6 +470,24 @@ define i1 @src_is_notmask_x_xor_neg_x(i8 %x_in, i8 %y, i1 %cond) { ret i1 %r } +define i1 @src_is_notmask_x_xor_neg_x_inv(i8 %x_in, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_is_notmask_x_xor_neg_x_inv( +; CHECK-NEXT: [[X:%.*]] = xor i8 [[X_IN:%.*]], 123 +; CHECK-NEXT: [[NEG_Y:%.*]] = add i8 [[Y:%.*]], -1 +; CHECK-NEXT: [[NOTMASK0:%.*]] = xor i8 [[NEG_Y]], [[Y]] +; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[COND:%.*]], i8 [[NOTMASK0]], i8 7 +; CHECK-NEXT: [[R:%.*]] = icmp ule i8 [[X]], [[TMP3]] +; CHECK-NEXT: ret i1 [[R]] +; + %x = xor i8 %x_in, 123 + %neg_y = sub i8 0, %y + %nmask0 = xor i8 %y, %neg_y + %notmask = select i1 %cond, i8 %nmask0, i8 -8 + %and = and i8 %notmask, %x + %r = icmp eq i8 %and, 0 + ret i1 %r +} + define i1 @src_is_notmask_shl_fail_multiuse_invert(i8 %x_in, i8 %y, i1 %cond) { ; CHECK-LABEL: @src_is_notmask_shl_fail_multiuse_invert( ; CHECK-NEXT: [[X:%.*]] = xor i8 [[X_IN:%.*]], 122 @@ -655,3 +673,238 @@ define i1 @src_is_mask_const_sge(i8 %x_in) { %r = icmp sge i8 %and, %x ret i1 %r } + +define i1 @src_x_and_mask_slt(i8 %x, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_x_and_mask_slt( +; CHECK-NEXT: [[MASK0:%.*]] = lshr i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0 +; CHECK-NEXT: [[MASK_POS:%.*]] = icmp sgt i8 [[MASK]], -1 +; CHECK-NEXT: call void @llvm.assume(i1 [[MASK_POS]]) +; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[MASK]], [[X:%.*]] +; CHECK-NEXT: ret i1 [[R]] +; + %mask0 = lshr i8 -1, %y + %mask = select i1 %cond, i8 %mask0, i8 0 + %mask_pos = icmp sge i8 %mask, 0 + call void @llvm.assume(i1 %mask_pos) + %and = and i8 %x, %mask + %r = icmp slt i8 %and, %x + ret i1 %r +} + +define i1 @src_x_and_mask_sge(i8 %x, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_x_and_mask_sge( +; CHECK-NEXT: [[MASK0:%.*]] = lshr i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0 +; CHECK-NEXT: [[MASK_POS:%.*]] = icmp sgt i8 [[MASK]], -1 +; CHECK-NEXT: call void @llvm.assume(i1 [[MASK_POS]]) +; CHECK-NEXT: [[R:%.*]] = icmp sge i8 [[MASK]], [[X:%.*]] +; CHECK-NEXT: ret i1 [[R]] +; + %mask0 = lshr i8 -1, %y + %mask = select i1 %cond, i8 %mask0, i8 0 + %mask_pos = icmp sge i8 %mask, 0 + call void @llvm.assume(i1 %mask_pos) + %and = and i8 %x, %mask + %r = icmp sge i8 %and, %x + ret i1 %r +} + +define i1 @src_x_and_mask_slt_fail_maybe_neg(i8 %x, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_x_and_mask_slt_fail_maybe_neg( +; CHECK-NEXT: [[MASK0:%.*]] = lshr i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0 +; CHECK-NEXT: [[AND:%.*]] = and i8 [[MASK]], [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[AND]], [[X]] +; CHECK-NEXT: ret i1 [[R]] +; + %mask0 = lshr i8 -1, %y + %mask = select i1 %cond, i8 %mask0, i8 0 + %and = and i8 %x, %mask + %r = icmp slt i8 %and, %x + ret i1 %r +} + +define i1 @src_x_and_mask_sge_fail_maybe_neg(i8 %x, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_x_and_mask_sge_fail_maybe_neg( +; CHECK-NEXT: [[MASK0:%.*]] = lshr i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0 +; CHECK-NEXT: [[AND:%.*]] = and i8 [[MASK]], [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = icmp sge i8 [[AND]], [[X]] +; CHECK-NEXT: ret i1 [[R]] +; + %mask0 = lshr i8 -1, %y + %mask = select i1 %cond, i8 %mask0, i8 0 + %and = and i8 %x, %mask + %r = icmp sge i8 %and, %x + ret i1 %r +} + +define i1 @src_x_and_nmask_eq(i8 %x, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_x_and_nmask_eq( +; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[R1:%.*]] = icmp ule i8 [[NOT_MASK0]], [[X:%.*]] +; CHECK-NEXT: [[NOT_COND:%.*]] = xor i1 [[COND:%.*]], true +; CHECK-NEXT: [[R:%.*]] = select i1 [[NOT_COND]], i1 true, i1 [[R1]] +; CHECK-NEXT: ret i1 [[R]] +; + %not_mask0 = shl i8 -1, %y + %not_mask = select i1 %cond, i8 %not_mask0, i8 0 + %and = and i8 %x, %not_mask + %r = icmp eq i8 %not_mask, %and + ret i1 %r +} + +define i1 @src_x_and_nmask_ne(i8 %x, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_x_and_nmask_ne( +; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[R1:%.*]] = icmp ugt i8 [[NOT_MASK0]], [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[COND:%.*]], i1 [[R1]], i1 false +; CHECK-NEXT: ret i1 [[R]] +; + %not_mask0 = shl i8 -1, %y + %not_mask = select i1 %cond, i8 %not_mask0, i8 0 + %and = and i8 %x, %not_mask + %r = icmp ne i8 %and, %not_mask + ret i1 %r +} + +define i1 @src_x_and_nmask_ult(i8 %x, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_x_and_nmask_ult( +; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[R1:%.*]] = icmp ugt i8 [[NOT_MASK0]], [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[COND:%.*]], i1 [[R1]], i1 false +; CHECK-NEXT: ret i1 [[R]] +; + %not_mask0 = shl i8 -1, %y + %not_mask = select i1 %cond, i8 %not_mask0, i8 0 + %and = and i8 %x, %not_mask + %r = icmp ult i8 %and, %not_mask + ret i1 %r +} + +define i1 @src_x_and_nmask_uge(i8 %x, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_x_and_nmask_uge( +; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[R1:%.*]] = icmp ule i8 [[NOT_MASK0]], [[X:%.*]] +; CHECK-NEXT: [[NOT_COND:%.*]] = xor i1 [[COND:%.*]], true +; CHECK-NEXT: [[R:%.*]] = select i1 [[NOT_COND]], i1 true, i1 [[R1]] +; CHECK-NEXT: ret i1 [[R]] +; + %not_mask0 = shl i8 -1, %y + %not_mask = select i1 %cond, i8 %not_mask0, i8 0 + %and = and i8 %x, %not_mask + %r = icmp uge i8 %and, %not_mask + ret i1 %r +} + +define i1 @src_x_and_nmask_slt(i8 %x, i8 %y) { +; CHECK-LABEL: @src_x_and_nmask_slt( +; CHECK-NEXT: [[NOT_MASK:%.*]] = shl nsw i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = icmp sgt i8 [[NOT_MASK]], [[X:%.*]] +; CHECK-NEXT: ret i1 [[R]] +; + %not_mask = shl i8 -1, %y + %and = and i8 %x, %not_mask + %r = icmp slt i8 %and, %not_mask + ret i1 %r +} + +define i1 @src_x_and_nmask_sge(i8 %x, i8 %y) { +; CHECK-LABEL: @src_x_and_nmask_sge( +; CHECK-NEXT: [[NOT_MASK:%.*]] = shl nsw i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = icmp sle i8 [[NOT_MASK]], [[X:%.*]] +; CHECK-NEXT: ret i1 [[R]] +; + %not_mask = shl i8 -1, %y + %and = and i8 %x, %not_mask + %r = icmp sge i8 %and, %not_mask + ret i1 %r +} + +define i1 @src_x_and_nmask_slt_fail_maybe_z(i8 %x, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_x_and_nmask_slt_fail_maybe_z( +; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[NOT_MASK:%.*]] = select i1 [[COND:%.*]], i8 [[NOT_MASK0]], i8 0 +; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[AND]], [[NOT_MASK]] +; CHECK-NEXT: ret i1 [[R]] +; + %not_mask0 = shl i8 -1, %y + %not_mask = select i1 %cond, i8 %not_mask0, i8 0 + %and = and i8 %x, %not_mask + %r = icmp slt i8 %and, %not_mask + ret i1 %r +} + +define i1 @src_x_and_nmask_sge_fail_maybe_z(i8 %x, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_x_and_nmask_sge_fail_maybe_z( +; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[NOT_MASK:%.*]] = select i1 [[COND:%.*]], i8 [[NOT_MASK0]], i8 0 +; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = icmp sge i8 [[AND]], [[NOT_MASK]] +; CHECK-NEXT: ret i1 [[R]] +; + %not_mask0 = shl i8 -1, %y + %not_mask = select i1 %cond, i8 %not_mask0, i8 0 + %and = and i8 %x, %not_mask + %r = icmp sge i8 %and, %not_mask + ret i1 %r +} + +define i1 @src_x_or_mask_eq(i8 %x, i8 %y, i8 %z, i1 %c2, i1 %cond) { +; CHECK-LABEL: @src_x_or_mask_eq( +; CHECK-NEXT: [[MASK0:%.*]] = lshr i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0 +; CHECK-NEXT: [[TMP1:%.*]] = xor i8 [[X:%.*]], -124 +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[C2:%.*]], i8 [[TMP1]], i8 -46 +; CHECK-NEXT: [[TMP3:%.*]] = call i8 @llvm.umax.i8(i8 [[Z:%.*]], i8 [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = add i8 [[TMP3]], -12 +; CHECK-NEXT: [[R:%.*]] = icmp ule i8 [[TMP4]], [[MASK]] +; CHECK-NEXT: ret i1 [[R]] +; + %mask0 = lshr i8 -1, %y + %mask = select i1 %cond, i8 %mask0, i8 0 + %nx = xor i8 %x, 123 + %nx_c = select i1 %c2, i8 %nx, i8 45 + %nz = xor i8 %z, -1 + %nx_cc = call i8 @llvm.umin.i8(i8 %nz, i8 %nx_c) + %nx_ccc = add i8 %nx_cc, 12 + %or = or i8 %nx_ccc, %mask + %r = icmp eq i8 %or, -1 + ret i1 %r +} + +define i1 @src_x_or_mask_ne(i8 %x, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_x_or_mask_ne( +; CHECK-NEXT: [[MASK0:%.*]] = lshr i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0 +; CHECK-NEXT: [[R:%.*]] = icmp ult i8 [[MASK]], [[X:%.*]] +; CHECK-NEXT: ret i1 [[R]] +; + %mask0 = lshr i8 -1, %y + %mask = select i1 %cond, i8 %mask0, i8 0 + %nx = xor i8 %x, -1 + %or = or i8 %mask, %nx + %r = icmp ne i8 %or, -1 + ret i1 %r +} + +define i1 @src_x_or_mask_ne_fail_multiuse(i8 %x, i8 %y, i1 %cond) { +; CHECK-LABEL: @src_x_or_mask_ne_fail_multiuse( +; CHECK-NEXT: [[MASK0:%.*]] = lshr i8 -1, [[Y:%.*]] +; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0 +; CHECK-NEXT: [[NX:%.*]] = xor i8 [[X:%.*]], -1 +; CHECK-NEXT: [[OR:%.*]] = or i8 [[MASK]], [[NX]] +; CHECK-NEXT: call void @use.i8(i8 [[OR]]) +; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[OR]], -1 +; CHECK-NEXT: ret i1 [[R]] +; + %mask0 = lshr i8 -1, %y + %mask = select i1 %cond, i8 %mask0, i8 0 + %nx = xor i8 %x, -1 + %or = or i8 %mask, %nx + call void @use.i8(i8 %or) + %r = icmp ne i8 %or, -1 + ret i1 %r +}