diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 2b5ca520c21..1a5ef7ef958 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -432,9 +432,84 @@ namespace { /// SCEVComplexityCompare - Return true if the complexity of the LHS is less /// than the complexity of the RHS. This comparator is used to canonicalize /// expressions. - struct VISIBILITY_HIDDEN SCEVComplexityCompare { + class VISIBILITY_HIDDEN SCEVComplexityCompare { + LoopInfo *LI; + public: + explicit SCEVComplexityCompare(LoopInfo *li) : LI(li) {} + bool operator()(const SCEV *LHS, const SCEV *RHS) const { - return LHS->getSCEVType() < RHS->getSCEVType(); + // Primarily, sort the SCEVs by their getSCEVType(). + if (LHS->getSCEVType() != RHS->getSCEVType()) + return LHS->getSCEVType() < RHS->getSCEVType(); + + // Aside from the getSCEVType() ordering, the particular ordering + // isn't very important except that it's beneficial to be consistent, + // so that (a + b) and (b + a) don't end up as different expressions. + + // Sort SCEVUnknown values with some loose heuristics. TODO: This is + // not as complete as it could be. + if (const SCEVUnknown *LU = dyn_cast(LHS)) { + const SCEVUnknown *RU = cast(RHS); + + // Compare getValueID values. + if (LU->getValue()->getValueID() != RU->getValue()->getValueID()) + return LU->getValue()->getValueID() < RU->getValue()->getValueID(); + + // Sort arguments by their position. + if (const Argument *LA = dyn_cast(LU->getValue())) { + const Argument *RA = cast(RU->getValue()); + return LA->getArgNo() < RA->getArgNo(); + } + + // For instructions, compare their loop depth, and their opcode. + // This is pretty loose. + if (Instruction *LV = dyn_cast(LU->getValue())) { + Instruction *RV = cast(RU->getValue()); + + // Compare loop depths. + if (LI->getLoopDepth(LV->getParent()) != + LI->getLoopDepth(RV->getParent())) + return LI->getLoopDepth(LV->getParent()) < + LI->getLoopDepth(RV->getParent()); + + // Compare opcodes. + if (LV->getOpcode() != RV->getOpcode()) + return LV->getOpcode() < RV->getOpcode(); + + // Compare the number of operands. + if (LV->getNumOperands() != RV->getNumOperands()) + return LV->getNumOperands() < RV->getNumOperands(); + } + + return false; + } + + // Constant sorting doesn't matter since they'll be folded. + if (isa(LHS)) + return false; + + // Lexicographically compare n-ary expressions. + if (const SCEVNAryExpr *LC = dyn_cast(LHS)) { + const SCEVNAryExpr *RC = cast(RHS); + for (unsigned i = 0, e = LC->getNumOperands(); i != e; ++i) { + if (i >= RC->getNumOperands()) + return false; + if (operator()(LC->getOperand(i), RC->getOperand(i))) + return true; + if (operator()(RC->getOperand(i), LC->getOperand(i))) + return false; + } + return LC->getNumOperands() < RC->getNumOperands(); + } + + // Compare cast expressions by operand. + if (const SCEVCastExpr *LC = dyn_cast(LHS)) { + const SCEVCastExpr *RC = cast(RHS); + return operator()(LC->getOperand(), RC->getOperand()); + } + + assert(0 && "Unknown SCEV kind!"); + return false; } }; } @@ -449,18 +524,19 @@ namespace { /// this to depend on where the addresses of various SCEV objects happened to /// land in memory. /// -static void GroupByComplexity(std::vector &Ops) { +static void GroupByComplexity(std::vector &Ops, + LoopInfo *LI) { if (Ops.size() < 2) return; // Noop if (Ops.size() == 2) { // This is the common case, which also happens to be trivially simple. // Special case it. - if (SCEVComplexityCompare()(Ops[1], Ops[0])) + if (SCEVComplexityCompare(LI)(Ops[1], Ops[0])) std::swap(Ops[0], Ops[1]); return; } // Do the rough sort by complexity. - std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare()); + std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI)); // Now that we are sorted by complexity, group elements of the same // complexity. Note that this is, at worst, N^2, but the vector is likely to @@ -833,7 +909,7 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector &Ops) { if (Ops.size() == 1) return Ops[0]; // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops); + GroupByComplexity(Ops, LI); // If there are any constants, fold them together. unsigned Idx = 0; @@ -1058,7 +1134,7 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector &Ops) { assert(!Ops.empty() && "Cannot get empty mul!"); // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops); + GroupByComplexity(Ops, LI); // If there are any constants, fold them together. unsigned Idx = 0; @@ -1292,7 +1368,7 @@ SCEVHandle ScalarEvolution::getSMaxExpr(std::vector Ops) { if (Ops.size() == 1) return Ops[0]; // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops); + GroupByComplexity(Ops, LI); // If there are any constants, fold them together. unsigned Idx = 0; @@ -1372,7 +1448,7 @@ SCEVHandle ScalarEvolution::getUMaxExpr(std::vector Ops) { if (Ops.size() == 1) return Ops[0]; // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops); + GroupByComplexity(Ops, LI); // If there are any constants, fold them together. unsigned Idx = 0;