diff --git a/lib/Transforms/InstCombine/InstCombine.h b/lib/Transforms/InstCombine/InstCombine.h index 6c0d4e74a7a..c569839855d 100644 --- a/lib/Transforms/InstCombine/InstCombine.h +++ b/lib/Transforms/InstCombine/InstCombine.h @@ -191,6 +191,8 @@ public: ConstantInt *DivRHS); Instruction *FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, ConstantInt *CI1, ConstantInt *CI2); + Instruction *FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, ConstantInt *CI2); Instruction *FoldICmpAddOpCst(Instruction &ICI, Value *X, ConstantInt *CI, ICmpInst::Predicate Pred); Instruction *FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 2918f031b3f..fec51aa5ec6 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1119,6 +1119,49 @@ Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, return getConstant(false); } +/// FoldICmpCstShlCst - Handle "(icmp eq/ne (shl const2, A), const1)" -> +/// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)). +Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, + ConstantInt *CI2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getConstant = [&I, this](bool IsTrue) { + if (I.getPredicate() == I.ICMP_NE) + IsTrue = !IsTrue; + return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + }; + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + APInt AP1 = CI1->getValue(); + APInt AP2 = CI2->getValue(); + + assert(AP2 != 0 && "Handled in InstSimplify"); + + unsigned AP2TrailingZeros = AP2.countTrailingZeros(); + + if (!AP1 && AP2TrailingZeros != 0) + return getICmp(I.ICMP_UGE, A, + ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the lowest bits that are set. + int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + + if (Shift > 0 && AP2.shl(Shift) == AP1) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + + // Shifting const2 will never be equal to const1. + return getConstant(false); +} + /// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". /// Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, @@ -2575,13 +2618,17 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Builder->getInt(CI->getValue()-1)); } - // (icmp eq/ne (ashr/lshr const2, A), const1) if (I.isEquality()) { ConstantInt *CI2; if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) || match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) { + // (icmp eq/ne (ashr/lshr const2, A), const1) return FoldICmpCstShrCst(I, Op0, A, CI, CI2); } + if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) { + // (icmp eq/ne (shl const2, A), const1) + return FoldICmpCstShlCst(I, Op0, A, CI, CI2); + } } // If this comparison is a normal comparison, it demands all diff --git a/test/Transforms/InstCombine/icmp.ll b/test/Transforms/InstCombine/icmp.ll index 1713960ae57..343f80cfeb9 100644 --- a/test/Transforms/InstCombine/icmp.ll +++ b/test/Transforms/InstCombine/icmp.ll @@ -1418,3 +1418,64 @@ define i1 @icmp_and_or_lshr_cst(i32 %x) { %ret = icmp ne i32 %and, 0 ret i1 %ret } + +; CHECK-LABEL: @shl_ap1_zero_ap2_non_zero_2 +; CHECK-NEXT: %cmp = icmp ugt i32 %a, 29 +; CHECK-NEXT: ret i1 %cmp +define i1 @shl_ap1_zero_ap2_non_zero_2(i32 %a) { + %shl = shl i32 4, %a + %cmp = icmp eq i32 %shl, 0 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_zero_ap2_non_zero_4 +; CHECK-NEXT: %cmp = icmp ugt i32 %a, 30 +; CHECK-NEXT: ret i1 %cmp +define i1 @shl_ap1_zero_ap2_non_zero_4(i32 %a) { + %shl = shl i32 -2, %a + %cmp = icmp eq i32 %shl, 0 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_non_zero_ap2_non_zero_both_positive +; CHECK-NEXT: %cmp = icmp eq i32 %a, 0 +; CHECK-NEXT: ret i1 %cmp +define i1 @shl_ap1_non_zero_ap2_non_zero_both_positive(i32 %a) { + %shl = shl i32 50, %a + %cmp = icmp eq i32 %shl, 50 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_non_zero_ap2_non_zero_both_negative +; CHECK-NEXT: %cmp = icmp eq i32 %a, 0 +; CHECK-NEXT: ret i1 %cmp +define i1 @shl_ap1_non_zero_ap2_non_zero_both_negative(i32 %a) { + %shl = shl i32 -50, %a + %cmp = icmp eq i32 %shl, -50 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_non_zero_ap2_non_zero_ap1_1 +; CHECK-NEXT: ret i1 false +define i1 @shl_ap1_non_zero_ap2_non_zero_ap1_1(i32 %a) { + %shl = shl i32 50, %a + %cmp = icmp eq i32 %shl, 25 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_non_zero_ap2_non_zero_ap1_2 +; CHECK-NEXT: %cmp = icmp eq i32 %a, 1 +; CHECK-NEXT: ret i1 %cmp +define i1 @shl_ap1_non_zero_ap2_non_zero_ap1_2(i32 %a) { + %shl = shl i32 25, %a + %cmp = icmp eq i32 %shl, 50 + ret i1 %cmp +} + +; CHECK-LABEL: @shl_ap1_non_zero_ap2_non_zero_ap1_3 +; CHECK-NEXT: ret i1 false +define i1 @shl_ap1_non_zero_ap2_non_zero_ap1_3(i32 %a) { + %shl = shl i32 26, %a + %cmp = icmp eq i32 %shl, 50 + ret i1 %cmp +}