diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index e5c26404245..487bec6a39a 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -2051,12 +2051,13 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, ++MaxShiftAmt; IntegerType *ExtTy = IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt); - // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. if (const SCEVAddRecExpr *AR = dyn_cast(LHS)) if (const SCEVConstant *Step = - dyn_cast(AR->getStepRecurrence(*this))) - if (!Step->getValue()->getValue() - .urem(RHSC->getValue()->getValue()) && + dyn_cast(AR->getStepRecurrence(*this))) { + // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. + const APInt &StepInt = Step->getValue()->getValue(); + const APInt &DivInt = RHSC->getValue()->getValue(); + if (!StepInt.urem(DivInt) && getZeroExtendExpr(AR, ExtTy) == getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), getZeroExtendExpr(Step, ExtTy), @@ -2067,6 +2068,22 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW); } + /// Get a canonical UDivExpr for a recurrence. + /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0. + // We can currently only fold X%N if X is constant. + const SCEVConstant *StartC = dyn_cast(AR->getStart()); + if (StartC && !DivInt.urem(StepInt) && + getZeroExtendExpr(AR, ExtTy) == + getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), + getZeroExtendExpr(Step, ExtTy), + AR->getLoop(), SCEV::FlagAnyWrap)) { + const APInt &StartInt = StartC->getValue()->getValue(); + const APInt &StartRem = StartInt.urem(StepInt); + if (StartRem != 0) + LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step, + AR->getLoop(), SCEV::FlagNW); + } + } // (A*B)/C --> A*(B/C) if safe and B/C can be folded. if (const SCEVMulExpr *M = dyn_cast(LHS)) { SmallVector Operands; diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp index f2c69a258cb..e40d72979ee 100644 --- a/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -70,6 +70,7 @@ STATISTIC(NumInserted , "Number of canonical indvars added"); STATISTIC(NumReplaced , "Number of exit values replaced"); STATISTIC(NumLFTR , "Number of loop exit tests replaced"); STATISTIC(NumElimIdentity, "Number of IV identities eliminated"); +STATISTIC(NumElimOperand, "Number of IV operands folded into a use"); STATISTIC(NumElimExt , "Number of IV sign/zero extends eliminated"); STATISTIC(NumElimRem , "Number of IV remainder operations eliminated"); STATISTIC(NumElimCmp , "Number of IV comparisons eliminated"); @@ -142,6 +143,8 @@ namespace { Value *IVOperand, bool IsSigned); + bool FoldIVUser(Instruction *UseInst, Instruction *IVOperand); + void SimplifyCongruentIVs(Loop *L); void RewriteIVExpressions(Loop *L, SCEVExpander &Rewriter); @@ -1298,6 +1301,66 @@ bool IndVarSimplify::EliminateIVUser(Instruction *UseInst, return true; } +/// FoldIVUser - Fold an IV operand into its use. This removes increments of an +/// aligned IV when used by a instruction that ignores the low bits. +bool IndVarSimplify::FoldIVUser(Instruction *UseInst, Instruction *IVOperand) { + Value *IVSrc = 0; + unsigned OperIdx = 0; + const SCEV *FoldedExpr = 0; + switch (UseInst->getOpcode()) { + default: + return false; + case Instruction::UDiv: + case Instruction::LShr: + // We're only interested in the case where we know something about + // the numerator and have a constant denominator. + if (IVOperand != UseInst->getOperand(OperIdx) || + !isa(UseInst->getOperand(1))) + return false; + + // Attempt to fold a binary operator with constant operand. + // e.g. ((I + 1) >> 2) => I >> 2 + if (IVOperand->getNumOperands() != 2 || + !isa(IVOperand->getOperand(1))) + return false; + + IVSrc = IVOperand->getOperand(0); + // IVSrc must be the (SCEVable) IV, since the other operand is const. + assert(SE->isSCEVable(IVSrc->getType()) && "Expect SCEVable IV operand"); + + ConstantInt *D = cast(UseInst->getOperand(1)); + if (UseInst->getOpcode() == Instruction::LShr) { + // Get a constant for the divisor. See createSCEV. + uint32_t BitWidth = cast(UseInst->getType())->getBitWidth(); + if (D->getValue().uge(BitWidth)) + return false; + + D = ConstantInt::get(UseInst->getContext(), + APInt(BitWidth, 1).shl(D->getZExtValue())); + } + FoldedExpr = SE->getUDivExpr(SE->getSCEV(IVSrc), SE->getSCEV(D)); + } + // We have something that might fold it's operand. Compare SCEVs. + if (!SE->isSCEVable(UseInst->getType())) + return false; + + // Bypass the operand if SCEV can prove it has no effect. + if (SE->getSCEV(UseInst) != FoldedExpr) + return false; + + DEBUG(dbgs() << "INDVARS: Eliminated IV operand: " << *IVOperand + << " -> " << *UseInst << '\n'); + + UseInst->setOperand(OperIdx, IVSrc); + assert(SE->getSCEV(UseInst) == FoldedExpr && "bad SCEV with folded oper"); + + ++NumElimOperand; + Changed = true; + if (IVOperand->use_empty()) + DeadInsts.push_back(IVOperand); + return true; +} + /// pushIVUsers - Add all uses of Def to the current IV's worklist. /// static void pushIVUsers( @@ -1394,6 +1457,8 @@ void IndVarSimplify::SimplifyIVUsersNoRewrite(Loop *L, SCEVExpander &Rewriter) { // Bypass back edges to avoid extra work. if (UseOper.first == CurrIV) continue; + FoldIVUser(UseOper.first, UseOper.second); + if (EliminateIVUser(UseOper.first, UseOper.second)) { pushIVUsers(UseOper.second, Simplified, SimpleIVUsers); continue; diff --git a/test/Transforms/IndVarSimplify/iv-fold.ll b/test/Transforms/IndVarSimplify/iv-fold.ll new file mode 100644 index 00000000000..7e11cdf098b --- /dev/null +++ b/test/Transforms/IndVarSimplify/iv-fold.ll @@ -0,0 +1,56 @@ +; RUN: opt < %s -indvars -disable-iv-rewrite -S | FileCheck %s + +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n:32:64" + +; Indvars should be able to fold IV increments into shr when low bits are zero. +; +; CHECK: @foldIncShr +; CHECK: shr.1 = lshr i32 %0, 5 +define i32 @foldIncShr(i32* %bitmap, i32 %bit_addr, i32 %nbits) nounwind { +entry: + br label %while.body + +while.body: + %0 = phi i32 [ 0, %entry ], [ %inc.2, %while.body ] + %shr = lshr i32 %0, 5 + %arrayidx = getelementptr inbounds i32* %bitmap, i32 %shr + %tmp6 = load i32* %arrayidx, align 4 + %inc.1 = add i32 %0, 1 + %shr.1 = lshr i32 %inc.1, 5 + %arrayidx.1 = getelementptr inbounds i32* %bitmap, i32 %shr.1 + %tmp6.1 = load i32* %arrayidx.1, align 4 + %inc.2 = add i32 %inc.1, 1 + %exitcond.3 = icmp eq i32 %inc.2, 128 + br i1 %exitcond.3, label %while.end, label %while.body + +while.end: + %r = add i32 %tmp6, %tmp6.1 + ret i32 %r +} + +; Invdars should not fold an increment into shr unless 2^shiftBits is +; a multiple of the recurrence step. +; +; CHECK: @noFoldIncShr +; CHECK: shr.1 = lshr i32 %inc.1, 5 +define i32 @noFoldIncShr(i32* %bitmap, i32 %bit_addr, i32 %nbits) nounwind { +entry: + br label %while.body + +while.body: + %0 = phi i32 [ 0, %entry ], [ %inc.3, %while.body ] + %shr = lshr i32 %0, 5 + %arrayidx = getelementptr inbounds i32* %bitmap, i32 %shr + %tmp6 = load i32* %arrayidx, align 4 + %inc.1 = add i32 %0, 1 + %shr.1 = lshr i32 %inc.1, 5 + %arrayidx.1 = getelementptr inbounds i32* %bitmap, i32 %shr.1 + %tmp6.1 = load i32* %arrayidx.1, align 4 + %inc.3 = add i32 %inc.1, 2 + %exitcond.3 = icmp eq i32 %inc.3, 96 + br i1 %exitcond.3, label %while.end, label %while.body + +while.end: + %r = add i32 %tmp6, %tmp6.1 + ret i32 %r +}