Always compute all the bits in ComputeMaskedBits.

This allows us to keep passing reduced masks to SimplifyDemandedBits, but
know about all the bits if SimplifyDemandedBits fails. This allows instcombine
to simplify cases like the one in the included testcase.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@154011 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Rafael Espindola
2012-04-04 12:51:34 +00:00
parent 00b73a5e44
commit 26c8dcc692
35 changed files with 265 additions and 417 deletions
+2 -2
View File
@@ -291,9 +291,9 @@ public:
return 0; // Don't do anything with FI
}
void ComputeMaskedBits(Value *V, const APInt &Mask, APInt &KnownZero,
void ComputeMaskedBits(Value *V, APInt &KnownZero,
APInt &KnownOne, unsigned Depth = 0) const {
return llvm::ComputeMaskedBits(V, Mask, KnownZero, KnownOne, TD, Depth);
return llvm::ComputeMaskedBits(V, KnownZero, KnownOne, TD, Depth);
}
bool MaskedValueIsZero(Value *V, const APInt &Mask,
@@ -141,10 +141,9 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
// a sub and fuse this add with it.
if (LHS->hasOneUse() && (XorRHS->getValue()+1).isPowerOf2()) {
IntegerType *IT = cast<IntegerType>(I.getType());
APInt Mask = APInt::getAllOnesValue(IT->getBitWidth());
APInt LHSKnownOne(IT->getBitWidth(), 0);
APInt LHSKnownZero(IT->getBitWidth(), 0);
ComputeMaskedBits(XorLHS, Mask, LHSKnownZero, LHSKnownOne);
ComputeMaskedBits(XorLHS, LHSKnownZero, LHSKnownOne);
if ((XorRHS->getValue() | LHSKnownZero).isAllOnesValue())
return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI),
XorLHS);
@@ -202,14 +201,13 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
// A+B --> A|B iff A and B have no bits set in common.
if (IntegerType *IT = dyn_cast<IntegerType>(I.getType())) {
APInt Mask = APInt::getAllOnesValue(IT->getBitWidth());
APInt LHSKnownOne(IT->getBitWidth(), 0);
APInt LHSKnownZero(IT->getBitWidth(), 0);
ComputeMaskedBits(LHS, Mask, LHSKnownZero, LHSKnownOne);
ComputeMaskedBits(LHS, LHSKnownZero, LHSKnownOne);
if (LHSKnownZero != 0) {
APInt RHSKnownOne(IT->getBitWidth(), 0);
APInt RHSKnownZero(IT->getBitWidth(), 0);
ComputeMaskedBits(RHS, Mask, RHSKnownZero, RHSKnownOne);
ComputeMaskedBits(RHS, RHSKnownZero, RHSKnownOne);
// No bits in common -> bitwise or.
if ((LHSKnownZero|RHSKnownZero).isAllOnesValue())
@@ -361,8 +361,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
uint32_t BitWidth = IT->getBitWidth();
APInt KnownZero(BitWidth, 0);
APInt KnownOne(BitWidth, 0);
ComputeMaskedBits(II->getArgOperand(0), APInt::getAllOnesValue(BitWidth),
KnownZero, KnownOne);
ComputeMaskedBits(II->getArgOperand(0), KnownZero, KnownOne);
unsigned TrailingZeros = KnownOne.countTrailingZeros();
APInt Mask(APInt::getLowBitsSet(BitWidth, TrailingZeros));
if ((Mask & KnownZero) == Mask)
@@ -380,8 +379,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
uint32_t BitWidth = IT->getBitWidth();
APInt KnownZero(BitWidth, 0);
APInt KnownOne(BitWidth, 0);
ComputeMaskedBits(II->getArgOperand(0), APInt::getAllOnesValue(BitWidth),
KnownZero, KnownOne);
ComputeMaskedBits(II->getArgOperand(0), KnownZero, KnownOne);
unsigned LeadingZeros = KnownOne.countLeadingZeros();
APInt Mask(APInt::getHighBitsSet(BitWidth, LeadingZeros));
if ((Mask & KnownZero) == Mask)
@@ -394,17 +392,16 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1);
IntegerType *IT = cast<IntegerType>(II->getArgOperand(0)->getType());
uint32_t BitWidth = IT->getBitWidth();
APInt Mask = APInt::getSignBit(BitWidth);
APInt LHSKnownZero(BitWidth, 0);
APInt LHSKnownOne(BitWidth, 0);
ComputeMaskedBits(LHS, Mask, LHSKnownZero, LHSKnownOne);
ComputeMaskedBits(LHS, LHSKnownZero, LHSKnownOne);
bool LHSKnownNegative = LHSKnownOne[BitWidth - 1];
bool LHSKnownPositive = LHSKnownZero[BitWidth - 1];
if (LHSKnownNegative || LHSKnownPositive) {
APInt RHSKnownZero(BitWidth, 0);
APInt RHSKnownOne(BitWidth, 0);
ComputeMaskedBits(RHS, Mask, RHSKnownZero, RHSKnownOne);
ComputeMaskedBits(RHS, RHSKnownZero, RHSKnownOne);
bool RHSKnownNegative = RHSKnownOne[BitWidth - 1];
bool RHSKnownPositive = RHSKnownZero[BitWidth - 1];
if (LHSKnownNegative && RHSKnownNegative) {
@@ -488,14 +485,13 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::umul_with_overflow: {
Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1);
unsigned BitWidth = cast<IntegerType>(LHS->getType())->getBitWidth();
APInt Mask = APInt::getAllOnesValue(BitWidth);
APInt LHSKnownZero(BitWidth, 0);
APInt LHSKnownOne(BitWidth, 0);
ComputeMaskedBits(LHS, Mask, LHSKnownZero, LHSKnownOne);
ComputeMaskedBits(LHS, LHSKnownZero, LHSKnownOne);
APInt RHSKnownZero(BitWidth, 0);
APInt RHSKnownOne(BitWidth, 0);
ComputeMaskedBits(RHS, Mask, RHSKnownZero, RHSKnownOne);
ComputeMaskedBits(RHS, RHSKnownZero, RHSKnownOne);
// Get the largest possible values for each operand.
APInt LHSMax = ~LHSKnownZero;
@@ -541,8 +541,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI,
// If Op1C some other power of two, convert:
uint32_t BitWidth = Op1C->getType()->getBitWidth();
APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
APInt TypeMask(APInt::getAllOnesValue(BitWidth));
ComputeMaskedBits(ICI->getOperand(0), TypeMask, KnownZero, KnownOne);
ComputeMaskedBits(ICI->getOperand(0), KnownZero, KnownOne);
APInt KnownZeroMask(~KnownZero);
if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1?
@@ -590,9 +589,8 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI,
APInt KnownZeroLHS(BitWidth, 0), KnownOneLHS(BitWidth, 0);
APInt KnownZeroRHS(BitWidth, 0), KnownOneRHS(BitWidth, 0);
APInt TypeMask(APInt::getAllOnesValue(BitWidth));
ComputeMaskedBits(LHS, TypeMask, KnownZeroLHS, KnownOneLHS);
ComputeMaskedBits(RHS, TypeMask, KnownZeroRHS, KnownOneRHS);
ComputeMaskedBits(LHS, KnownZeroLHS, KnownOneLHS);
ComputeMaskedBits(RHS, KnownZeroRHS, KnownOneRHS);
if (KnownZeroLHS == KnownZeroRHS && KnownOneLHS == KnownOneRHS) {
APInt KnownBits = KnownZeroLHS | KnownOneLHS;
@@ -911,8 +909,7 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) {
ICI->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){
unsigned BitWidth = Op1C->getType()->getBitWidth();
APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
APInt TypeMask(APInt::getAllOnesValue(BitWidth));
ComputeMaskedBits(Op0, TypeMask, KnownZero, KnownOne);
ComputeMaskedBits(Op0, KnownZero, KnownOne);
APInt KnownZeroMask(~KnownZero);
if (KnownZeroMask.isPowerOf2()) {
@@ -1028,9 +1028,8 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
// of the high bits truncated out of x are known.
unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(),
SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits();
APInt Mask(APInt::getHighBitsSet(SrcBits, SrcBits-DstBits));
APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0);
ComputeMaskedBits(LHSI->getOperand(0), Mask, KnownZero, KnownOne);
ComputeMaskedBits(LHSI->getOperand(0), KnownZero, KnownOne);
// If all the high bits are known, we can do this xform.
if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) {
@@ -142,7 +142,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
Instruction *I = dyn_cast<Instruction>(V);
if (!I) {
ComputeMaskedBits(V, DemandedMask, KnownZero, KnownOne, Depth);
ComputeMaskedBits(V, KnownZero, KnownOne, Depth);
return 0; // Only analyze instructions.
}
@@ -156,10 +156,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// this instruction has a simpler value in that context.
if (I->getOpcode() == Instruction::And) {
// If either the LHS or the RHS are Zero, the result is zero.
ComputeMaskedBits(I->getOperand(1), DemandedMask,
RHSKnownZero, RHSKnownOne, Depth+1);
ComputeMaskedBits(I->getOperand(0), DemandedMask & ~RHSKnownZero,
LHSKnownZero, LHSKnownOne, Depth+1);
ComputeMaskedBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1);
ComputeMaskedBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1);
// If all of the demanded bits are known 1 on one side, return the other.
// These bits cannot contribute to the result of the 'and' in this
@@ -180,10 +178,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// only bits from X or Y are demanded.
// If either the LHS or the RHS are One, the result is One.
ComputeMaskedBits(I->getOperand(1), DemandedMask,
RHSKnownZero, RHSKnownOne, Depth+1);
ComputeMaskedBits(I->getOperand(0), DemandedMask & ~RHSKnownOne,
LHSKnownZero, LHSKnownOne, Depth+1);
ComputeMaskedBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1);
ComputeMaskedBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1);
// If all of the demanded bits are known zero on one side, return the
// other. These bits cannot contribute to the result of the 'or' in this
@@ -206,7 +202,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
// Compute the KnownZero/KnownOne bits to simplify things downstream.
ComputeMaskedBits(I, DemandedMask, KnownZero, KnownOne, Depth);
ComputeMaskedBits(I, KnownZero, KnownOne, Depth);
return 0;
}
@@ -219,7 +215,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
switch (I->getOpcode()) {
default:
ComputeMaskedBits(I, DemandedMask, KnownZero, KnownOne, Depth);
ComputeMaskedBits(I, KnownZero, KnownOne, Depth);
break;
case Instruction::And:
// If either the LHS or the RHS are Zero, the result is zero.
@@ -570,7 +566,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// Otherwise just hand the sub off to ComputeMaskedBits to fill in
// the known zeros and ones.
ComputeMaskedBits(V, DemandedMask, KnownZero, KnownOne, Depth);
ComputeMaskedBits(V, KnownZero, KnownOne, Depth);
// Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
// zero.
@@ -729,10 +725,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// The sign bit is the LHS's sign bit, except when the result of the
// remainder is zero.
if (DemandedMask.isNegative() && KnownZero.isNonNegative()) {
APInt Mask2 = APInt::getSignBit(BitWidth);
APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
ComputeMaskedBits(I->getOperand(0), Mask2, LHSKnownZero, LHSKnownOne,
Depth+1);
ComputeMaskedBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1);
// If it's known zero, our sign bit is also zero.
if (LHSKnownZero.isNegative())
KnownZero |= LHSKnownZero;
@@ -795,7 +789,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return 0;
}
}
ComputeMaskedBits(V, DemandedMask, KnownZero, KnownOne, Depth);
ComputeMaskedBits(V, KnownZero, KnownOne, Depth);
break;
}
+1 -2
View File
@@ -762,9 +762,8 @@ unsigned llvm::getOrEnforceKnownAlignment(Value *V, unsigned PrefAlign,
assert(V->getType()->isPointerTy() &&
"getOrEnforceKnownAlignment expects a pointer!");
unsigned BitWidth = TD ? TD->getPointerSizeInBits() : 64;
APInt Mask = APInt::getAllOnesValue(BitWidth);
APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
ComputeMaskedBits(V, Mask, KnownZero, KnownOne, TD);
ComputeMaskedBits(V, KnownZero, KnownOne, TD);
unsigned TrailZ = KnownZero.countTrailingOnes();
// Avoid trouble with rediculously large TrailZ values, such as
+1 -1
View File
@@ -2562,7 +2562,7 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI) {
Value *Cond = SI->getCondition();
unsigned Bits = cast<IntegerType>(Cond->getType())->getBitWidth();
APInt KnownZero(Bits, 0), KnownOne(Bits, 0);
ComputeMaskedBits(Cond, APInt::getAllOnesValue(Bits), KnownZero, KnownOne);
ComputeMaskedBits(Cond, KnownZero, KnownOne);
// Gather dead cases.
SmallVector<ConstantInt*, 8> DeadCases;