diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 349979843a5..2601277cd8c 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -1168,14 +1168,19 @@ namespace { /// HowFarToZero - Return the number of times a backedge comparing the /// specified value to zero will execute. If not computable, return - /// UnknownValue + /// UnknownValue. SCEVHandle HowFarToZero(SCEV *V, const Loop *L); /// HowFarToNonZero - Return the number of times a backedge checking the /// specified value for nonzero will execute. If not computable, return - /// UnknownValue + /// UnknownValue. SCEVHandle HowFarToNonZero(SCEV *V, const Loop *L); + /// HowManyLessThans - Return the number of times a backedge containing the + /// specified less-than comparison will execute. If not computable, return + /// UnknownValue. + SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L); + /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence @@ -1530,6 +1535,20 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { if (!isa(TC)) return TC; } break; + case Instruction::SetLT: + if (LHS->getType()->isInteger() && + ExitCond->getOperand(0)->getType()->isSigned()) { + SCEVHandle TC = HowManyLessThans(LHS, RHS, L); + if (!isa(TC)) return TC; + } + break; + case Instruction::SetGT: + if (LHS->getType()->isInteger() && + ExitCond->getOperand(0)->getType()->isSigned()) { + SCEVHandle TC = HowManyLessThans(RHS, LHS, L); + if (!isa(TC)) return TC; + } + break; default: #if 0 std::cerr << "ComputeIterationCount "; @@ -2169,6 +2188,95 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) { return UnknownValue; } +/// HowManyLessThans - Return the number of times a backedge containing the +/// specified less-than comparison will execute. If not computable, return +/// UnknownValue. +SCEVHandle ScalarEvolutionsImpl:: +HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) { + // Only handle: "ADDREC < LoopInvariant". + if (!RHS->isLoopInvariant(L)) return UnknownValue; + + SCEVAddRecExpr *AddRec = dyn_cast(LHS); + if (!AddRec || AddRec->getLoop() != L) + return UnknownValue; + + if (AddRec->isAffine()) { + // FORNOW: We only support unit strides. + SCEVHandle One = SCEVUnknown::getIntegerSCEV(1, RHS->getType()); + if (AddRec->getOperand(1) != One) + return UnknownValue; + + // The number of iterations for "[n,+,1] < m", is m-n. However, we don't + // know that m is >= n on input to the loop. If it is, the condition return + // true zero times. What we really should return, for full generality, is + // SMAX(0, m-n). Since we cannot check this, we will instead check for a + // canonical loop form: most do-loops will have a check that dominates the + // loop, that only enters the loop if [n-1]= n. + + // Search for the check. + BasicBlock *Preheader = L->getLoopPreheader(); + BasicBlock *PreheaderDest = L->getHeader(); + if (Preheader == 0) return UnknownValue; + + BranchInst *LoopEntryPredicate = + dyn_cast(Preheader->getTerminator()); + if (!LoopEntryPredicate) return UnknownValue; + + // This might be a critical edge broken out. If the loop preheader ends in + // an unconditional branch to the loop, check to see if the preheader has a + // single predecessor, and if so, look for its terminator. + while (LoopEntryPredicate->isUnconditional()) { + PreheaderDest = Preheader; + Preheader = Preheader->getSinglePredecessor(); + if (!Preheader) return UnknownValue; // Multiple preds. + + LoopEntryPredicate = + dyn_cast(Preheader->getTerminator()); + if (!LoopEntryPredicate) return UnknownValue; + } + + // Now that we found a conditional branch that dominates the loop, check to + // see if it is the comparison we are looking for. + SetCondInst *SCI =dyn_cast(LoopEntryPredicate->getCondition()); + if (!SCI) return UnknownValue; + Value *PreCondLHS = SCI->getOperand(0); + Value *PreCondRHS = SCI->getOperand(1); + Instruction::BinaryOps Cond; + if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest) + Cond = SCI->getOpcode(); + else + Cond = SCI->getInverseCondition(); + + switch (Cond) { + case Instruction::SetGT: + std::swap(PreCondLHS, PreCondRHS); + Cond = Instruction::SetLT; + // Fall Through. + case Instruction::SetLT: + if (PreCondLHS->getType()->isInteger() && + PreCondLHS->getType()->isSigned()) { + if (RHS != getSCEV(PreCondRHS)) + return UnknownValue; // Not a comparison against 'm'. + + if (SCEV::getMinusSCEV(AddRec->getOperand(0), One) + != getSCEV(PreCondLHS)) + return UnknownValue; // Not a comparison against 'n-1'. + break; + } else { + return UnknownValue; + } + default: break; + } + + //std::cerr << "Computed Loop Trip Count as: " << + // *SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n"; + return SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)); + } + + return UnknownValue; +} + /// getNumIterationsInRange - Return the number of iterations of this loop that /// produce values in the specified constant range. Another way of looking at /// this is that it returns the first iteration number where the value is not in