diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index e9ffb730305..549c57633a4 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -2781,20 +2781,18 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { // If one of the operands of the multiply is a cast from a boolean value, then // we know the bool is either zero or one, so this is a 'masking' multiply. - // See if we can simplify things based on how the boolean was originally - // formed. - { - Value *BoolCast = 0, *OtherOp = 0; - if (ZExtInst *CI = dyn_cast(Op0)) - if (CI->getOperand(0)->getType() == Type::getInt1Ty(*Context)) - BoolCast = CI, OtherOp = I.getOperand(1); - if (!BoolCast) - if (ZExtInst *CI = dyn_cast(I.getOperand(1))) - if (CI->getOperand(0)->getType() == Type::getInt1Ty(*Context)) - BoolCast = CI, OtherOp = Op0; + // X * Y (where Y is 0 or 1) -> X & (0-Y) + if (!isa(I.getType())) { + // -2 is "-1 << 1" so it is all bits set except the low one. + APInt Negative2(I.getType()->getPrimitiveSizeInBits(), -2, true); + Value *BoolCast = 0, *OtherOp = 0; + if (MaskedValueIsZero(Op0, Negative2)) + BoolCast = Op0, OtherOp = I.getOperand(1); + else if (MaskedValueIsZero(I.getOperand(1), Negative2)) + BoolCast = I.getOperand(1), OtherOp = Op0; + if (BoolCast) { - // X * Y (where Y is 0 or 1) -> X & (0-Y) Value *V = Builder->CreateSub(Constant::getNullValue(I.getType()), BoolCast, "tmp"); return BinaryOperator::CreateAnd(V, OtherOp); diff --git a/test/Transforms/InstCombine/mul.ll b/test/Transforms/InstCombine/mul.ll index d8e623c4834..53a56434aed 100644 --- a/test/Transforms/InstCombine/mul.ll +++ b/test/Transforms/InstCombine/mul.ll @@ -105,5 +105,12 @@ define i32 @test16(i32 %b, i1 %c) { ret i32 %e } +; X * Y (when Y is 0 or 1) --> x & (0-Y) +define i32 @test17(i32 %a, i32 %b) { + %a.lobit = lshr i32 %a, 31 + %e = mul i32 %a.lobit, %b + ret i32 %e +} +