Generalize lsr code that optimize loop to count down towards zero.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@86715 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Evan Cheng 2009-11-10 21:14:05 +00:00
parent 14bbbf3168
commit 81ebdcf7dd
2 changed files with 168 additions and 80 deletions

View File

@ -51,6 +51,7 @@ STATISTIC(NumEliminated, "Number of strides eliminated");
STATISTIC(NumShadow, "Number of Shadow IVs optimized"); STATISTIC(NumShadow, "Number of Shadow IVs optimized");
STATISTIC(NumImmSunk, "Number of common expr immediates sunk into uses"); STATISTIC(NumImmSunk, "Number of common expr immediates sunk into uses");
STATISTIC(NumLoopCond, "Number of loop terminating conds optimized"); STATISTIC(NumLoopCond, "Number of loop terminating conds optimized");
STATISTIC(NumCountZero, "Number of count iv optimized to count toward zero");
static cl::opt<bool> EnableFullLSRMode("enable-full-lsr", static cl::opt<bool> EnableFullLSRMode("enable-full-lsr",
cl::init(false), cl::init(false),
@ -136,7 +137,8 @@ namespace {
const SCEV *const * &CondStride); const SCEV *const * &CondStride);
void OptimizeIndvars(Loop *L); void OptimizeIndvars(Loop *L);
void OptimizeLoopCountIV(Loop *L); void OptimizeLoopCountIV(const SCEV *Stride,
IVUsersOfOneStride &Uses, Loop *L);
void OptimizeLoopTermCond(Loop *L); void OptimizeLoopTermCond(Loop *L);
/// OptimizeShadowIV - If IV is used in a int-to-float cast /// OptimizeShadowIV - If IV is used in a int-to-float cast
@ -1519,8 +1521,8 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEV *const &Stride,
// have the full access expression to rewrite the use. // have the full access expression to rewrite the use.
std::vector<BasedUser> UsersToProcess; std::vector<BasedUser> UsersToProcess;
const SCEV *CommonExprs = CollectIVUsers(Stride, Uses, L, AllUsesAreAddresses, const SCEV *CommonExprs = CollectIVUsers(Stride, Uses, L, AllUsesAreAddresses,
AllUsesAreOutsideLoop, AllUsesAreOutsideLoop,
UsersToProcess); UsersToProcess);
// Sort the UsersToProcess array so that users with common bases are // Sort the UsersToProcess array so that users with common bases are
// next to each other. // next to each other.
@ -1593,8 +1595,8 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEV *const &Stride,
Type::getInt32Ty(Preheader->getContext())), Type::getInt32Ty(Preheader->getContext())),
0); 0);
/// Choose a strength-reduction strategy and prepare for it by creating // Choose a strength-reduction strategy and prepare for it by creating
/// the necessary PHIs and adjusting the bookkeeping. // the necessary PHIs and adjusting the bookkeeping.
if (ShouldUseFullStrengthReductionMode(UsersToProcess, L, if (ShouldUseFullStrengthReductionMode(UsersToProcess, L,
AllUsesAreAddresses, Stride)) { AllUsesAreAddresses, Stride)) {
PrepareToStrengthReduceFully(UsersToProcess, Stride, CommonExprs, L, PrepareToStrengthReduceFully(UsersToProcess, Stride, CommonExprs, L,
@ -2418,107 +2420,142 @@ void LoopStrengthReduce::OptimizeLoopTermCond(Loop *L) {
++NumLoopCond; ++NumLoopCond;
} }
/// isUsedByExitBranch - Return true if icmp is used by a loop terminating
/// conditional branch or it's and / or with other conditions before being used
/// as the condition.
static bool isUsedByExitBranch(ICmpInst *Cond, Loop *L) {
BasicBlock *CondBB = Cond->getParent();
if (!L->isLoopExiting(CondBB))
return false;
BranchInst *TermBr = dyn_cast<BranchInst>(CondBB->getTerminator());
if (!TermBr->isConditional())
return false;
Value *User = *Cond->use_begin();
Instruction *UserInst = dyn_cast<Instruction>(User);
while (UserInst &&
(UserInst->getOpcode() == Instruction::And ||
UserInst->getOpcode() == Instruction::Or)) {
if (!UserInst->hasOneUse() || UserInst->getParent() != CondBB)
return false;
User = *User->use_begin();
UserInst = dyn_cast<Instruction>(User);
}
return User == TermBr;
}
/// OptimizeLoopCountIV - If, after all sharing of IVs, the IV used for deciding /// OptimizeLoopCountIV - If, after all sharing of IVs, the IV used for deciding
/// when to exit the loop is used only for that purpose, try to rearrange things /// when to exit the loop is used only for that purpose, try to rearrange things
/// so it counts down to a test against zero. /// so it counts down to a test against zero which tends to be cheaper.
void LoopStrengthReduce::OptimizeLoopCountIV(Loop *L) { void LoopStrengthReduce::OptimizeLoopCountIV(const SCEV *Stride,
IVUsersOfOneStride &Uses,
// If the number of times the loop is executed isn't computable, give up. Loop *L) {
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); if (Uses.Users.size() != 1)
if (isa<SCEVCouldNotCompute>(BackedgeTakenCount))
return; return;
// Get the terminating condition for the loop if possible (this isn't // If the only use is an icmp of an loop exiting conditional branch, then
// necessarily in the latch, or a block that's a predecessor of the header). // attempts the optimization.
if (!L->getExitBlock()) BasedUser User = BasedUser(*Uses.Users.begin(), SE);
return; // More than one loop exit blocks. Instruction *Inst = User.Inst;
if (!L->contains(Inst->getParent()))
// Okay, there is one exit block. Try to find the condition that causes the
// loop to be exited.
BasicBlock *ExitingBlock = L->getExitingBlock();
if (!ExitingBlock)
return; // More than one block exiting!
// Okay, we've computed the exiting block. See what condition causes us to
// exit.
//
// FIXME: we should be able to handle switch instructions (with a single exit)
BranchInst *TermBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
if (TermBr == 0) return;
assert(TermBr->isConditional() && "If unconditional, it can't be in loop!");
if (!isa<ICmpInst>(TermBr->getCondition()))
return; return;
ICmpInst *Cond = cast<ICmpInst>(TermBr->getCondition());
// Handle only tests for equality for the moment, and only stride 1. ICmpInst *Cond = dyn_cast<ICmpInst>(Inst);
if (Cond->getPredicate() != CmpInst::ICMP_EQ) if (!Cond)
return; return;
const SCEV *IV = SE->getSCEV(Cond->getOperand(0)); // Handle only tests for equality for the moment.
if (Cond->getPredicate() != CmpInst::ICMP_EQ || !Cond->hasOneUse())
return;
if (!isUsedByExitBranch(Cond, L))
return;
Value *CondOp0 = Cond->getOperand(0);
const SCEV *IV = SE->getSCEV(CondOp0);
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(IV); const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(IV);
const SCEV *One = SE->getIntegerSCEV(1, BackedgeTakenCount->getType()); if (!AR || !AR->isAffine())
if (!AR || !AR->isAffine() || AR->getStepRecurrence(*SE) != One) return;
const SCEVConstant *SC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE));
if (!SC || SC->getValue()->getSExtValue() < 0)
// If it's already counting down, don't do anything.
return;
// If the RHS of the comparison is not an loop invariant, the rewrite
// cannot be done. Also bail out if it's already comparing against a zero.
Value *RHS = Cond->getOperand(1);
if (!L->isLoopInvariant(RHS) ||
(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
return; return;
// If the RHS of the comparison is defined inside the loop, the rewrite
// cannot be done.
if (Instruction *CR = dyn_cast<Instruction>(Cond->getOperand(1)))
if (L->contains(CR->getParent()))
return;
// Make sure the IV is only used for counting. Value may be preinc or // Make sure the IV is only used for counting. Value may be preinc or
// postinc; 2 uses in either case. // postinc; 2 uses in either case.
if (!Cond->getOperand(0)->hasNUses(2)) if (!CondOp0->hasNUses(2))
return; return;
PHINode *phi = dyn_cast<PHINode>(Cond->getOperand(0)); PHINode *PHIExpr;
Instruction *incr; Instruction *Incr;
if (phi && phi->getParent()==L->getHeader()) { if (User.isUseOfPostIncrementedValue) {
// value tested is preinc. Find the increment. // Value tested is postinc. Find the phi node.
// A CmpInst is not a BinaryOperator; we depend on this. Incr = dyn_cast<BinaryOperator>(CondOp0);
Instruction::use_iterator UI = phi->use_begin(); if (!Incr || Incr->getOpcode() != Instruction::Add)
incr = dyn_cast<BinaryOperator>(UI);
if (!incr)
incr = dyn_cast<BinaryOperator>(++UI);
// 1 use for postinc value, the phi. Unnecessarily conservative?
if (!incr || !incr->hasOneUse() || incr->getOpcode()!=Instruction::Add)
return;
} else {
// Value tested is postinc. Find the phi node.
incr = dyn_cast<BinaryOperator>(Cond->getOperand(0));
if (!incr || incr->getOpcode()!=Instruction::Add)
return; return;
Instruction::use_iterator UI = Cond->getOperand(0)->use_begin(); Instruction::use_iterator UI = CondOp0->use_begin();
phi = dyn_cast<PHINode>(UI); PHIExpr = dyn_cast<PHINode>(UI);
if (!phi) if (!PHIExpr)
phi = dyn_cast<PHINode>(++UI); PHIExpr = dyn_cast<PHINode>(++UI);
// 1 use for preinc value, the increment. // 1 use for preinc value, the increment.
if (!phi || phi->getParent()!=L->getHeader() || !phi->hasOneUse()) if (!PHIExpr || !PHIExpr->hasOneUse())
return;
} else {
assert(isa<PHINode>(CondOp0) &&
"Unexpected loop exiting counting instruction sequence!");
PHIExpr = cast<PHINode>(CondOp0);
// Value tested is preinc. Find the increment.
// A CmpInst is not a BinaryOperator; we depend on this.
Instruction::use_iterator UI = PHIExpr->use_begin();
Incr = dyn_cast<BinaryOperator>(UI);
if (!Incr)
Incr = dyn_cast<BinaryOperator>(++UI);
// One use for postinc value, the phi. Unnecessarily conservative?
if (!Incr || !Incr->hasOneUse() || Incr->getOpcode() != Instruction::Add)
return; return;
} }
// Replace the increment with a decrement. // Replace the increment with a decrement.
BinaryOperator *decr = DEBUG(errs() << " Examining ");
BinaryOperator::Create(Instruction::Sub, incr->getOperand(0), if (User.isUseOfPostIncrementedValue)
incr->getOperand(1), "tmp", incr); DEBUG(errs() << "postinc");
incr->replaceAllUsesWith(decr); else
incr->eraseFromParent(); DEBUG(errs() << "preinc");
DEBUG(errs() << " use ");
DEBUG(WriteAsOperand(errs(), CondOp0, /*PrintType=*/false));
DEBUG(errs() << " in Inst: " << *Inst << '\n');
BinaryOperator *Decr = BinaryOperator::Create(Instruction::Sub,
Incr->getOperand(0), Incr->getOperand(1), "tmp", Incr);
Incr->replaceAllUsesWith(Decr);
Incr->eraseFromParent();
// Substitute endval-startval for the original startval, and 0 for the // Substitute endval-startval for the original startval, and 0 for the
// original endval. Since we're only testing for equality this is OK even // original endval. Since we're only testing for equality this is OK even
// if the computation wraps around. // if the computation wraps around.
BasicBlock *Preheader = L->getLoopPreheader(); BasicBlock *Preheader = L->getLoopPreheader();
Instruction *PreInsertPt = Preheader->getTerminator(); Instruction *PreInsertPt = Preheader->getTerminator();
int inBlock = L->contains(phi->getIncomingBlock(0)) ? 1 : 0; unsigned InBlock = L->contains(PHIExpr->getIncomingBlock(0)) ? 1 : 0;
Value *startVal = phi->getIncomingValue(inBlock); Value *StartVal = PHIExpr->getIncomingValue(InBlock);
Value *endVal = Cond->getOperand(1); Value *EndVal = Cond->getOperand(1);
// FIXME check for case where both are constant DEBUG(errs() << " Optimize loop counting iv to count down ["
<< *EndVal << " .. " << *StartVal << "]\n");
// FIXME: check for case where both are constant.
Constant* Zero = ConstantInt::get(Cond->getOperand(1)->getType(), 0); Constant* Zero = ConstantInt::get(Cond->getOperand(1)->getType(), 0);
BinaryOperator *NewStartVal = BinaryOperator *NewStartVal = BinaryOperator::Create(Instruction::Sub,
BinaryOperator::Create(Instruction::Sub, endVal, startVal, EndVal, StartVal, "tmp", PreInsertPt);
"tmp", PreInsertPt); PHIExpr->setIncomingValue(InBlock, NewStartVal);
phi->setIncomingValue(inBlock, NewStartVal);
Cond->setOperand(1, Zero); Cond->setOperand(1, Zero);
DEBUG(errs() << " New icmp: " << *Cond << "\n");
Changed = true; Changed = true;
++NumCountZero;
} }
bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager &LPM) { bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager &LPM) {
@ -2581,11 +2618,20 @@ bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager &LPM) {
continue; continue;
StrengthReduceStridedIVUsers(SI->first, *SI->second, L); StrengthReduceStridedIVUsers(SI->first, *SI->second, L);
} }
}
// After all sharing is done, see if we can adjust the loop to test against // After all sharing is done, see if we can adjust the loop to test against
// zero instead of counting up to a maximum. This is usually faster. // zero instead of counting up to a maximum. This is usually faster.
OptimizeLoopCountIV(L); for (unsigned Stride = 0, e = IU->StrideOrder.size();
Stride != e; ++Stride) {
std::map<const SCEV *, IVUsersOfOneStride *>::iterator SI =
IU->IVUsesByStride.find(IU->StrideOrder[Stride]);
assert(SI != IU->IVUsesByStride.end() && "Stride doesn't exist!");
// FIXME: Generalize to non-affine IV's.
if (!SI->first->isLoopInvariant(L))
continue;
OptimizeLoopCountIV(SI->first, *SI->second, L);
}
}
// We're done analyzing this loop; release all the state we built up for it. // We're done analyzing this loop; release all the state we built up for it.
IVsByStride.clear(); IVsByStride.clear();

View File

@ -0,0 +1,42 @@
; RUN: opt < %s -loop-reduce -S | FileCheck %s
; rdar://7382068
define void @t(i32 %c) nounwind optsize {
entry:
br label %bb6
bb1: ; preds = %bb6
%tmp = icmp eq i32 %c_addr.1, 20 ; <i1> [#uses=1]
br i1 %tmp, label %bb2, label %bb3
bb2: ; preds = %bb1
%tmp1 = tail call i32 @f20(i32 %c_addr.1) nounwind ; <i32> [#uses=1]
br label %bb7
bb3: ; preds = %bb1
%tmp2 = icmp slt i32 %c_addr.1, 10 ; <i1> [#uses=1]
%tmp3 = add nsw i32 %c_addr.1, 1 ; <i32> [#uses=1]
%tmp4 = add i32 %c_addr.1, -1 ; <i32> [#uses=1]
%c_addr.1.be = select i1 %tmp2, i32 %tmp3, i32 %tmp4 ; <i32> [#uses=1]
%indvar.next = add i32 %indvar, 1 ; <i32> [#uses=1]
; CHECK: sub i32 %lsr.iv, 1
br label %bb6
bb6: ; preds = %bb3, %entry
%indvar = phi i32 [ %indvar.next, %bb3 ], [ 0, %entry ] ; <i32> [#uses=2]
%c_addr.1 = phi i32 [ %c_addr.1.be, %bb3 ], [ %c, %entry ] ; <i32> [#uses=7]
%tmp5 = icmp eq i32 %indvar, 9999 ; <i1> [#uses=1]
; CHECK: icmp eq i32 %lsr.iv, 0
%tmp6 = icmp eq i32 %c_addr.1, 100 ; <i1> [#uses=1]
%or.cond = or i1 %tmp5, %tmp6 ; <i1> [#uses=1]
br i1 %or.cond, label %bb7, label %bb1
bb7: ; preds = %bb6, %bb2
%c_addr.0 = phi i32 [ %tmp1, %bb2 ], [ %c_addr.1, %bb6 ] ; <i32> [#uses=1]
tail call void @bar(i32 %c_addr.0) nounwind
ret void
}
declare i32 @f20(i32)
declare void @bar(i32)