diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index 083dd520b4a..8ebd2cca612 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -236,6 +236,8 @@ namespace { Instruction *visitCallSite(CallSite CS); bool transformConstExprCastCall(CallSite CS); Instruction *transformCallThroughTrampoline(CallSite CS); + Instruction *transformZExtICmp(ICmpInst *ICI, Instruction &CI, + bool DoXform = true); public: // InsertNewInstBefore - insert an instruction New before instruction Old @@ -4363,18 +4365,22 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (CastInst *Op0C = dyn_cast(Op0)) { if (CastInst *Op1C = dyn_cast(Op1)) if (Op0C->getOpcode() == Op1C->getOpcode()) {// same cast kind ? - const Type *SrcTy = Op0C->getOperand(0)->getType(); - if (SrcTy == Op1C->getOperand(0)->getType() && SrcTy->isInteger() && - // Only do this if the casts both really cause code to be generated. - ValueRequiresCast(Op0C->getOpcode(), Op0C->getOperand(0), - I.getType(), TD) && - ValueRequiresCast(Op1C->getOpcode(), Op1C->getOperand(0), - I.getType(), TD)) { - Instruction *NewOp = BinaryOperator::createOr(Op0C->getOperand(0), - Op1C->getOperand(0), - I.getName()); - InsertNewInstBefore(NewOp, I); - return CastInst::create(Op0C->getOpcode(), NewOp, I.getType()); + if (!isa(Op0C->getOperand(0)) || + !isa(Op1C->getOperand(0))) { + const Type *SrcTy = Op0C->getOperand(0)->getType(); + if (SrcTy == Op1C->getOperand(0)->getType() && SrcTy->isInteger() && + // Only do this if the casts both really cause code to be + // generated. + ValueRequiresCast(Op0C->getOpcode(), Op0C->getOperand(0), + I.getType(), TD) && + ValueRequiresCast(Op1C->getOpcode(), Op1C->getOperand(0), + I.getType(), TD)) { + Instruction *NewOp = BinaryOperator::createOr(Op0C->getOperand(0), + Op1C->getOperand(0), + I.getName()); + InsertNewInstBefore(NewOp, I); + return CastInst::create(Op0C->getOpcode(), NewOp, I.getType()); + } } } } @@ -7188,6 +7194,101 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { return 0; } +/// transformZExtICmp - Transform (zext icmp) to bitwise / integer operations +/// in order to eliminate the icmp. +Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, + bool DoXform) { + // If we are just checking for a icmp eq of a single bit and zext'ing it + // to an integer, then shift the bit to the appropriate place and then + // cast to integer to avoid the comparison. + if (ConstantInt *Op1C = dyn_cast(ICI->getOperand(1))) { + const APInt &Op1CV = Op1C->getValue(); + + // zext (x x>>u31 true if signbit set. + // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. + if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV == 0) || + (ICI->getPredicate() == ICmpInst::ICMP_SGT &&Op1CV.isAllOnesValue())) { + if (!DoXform) return ICI; + + Value *In = ICI->getOperand(0); + Value *Sh = ConstantInt::get(In->getType(), + In->getType()->getPrimitiveSizeInBits()-1); + In = InsertNewInstBefore(BinaryOperator::createLShr(In, Sh, + In->getName()+".lobit"), + CI); + if (In->getType() != CI.getType()) + In = CastInst::createIntegerCast(In, CI.getType(), + false/*ZExt*/, "tmp", &CI); + + if (ICI->getPredicate() == ICmpInst::ICMP_SGT) { + Constant *One = ConstantInt::get(In->getType(), 1); + In = InsertNewInstBefore(BinaryOperator::createXor(In, One, + In->getName()+".not"), + CI); + } + + return ReplaceInstUsesWith(CI, In); + } + + + + // zext (X == 0) to i32 --> X^1 iff X has only the low bit set. + // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. + // zext (X == 1) to i32 --> X iff X has only the low bit set. + // zext (X == 2) to i32 --> X>>1 iff X has only the 2nd bit set. + // zext (X != 0) to i32 --> X iff X has only the low bit set. + // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. + // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. + // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. + if ((Op1CV == 0 || Op1CV.isPowerOf2()) && + // This only works for EQ and NE + ICI->isEquality()) { + // If Op1C some other power of two, convert: + uint32_t BitWidth = Op1C->getType()->getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + APInt TypeMask(APInt::getAllOnesValue(BitWidth)); + ComputeMaskedBits(ICI->getOperand(0), TypeMask, KnownZero, KnownOne); + + APInt KnownZeroMask(~KnownZero); + if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? + if (!DoXform) return ICI; + + bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE; + if (Op1CV != 0 && (Op1CV != KnownZeroMask)) { + // (X&4) == 2 --> false + // (X&4) != 2 --> true + Constant *Res = ConstantInt::get(Type::Int1Ty, isNE); + Res = ConstantExpr::getZExt(Res, CI.getType()); + return ReplaceInstUsesWith(CI, Res); + } + + uint32_t ShiftAmt = KnownZeroMask.logBase2(); + Value *In = ICI->getOperand(0); + if (ShiftAmt) { + // Perform a logical shr by shiftamt. + // Insert the shift to put the result in the low bit. + In = InsertNewInstBefore(BinaryOperator::createLShr(In, + ConstantInt::get(In->getType(), ShiftAmt), + In->getName()+".lobit"), CI); + } + + if ((Op1CV != 0) == isNE) { // Toggle the low bit. + Constant *One = ConstantInt::get(In->getType(), 1); + In = BinaryOperator::createXor(In, One, "tmp"); + InsertNewInstBefore(cast(In), CI); + } + + if (CI.getType() == In->getType()) + return ReplaceInstUsesWith(CI, In); + else + return CastInst::createIntegerCast(In, CI.getType(), false/*ZExt*/); + } + } + } + + return 0; +} + Instruction *InstCombiner::visitZExt(ZExtInst &CI) { // If one of the common conversion will work .. if (Instruction *Result = commonIntCastTransforms(CI)) @@ -7224,92 +7325,24 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { } } - if (ICmpInst *ICI = dyn_cast(Src)) { - // If we are just checking for a icmp eq of a single bit and zext'ing it - // to an integer, then shift the bit to the appropriate place and then - // cast to integer to avoid the comparison. - if (ConstantInt *Op1C = dyn_cast(ICI->getOperand(1))) { - const APInt &Op1CV = Op1C->getValue(); - - // zext (x x>>u31 true if signbit set. - // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. - if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV == 0) || - (ICI->getPredicate() == ICmpInst::ICMP_SGT &&Op1CV.isAllOnesValue())){ - Value *In = ICI->getOperand(0); - Value *Sh = ConstantInt::get(In->getType(), - In->getType()->getPrimitiveSizeInBits()-1); - In = InsertNewInstBefore(BinaryOperator::createLShr(In, Sh, - In->getName()+".lobit"), - CI); - if (In->getType() != CI.getType()) - In = CastInst::createIntegerCast(In, CI.getType(), - false/*ZExt*/, "tmp", &CI); + if (ICmpInst *ICI = dyn_cast(Src)) + return transformZExtICmp(ICI, CI); - if (ICI->getPredicate() == ICmpInst::ICMP_SGT) { - Constant *One = ConstantInt::get(In->getType(), 1); - In = InsertNewInstBefore(BinaryOperator::createXor(In, One, - In->getName()+".not"), - CI); - } - - return ReplaceInstUsesWith(CI, In); - } - - - - // zext (X == 0) to i32 --> X^1 iff X has only the low bit set. - // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. - // zext (X == 1) to i32 --> X iff X has only the low bit set. - // zext (X == 2) to i32 --> X>>1 iff X has only the 2nd bit set. - // zext (X != 0) to i32 --> X iff X has only the low bit set. - // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. - // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. - // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. - if ((Op1CV == 0 || Op1CV.isPowerOf2()) && - // This only works for EQ and NE - ICI->isEquality()) { - // If Op1C some other power of two, convert: - uint32_t BitWidth = Op1C->getType()->getBitWidth(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - APInt TypeMask(APInt::getAllOnesValue(BitWidth)); - ComputeMaskedBits(ICI->getOperand(0), TypeMask, KnownZero, KnownOne); - - APInt KnownZeroMask(~KnownZero); - if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? - bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE; - if (Op1CV != 0 && (Op1CV != KnownZeroMask)) { - // (X&4) == 2 --> false - // (X&4) != 2 --> true - Constant *Res = ConstantInt::get(Type::Int1Ty, isNE); - Res = ConstantExpr::getZExt(Res, CI.getType()); - return ReplaceInstUsesWith(CI, Res); - } - - uint32_t ShiftAmt = KnownZeroMask.logBase2(); - Value *In = ICI->getOperand(0); - if (ShiftAmt) { - // Perform a logical shr by shiftamt. - // Insert the shift to put the result in the low bit. - In = InsertNewInstBefore( - BinaryOperator::createLShr(In, - ConstantInt::get(In->getType(), ShiftAmt), - In->getName()+".lobit"), CI); - } - - if ((Op1CV != 0) == isNE) { // Toggle the low bit. - Constant *One = ConstantInt::get(In->getType(), 1); - In = BinaryOperator::createXor(In, One, "tmp"); - InsertNewInstBefore(cast(In), CI); - } - - if (CI.getType() == In->getType()) - return ReplaceInstUsesWith(CI, In); - else - return CastInst::createIntegerCast(In, CI.getType(), false/*ZExt*/); - } - } + BinaryOperator *SrcI = dyn_cast(Src); + if (SrcI && SrcI->getOpcode() == Instruction::Or) { + // zext (or icmp, icmp) --> or (zext icmp), (zext icmp) if at least one + // of the (zext icmp) will be transformed. + ICmpInst *LHS = dyn_cast(SrcI->getOperand(0)); + ICmpInst *RHS = dyn_cast(SrcI->getOperand(1)); + if (LHS && RHS && LHS->hasOneUse() && RHS->hasOneUse() && + (transformZExtICmp(LHS, CI, false) || + transformZExtICmp(RHS, CI, false))) { + Value *LCast = InsertCastBefore(Instruction::ZExt, LHS, CI.getType(), CI); + Value *RCast = InsertCastBefore(Instruction::ZExt, RHS, CI.getType(), CI); + return BinaryOperator::create(Instruction::Or, LCast, RCast); } - } + } + return 0; } diff --git a/test/Transforms/InstCombine/zext-or-icmp.ll b/test/Transforms/InstCombine/zext-or-icmp.ll new file mode 100644 index 00000000000..35c7c0a6be6 --- /dev/null +++ b/test/Transforms/InstCombine/zext-or-icmp.ll @@ -0,0 +1,35 @@ +; RUN: llvm-as < %s | opt -instcombine | llvm-dis | grep icmp | count 1 + + %struct.FooBar = type <{ i8, i8, [2 x i8], i8, i8, i8, i8, i16, i16, [4 x i8], [8 x %struct.Rock] }> + %struct.Rock = type { i16, i16 } +@some_idx = internal constant [4 x i8] c"\0A\0B\0E\0F" ; <[4 x i8]*> [#uses=1] + +define i8 @t(%struct.FooBar* %up, i8 zeroext %intra_flag, i32 %blk_i) zeroext nounwind { +entry: + %tmp2 = lshr i32 %blk_i, 1 ; [#uses=1] + %tmp3 = and i32 %tmp2, 2 ; [#uses=1] + %tmp5 = and i32 %blk_i, 1 ; [#uses=1] + %tmp6 = or i32 %tmp3, %tmp5 ; [#uses=1] + %tmp8 = getelementptr %struct.FooBar* %up, i32 0, i32 7 ; [#uses=1] + %tmp9 = load i16* %tmp8, align 1 ; [#uses=1] + %tmp910 = zext i16 %tmp9 to i32 ; [#uses=1] + %tmp12 = getelementptr [4 x i8]* @some_idx, i32 0, i32 %tmp6 ; [#uses=1] + %tmp13 = load i8* %tmp12, align 1 ; [#uses=1] + %tmp1314 = zext i8 %tmp13 to i32 ; [#uses=1] + %tmp151 = lshr i32 %tmp910, %tmp1314 ; [#uses=1] + %tmp1516 = trunc i32 %tmp151 to i8 ; [#uses=1] + %tmp18 = getelementptr %struct.FooBar* %up, i32 0, i32 0 ; [#uses=1] + %tmp19 = load i8* %tmp18, align 1 ; [#uses=1] + %tmp22 = and i8 %tmp1516, %tmp19 ; [#uses=1] + %tmp24 = getelementptr %struct.FooBar* %up, i32 0, i32 0 ; [#uses=1] + %tmp25 = load i8* %tmp24, align 1 ; [#uses=1] + %tmp26.mask = and i8 %tmp25, 1 ; [#uses=1] + %toBool = icmp eq i8 %tmp26.mask, 0 ; [#uses=1] + %toBool.not = xor i1 %toBool, true ; [#uses=1] + %toBool33 = icmp eq i8 %intra_flag, 0 ; [#uses=1] + %bothcond = or i1 %toBool.not, %toBool33 ; [#uses=1] + %iftmp.1.0 = select i1 %bothcond, i8 0, i8 1 ; [#uses=1] + %tmp40 = or i8 %tmp22, %iftmp.1.0 ; [#uses=1] + %tmp432 = and i8 %tmp40, 1 ; [#uses=1] + ret i8 %tmp432 +}