From 998ae69abe6f1651ddca6046e51ff987b55e79c5 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Fri, 26 Dec 2014 09:10:14 +0000 Subject: [PATCH] InstCombe: Infer nsw for multiplies We already utilize this logic for reducing overflow intrinsics, it makes sense to reuse it for normal multiplies as well. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@224847 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/InstCombine/InstCombine.h | 5 + .../InstCombine/InstCombineAddSub.cpp | 119 ++++++------------ .../InstCombine/InstCombineAndOrXor.cpp | 3 +- .../InstCombine/InstCombineMulDivRem.cpp | 47 +++++++ test/Transforms/InstCombine/mul.ll | 12 ++ 5 files changed, 100 insertions(+), 86 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombine.h b/lib/Transforms/InstCombine/InstCombine.h index 326bf8f726d..d6eb6d42d57 100644 --- a/lib/Transforms/InstCombine/InstCombine.h +++ b/lib/Transforms/InstCombine/InstCombine.h @@ -382,6 +382,11 @@ public: Instruction *CxtI = nullptr) const { return llvm::ComputeNumSignBits(Op, DL, Depth, AT, CxtI, DT); } + void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + unsigned Depth = 0, Instruction *CxtI = nullptr) const { + return llvm::ComputeSignBit(V, KnownZero, KnownOne, DL, Depth, AT, CxtI, + DT); + } private: /// SimplifyAssociativeOrCommutative - This performs a few simplifications for diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index c2d2eec3fc3..9ea4bc57cf2 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -890,7 +890,6 @@ static bool checkRippleForAdd(const APInt &Op0KnownZero, /// (sext (add LHS, RHS)) === (add (sext LHS), (sext RHS)) /// This basically requires proving that the add in the original type would not /// overflow to change the sign bit or have a carry out. -/// TODO: Handle this for Vectors. bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS, Instruction *CxtI) { // There are different heuristics we can use for this. Here are some simple @@ -914,28 +913,27 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS, ComputeNumSignBits(RHS, 0, CxtI) > 1) return true; - if (IntegerType *IT = dyn_cast(LHS->getType())) { - int BitWidth = IT->getBitWidth(); - APInt LHSKnownZero(BitWidth, 0); - APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI); + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + APInt LHSKnownZero(BitWidth, 0); + APInt LHSKnownOne(BitWidth, 0); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI); - APInt RHSKnownZero(BitWidth, 0); - APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI); + APInt RHSKnownZero(BitWidth, 0); + APInt RHSKnownOne(BitWidth, 0); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI); - // Addition of two 2's compliment numbers having opposite signs will never - // overflow. - if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) || - (LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1])) - return true; + // Addition of two 2's compliment numbers having opposite signs will never + // overflow. + if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) || + (LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1])) + return true; + + // Check if carry bit of addition will not cause overflow. + if (checkRippleForAdd(LHSKnownZero, RHSKnownZero)) + return true; + if (checkRippleForAdd(RHSKnownZero, LHSKnownZero)) + return true; - // Check if carry bit of addition will not cause overflow. - if (checkRippleForAdd(LHSKnownZero, RHSKnownZero)) - return true; - if (checkRippleForAdd(RHSKnownZero, LHSKnownZero)) - return true; - } return false; } @@ -947,8 +945,8 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS, // If the sign bit of LHS and that of RHS are both zero, no unsigned wrap. bool LHSKnownNonNegative, LHSKnownNegative; bool RHSKnownNonNegative, RHSKnownNegative; - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT); - ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, /*Depth=*/0, CxtI); + ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, /*Depth=*/0, CxtI); if (LHSKnownNonNegative && RHSKnownNonNegative) return true; @@ -968,24 +966,22 @@ bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS, ComputeNumSignBits(RHS, 0, CxtI) > 1) return true; - if (IntegerType *IT = dyn_cast(LHS->getType())) { - unsigned BitWidth = IT->getBitWidth(); - APInt LHSKnownZero(BitWidth, 0); - APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI); + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + APInt LHSKnownZero(BitWidth, 0); + APInt LHSKnownOne(BitWidth, 0); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI); - APInt RHSKnownZero(BitWidth, 0); - APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI); + APInt RHSKnownZero(BitWidth, 0); + APInt RHSKnownOne(BitWidth, 0); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI); - // Subtraction of two 2's compliment numbers having identical signs will - // never overflow. - if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) || - (LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1])) - return true; + // Subtraction of two 2's compliment numbers having identical signs will + // never overflow. + if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) || + (LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1])) + return true; - // TODO: implement logic similar to checkRippleForAdd - } + // TODO: implement logic similar to checkRippleForAdd return false; } @@ -996,59 +992,14 @@ bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS, // If the LHS is negative and the RHS is non-negative, no unsigned wrap. bool LHSKnownNonNegative, LHSKnownNegative; bool RHSKnownNonNegative, RHSKnownNegative; - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT); - ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, /*Depth=*/0, CxtI); + ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, /*Depth=*/0, CxtI); if (LHSKnownNegative && RHSKnownNonNegative) return true; return false; } -/// \brief Return true if we can prove that: -/// (mul LHS, RHS) === (mul nsw LHS, RHS) -bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS, - Instruction *CxtI) { - if (IntegerType *IT = dyn_cast(LHS->getType())) { - - // Multiplying n * m significant bits yields a result of n + m significant - // bits. If the total number of significant bits does not exceed the - // result bit width (minus 1), there is no overflow. - // This means if we have enough leading sign bits in the operands - // we can guarantee that the result does not overflow. - // Ref: "Hacker's Delight" by Henry Warren - unsigned BitWidth = IT->getBitWidth(); - - // Note that underestimating the number of sign bits gives a more - // conservative answer. - unsigned SignBits = ComputeNumSignBits(LHS, 0, CxtI) + - ComputeNumSignBits(RHS, 0, CxtI); - - // First handle the easy case: if we have enough sign bits there's - // definitely no overflow. - if (SignBits > BitWidth + 1) - return true; - - // There are two ambiguous cases where there can be no overflow: - // SignBits == BitWidth + 1 and - // SignBits == BitWidth - // The second case is difficult to check, therefore we only handle the - // first case. - if (SignBits == BitWidth + 1) { - // It overflows only when both arguments are negative and the true - // product is exactly the minimum negative number. - // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 - // For simplicity we just check if at least one side is not negative. - bool LHSNonNegative, LHSNegative; - bool RHSNonNegative, RHSNegative; - ComputeSignBit(LHS, LHSNonNegative, LHSNegative, DL, 0, AT, CxtI, DT); - ComputeSignBit(RHS, RHSNonNegative, RHSNegative, DL, 0, AT, CxtI, DT); - if (LHSNonNegative || RHSNonNegative) - return true; - } - } - return false; -} - // Checks if any operand is negative and we can convert add to sub. // This function checks for following negative patterns // ADD(XOR(OR(Z, NOT(C)), C)), 1) == NEG(AND(Z, C)) diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 6a4e721db46..a5dd89cc4af 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -886,8 +886,7 @@ Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, // This simplification is only valid if the upper range is not negative. bool IsNegative, IsNotNegative; - ComputeSignBit(RangeEnd, IsNotNegative, IsNegative, DL, 0, AT, - Cmp1, DT); + ComputeSignBit(RangeEnd, IsNotNegative, IsNegative, /*Depth=*/0, Cmp1); if (!IsNotNegative) return nullptr; diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 3956869e2d3..5beaf00c16e 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -123,6 +123,48 @@ static Constant *getLogBase2Vector(ConstantDataVector *CV) { return ConstantVector::get(Elts); } +/// \brief Return true if we can prove that: +/// (mul LHS, RHS) === (mul nsw LHS, RHS) +bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS, + Instruction *CxtI) { + // Multiplying n * m significant bits yields a result of n + m significant + // bits. If the total number of significant bits does not exceed the + // result bit width (minus 1), there is no overflow. + // This means if we have enough leading sign bits in the operands + // we can guarantee that the result does not overflow. + // Ref: "Hacker's Delight" by Henry Warren + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + + // Note that underestimating the number of sign bits gives a more + // conservative answer. + unsigned SignBits = ComputeNumSignBits(LHS, 0, CxtI) + + ComputeNumSignBits(RHS, 0, CxtI); + + // First handle the easy case: if we have enough sign bits there's + // definitely no overflow. + if (SignBits > BitWidth + 1) + return true; + + // There are two ambiguous cases where there can be no overflow: + // SignBits == BitWidth + 1 and + // SignBits == BitWidth + // The second case is difficult to check, therefore we only handle the + // first case. + if (SignBits == BitWidth + 1) { + // It overflows only when both arguments are negative and the true + // product is exactly the minimum negative number. + // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 + // For simplicity we just check if at least one side is not negative. + bool LHSNonNegative, LHSNegative; + bool RHSNonNegative, RHSNegative; + ComputeSignBit(LHS, LHSNonNegative, LHSNegative, /*Depth=*/0, CxtI); + ComputeSignBit(RHS, RHSNonNegative, RHSNegative, /*Depth=*/0, CxtI); + if (LHSNonNegative || RHSNonNegative) + return true; + } + return false; +} + Instruction *InstCombiner::visitMul(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -333,6 +375,11 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + if (!I.hasNoSignedWrap() && WillNotOverflowSignedMul(Op0, Op1, &I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + return Changed ? &I : nullptr; } diff --git a/test/Transforms/InstCombine/mul.ll b/test/Transforms/InstCombine/mul.ll index a52c31ab4af..d19338a3c8f 100644 --- a/test/Transforms/InstCombine/mul.ll +++ b/test/Transforms/InstCombine/mul.ll @@ -255,3 +255,15 @@ define i32 @test28(i32 %A) { ; CHECK-NEXT: %[[shl2:.*]] = shl i32 %[[shl1]], %A ; CHECK-NEXT: ret i32 %[[shl2]] } + +define i64 @test29(i31 %A, i31 %B) { +; CHECK-LABEL: @test29( + %C = zext i31 %A to i64 + %D = zext i31 %B to i64 + %E = mul i64 %C, %D + ret i64 %E +; CHECK: %[[zext1:.*]] = zext i31 %A to i64 +; CHECK-NEXT: %[[zext2:.*]] = zext i31 %B to i64 +; CHECK-NEXT: %[[mul:.*]] = mul nsw i64 %[[zext1]], %[[zext2]] +; CHECK-NEXT: ret i64 %[[mul]] +}