Enhance the "compare with shift" and "compare with div"

optimizations to be much more aggressive in the face of
exact/nsw/nuw div and shifts.  For example, these (which
are the same except the first is 'exact' sdiv:

define i1 @sdiv_icmp4_exact(i64 %X) nounwind {
  %A = sdiv exact i64 %X, -5   ; X/-5 == 0 --> x == 0
  %B = icmp eq i64 %A, 0
  ret i1 %B
}

define i1 @sdiv_icmp4(i64 %X) nounwind {
  %A = sdiv i64 %X, -5   ; X/-5 == 0 --> x == 0
  %B = icmp eq i64 %A, 0
  ret i1 %B
}

compile down to:

define i1 @sdiv_icmp4_exact(i64 %X) nounwind {
  %1 = icmp eq i64 %X, 0
  ret i1 %1
}

define i1 @sdiv_icmp4(i64 %X) nounwind {
  %X.off = add i64 %X, 4
  %1 = icmp ult i64 %X.off, 9
  ret i1 %1
}

This happens when you do something like:
  (ptr1-ptr2) == 42

where the pointers are pointers to non-unit types.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@125266 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Chris Lattner
2011-02-10 05:23:05 +00:00
parent 44cc997d42
commit b20c0b5092
3 changed files with 151 additions and 74 deletions

View File

@@ -22,13 +22,17 @@
using namespace llvm;
using namespace PatternMatch;
static ConstantInt *getOne(Constant *C) {
return ConstantInt::get(cast<IntegerType>(C->getType()), 1);
}
/// AddOne - Add one to a ConstantInt
static Constant *AddOne(Constant *C) {
return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1));
}
/// SubOne - Subtract one from a ConstantInt
static Constant *SubOne(ConstantInt *C) {
return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
static Constant *SubOne(Constant *C) {
return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
}
static ConstantInt *ExtractElement(Constant *V, Constant *Idx) {
@@ -782,7 +786,7 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
// results than (x /s C1) <u C2 or (x /u C1) <s C2 or even
// (x /u C1) <u C2. Simply casting the operands and result won't
// work. :( The if statement below tests that condition and bails
// if it finds it.
// if it finds it.
bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv;
if (!ICI.isEquality() && DivIsSigned != ICI.isSigned())
return 0;
@@ -809,6 +813,10 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
// Get the ICmp opcode
ICmpInst::Predicate Pred = ICI.getPredicate();
/// If the division is known to be exact, then there is no remainder from the
/// divide, so the covered range size is unit, otherwise it is the divisor.
ConstantInt *RangeSize = DivI->isExact() ? getOne(Prod) : DivRHS;
// Figure out the interval that is being checked. For example, a comparison
// like "X /u 5 == 0" is really checking that X is in the interval [0, 5).
// Compute this interval based on the constants involved and the signedness of
@@ -818,38 +826,43 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
// -1 if overflowed off the bottom end, or +1 if overflowed off the top end.
int LoOverflow = 0, HiOverflow = 0;
Constant *LoBound = 0, *HiBound = 0;
if (!DivIsSigned) { // udiv
// e.g. X/5 op 3 --> [15, 20)
LoBound = Prod;
HiOverflow = LoOverflow = ProdOV;
if (!HiOverflow)
HiOverflow = AddWithOverflow(HiBound, LoBound, DivRHS, false);
if (!HiOverflow) {
// If this is not an exact divide, then many values in the range collapse
// to the same result value.
HiOverflow = AddWithOverflow(HiBound, LoBound, RangeSize, false);
}
} else if (DivRHS->getValue().isStrictlyPositive()) { // Divisor is > 0.
if (CmpRHSV == 0) { // (X / pos) op 0
// Can't overflow. e.g. X/2 op 0 --> [-1, 2)
LoBound = cast<ConstantInt>(ConstantExpr::getNeg(SubOne(DivRHS)));
HiBound = DivRHS;
LoBound = ConstantExpr::getNeg(SubOne(RangeSize));
HiBound = RangeSize;
} else if (CmpRHSV.isStrictlyPositive()) { // (X / pos) op pos
LoBound = Prod; // e.g. X/5 op 3 --> [15, 20)
HiOverflow = LoOverflow = ProdOV;
if (!HiOverflow)
HiOverflow = AddWithOverflow(HiBound, Prod, DivRHS, true);
HiOverflow = AddWithOverflow(HiBound, Prod, RangeSize, true);
} else { // (X / pos) op neg
// e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14)
HiBound = AddOne(Prod);
LoOverflow = HiOverflow = ProdOV ? -1 : 0;
if (!LoOverflow) {
ConstantInt* DivNeg =
cast<ConstantInt>(ConstantExpr::getNeg(DivRHS));
ConstantInt *DivNeg =cast<ConstantInt>(ConstantExpr::getNeg(RangeSize));
LoOverflow = AddWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0;
}
}
}
} else if (DivRHS->getValue().isNegative()) { // Divisor is < 0.
if (DivI->isExact())
RangeSize = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize));
if (CmpRHSV == 0) { // (X / neg) op 0
// e.g. X/-5 op 0 --> [-4, 5)
LoBound = AddOne(DivRHS);
HiBound = cast<ConstantInt>(ConstantExpr::getNeg(DivRHS));
LoBound = AddOne(RangeSize);
HiBound = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize));
if (HiBound == DivRHS) { // -INTMIN = INTMIN
HiOverflow = 1; // [INTMIN+1, overflow)
HiBound = 0; // e.g. X/INTMIN = 0 --> X > INTMIN
@@ -859,12 +872,12 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
HiBound = AddOne(Prod);
HiOverflow = LoOverflow = ProdOV ? -1 : 0;
if (!LoOverflow)
LoOverflow = AddWithOverflow(LoBound, HiBound, DivRHS, true) ? -1 : 0;
LoOverflow = AddWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0;
} else { // (X / neg) op neg
LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20)
LoOverflow = HiOverflow = ProdOV;
if (!HiOverflow)
HiOverflow = SubWithOverflow(HiBound, Prod, DivRHS, true);
HiOverflow = SubWithOverflow(HiBound, Prod, RangeSize, true);
}
// Dividing by a negative swaps the condition. LT <-> GT
@@ -883,9 +896,8 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
if (LoOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT :
ICmpInst::ICMP_ULT, X, HiBound);
return ReplaceInstUsesWith(ICI,
InsertRangeTest(X, LoBound, HiBound, DivIsSigned,
true));
return ReplaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound,
DivIsSigned, true));
case ICmpInst::ICMP_NE:
if (LoOverflow && HiOverflow)
return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext()));
@@ -908,12 +920,11 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
case ICmpInst::ICMP_SGT:
if (HiOverflow == +1) // High bound greater than input range.
return ReplaceInstUsesWith(ICI, ConstantInt::getFalse(ICI.getContext()));
else if (HiOverflow == -1) // High bound less than input range.
if (HiOverflow == -1) // High bound less than input range.
return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext()));
if (Pred == ICmpInst::ICMP_UGT)
return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound);
else
return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound);
return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound);
}
}
@@ -1182,6 +1193,12 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
return ReplaceInstUsesWith(ICI, Cst);
}
// If the shift is NUW, then it is just shifting out zeros, no need for an
// AND.
if (cast<BinaryOperator>(LHSI)->hasNoUnsignedWrap())
return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
ConstantExpr::getLShr(RHS, ShAmt));
if (LHSI->hasOneUse()) {
// Otherwise strength reduce the shift into an and.
uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits);
@@ -1192,8 +1209,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
Value *And =
Builder->CreateAnd(LHSI->getOperand(0),Mask, LHSI->getName()+".mask");
return new ICmpInst(ICI.getPredicate(), And,
ConstantInt::get(ICI.getContext(),
RHSV.lshr(ShAmtVal)));
ConstantExpr::getLShr(RHS, ShAmt));
}
}
@@ -1222,10 +1238,9 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
// undefined shifts. When the shift is visited it will be
// simplified.
uint32_t TypeBits = RHSV.getBitWidth();
if (ShAmt->uge(TypeBits))
break;
uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits);
if (ShAmtVal >= TypeBits)
break;
// If we are comparing against bits always shifted out, the
// comparison cannot succeed.
@@ -1245,13 +1260,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
// Otherwise, check to see if the bits shifted out are known to be zero.
// If so, we can compare against the unshifted value:
// (X & 4) >> 1 == 2 --> (X & 4) == 4.
if (LHSI->hasOneUse() &&
MaskedValueIsZero(LHSI->getOperand(0),
APInt::getLowBitsSet(Comp.getBitWidth(), ShAmtVal))) {
if (LHSI->hasOneUse() && cast<BinaryOperator>(LHSI)->isExact())
return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
ConstantExpr::getShl(RHS, ShAmt));
}
if (LHSI->hasOneUse()) {
// Otherwise strength reduce the shift into an and.
APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal));
@@ -1911,14 +1923,12 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
// If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
// then turn "((8 >>u x)&1) == 0" into "x != 3".
ConstantInt *CI = 0;
const APInt *CI;
if (Op0KnownZeroInverted == 1 &&
match(LHS, m_LShr(m_ConstantInt(CI), m_Value(X))) &&
CI->getValue().isPowerOf2()) {
unsigned CmpVal = CI->getValue().countTrailingZeros();
match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
return new ICmpInst(ICmpInst::ICMP_NE, X,
ConstantInt::get(X->getType(), CmpVal));
}
ConstantInt::get(X->getType(),
CI->countTrailingZeros()));
}
break;
@@ -1950,14 +1960,12 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
// If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
// then turn "((8 >>u x)&1) != 0" into "x == 3".
ConstantInt *CI = 0;
const APInt *CI;
if (Op0KnownZeroInverted == 1 &&
match(LHS, m_LShr(m_ConstantInt(CI), m_Value(X))) &&
CI->getValue().isPowerOf2()) {
unsigned CmpVal = CI->getValue().countTrailingZeros();
match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
return new ICmpInst(ICmpInst::ICMP_EQ, X,
ConstantInt::get(X->getType(), CmpVal));
}
ConstantInt::get(X->getType(),
CI->countTrailingZeros()));
}
break;