diff --git a/include/llvm/ADT/APSInt.h b/include/llvm/ADT/APSInt.h index a6693f7992c..91ccda22f2f 100644 --- a/include/llvm/ADT/APSInt.h +++ b/include/llvm/ADT/APSInt.h @@ -62,6 +62,12 @@ public: } using APInt::toString; + /// \brief Get the correctly-extended \c int64_t value. + int64_t getExtValue() const { + assert(getMinSignedBits() <= 64 && "Too many bits for int64_t"); + return isSigned() ? getSExtValue() : getZExtValue(); + } + APSInt LLVM_ATTRIBUTE_UNUSED_RESULT trunc(uint32_t width) const { return APSInt(APInt::trunc(width), IsUnsigned); } @@ -133,14 +139,27 @@ public: assert(IsUnsigned == RHS.IsUnsigned && "Signedness mismatch!"); return eq(RHS); } - inline bool operator==(int64_t RHS) const { - return isSameValue(*this, APSInt(APInt(64, RHS), true)); - } inline bool operator!=(const APSInt& RHS) const { return !((*this) == RHS); } - inline bool operator!=(int64_t RHS) const { - return !((*this) == RHS); + + bool operator==(int64_t RHS) const { + return compareValues(*this, get(RHS)) == 0; + } + bool operator!=(int64_t RHS) const { + return compareValues(*this, get(RHS)) != 0; + } + bool operator<=(int64_t RHS) const { + return compareValues(*this, get(RHS)) <= 0; + } + bool operator>=(int64_t RHS) const { + return compareValues(*this, get(RHS)) >= 0; + } + bool operator<(int64_t RHS) const { + return compareValues(*this, get(RHS)) < 0; + } + bool operator>(int64_t RHS) const { + return compareValues(*this, get(RHS)) > 0; } // The remaining operators just wrap the logic of APInt, but retain the @@ -260,37 +279,49 @@ public: /// \brief Determine if two APSInts have the same value, zero- or /// sign-extending as needed. static bool isSameValue(const APSInt &I1, const APSInt &I2) { + return !compareValues(I1, I2); + } + + /// \brief Compare underlying values of two numbers. + static int compareValues(const APSInt &I1, const APSInt &I2) { if (I1.getBitWidth() == I2.getBitWidth() && I1.isSigned() == I2.isSigned()) - return I1 == I2; + return I1 == I2 ? 0 : I1 > I2 ? 1 : -1; // Check for a bit-width mismatch. if (I1.getBitWidth() > I2.getBitWidth()) - return isSameValue(I1, I2.extend(I1.getBitWidth())); + return compareValues(I1, I2.extend(I1.getBitWidth())); else if (I2.getBitWidth() > I1.getBitWidth()) - return isSameValue(I1.extend(I2.getBitWidth()), I2); - - assert(I1.isSigned() != I2.isSigned()); + return compareValues(I1.extend(I2.getBitWidth()), I2); // We have a signedness mismatch. Check for negative values and do an - // unsigned compare if signs match. - if ((I1.isSigned() && I1.isNegative()) || - (!I1.isSigned() && I2.isNegative())) - return false; + // unsigned compare if both are positive. + if (I1.isSigned()) { + assert(!I2.isSigned() && "Expected signed mismatch"); + if (I1.isNegative()) + return -1; + } else { + assert(I2.isSigned() && "Expected signed mismatch"); + if (I2.isNegative()) + return 1; + } - return I1.eq(I2); + return I1.eq(I2) ? 0 : I1.ugt(I2) ? 1 : -1; } + static APSInt get(int64_t X) { return APSInt(APInt(64, X), false); } + static APSInt getUnsigned(uint64_t X) { return APSInt(APInt(64, X), true); } + /// Profile - Used to insert APSInt objects, or objects that contain APSInt /// objects, into FoldingSets. void Profile(FoldingSetNodeID& ID) const; }; -inline bool operator==(int64_t V1, const APSInt& V2) { - return V2 == V1; -} -inline bool operator!=(int64_t V1, const APSInt& V2) { - return V2 != V1; -} +inline bool operator==(int64_t V1, const APSInt &V2) { return V2 == V1; } +inline bool operator!=(int64_t V1, const APSInt &V2) { return V2 != V1; } +inline bool operator<=(int64_t V1, const APSInt &V2) { return V2 >= V1; } +inline bool operator>=(int64_t V1, const APSInt &V2) { return V2 <= V1; } +inline bool operator<(int64_t V1, const APSInt &V2) { return V2 > V1; } +inline bool operator>(int64_t V1, const APSInt &V2) { return V2 < V1; } inline raw_ostream &operator<<(raw_ostream &OS, const APSInt &I) { I.print(OS, I.isSigned()); diff --git a/unittests/ADT/APSIntTest.cpp b/unittests/ADT/APSIntTest.cpp index eef9c8a90af..5e4e87440bd 100644 --- a/unittests/ADT/APSIntTest.cpp +++ b/unittests/ADT/APSIntTest.cpp @@ -41,4 +41,106 @@ TEST(APSIntTest, MoveTest) { EXPECT_EQ(Bits, A.getRawData()); // Verify that "Wide" was really moved. } +TEST(APSIntTest, get) { + EXPECT_TRUE(APSInt::get(7).isSigned()); + EXPECT_EQ(64u, APSInt::get(7).getBitWidth()); + EXPECT_EQ(7u, APSInt::get(7).getZExtValue()); + EXPECT_EQ(7, APSInt::get(7).getSExtValue()); + EXPECT_TRUE(APSInt::get(-7).isSigned()); + EXPECT_EQ(64u, APSInt::get(-7).getBitWidth()); + EXPECT_EQ(-7, APSInt::get(-7).getSExtValue()); + EXPECT_EQ(UINT64_C(0) - 7, APSInt::get(-7).getZExtValue()); +} + +TEST(APSIntTest, getUnsigned) { + EXPECT_TRUE(APSInt::getUnsigned(7).isUnsigned()); + EXPECT_EQ(64u, APSInt::getUnsigned(7).getBitWidth()); + EXPECT_EQ(7u, APSInt::getUnsigned(7).getZExtValue()); + EXPECT_EQ(7, APSInt::getUnsigned(7).getSExtValue()); + EXPECT_TRUE(APSInt::getUnsigned(-7).isUnsigned()); + EXPECT_EQ(64u, APSInt::getUnsigned(-7).getBitWidth()); + EXPECT_EQ(-7, APSInt::getUnsigned(-7).getSExtValue()); + EXPECT_EQ(UINT64_C(0) - 7, APSInt::getUnsigned(-7).getZExtValue()); +} + +TEST(APSIntTest, getExtValue) { + EXPECT_TRUE(APSInt(APInt(3, 7), true).isUnsigned()); + EXPECT_TRUE(APSInt(APInt(3, 7), false).isSigned()); + EXPECT_TRUE(APSInt(APInt(4, 7), true).isUnsigned()); + EXPECT_TRUE(APSInt(APInt(4, 7), false).isSigned()); + EXPECT_TRUE(APSInt(APInt(4, -7), true).isUnsigned()); + EXPECT_TRUE(APSInt(APInt(4, -7), false).isSigned()); + EXPECT_EQ(7, APSInt(APInt(3, 7), true).getExtValue()); + EXPECT_EQ(-1, APSInt(APInt(3, 7), false).getExtValue()); + EXPECT_EQ(7, APSInt(APInt(4, 7), true).getExtValue()); + EXPECT_EQ(7, APSInt(APInt(4, 7), false).getExtValue()); + EXPECT_EQ(9, APSInt(APInt(4, -7), true).getExtValue()); + EXPECT_EQ(-7, APSInt(APInt(4, -7), false).getExtValue()); +} + +TEST(APSIntTest, compareValues) { + auto U = [](uint64_t V) { return APSInt::getUnsigned(V); }; + auto S = [](int64_t V) { return APSInt::get(V); }; + + // Bit-width matches and is-signed. + EXPECT_TRUE(APSInt::compareValues(S(7), S(8)) < 0); + EXPECT_TRUE(APSInt::compareValues(S(8), S(7)) > 0); + EXPECT_TRUE(APSInt::compareValues(S(7), S(7)) == 0); + EXPECT_TRUE(APSInt::compareValues(S(-7), S(8)) < 0); + EXPECT_TRUE(APSInt::compareValues(S(8), S(-7)) > 0); + EXPECT_TRUE(APSInt::compareValues(S(-7), S(-7)) == 0); + EXPECT_TRUE(APSInt::compareValues(S(-7), S(-8)) > 0); + EXPECT_TRUE(APSInt::compareValues(S(-8), S(-7)) < 0); + EXPECT_TRUE(APSInt::compareValues(S(-7), S(-7)) == 0); + + // Bit-width matches and not is-signed. + EXPECT_TRUE(APSInt::compareValues(U(7), U(8)) < 0); + EXPECT_TRUE(APSInt::compareValues(U(8), U(7)) > 0); + EXPECT_TRUE(APSInt::compareValues(U(7), U(7)) == 0); + + // Bit-width matches and mixed signs. + EXPECT_TRUE(APSInt::compareValues(U(7), S(8)) < 0); + EXPECT_TRUE(APSInt::compareValues(U(8), S(7)) > 0); + EXPECT_TRUE(APSInt::compareValues(U(7), S(7)) == 0); + EXPECT_TRUE(APSInt::compareValues(U(8), S(-7)) > 0); + + // Bit-width mismatch and is-signed. + EXPECT_TRUE(APSInt::compareValues(S(7).trunc(32), S(8)) < 0); + EXPECT_TRUE(APSInt::compareValues(S(8).trunc(32), S(7)) > 0); + EXPECT_TRUE(APSInt::compareValues(S(7).trunc(32), S(7)) == 0); + EXPECT_TRUE(APSInt::compareValues(S(-7).trunc(32), S(8)) < 0); + EXPECT_TRUE(APSInt::compareValues(S(8).trunc(32), S(-7)) > 0); + EXPECT_TRUE(APSInt::compareValues(S(-7).trunc(32), S(-7)) == 0); + EXPECT_TRUE(APSInt::compareValues(S(-7).trunc(32), S(-8)) > 0); + EXPECT_TRUE(APSInt::compareValues(S(-8).trunc(32), S(-7)) < 0); + EXPECT_TRUE(APSInt::compareValues(S(-7).trunc(32), S(-7)) == 0); + EXPECT_TRUE(APSInt::compareValues(S(7), S(8).trunc(32)) < 0); + EXPECT_TRUE(APSInt::compareValues(S(8), S(7).trunc(32)) > 0); + EXPECT_TRUE(APSInt::compareValues(S(7), S(7).trunc(32)) == 0); + EXPECT_TRUE(APSInt::compareValues(S(-7), S(8).trunc(32)) < 0); + EXPECT_TRUE(APSInt::compareValues(S(8), S(-7).trunc(32)) > 0); + EXPECT_TRUE(APSInt::compareValues(S(-7), S(-7).trunc(32)) == 0); + EXPECT_TRUE(APSInt::compareValues(S(-7), S(-8).trunc(32)) > 0); + EXPECT_TRUE(APSInt::compareValues(S(-8), S(-7).trunc(32)) < 0); + EXPECT_TRUE(APSInt::compareValues(S(-7), S(-7).trunc(32)) == 0); + + // Bit-width mismatch and not is-signed. + EXPECT_TRUE(APSInt::compareValues(U(7), U(8).trunc(32)) < 0); + EXPECT_TRUE(APSInt::compareValues(U(8), U(7).trunc(32)) > 0); + EXPECT_TRUE(APSInt::compareValues(U(7), U(7).trunc(32)) == 0); + EXPECT_TRUE(APSInt::compareValues(U(7).trunc(32), U(8)) < 0); + EXPECT_TRUE(APSInt::compareValues(U(8).trunc(32), U(7)) > 0); + EXPECT_TRUE(APSInt::compareValues(U(7).trunc(32), U(7)) == 0); + + // Bit-width mismatch and mixed signs. + EXPECT_TRUE(APSInt::compareValues(U(7).trunc(32), S(8)) < 0); + EXPECT_TRUE(APSInt::compareValues(U(8).trunc(32), S(7)) > 0); + EXPECT_TRUE(APSInt::compareValues(U(7).trunc(32), S(7)) == 0); + EXPECT_TRUE(APSInt::compareValues(U(8).trunc(32), S(-7)) > 0); + EXPECT_TRUE(APSInt::compareValues(U(7), S(8).trunc(32)) < 0); + EXPECT_TRUE(APSInt::compareValues(U(8), S(7).trunc(32)) > 0); + EXPECT_TRUE(APSInt::compareValues(U(7), S(7).trunc(32)) == 0); + EXPECT_TRUE(APSInt::compareValues(U(8), S(-7).trunc(32)) > 0); +} + }