From 81ebdcf7dd3c182f5fca352b59b336f79d5da23e Mon Sep 17 00:00:00 2001 From: Evan Cheng Date: Tue, 10 Nov 2009 21:14:05 +0000 Subject: [PATCH] Generalize lsr code that optimize loop to count down towards zero. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@86715 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/Scalar/LoopStrengthReduce.cpp | 206 +++++++++++------- .../LoopStrengthReduce/count-to-zero.ll | 42 ++++ 2 files changed, 168 insertions(+), 80 deletions(-) create mode 100644 test/Transforms/LoopStrengthReduce/count-to-zero.ll diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 4b82dbfd539..aad512448ca 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -51,6 +51,7 @@ STATISTIC(NumEliminated, "Number of strides eliminated"); STATISTIC(NumShadow, "Number of Shadow IVs optimized"); STATISTIC(NumImmSunk, "Number of common expr immediates sunk into uses"); STATISTIC(NumLoopCond, "Number of loop terminating conds optimized"); +STATISTIC(NumCountZero, "Number of count iv optimized to count toward zero"); static cl::opt EnableFullLSRMode("enable-full-lsr", cl::init(false), @@ -136,7 +137,8 @@ namespace { const SCEV *const * &CondStride); void OptimizeIndvars(Loop *L); - void OptimizeLoopCountIV(Loop *L); + void OptimizeLoopCountIV(const SCEV *Stride, + IVUsersOfOneStride &Uses, Loop *L); void OptimizeLoopTermCond(Loop *L); /// OptimizeShadowIV - If IV is used in a int-to-float cast @@ -1519,8 +1521,8 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEV *const &Stride, // have the full access expression to rewrite the use. std::vector UsersToProcess; const SCEV *CommonExprs = CollectIVUsers(Stride, Uses, L, AllUsesAreAddresses, - AllUsesAreOutsideLoop, - UsersToProcess); + AllUsesAreOutsideLoop, + UsersToProcess); // Sort the UsersToProcess array so that users with common bases are // next to each other. @@ -1593,8 +1595,8 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEV *const &Stride, Type::getInt32Ty(Preheader->getContext())), 0); - /// Choose a strength-reduction strategy and prepare for it by creating - /// the necessary PHIs and adjusting the bookkeeping. + // Choose a strength-reduction strategy and prepare for it by creating + // the necessary PHIs and adjusting the bookkeeping. if (ShouldUseFullStrengthReductionMode(UsersToProcess, L, AllUsesAreAddresses, Stride)) { PrepareToStrengthReduceFully(UsersToProcess, Stride, CommonExprs, L, @@ -2418,107 +2420,142 @@ void LoopStrengthReduce::OptimizeLoopTermCond(Loop *L) { ++NumLoopCond; } +/// isUsedByExitBranch - Return true if icmp is used by a loop terminating +/// conditional branch or it's and / or with other conditions before being used +/// as the condition. +static bool isUsedByExitBranch(ICmpInst *Cond, Loop *L) { + BasicBlock *CondBB = Cond->getParent(); + if (!L->isLoopExiting(CondBB)) + return false; + BranchInst *TermBr = dyn_cast(CondBB->getTerminator()); + if (!TermBr->isConditional()) + return false; + + Value *User = *Cond->use_begin(); + Instruction *UserInst = dyn_cast(User); + while (UserInst && + (UserInst->getOpcode() == Instruction::And || + UserInst->getOpcode() == Instruction::Or)) { + if (!UserInst->hasOneUse() || UserInst->getParent() != CondBB) + return false; + User = *User->use_begin(); + UserInst = dyn_cast(User); + } + return User == TermBr; +} + /// OptimizeLoopCountIV - If, after all sharing of IVs, the IV used for deciding /// when to exit the loop is used only for that purpose, try to rearrange things -/// so it counts down to a test against zero. -void LoopStrengthReduce::OptimizeLoopCountIV(Loop *L) { - - // If the number of times the loop is executed isn't computable, give up. - const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); - if (isa(BackedgeTakenCount)) +/// so it counts down to a test against zero which tends to be cheaper. +void LoopStrengthReduce::OptimizeLoopCountIV(const SCEV *Stride, + IVUsersOfOneStride &Uses, + Loop *L) { + if (Uses.Users.size() != 1) return; - // Get the terminating condition for the loop if possible (this isn't - // necessarily in the latch, or a block that's a predecessor of the header). - if (!L->getExitBlock()) - return; // More than one loop exit blocks. - - // Okay, there is one exit block. Try to find the condition that causes the - // loop to be exited. - BasicBlock *ExitingBlock = L->getExitingBlock(); - if (!ExitingBlock) - return; // More than one block exiting! - - // Okay, we've computed the exiting block. See what condition causes us to - // exit. - // - // FIXME: we should be able to handle switch instructions (with a single exit) - BranchInst *TermBr = dyn_cast(ExitingBlock->getTerminator()); - if (TermBr == 0) return; - assert(TermBr->isConditional() && "If unconditional, it can't be in loop!"); - if (!isa(TermBr->getCondition())) + // If the only use is an icmp of an loop exiting conditional branch, then + // attempts the optimization. + BasedUser User = BasedUser(*Uses.Users.begin(), SE); + Instruction *Inst = User.Inst; + if (!L->contains(Inst->getParent())) return; - ICmpInst *Cond = cast(TermBr->getCondition()); - // Handle only tests for equality for the moment, and only stride 1. - if (Cond->getPredicate() != CmpInst::ICMP_EQ) + ICmpInst *Cond = dyn_cast(Inst); + if (!Cond) return; - const SCEV *IV = SE->getSCEV(Cond->getOperand(0)); + // Handle only tests for equality for the moment. + if (Cond->getPredicate() != CmpInst::ICMP_EQ || !Cond->hasOneUse()) + return; + if (!isUsedByExitBranch(Cond, L)) + return; + + Value *CondOp0 = Cond->getOperand(0); + const SCEV *IV = SE->getSCEV(CondOp0); const SCEVAddRecExpr *AR = dyn_cast(IV); - const SCEV *One = SE->getIntegerSCEV(1, BackedgeTakenCount->getType()); - if (!AR || !AR->isAffine() || AR->getStepRecurrence(*SE) != One) + if (!AR || !AR->isAffine()) + return; + + const SCEVConstant *SC = dyn_cast(AR->getStepRecurrence(*SE)); + if (!SC || SC->getValue()->getSExtValue() < 0) + // If it's already counting down, don't do anything. + return; + + // If the RHS of the comparison is not an loop invariant, the rewrite + // cannot be done. Also bail out if it's already comparing against a zero. + Value *RHS = Cond->getOperand(1); + if (!L->isLoopInvariant(RHS) || + (isa(RHS) && cast(RHS)->isZero())) return; - // If the RHS of the comparison is defined inside the loop, the rewrite - // cannot be done. - if (Instruction *CR = dyn_cast(Cond->getOperand(1))) - if (L->contains(CR->getParent())) - return; // Make sure the IV is only used for counting. Value may be preinc or // postinc; 2 uses in either case. - if (!Cond->getOperand(0)->hasNUses(2)) + if (!CondOp0->hasNUses(2)) return; - PHINode *phi = dyn_cast(Cond->getOperand(0)); - Instruction *incr; - if (phi && phi->getParent()==L->getHeader()) { - // value tested is preinc. Find the increment. - // A CmpInst is not a BinaryOperator; we depend on this. - Instruction::use_iterator UI = phi->use_begin(); - incr = dyn_cast(UI); - if (!incr) - incr = dyn_cast(++UI); - // 1 use for postinc value, the phi. Unnecessarily conservative? - if (!incr || !incr->hasOneUse() || incr->getOpcode()!=Instruction::Add) - return; - } else { - // Value tested is postinc. Find the phi node. - incr = dyn_cast(Cond->getOperand(0)); - if (!incr || incr->getOpcode()!=Instruction::Add) + PHINode *PHIExpr; + Instruction *Incr; + if (User.isUseOfPostIncrementedValue) { + // Value tested is postinc. Find the phi node. + Incr = dyn_cast(CondOp0); + if (!Incr || Incr->getOpcode() != Instruction::Add) return; - Instruction::use_iterator UI = Cond->getOperand(0)->use_begin(); - phi = dyn_cast(UI); - if (!phi) - phi = dyn_cast(++UI); + Instruction::use_iterator UI = CondOp0->use_begin(); + PHIExpr = dyn_cast(UI); + if (!PHIExpr) + PHIExpr = dyn_cast(++UI); // 1 use for preinc value, the increment. - if (!phi || phi->getParent()!=L->getHeader() || !phi->hasOneUse()) + if (!PHIExpr || !PHIExpr->hasOneUse()) + return; + } else { + assert(isa(CondOp0) && + "Unexpected loop exiting counting instruction sequence!"); + PHIExpr = cast(CondOp0); + // Value tested is preinc. Find the increment. + // A CmpInst is not a BinaryOperator; we depend on this. + Instruction::use_iterator UI = PHIExpr->use_begin(); + Incr = dyn_cast(UI); + if (!Incr) + Incr = dyn_cast(++UI); + // One use for postinc value, the phi. Unnecessarily conservative? + if (!Incr || !Incr->hasOneUse() || Incr->getOpcode() != Instruction::Add) return; } // Replace the increment with a decrement. - BinaryOperator *decr = - BinaryOperator::Create(Instruction::Sub, incr->getOperand(0), - incr->getOperand(1), "tmp", incr); - incr->replaceAllUsesWith(decr); - incr->eraseFromParent(); + DEBUG(errs() << " Examining "); + if (User.isUseOfPostIncrementedValue) + DEBUG(errs() << "postinc"); + else + DEBUG(errs() << "preinc"); + DEBUG(errs() << " use "); + DEBUG(WriteAsOperand(errs(), CondOp0, /*PrintType=*/false)); + DEBUG(errs() << " in Inst: " << *Inst << '\n'); + BinaryOperator *Decr = BinaryOperator::Create(Instruction::Sub, + Incr->getOperand(0), Incr->getOperand(1), "tmp", Incr); + Incr->replaceAllUsesWith(Decr); + Incr->eraseFromParent(); // Substitute endval-startval for the original startval, and 0 for the // original endval. Since we're only testing for equality this is OK even // if the computation wraps around. BasicBlock *Preheader = L->getLoopPreheader(); Instruction *PreInsertPt = Preheader->getTerminator(); - int inBlock = L->contains(phi->getIncomingBlock(0)) ? 1 : 0; - Value *startVal = phi->getIncomingValue(inBlock); - Value *endVal = Cond->getOperand(1); - // FIXME check for case where both are constant + unsigned InBlock = L->contains(PHIExpr->getIncomingBlock(0)) ? 1 : 0; + Value *StartVal = PHIExpr->getIncomingValue(InBlock); + Value *EndVal = Cond->getOperand(1); + DEBUG(errs() << " Optimize loop counting iv to count down [" + << *EndVal << " .. " << *StartVal << "]\n"); + + // FIXME: check for case where both are constant. Constant* Zero = ConstantInt::get(Cond->getOperand(1)->getType(), 0); - BinaryOperator *NewStartVal = - BinaryOperator::Create(Instruction::Sub, endVal, startVal, - "tmp", PreInsertPt); - phi->setIncomingValue(inBlock, NewStartVal); + BinaryOperator *NewStartVal = BinaryOperator::Create(Instruction::Sub, + EndVal, StartVal, "tmp", PreInsertPt); + PHIExpr->setIncomingValue(InBlock, NewStartVal); Cond->setOperand(1, Zero); + DEBUG(errs() << " New icmp: " << *Cond << "\n"); Changed = true; + ++NumCountZero; } bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager &LPM) { @@ -2581,11 +2618,20 @@ bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager &LPM) { continue; StrengthReduceStridedIVUsers(SI->first, *SI->second, L); } - } - // After all sharing is done, see if we can adjust the loop to test against - // zero instead of counting up to a maximum. This is usually faster. - OptimizeLoopCountIV(L); + // After all sharing is done, see if we can adjust the loop to test against + // zero instead of counting up to a maximum. This is usually faster. + for (unsigned Stride = 0, e = IU->StrideOrder.size(); + Stride != e; ++Stride) { + std::map::iterator SI = + IU->IVUsesByStride.find(IU->StrideOrder[Stride]); + assert(SI != IU->IVUsesByStride.end() && "Stride doesn't exist!"); + // FIXME: Generalize to non-affine IV's. + if (!SI->first->isLoopInvariant(L)) + continue; + OptimizeLoopCountIV(SI->first, *SI->second, L); + } + } // We're done analyzing this loop; release all the state we built up for it. IVsByStride.clear(); diff --git a/test/Transforms/LoopStrengthReduce/count-to-zero.ll b/test/Transforms/LoopStrengthReduce/count-to-zero.ll new file mode 100644 index 00000000000..8cc3b5c1034 --- /dev/null +++ b/test/Transforms/LoopStrengthReduce/count-to-zero.ll @@ -0,0 +1,42 @@ +; RUN: opt < %s -loop-reduce -S | FileCheck %s +; rdar://7382068 + +define void @t(i32 %c) nounwind optsize { +entry: + br label %bb6 + +bb1: ; preds = %bb6 + %tmp = icmp eq i32 %c_addr.1, 20 ; [#uses=1] + br i1 %tmp, label %bb2, label %bb3 + +bb2: ; preds = %bb1 + %tmp1 = tail call i32 @f20(i32 %c_addr.1) nounwind ; [#uses=1] + br label %bb7 + +bb3: ; preds = %bb1 + %tmp2 = icmp slt i32 %c_addr.1, 10 ; [#uses=1] + %tmp3 = add nsw i32 %c_addr.1, 1 ; [#uses=1] + %tmp4 = add i32 %c_addr.1, -1 ; [#uses=1] + %c_addr.1.be = select i1 %tmp2, i32 %tmp3, i32 %tmp4 ; [#uses=1] + %indvar.next = add i32 %indvar, 1 ; [#uses=1] +; CHECK: sub i32 %lsr.iv, 1 + br label %bb6 + +bb6: ; preds = %bb3, %entry + %indvar = phi i32 [ %indvar.next, %bb3 ], [ 0, %entry ] ; [#uses=2] + %c_addr.1 = phi i32 [ %c_addr.1.be, %bb3 ], [ %c, %entry ] ; [#uses=7] + %tmp5 = icmp eq i32 %indvar, 9999 ; [#uses=1] +; CHECK: icmp eq i32 %lsr.iv, 0 + %tmp6 = icmp eq i32 %c_addr.1, 100 ; [#uses=1] + %or.cond = or i1 %tmp5, %tmp6 ; [#uses=1] + br i1 %or.cond, label %bb7, label %bb1 + +bb7: ; preds = %bb6, %bb2 + %c_addr.0 = phi i32 [ %tmp1, %bb2 ], [ %c_addr.1, %bb6 ] ; [#uses=1] + tail call void @bar(i32 %c_addr.0) nounwind + ret void +} + +declare i32 @f20(i32) + +declare void @bar(i32)