InstCombe: Infer nsw for multiplies

We already utilize this logic for reducing overflow intrinsics, it makes
sense to reuse it for normal multiplies as well.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@224847 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
David Majnemer 2014-12-26 09:10:14 +00:00
parent 654a66dbd3
commit 998ae69abe
5 changed files with 100 additions and 86 deletions

View File

@ -382,6 +382,11 @@ public:
Instruction *CxtI = nullptr) const {
return llvm::ComputeNumSignBits(Op, DL, Depth, AT, CxtI, DT);
}
void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne,
unsigned Depth = 0, Instruction *CxtI = nullptr) const {
return llvm::ComputeSignBit(V, KnownZero, KnownOne, DL, Depth, AT, CxtI,
DT);
}
private:
/// SimplifyAssociativeOrCommutative - This performs a few simplifications for

View File

@ -890,7 +890,6 @@ static bool checkRippleForAdd(const APInt &Op0KnownZero,
/// (sext (add LHS, RHS)) === (add (sext LHS), (sext RHS))
/// This basically requires proving that the add in the original type would not
/// overflow to change the sign bit or have a carry out.
/// TODO: Handle this for Vectors.
bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS,
Instruction *CxtI) {
// There are different heuristics we can use for this. Here are some simple
@ -914,28 +913,27 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS,
ComputeNumSignBits(RHS, 0, CxtI) > 1)
return true;
if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) {
int BitWidth = IT->getBitWidth();
APInt LHSKnownZero(BitWidth, 0);
APInt LHSKnownOne(BitWidth, 0);
computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI);
unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
APInt LHSKnownZero(BitWidth, 0);
APInt LHSKnownOne(BitWidth, 0);
computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI);
APInt RHSKnownZero(BitWidth, 0);
APInt RHSKnownOne(BitWidth, 0);
computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI);
APInt RHSKnownZero(BitWidth, 0);
APInt RHSKnownOne(BitWidth, 0);
computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI);
// Addition of two 2's compliment numbers having opposite signs will never
// overflow.
if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) ||
(LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1]))
return true;
// Addition of two 2's compliment numbers having opposite signs will never
// overflow.
if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) ||
(LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1]))
return true;
// Check if carry bit of addition will not cause overflow.
if (checkRippleForAdd(LHSKnownZero, RHSKnownZero))
return true;
if (checkRippleForAdd(RHSKnownZero, LHSKnownZero))
return true;
// Check if carry bit of addition will not cause overflow.
if (checkRippleForAdd(LHSKnownZero, RHSKnownZero))
return true;
if (checkRippleForAdd(RHSKnownZero, LHSKnownZero))
return true;
}
return false;
}
@ -947,8 +945,8 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS,
// If the sign bit of LHS and that of RHS are both zero, no unsigned wrap.
bool LHSKnownNonNegative, LHSKnownNegative;
bool RHSKnownNonNegative, RHSKnownNegative;
ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT);
ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT);
ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, /*Depth=*/0, CxtI);
ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, /*Depth=*/0, CxtI);
if (LHSKnownNonNegative && RHSKnownNonNegative)
return true;
@ -968,24 +966,22 @@ bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS,
ComputeNumSignBits(RHS, 0, CxtI) > 1)
return true;
if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) {
unsigned BitWidth = IT->getBitWidth();
APInt LHSKnownZero(BitWidth, 0);
APInt LHSKnownOne(BitWidth, 0);
computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI);
unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
APInt LHSKnownZero(BitWidth, 0);
APInt LHSKnownOne(BitWidth, 0);
computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI);
APInt RHSKnownZero(BitWidth, 0);
APInt RHSKnownOne(BitWidth, 0);
computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI);
APInt RHSKnownZero(BitWidth, 0);
APInt RHSKnownOne(BitWidth, 0);
computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI);
// Subtraction of two 2's compliment numbers having identical signs will
// never overflow.
if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) ||
(LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1]))
return true;
// Subtraction of two 2's compliment numbers having identical signs will
// never overflow.
if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) ||
(LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1]))
return true;
// TODO: implement logic similar to checkRippleForAdd
}
// TODO: implement logic similar to checkRippleForAdd
return false;
}
@ -996,59 +992,14 @@ bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS,
// If the LHS is negative and the RHS is non-negative, no unsigned wrap.
bool LHSKnownNonNegative, LHSKnownNegative;
bool RHSKnownNonNegative, RHSKnownNegative;
ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT);
ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT);
ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, /*Depth=*/0, CxtI);
ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, /*Depth=*/0, CxtI);
if (LHSKnownNegative && RHSKnownNonNegative)
return true;
return false;
}
/// \brief Return true if we can prove that:
/// (mul LHS, RHS) === (mul nsw LHS, RHS)
bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS,
Instruction *CxtI) {
if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) {
// Multiplying n * m significant bits yields a result of n + m significant
// bits. If the total number of significant bits does not exceed the
// result bit width (minus 1), there is no overflow.
// This means if we have enough leading sign bits in the operands
// we can guarantee that the result does not overflow.
// Ref: "Hacker's Delight" by Henry Warren
unsigned BitWidth = IT->getBitWidth();
// Note that underestimating the number of sign bits gives a more
// conservative answer.
unsigned SignBits = ComputeNumSignBits(LHS, 0, CxtI) +
ComputeNumSignBits(RHS, 0, CxtI);
// First handle the easy case: if we have enough sign bits there's
// definitely no overflow.
if (SignBits > BitWidth + 1)
return true;
// There are two ambiguous cases where there can be no overflow:
// SignBits == BitWidth + 1 and
// SignBits == BitWidth
// The second case is difficult to check, therefore we only handle the
// first case.
if (SignBits == BitWidth + 1) {
// It overflows only when both arguments are negative and the true
// product is exactly the minimum negative number.
// E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
// For simplicity we just check if at least one side is not negative.
bool LHSNonNegative, LHSNegative;
bool RHSNonNegative, RHSNegative;
ComputeSignBit(LHS, LHSNonNegative, LHSNegative, DL, 0, AT, CxtI, DT);
ComputeSignBit(RHS, RHSNonNegative, RHSNegative, DL, 0, AT, CxtI, DT);
if (LHSNonNegative || RHSNonNegative)
return true;
}
}
return false;
}
// Checks if any operand is negative and we can convert add to sub.
// This function checks for following negative patterns
// ADD(XOR(OR(Z, NOT(C)), C)), 1) == NEG(AND(Z, C))

View File

@ -886,8 +886,7 @@ Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1,
// This simplification is only valid if the upper range is not negative.
bool IsNegative, IsNotNegative;
ComputeSignBit(RangeEnd, IsNotNegative, IsNegative, DL, 0, AT,
Cmp1, DT);
ComputeSignBit(RangeEnd, IsNotNegative, IsNegative, /*Depth=*/0, Cmp1);
if (!IsNotNegative)
return nullptr;

View File

@ -123,6 +123,48 @@ static Constant *getLogBase2Vector(ConstantDataVector *CV) {
return ConstantVector::get(Elts);
}
/// \brief Return true if we can prove that:
/// (mul LHS, RHS) === (mul nsw LHS, RHS)
bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS,
Instruction *CxtI) {
// Multiplying n * m significant bits yields a result of n + m significant
// bits. If the total number of significant bits does not exceed the
// result bit width (minus 1), there is no overflow.
// This means if we have enough leading sign bits in the operands
// we can guarantee that the result does not overflow.
// Ref: "Hacker's Delight" by Henry Warren
unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
// Note that underestimating the number of sign bits gives a more
// conservative answer.
unsigned SignBits = ComputeNumSignBits(LHS, 0, CxtI) +
ComputeNumSignBits(RHS, 0, CxtI);
// First handle the easy case: if we have enough sign bits there's
// definitely no overflow.
if (SignBits > BitWidth + 1)
return true;
// There are two ambiguous cases where there can be no overflow:
// SignBits == BitWidth + 1 and
// SignBits == BitWidth
// The second case is difficult to check, therefore we only handle the
// first case.
if (SignBits == BitWidth + 1) {
// It overflows only when both arguments are negative and the true
// product is exactly the minimum negative number.
// E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
// For simplicity we just check if at least one side is not negative.
bool LHSNonNegative, LHSNegative;
bool RHSNonNegative, RHSNegative;
ComputeSignBit(LHS, LHSNonNegative, LHSNegative, /*Depth=*/0, CxtI);
ComputeSignBit(RHS, RHSNonNegative, RHSNegative, /*Depth=*/0, CxtI);
if (LHSNonNegative || RHSNonNegative)
return true;
}
return false;
}
Instruction *InstCombiner::visitMul(BinaryOperator &I) {
bool Changed = SimplifyAssociativeOrCommutative(I);
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@ -333,6 +375,11 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
}
}
if (!I.hasNoSignedWrap() && WillNotOverflowSignedMul(Op0, Op1, &I)) {
Changed = true;
I.setHasNoSignedWrap(true);
}
return Changed ? &I : nullptr;
}

View File

@ -255,3 +255,15 @@ define i32 @test28(i32 %A) {
; CHECK-NEXT: %[[shl2:.*]] = shl i32 %[[shl1]], %A
; CHECK-NEXT: ret i32 %[[shl2]]
}
define i64 @test29(i31 %A, i31 %B) {
; CHECK-LABEL: @test29(
%C = zext i31 %A to i64
%D = zext i31 %B to i64
%E = mul i64 %C, %D
ret i64 %E
; CHECK: %[[zext1:.*]] = zext i31 %A to i64
; CHECK-NEXT: %[[zext2:.*]] = zext i31 %B to i64
; CHECK-NEXT: %[[mul:.*]] = mul nsw i64 %[[zext1]], %[[zext2]]
; CHECK-NEXT: ret i64 %[[mul]]
}