diff --git a/lib/Transforms/InstCombine/InstCombine.h b/lib/Transforms/InstCombine/InstCombine.h index 450e7f0a398..9c2969c7ab2 100644 --- a/lib/Transforms/InstCombine/InstCombine.h +++ b/lib/Transforms/InstCombine/InstCombine.h @@ -145,6 +145,8 @@ public: ConstantInt *RHS); Instruction *FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, ConstantInt *DivRHS); + Instruction *FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *DivI, + ConstantInt *DivRHS); Instruction *FoldICmpAddOpCst(ICmpInst &ICI, Value *X, ConstantInt *CI, ICmpInst::Predicate Pred, Value *TheAdd); Instruction *FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index c34e698e3e7..f3c35330b9c 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -928,6 +928,55 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, } } +/// FoldICmpShrCst - Handle "icmp(([al]shr X, cst1), cst2)". +Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, + ConstantInt *ShAmt) { + if (!ICI.isEquality()) return 0; + + const APInt &CmpRHSV = cast(ICI.getOperand(1))->getValue(); + + // Check that the shift amount is in range. If not, don't perform + // undefined shifts. When the shift is visited it will be + // simplified. + uint32_t TypeBits = CmpRHSV.getBitWidth(); + uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); + if (ShAmtVal >= TypeBits) + return 0; + + // If we are comparing against bits always shifted out, the + // comparison cannot succeed. + APInt Comp = CmpRHSV << ShAmtVal; + ConstantInt *ShiftedCmpRHS = ConstantInt::get(ICI.getContext(), Comp); + if (Shr->getOpcode() == Instruction::LShr) + Comp = Comp.lshr(ShAmtVal); + else + Comp = Comp.ashr(ShAmtVal); + + if (Comp != CmpRHSV) { // Comparing against a bit that we know is zero. + bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + Constant *Cst = ConstantInt::get(Type::getInt1Ty(ICI.getContext()), + IsICMP_NE); + return ReplaceInstUsesWith(ICI, Cst); + } + + // Otherwise, check to see if the bits shifted out are known to be zero. + // If so, we can compare against the unshifted value: + // (X & 4) >> 1 == 2 --> (X & 4) == 4. + if (Shr->hasOneUse() && cast(Shr)->isExact()) + return new ICmpInst(ICI.getPredicate(), Shr->getOperand(0), ShiftedCmpRHS); + + if (Shr->hasOneUse()) { + // Otherwise strength reduce the shift into an and. + APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); + Constant *Mask = ConstantInt::get(ICI.getContext(), Val); + + Value *And = Builder->CreateAnd(Shr->getOperand(0), + Mask, Shr->getName()+".mask"); + return new ICmpInst(ICI.getPredicate(), And, ShiftedCmpRHS); + } + return 0; +} + /// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". /// @@ -1153,7 +1202,6 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (match(LHSI, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 // -> and (icmp eq P, null), (icmp eq Q, null). - Value *ICIP = Builder->CreateICmp(ICI.getPredicate(), P, Constant::getNullValue(P->getType())); Value *ICIQ = Builder->CreateICmp(ICI.getPredicate(), Q, @@ -1229,53 +1277,13 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, } case Instruction::LShr: // (icmp pred (shr X, ShAmt), CI) - case Instruction::AShr: { + case Instruction::AShr: // Only handle equality comparisons of shift-by-constant. - ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1)); - if (!ShAmt || !ICI.isEquality()) break; - - // Check that the shift amount is in range. If not, don't perform - // undefined shifts. When the shift is visited it will be - // simplified. - uint32_t TypeBits = RHSV.getBitWidth(); - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); - if (ShAmtVal >= TypeBits) - break; - - // If we are comparing against bits always shifted out, the - // comparison cannot succeed. - APInt Comp = RHSV << ShAmtVal; - if (LHSI->getOpcode() == Instruction::LShr) - Comp = Comp.lshr(ShAmtVal); - else - Comp = Comp.ashr(ShAmtVal); - - if (Comp != RHSV) { // Comparing against a bit that we know is zero. - bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - Constant *Cst = ConstantInt::get(Type::getInt1Ty(ICI.getContext()), - IsICMP_NE); - return ReplaceInstUsesWith(ICI, Cst); - } - - // Otherwise, check to see if the bits shifted out are known to be zero. - // If so, we can compare against the unshifted value: - // (X & 4) >> 1 == 2 --> (X & 4) == 4. - if (LHSI->hasOneUse() && cast(LHSI)->isExact()) - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - ConstantExpr::getShl(RHS, ShAmt)); - - if (LHSI->hasOneUse()) { - // Otherwise strength reduce the shift into an and. - APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); - Constant *Mask = ConstantInt::get(ICI.getContext(), Val); - - Value *And = Builder->CreateAnd(LHSI->getOperand(0), - Mask, LHSI->getName()+".mask"); - return new ICmpInst(ICI.getPredicate(), And, - ConstantExpr::getShl(RHS, ShAmt)); - } + if (ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1))) + if (Instruction *Res = FoldICmpShrCst(ICI, cast(LHSI), + ShAmt)) + return Res; break; - } case Instruction::SDiv: case Instruction::UDiv: