From 48dd644109d97a76288f0b5045f6aa6a3c075732 Mon Sep 17 00:00:00 2001 From: Nick Lewycky Date: Tue, 2 Dec 2008 08:05:48 +0000 Subject: [PATCH] Add a new SCEV representing signed division. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@60407 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/Analysis/ScalarEvolution.h | 1 + .../llvm/Analysis/ScalarEvolutionExpander.h | 2 + .../Analysis/ScalarEvolutionExpressions.h | 53 ++++++++++++- lib/Analysis/ScalarEvolution.cpp | 77 ++++++++++++++++--- lib/Analysis/ScalarEvolutionExpander.cpp | 9 +++ 5 files changed, 131 insertions(+), 11 deletions(-) diff --git a/include/llvm/Analysis/ScalarEvolution.h b/include/llvm/Analysis/ScalarEvolution.h index e16e990bad2..f524daabf25 100644 --- a/include/llvm/Analysis/ScalarEvolution.h +++ b/include/llvm/Analysis/ScalarEvolution.h @@ -225,6 +225,7 @@ namespace llvm { return getMulExpr(Ops); } SCEVHandle getUDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS); + SCEVHandle getSDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS); SCEVHandle getAddRecExpr(const SCEVHandle &Start, const SCEVHandle &Step, const Loop *L); SCEVHandle getAddRecExpr(std::vector &Operands, diff --git a/include/llvm/Analysis/ScalarEvolutionExpander.h b/include/llvm/Analysis/ScalarEvolutionExpander.h index cd075ef643a..7ecf5332edd 100644 --- a/include/llvm/Analysis/ScalarEvolutionExpander.h +++ b/include/llvm/Analysis/ScalarEvolutionExpander.h @@ -104,6 +104,8 @@ namespace llvm { Value *visitUDivExpr(SCEVUDivExpr *S); + Value *visitSDivExpr(SCEVSDivExpr *S); + Value *visitAddRecExpr(SCEVAddRecExpr *S); Value *visitSMaxExpr(SCEVSMaxExpr *S); diff --git a/include/llvm/Analysis/ScalarEvolutionExpressions.h b/include/llvm/Analysis/ScalarEvolutionExpressions.h index 652a99d0fca..bedd075a107 100644 --- a/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -25,7 +25,7 @@ namespace llvm { // These should be ordered in terms of increasing complexity to make the // folders simpler. scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr, - scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, scUnknown, + scUDivExpr, scSDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, scUnknown, scCouldNotCompute }; @@ -357,6 +357,55 @@ namespace llvm { }; + //===--------------------------------------------------------------------===// + /// SCEVSDivExpr - This class represents a binary signed division operation. + /// + class SCEVSDivExpr : public SCEV { + friend class ScalarEvolution; + + SCEVHandle LHS, RHS; + SCEVSDivExpr(const SCEVHandle &lhs, const SCEVHandle &rhs) + : SCEV(scSDivExpr), LHS(lhs), RHS(rhs) {} + + virtual ~SCEVSDivExpr(); + public: + const SCEVHandle &getLHS() const { return LHS; } + const SCEVHandle &getRHS() const { return RHS; } + + virtual bool isLoopInvariant(const Loop *L) const { + return LHS->isLoopInvariant(L) && RHS->isLoopInvariant(L); + } + + virtual bool hasComputableLoopEvolution(const Loop *L) const { + return LHS->hasComputableLoopEvolution(L) && + RHS->hasComputableLoopEvolution(L); + } + + SCEVHandle replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, + const SCEVHandle &Conc, + ScalarEvolution &SE) const { + SCEVHandle L = LHS->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); + SCEVHandle R = RHS->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); + if (L == LHS && R == RHS) + return this; + else + return SE.getSDivExpr(L, R); + } + + + virtual const Type *getType() const; + + void print(std::ostream &OS) const; + void print(std::ostream *OS) const { if (OS) print(*OS); } + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVSDivExpr *S) { return true; } + static inline bool classof(const SCEV *S) { + return S->getSCEVType() == scSDivExpr; + } + }; + + //===--------------------------------------------------------------------===// /// SCEVAddRecExpr - This node represents a polynomial recurrence on the trip /// count of the specified loop. @@ -550,6 +599,8 @@ namespace llvm { return ((SC*)this)->visitMulExpr((SCEVMulExpr*)S); case scUDivExpr: return ((SC*)this)->visitUDivExpr((SCEVUDivExpr*)S); + case scSDivExpr: + return ((SC*)this)->visitSDivExpr((SCEVSDivExpr*)S); case scAddRecExpr: return ((SC*)this)->visitAddRecExpr((SCEVAddRecExpr*)S); case scSMaxExpr: diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 8fb46dd883b..e82f5a4c503 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -324,6 +324,26 @@ const Type *SCEVUDivExpr::getType() const { return LHS->getType(); } + +// SCEVSDivs - Only allow the creation of one SCEVSDivExpr for any particular +// input. Don't use a SCEVHandle here, or else the object will never be +// deleted! +static ManagedStatic, + SCEVSDivExpr*> > SCEVSDivs; + +SCEVSDivExpr::~SCEVSDivExpr() { + SCEVSDivs->erase(std::make_pair(LHS, RHS)); +} + +void SCEVSDivExpr::print(std::ostream &OS) const { + OS << "(" << *LHS << " /s " << *RHS << ")"; +} + +const Type *SCEVSDivExpr::getType() const { + return LHS->getType(); +} + + // SCEVAddRecExprs - Only allow the creation of one SCEVAddRecExpr for any // particular input. Don't use a SCEVHandle here, or else the object will never // be deleted! @@ -1109,9 +1129,12 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector &Ops) { } SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) { + if (LHS == RHS) + return getIntegerSCEV(1, LHS->getType()); // X udiv X --> 1 + if (SCEVConstant *RHSC = dyn_cast(RHS)) { if (RHSC->getValue()->equalsInt(1)) - return LHS; // X udiv 1 --> x + return LHS; // X udiv 1 --> X if (SCEVConstant *LHSC = dyn_cast(LHS)) { Constant *LHSCV = LHSC->getValue(); @@ -1120,13 +1143,34 @@ SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, const SCEVHandle } } - // FIXME: implement folding of (X*4)/4 when we know X*4 doesn't overflow. - SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)]; if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS); return Result; } +SCEVHandle ScalarEvolution::getSDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) { + if (LHS == RHS) + return getIntegerSCEV(1, LHS->getType()); // X sdiv X --> 1 + + if (SCEVConstant *RHSC = dyn_cast(RHS)) { + if (RHSC->getValue()->equalsInt(1)) + return LHS; // X sdiv 1 --> X + + if (RHSC->getValue()->isAllOnesValue()) + return getNegativeSCEV(LHS); // X sdiv -1 --> -X + + if (SCEVConstant *LHSC = dyn_cast(LHS)) { + Constant *LHSCV = LHSC->getValue(); + Constant *RHSCV = RHSC->getValue(); + return getUnknown(ConstantExpr::getSDiv(LHSCV, RHSCV)); + } + } + + SCEVSDivExpr *&Result = (*SCEVSDivs)[std::make_pair(LHS, RHS)]; + if (Result == 0) Result = new SCEVSDivExpr(LHS, RHS); + return Result; +} + /// SCEVAddRecExpr::get - Get a add recurrence expression for the /// specified loop. Simplify the expression as much as possible. @@ -1732,7 +1776,7 @@ static uint32_t GetMinTrailingZeros(SCEVHandle S) { return MinOpRes; } - // SCEVUDivExpr, SCEVUnknown + // SCEVUDivExpr, SCEVSDivExpr, SCEVUnknown return 0; } @@ -1762,6 +1806,9 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { case Instruction::UDiv: return SE.getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); + case Instruction::SDiv: + return SE.getSDivExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); case Instruction::Sub: return SE.getMinusSCEV(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); @@ -1805,7 +1852,7 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { break; case Instruction::LShr: - // Turn logical shift right of a constant into a unsigned divide. + // Turn logical shift right of a constant into an unsigned divide. if (ConstantInt *SA = dyn_cast(U->getOperand(1))) { uint32_t BitWidth = cast(V->getType())->getBitWidth(); Constant *X = ConstantInt::get( @@ -2505,16 +2552,26 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { return Comm; } - if (SCEVUDivExpr *Div = dyn_cast(V)) { - SCEVHandle LHS = getSCEVAtScope(Div->getLHS(), L); + if (SCEVUDivExpr *UDiv = dyn_cast(V)) { + SCEVHandle LHS = getSCEVAtScope(UDiv->getLHS(), L); if (LHS == UnknownValue) return LHS; - SCEVHandle RHS = getSCEVAtScope(Div->getRHS(), L); + SCEVHandle RHS = getSCEVAtScope(UDiv->getRHS(), L); if (RHS == UnknownValue) return RHS; - if (LHS == Div->getLHS() && RHS == Div->getRHS()) - return Div; // must be loop invariant + if (LHS == UDiv->getLHS() && RHS == UDiv->getRHS()) + return UDiv; // must be loop invariant return SE.getUDivExpr(LHS, RHS); } + if (SCEVSDivExpr *SDiv = dyn_cast(V)) { + SCEVHandle LHS = getSCEVAtScope(SDiv->getLHS(), L); + if (LHS == UnknownValue) return LHS; + SCEVHandle RHS = getSCEVAtScope(SDiv->getRHS(), L); + if (RHS == UnknownValue) return RHS; + if (LHS == SDiv->getLHS() && RHS == SDiv->getRHS()) + return SDiv; // must be loop invariant + return SE.getSDivExpr(LHS, RHS); + } + // If this is a loop recurrence for a loop that does not contain L, then we // are dealing with the final value computed by the loop. if (SCEVAddRecExpr *AddRec = dyn_cast(V)) { diff --git a/lib/Analysis/ScalarEvolutionExpander.cpp b/lib/Analysis/ScalarEvolutionExpander.cpp index 30df087cef3..211f013c25c 100644 --- a/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/lib/Analysis/ScalarEvolutionExpander.cpp @@ -143,6 +143,15 @@ Value *SCEVExpander::visitUDivExpr(SCEVUDivExpr *S) { return InsertBinop(Instruction::UDiv, LHS, RHS, InsertPt); } +Value *SCEVExpander::visitSDivExpr(SCEVSDivExpr *S) { + // Do not fold sdiv into ashr, unless you know that LHS is positive. On + // negative values, it rounds the wrong way. + + Value *LHS = expand(S->getLHS()); + Value *RHS = expand(S->getRHS()); + return InsertBinop(Instruction::SDiv, LHS, RHS, InsertPt); +} + Value *SCEVExpander::visitAddRecExpr(SCEVAddRecExpr *S) { const Type *Ty = S->getType(); const Loop *L = S->getLoop();