diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index bf68d540e1f..5f9c54be4e7 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -2271,26 +2271,21 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return 0; } -/// isSignBitCheck - Given an exploded icmp instruction, return true if it -/// really just returns true if the most significant (sign) bit is set. -static bool isSignBitCheck(ICmpInst::Predicate pred, ConstantInt *RHS) { +/// isSignBitCheck - Given an exploded icmp instruction, return true if the +/// comparison only checks the sign bit. If it only checks the sign bit, set +/// TrueIfSigned if the result of the comparison is true when the input value is +/// signed. +static bool isSignBitCheck(ICmpInst::Predicate pred, ConstantInt *RHS, + bool &TrueIfSigned) { switch (pred) { - case ICmpInst::ICMP_SLT: - // True if LHS s< RHS and RHS == 0 - return RHS->isZero(); - case ICmpInst::ICMP_SLE: - // True if LHS s<= RHS and RHS == -1 - return RHS->isAllOnesValue(); - case ICmpInst::ICMP_UGE: - // True if LHS u>= RHS and RHS == high-bit-mask (2^7, 2^15, 2^31, etc) - return RHS->getValue() == - APInt::getSignBit(RHS->getType()->getPrimitiveSizeInBits()); - case ICmpInst::ICMP_UGT: - // True if LHS u> RHS and RHS == high-bit-mask - 1 - return RHS->getValue() == - APInt::getSignedMaxValue(RHS->getType()->getPrimitiveSizeInBits()); - default: - return false; + case ICmpInst::ICMP_SLT: // True if LHS s< 0 + TrueIfSigned = true; + return RHS->isZero(); + case ICmpInst::ICMP_SGT: // True if LHS s> -1 + TrueIfSigned = false; + return RHS->isAllOnesValue(); + default: + return false; } } @@ -2377,11 +2372,13 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (ICmpInst *SCI = dyn_cast(BoolCast->getOperand(0))) { Value *SCIOp0 = SCI->getOperand(0), *SCIOp1 = SCI->getOperand(1); const Type *SCOpTy = SCIOp0->getType(); - + bool TIS = false; + // If the icmp is true iff the sign bit of X is set, then convert this // multiply into a shift/and combination. if (isa(SCIOp1) && - isSignBitCheck(SCI->getPredicate(), cast(SCIOp1))) { + isSignBitCheck(SCI->getPredicate(), cast(SCIOp1), TIS) && + TIS) { // Shift the X value right to turn it into "all signbits". Constant *Amt = ConstantInt::get(SCIOp0->getType(), SCOpTy->getPrimitiveSizeInBits()-1); @@ -2805,23 +2802,19 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { // isMaxValueMinusOne - return true if this is Max-1 static bool isMaxValueMinusOne(const ConstantInt *C, bool isSigned) { uint32_t TypeBits = C->getType()->getPrimitiveSizeInBits(); - if (isSigned) { - // Calculate 0111111111..11111 - APInt Val(APInt::getSignedMaxValue(TypeBits)); - return C->getValue() == Val-1; - } - return C->getValue() == APInt::getAllOnesValue(TypeBits) - 1; + if (!isSigned) + return C->getValue() == APInt::getAllOnesValue(TypeBits) - 1; + return C->getValue() == APInt::getSignedMaxValue(TypeBits)-1; } // isMinValuePlusOne - return true if this is Min+1 static bool isMinValuePlusOne(const ConstantInt *C, bool isSigned) { - if (isSigned) { - // Calculate 1111111111000000000000 - uint32_t TypeBits = C->getType()->getPrimitiveSizeInBits(); - APInt Val(APInt::getSignedMinValue(TypeBits)); - return C->getValue() == Val+1; - } - return C->getValue() == 1; // unsigned + if (!isSigned) + return C->getValue() == 1; // unsigned + + // Calculate 1111111111000000000000 + uint32_t TypeBits = C->getType()->getPrimitiveSizeInBits(); + return C->getValue() == APInt::getSignedMinValue(TypeBits)+1; } // isOneBitSet - Return true if there is exactly one bit set in the specified @@ -5415,85 +5408,105 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, } break; - case Instruction::Shl: // (icmp pred (shl X, ShAmt), CI) - if (ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1))) { - if (ICI.isEquality()) { - uint32_t TypeBits = RHSV.getBitWidth(); + case Instruction::Shl: { // (icmp pred (shl X, ShAmt), CI) + ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1)); + if (!ShAmt) break; + + uint32_t TypeBits = RHSV.getBitWidth(); + + // Check that the shift amount is in range. If not, don't perform + // undefined shifts. When the shift is visited it will be + // simplified. + if (ShAmt->uge(TypeBits)) + break; + + if (ICI.isEquality()) { + // If we are comparing against bits always shifted out, the + // comparison cannot succeed. + Constant *Comp = + ConstantExpr::getShl(ConstantExpr::getLShr(RHS, ShAmt), ShAmt); + if (Comp != RHS) {// Comparing against a bit that we know is zero. + bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + Constant *Cst = ConstantInt::get(Type::Int1Ty, IsICMP_NE); + return ReplaceInstUsesWith(ICI, Cst); + } + + if (LHSI->hasOneUse()) { + // Otherwise strength reduce the shift into an and. + uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); + Constant *Mask = + ConstantInt::get(APInt::getLowBitsSet(TypeBits, TypeBits-ShAmtVal)); - // Check that the shift amount is in range. If not, don't perform - // undefined shifts. When the shift is visited it will be - // simplified. - if (ShAmt->uge(TypeBits)) - break; - - // If we are comparing against bits always shifted out, the - // comparison cannot succeed. - Constant *Comp = - ConstantExpr::getShl(ConstantExpr::getLShr(RHS, ShAmt), ShAmt); - if (Comp != RHS) {// Comparing against a bit that we know is zero. - bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - Constant *Cst = ConstantInt::get(Type::Int1Ty, IsICMP_NE); - return ReplaceInstUsesWith(ICI, Cst); - } - - if (LHSI->hasOneUse()) { - // Otherwise strength reduce the shift into an and. - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); - Constant *Mask = - ConstantInt::get(APInt::getLowBitsSet(TypeBits, TypeBits-ShAmtVal)); - - Instruction *AndI = - BinaryOperator::createAnd(LHSI->getOperand(0), - Mask, LHSI->getName()+".mask"); - Value *And = InsertNewInstBefore(AndI, ICI); - return new ICmpInst(ICI.getPredicate(), And, - ConstantInt::get(RHSV.lshr(ShAmtVal))); - } + Instruction *AndI = + BinaryOperator::createAnd(LHSI->getOperand(0), + Mask, LHSI->getName()+".mask"); + Value *And = InsertNewInstBefore(AndI, ICI); + return new ICmpInst(ICI.getPredicate(), And, + ConstantInt::get(RHSV.lshr(ShAmtVal))); } } + + // Otherwise, if this is a comparison of the sign bit, simplify to and/test. + bool TrueIfSigned = false; + if (LHSI->hasOneUse() && + isSignBitCheck(ICI.getPredicate(), RHS, TrueIfSigned)) { + // (X << 31) (X&1) != 0 + Constant *Mask = ConstantInt::get(APInt(TypeBits, 1) << + (TypeBits-ShAmt->getZExtValue()-1)); + Instruction *AndI = + BinaryOperator::createAnd(LHSI->getOperand(0), + Mask, LHSI->getName()+".mask"); + Value *And = InsertNewInstBefore(AndI, ICI); + + return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, + And, Constant::getNullValue(And->getType())); + } break; + } case Instruction::LShr: // (icmp pred (shr X, ShAmt), CI) - case Instruction::AShr: - if (ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1))) { - if (ICI.isEquality()) { - // Check that the shift amount is in range. If not, don't perform - // undefined shifts. When the shift is visited it will be - // simplified. - uint32_t TypeBits = RHSV.getBitWidth(); - if (ShAmt->uge(TypeBits)) - break; - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); + case Instruction::AShr: { + ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1)); + if (!ShAmt) break; + + if (ICI.isEquality()) { + // Check that the shift amount is in range. If not, don't perform + // undefined shifts. When the shift is visited it will be + // simplified. + uint32_t TypeBits = RHSV.getBitWidth(); + if (ShAmt->uge(TypeBits)) + break; + uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); + + // If we are comparing against bits always shifted out, the + // comparison cannot succeed. + APInt Comp = RHSV << ShAmtVal; + if (LHSI->getOpcode() == Instruction::LShr) + Comp = Comp.lshr(ShAmtVal); + else + Comp = Comp.ashr(ShAmtVal); + + if (Comp != RHSV) { // Comparing against a bit that we know is zero. + bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + Constant *Cst = ConstantInt::get(Type::Int1Ty, IsICMP_NE); + return ReplaceInstUsesWith(ICI, Cst); + } + + if (LHSI->hasOneUse() || RHSV == 0) { + // Otherwise strength reduce the shift into an and. + APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); + Constant *Mask = ConstantInt::get(Val); - // If we are comparing against bits always shifted out, the - // comparison cannot succeed. - APInt Comp = RHSV << ShAmtVal; - if (LHSI->getOpcode() == Instruction::LShr) - Comp = Comp.lshr(ShAmtVal); - else - Comp = Comp.ashr(ShAmtVal); - - if (Comp != RHSV) { // Comparing against a bit that we know is zero. - bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - Constant *Cst = ConstantInt::get(Type::Int1Ty, IsICMP_NE); - return ReplaceInstUsesWith(ICI, Cst); - } - - if (LHSI->hasOneUse() || RHSV == 0) { - // Otherwise strength reduce the shift into an and. - APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); - Constant *Mask = ConstantInt::get(Val); - - Instruction *AndI = - BinaryOperator::createAnd(LHSI->getOperand(0), - Mask, LHSI->getName()+".mask"); - Value *And = InsertNewInstBefore(AndI, ICI); - return new ICmpInst(ICI.getPredicate(), And, - ConstantExpr::getShl(RHS, ShAmt)); - } + Instruction *AndI = + BinaryOperator::createAnd(LHSI->getOperand(0), + Mask, LHSI->getName()+".mask"); + Value *And = InsertNewInstBefore(AndI, ICI); + return new ICmpInst(ICI.getPredicate(), And, + ConstantExpr::getShl(RHS, ShAmt)); } } break; + } case Instruction::SDiv: case Instruction::UDiv: diff --git a/test/Transforms/InstCombine/shift-simplify.ll b/test/Transforms/InstCombine/shift-simplify.ll index 639c3fae5c6..4c846127482 100644 --- a/test/Transforms/InstCombine/shift-simplify.ll +++ b/test/Transforms/InstCombine/shift-simplify.ll @@ -21,3 +21,10 @@ define i32 @test2(i32 %A, i32 %B, i32 %C) { %Z = xor i32 %X, %Y ret i32 %Z } + +define i1 @test3(i32 %X) { + %tmp1 = shl i32 %X, 7 + %tmp2 = icmp slt i32 %tmp1, 0 + ret i1 %tmp2 +} +