diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 6eb07938894..21f36b25673 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -523,24 +523,34 @@ namespace { public: explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {} + // Return true or false if LHS is less than, or at least RHS, respectively. bool operator()(const SCEV *LHS, const SCEV *RHS) const { + return compare(LHS, RHS) < 0; + } + + // Return negative, zero, or positive, if LHS is less than, equal to, or + // greater than RHS, respectively. A three-way result allows recursive + // comparisons to be more efficient. + int compare(const SCEV *LHS, const SCEV *RHS) const { // Fast-path: SCEVs are uniqued so we can do a quick equality check. if (LHS == RHS) - return false; + return 0; // Primarily, sort the SCEVs by their getSCEVType(). unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); if (LType != RType) - return LType < RType; + return (int)LType - (int)RType; // Aside from the getSCEVType() ordering, the particular ordering // isn't very important except that it's beneficial to be consistent, // so that (a + b) and (b + a) don't end up as different expressions. - - // Sort SCEVUnknown values with some loose heuristics. TODO: This is - // not as complete as it could be. - if (const SCEVUnknown *LU = dyn_cast(LHS)) { + switch (LType) { + case scUnknown: { + const SCEVUnknown *LU = cast(LHS); const SCEVUnknown *RU = cast(RHS); + + // Sort SCEVUnknown values with some loose heuristics. TODO: This is + // not as complete as it could be. const Value *LV = LU->getValue(), *RV = RU->getValue(); // Order pointer values after integer values. This helps SCEVExpander @@ -548,22 +558,23 @@ namespace { bool LIsPointer = LV->getType()->isPointerTy(), RIsPointer = RV->getType()->isPointerTy(); if (LIsPointer != RIsPointer) - return RIsPointer; + return (int)LIsPointer - (int)RIsPointer; // Compare getValueID values. unsigned LID = LV->getValueID(), RID = RV->getValueID(); if (LID != RID) - return LID < RID; + return (int)LID - (int)RID; // Sort arguments by their position. if (const Argument *LA = dyn_cast(LV)) { const Argument *RA = cast(RV); - return LA->getArgNo() < RA->getArgNo(); + unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); + return (int)LArgNo - (int)RArgNo; } - // For instructions, compare their loop depth, and their opcode. - // This is pretty loose. + // For instructions, compare their loop depth, and their operand + // count. This is pretty loose. if (const Instruction *LInst = dyn_cast(LV)) { const Instruction *RInst = cast(RV); @@ -574,82 +585,105 @@ namespace { unsigned LDepth = LI->getLoopDepth(LParent), RDepth = LI->getLoopDepth(RParent); if (LDepth != RDepth) - return LDepth < RDepth; + return (int)LDepth - (int)RDepth; } // Compare the number of operands. unsigned LNumOps = LInst->getNumOperands(), RNumOps = RInst->getNumOperands(); - if (LNumOps != RNumOps) - return LNumOps < RNumOps; + return (int)LNumOps - (int)RNumOps; } - return false; + return 0; } - // Compare constant values. - if (const SCEVConstant *LC = dyn_cast(LHS)) { + case scConstant: { + const SCEVConstant *LC = cast(LHS); const SCEVConstant *RC = cast(RHS); + + // Compare constant values. const APInt &LA = LC->getValue()->getValue(); const APInt &RA = RC->getValue()->getValue(); unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); if (LBitWidth != RBitWidth) - return LBitWidth < RBitWidth; - return LA.ult(RA); + return (int)LBitWidth - (int)RBitWidth; + return LA.ult(RA) ? -1 : 1; } - // Compare addrec loop depths. - if (const SCEVAddRecExpr *LA = dyn_cast(LHS)) { + case scAddRecExpr: { + const SCEVAddRecExpr *LA = cast(LHS); const SCEVAddRecExpr *RA = cast(RHS); + + // Compare addrec loop depths. const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); if (LLoop != RLoop) { unsigned LDepth = LLoop->getLoopDepth(), RDepth = RLoop->getLoopDepth(); if (LDepth != RDepth) - return LDepth < RDepth; + return (int)LDepth - (int)RDepth; } + + // Addrec complexity grows with operand count. + unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); + if (LNumOps != RNumOps) + return (int)LNumOps - (int)RNumOps; + + // Lexicographically compare. + for (unsigned i = 0; i != LNumOps; ++i) { + long X = compare(LA->getOperand(i), RA->getOperand(i)); + if (X != 0) + return X; + } + + return 0; } - // Lexicographically compare n-ary expressions. - if (const SCEVNAryExpr *LC = dyn_cast(LHS)) { + case scAddExpr: + case scMulExpr: + case scSMaxExpr: + case scUMaxExpr: { + const SCEVNAryExpr *LC = cast(LHS); const SCEVNAryExpr *RC = cast(RHS); + + // Lexicographically compare n-ary expressions. unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); for (unsigned i = 0; i != LNumOps; ++i) { if (i >= RNumOps) - return false; - const SCEV *LOp = LC->getOperand(i), *ROp = RC->getOperand(i); - if (operator()(LOp, ROp)) - return true; - if (operator()(ROp, LOp)) - return false; + return 1; + long X = compare(LC->getOperand(i), RC->getOperand(i)); + if (X != 0) + return X; } - return LNumOps < RNumOps; + return (int)LNumOps - (int)RNumOps; } - // Lexicographically compare udiv expressions. - if (const SCEVUDivExpr *LC = dyn_cast(LHS)) { + case scUDivExpr: { + const SCEVUDivExpr *LC = cast(LHS); const SCEVUDivExpr *RC = cast(RHS); - const SCEV *LL = LC->getLHS(), *LR = LC->getRHS(), - *RL = RC->getLHS(), *RR = RC->getRHS(); - if (operator()(LL, RL)) - return true; - if (operator()(RL, LL)) - return false; - if (operator()(LR, RR)) - return true; - if (operator()(RR, LR)) - return false; - return false; + + // Lexicographically compare udiv expressions. + long X = compare(LC->getLHS(), RC->getLHS()); + if (X != 0) + return X; + return compare(LC->getRHS(), RC->getRHS()); } - // Compare cast expressions by operand. - if (const SCEVCastExpr *LC = dyn_cast(LHS)) { + case scTruncate: + case scZeroExtend: + case scSignExtend: { + const SCEVCastExpr *LC = cast(LHS); const SCEVCastExpr *RC = cast(RHS); - return operator()(LC->getOperand(), RC->getOperand()); + + // Compare cast expressions by operand. + return compare(LC->getOperand(), RC->getOperand()); + } + + default: + break; } llvm_unreachable("Unknown SCEV kind!"); - return false; + return 0; } }; }