diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index a06cfc101b3..0a0238ab800 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -675,62 +675,6 @@ static void GroupByComplexity(SmallVectorImpl &Ops, } } -static const APInt srem(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.sext(ABW); - else if (ABW < BBW) - A = A.sext(BBW); - - return APIntOps::srem(A, B); -} - -static const APInt sdiv(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.sext(ABW); - else if (ABW < BBW) - A = A.sext(BBW); - - return APIntOps::sdiv(A, B); -} - -static const APInt urem(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.zext(ABW); - else if (ABW < BBW) - A = A.zext(BBW); - - return APIntOps::urem(A, B); -} - -static const APInt udiv(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.zext(ABW); - else if (ABW < BBW) - A = A.zext(BBW); - - return APIntOps::udiv(A, B); -} - namespace { struct FindSCEVSize { int Size; @@ -757,8 +701,7 @@ static inline int sizeOfSCEV(const SCEV *S) { namespace { -template -struct SCEVDivision : public SCEVVisitor { +struct SCEVDivision : public SCEVVisitor { public: // Computes the Quotient and Remainder of the division of Numerator by // Denominator. @@ -767,7 +710,7 @@ public: const SCEV **Remainder) { assert(Numerator && Denominator && "Uninitialized SCEV"); - Derived D(SE, Numerator, Denominator); + SCEVDivision D(SE, Numerator, Denominator); // Check for the trivial case here to avoid having to check for it in the // rest of the code. @@ -819,6 +762,27 @@ public: void visitUnknown(const SCEVUnknown *Numerator) {} void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} + void visitConstant(const SCEVConstant *Numerator) { + if (const SCEVConstant *D = dyn_cast(Denominator)) { + APInt NumeratorVal = Numerator->getValue()->getValue(); + APInt DenominatorVal = D->getValue()->getValue(); + uint32_t NumeratorBW = NumeratorVal.getBitWidth(); + uint32_t DenominatorBW = DenominatorVal.getBitWidth(); + + if (NumeratorBW > DenominatorBW) + DenominatorVal = DenominatorVal.sext(NumeratorBW); + else if (NumeratorBW < DenominatorBW) + NumeratorVal = NumeratorVal.sext(DenominatorBW); + + APInt QuotientVal(NumeratorVal.getBitWidth(), 0); + APInt RemainderVal(NumeratorVal.getBitWidth(), 0); + APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); + Quotient = SE.getConstant(QuotientVal); + Remainder = SE.getConstant(RemainderVal); + return; + } + } + void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { const SCEV *StartQ, *StartR, *StepQ, *StepR; assert(Numerator->isAffine() && "Numerator should be affine"); @@ -956,37 +920,6 @@ private: ScalarEvolution &SE; const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; - - friend struct SCEVSDivision; - friend struct SCEVUDivision; -}; - -struct SCEVSDivision : public SCEVDivision { - SCEVSDivision(ScalarEvolution &S, const SCEV *Numerator, - const SCEV *Denominator) - : SCEVDivision(S, Numerator, Denominator) {} - - void visitConstant(const SCEVConstant *Numerator) { - if (const SCEVConstant *D = dyn_cast(Denominator)) { - Quotient = SE.getConstant(sdiv(Numerator, D)); - Remainder = SE.getConstant(srem(Numerator, D)); - return; - } - } -}; - -struct SCEVUDivision : public SCEVDivision { - SCEVUDivision(ScalarEvolution &S, const SCEV *Numerator, - const SCEV *Denominator) - : SCEVDivision(S, Numerator, Denominator) {} - - void visitConstant(const SCEVConstant *Numerator) { - if (const SCEVConstant *D = dyn_cast(Denominator)) { - Quotient = SE.getConstant(udiv(Numerator, D)); - Remainder = SE.getConstant(urem(Numerator, D)); - return; - } - } }; } @@ -7478,7 +7411,7 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE, for (const SCEV *&Term : Terms) { // Normalize the terms before the next call to findArrayDimensionsRec. const SCEV *Q, *R; - SCEVSDivision::divide(SE, Term, Step, &Q, &R); + SCEVDivision::divide(SE, Term, Step, &Q, &R); // Bail out when GCD does not evenly divide one of the terms. if (!R->isZero()) @@ -7615,7 +7548,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl &Terms, // Divide all terms by the element size. for (const SCEV *&Term : Terms) { const SCEV *Q, *R; - SCEVSDivision::divide(SE, Term, ElementSize, &Q, &R); + SCEVDivision::divide(SE, Term, ElementSize, &Q, &R); Term = Q; } @@ -7662,7 +7595,7 @@ void SCEVAddRecExpr::computeAccessFunctions( int Last = Sizes.size() - 1; for (int i = Last; i >= 0; i--) { const SCEV *Q, *R; - SCEVSDivision::divide(SE, Res, Sizes[i], &Q, &R); + SCEVDivision::divide(SE, Res, Sizes[i], &Q, &R); DEBUG({ dbgs() << "Res: " << *Res << "\n";