diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index 9193b5f8994e0..b6a659dde1b15 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -549,16 +549,22 @@ class [[nodiscard]] APInt { /// \returns the low "numBits" bits of this APInt. LLVM_ABI APInt getLoBits(unsigned numBits) const; - /// Determine if two APInts have the same value, after zero-extending - /// one of them (if needed!) to ensure that the bit-widths match. - static bool isSameValue(const APInt &I1, const APInt &I2) { + /// Determine if two APInts have the same value, after zero-extending or + /// sign-extending (if \p SignedCompare) one of them (if needed!) to ensure + /// that the bit-widths match. + static bool isSameValue(const APInt &I1, const APInt &I2, + bool SignedCompare = false) { if (I1.getBitWidth() == I2.getBitWidth()) return I1 == I2; + auto ZExtOrSExt = [SignedCompare](const APInt &I, unsigned BitWidth) { + return SignedCompare ? I.sext(BitWidth) : I.zext(BitWidth); + }; + if (I1.getBitWidth() > I2.getBitWidth()) - return I1 == I2.zext(I1.getBitWidth()); + return I1 == ZExtOrSExt(I2, I1.getBitWidth()); - return I1.zext(I2.getBitWidth()) == I2; + return ZExtOrSExt(I1, I2.getBitWidth()) == I2; } /// Overload to compute a hash_code for an APInt value. diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp index 271d17cb29905..9bd283953f733 100644 --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -14,7 +14,6 @@ #include "llvm/Support/Alignment.h" #include "gtest/gtest.h" #include -#include #include #include @@ -29,6 +28,28 @@ TEST(APIntTest, ValueInit) { EXPECT_TRUE(!Zero.sext(64)); } +TEST(APIntTest, IsSameValue) { + APInt One8(8, 1, /*isSigned=*/false); + APInt Three4(4, 3, /*isSigned=*/false); + EXPECT_FALSE(APInt::isSameValue(One8, Three4, /*SignedCompare=*/false)); + EXPECT_FALSE(APInt::isSameValue(One8, Three4, /*SignedCompare=*/true)); + + APInt Two8(8, 2, /*isSigned=*/false); + APInt Two4(4, 2, /*isSigned=*/false); + EXPECT_TRUE(APInt::isSameValue(Two8, Two4, /*SignedCompare=*/false)); + EXPECT_TRUE(APInt::isSameValue(Two8, Two4, /*SignedCompare=*/true)); + + APInt Seven8(8, 7, /*isSigned=*/false); + APInt Seven3(3, 7, /*isSigned=*/false); + EXPECT_TRUE(APInt::isSameValue(Seven8, Seven3, /*SignedCompare=*/false)); + EXPECT_FALSE(APInt::isSameValue(Seven8, Seven3, /*SignedCompare=*/true)); + + APInt Ones8 = APInt::getAllOnes(8); + APInt Ones4 = APInt::getAllOnes(4); + EXPECT_FALSE(APInt::isSameValue(Ones8, Ones4, /*SignedCompare=*/false)); + EXPECT_TRUE(APInt::isSameValue(Ones8, Ones4, /*SignedCompare=*/true)); +} + // Test that 0^5 == 0 TEST(APIntTest, PowZeroTo5) { APInt Zero = APInt::getZero(32);