diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index cebf6183590..270f489f682 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -115,7 +115,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, return CanEvaluateShifted(I->getOperand(0), NumBits, isLeftShift, IC) && CanEvaluateShifted(I->getOperand(1), NumBits, isLeftShift, IC); - case Instruction::Shl: + case Instruction::Shl: { // We can often fold the shift into shifts-by-a-constant. CI = dyn_cast(I->getOperand(1)); if (CI == 0) return false; @@ -125,10 +125,21 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // We can always turn shl(c)+shr(c) -> and(c2). if (CI->getValue() == NumBits) return true; - // We can always turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but it isn't + + unsigned TypeWidth = I->getType()->getScalarSizeInBits(); + + // We can turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but it isn't // profitable unless we know the and'd out bits are already zero. + if (CI->getZExtValue() > NumBits) { + unsigned HighBits = CI->getZExtValue() - NumBits; + if (MaskedValueIsZero(I->getOperand(0), + APInt::getHighBitsSet(TypeWidth, HighBits))) + return true; + } + return false; - case Instruction::LShr: + } + case Instruction::LShr: { // We can often fold the shift into shifts-by-a-constant. CI = dyn_cast(I->getOperand(1)); if (CI == 0) return false; @@ -139,10 +150,19 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // We can always turn lshr(c)+shl(c) -> and(c2). if (CI->getValue() == NumBits) return true; + unsigned TypeWidth = I->getType()->getScalarSizeInBits(); + // We can always turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but it isn't // profitable unless we know the and'd out bits are already zero. - return false; + if (CI->getZExtValue() > NumBits) { + unsigned LowBits = CI->getZExtValue() - NumBits; + if (MaskedValueIsZero(I->getOperand(0), + APInt::getLowBitsSet(TypeWidth, LowBits))) + return true; + } + return false; + } case Instruction::Select: { SelectInst *SI = cast(I); return CanEvaluateShifted(SI->getTrueValue(), NumBits, isLeftShift, IC) && @@ -209,16 +229,23 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, // We turn shl(c)+lshr(c) -> and(c2) if the input doesn't already have // zeros. - assert(CI->getValue() == NumBits); - - APInt Mask(APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits)); - V = IC.Builder->CreateAnd(I->getOperand(0), - ConstantInt::get(I->getContext(), Mask)); - if (Instruction *VI = dyn_cast(V)) { - VI->moveBefore(I); - VI->takeName(I); + if (CI->getValue() == NumBits) { + APInt Mask(APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits)); + V = IC.Builder->CreateAnd(I->getOperand(0), + ConstantInt::get(I->getContext(), Mask)); + if (Instruction *VI = dyn_cast(V)) { + VI->moveBefore(I); + VI->takeName(I); + } + return V; } - return V; + + // We turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but only when we know that + // the and won't be needed. + assert(CI->getZExtValue() > NumBits); + I->setOperand(1, ConstantInt::get(I->getType(), + CI->getZExtValue() - NumBits)); + return I; } case Instruction::LShr: { unsigned TypeWidth = I->getType()->getScalarSizeInBits(); @@ -238,16 +265,23 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, // We turn lshr(c)+shl(c) -> and(c2) if the input doesn't already have // zeros. - assert(CI->getValue() == NumBits); - - APInt Mask(APInt::getHighBitsSet(TypeWidth, TypeWidth - NumBits)); - V = IC.Builder->CreateAnd(I->getOperand(0), - ConstantInt::get(I->getContext(), Mask)); - if (Instruction *VI = dyn_cast(V)) { - VI->moveBefore(I); - VI->takeName(I); + if (CI->getValue() == NumBits) { + APInt Mask(APInt::getHighBitsSet(TypeWidth, TypeWidth - NumBits)); + V = IC.Builder->CreateAnd(I->getOperand(0), + ConstantInt::get(I->getContext(), Mask)); + if (Instruction *VI = dyn_cast(V)) { + VI->moveBefore(I); + VI->takeName(I); + } + return V; } - return V; + + // We turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but only when we know that + // the and won't be needed. + assert(CI->getZExtValue() > NumBits); + I->setOperand(1, ConstantInt::get(I->getType(), + CI->getZExtValue() - NumBits)); + return I; } case Instruction::Select: diff --git a/test/Transforms/InstCombine/shift.ll b/test/Transforms/InstCombine/shift.ll index 91a8ed7dda0..34e1835d74a 100644 --- a/test/Transforms/InstCombine/shift.ll +++ b/test/Transforms/InstCombine/shift.ll @@ -425,3 +425,18 @@ entry: ; CHECK: ret i128 %ins } +define i64 @test37(i128 %A, i32 %B) { +entry: + %tmp27 = shl i128 %A, 64 + %tmp22 = zext i32 %B to i128 + %tmp23 = shl i128 %tmp22, 96 + %ins = or i128 %tmp23, %tmp27 + %tmp45 = lshr i128 %ins, 64 + %tmp46 = trunc i128 %tmp45 to i64 + ret i64 %tmp46 + +; CHECK: %tmp23 = shl i128 %tmp22, 32 +; CHECK: %ins = or i128 %tmp23, %A +; CHECK: %tmp46 = trunc i128 %ins to i64 +} +