diff --git a/include/llvm/Analysis/ScalarEvolutionExpander.h b/include/llvm/Analysis/ScalarEvolutionExpander.h index 44e8fb0a9a5..a5cc7138cac 100644 --- a/include/llvm/Analysis/ScalarEvolutionExpander.h +++ b/include/llvm/Analysis/ScalarEvolutionExpander.h @@ -78,13 +78,10 @@ namespace llvm { /// expandCodeFor - Insert code to directly compute the specified SCEV /// expression into the program. The inserted code is inserted into the /// specified block. - /// - /// If a particular value sign is required, a type may be specified for the - /// result. - Value *expandCodeFor(SCEVHandle SH, Instruction *IP, const Type *Ty = 0) { + Value *expandCodeFor(SCEVHandle SH, Instruction *IP) { // Expand the code for this SCEV. this->InsertPt = IP; - return expandInTy(SH, Ty); + return expand(SH); } /// InsertCastOfTo - Insert a cast of V to the specified type, doing what @@ -107,25 +104,6 @@ namespace llvm { return V; } - Value *expandInTy(SCEV *S, const Type *Ty) { - Value *V = expand(S); - if (Ty && V->getType() != Ty) { - if (isa(Ty) && V->getType()->isInteger()) - return InsertCastOfTo(Instruction::IntToPtr, V, Ty); - else if (Ty->isInteger() && isa(V->getType())) - return InsertCastOfTo(Instruction::PtrToInt, V, Ty); - else if (Ty->getPrimitiveSizeInBits() == - V->getType()->getPrimitiveSizeInBits()) - return InsertCastOfTo(Instruction::BitCast, V, Ty); - else if (Ty->getPrimitiveSizeInBits() > - V->getType()->getPrimitiveSizeInBits()) - return InsertCastOfTo(Instruction::ZExt, V, Ty); - else - return InsertCastOfTo(Instruction::Trunc, V, Ty); - } - return V; - } - Value *visitConstant(SCEVConstant *S) { return S->getValue(); } @@ -136,17 +114,21 @@ namespace llvm { } Value *visitZeroExtendExpr(SCEVZeroExtendExpr *S) { - Value *V = expandInTy(S->getOperand(), S->getType()); + Value *V = expand(S->getOperand()); return CastInst::createZExtOrBitCast(V, S->getType(), "tmp.", InsertPt); } + Value *visitSignExtendExpr(SCEVSignExtendExpr *S) { + Value *V = expand(S->getOperand()); + return CastInst::createSExtOrBitCast(V, S->getType(), "tmp.", InsertPt); + } + Value *visitAddExpr(SCEVAddExpr *S) { - const Type *Ty = S->getType(); - Value *V = expandInTy(S->getOperand(S->getNumOperands()-1), Ty); + Value *V = expand(S->getOperand(S->getNumOperands()-1)); // Emit a bunch of add instructions for (int i = S->getNumOperands()-2; i >= 0; --i) - V = InsertBinop(Instruction::Add, V, expandInTy(S->getOperand(i), Ty), + V = InsertBinop(Instruction::Add, V, expand(S->getOperand(i)), InsertPt); return V; } @@ -154,9 +136,8 @@ namespace llvm { Value *visitMulExpr(SCEVMulExpr *S); Value *visitSDivExpr(SCEVSDivExpr *S) { - const Type *Ty = S->getType(); - Value *LHS = expandInTy(S->getLHS(), Ty); - Value *RHS = expandInTy(S->getRHS(), Ty); + Value *LHS = expand(S->getLHS()); + Value *RHS = expand(S->getRHS()); return InsertBinop(Instruction::SDiv, LHS, RHS, InsertPt); } diff --git a/include/llvm/Analysis/ScalarEvolutionExpressions.h b/include/llvm/Analysis/ScalarEvolutionExpressions.h index af795377c2b..dd6871fdd18 100644 --- a/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -24,8 +24,8 @@ namespace llvm { enum SCEVTypes { // These should be ordered in terms of increasing complexity to make the // folders simpler. - scConstant, scTruncate, scZeroExtend, scAddExpr, scMulExpr, scSDivExpr, - scAddRecExpr, scUnknown, scCouldNotCompute + scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr, + scSDivExpr, scAddRecExpr, scUnknown, scCouldNotCompute }; //===--------------------------------------------------------------------===// @@ -166,6 +166,53 @@ namespace llvm { } }; + //===--------------------------------------------------------------------===// + /// SCEVSignExtendExpr - This class represents a sign extension of a small + /// integer value to a larger integer value. + /// + class SCEVSignExtendExpr : public SCEV { + SCEVHandle Op; + const Type *Ty; + SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty); + virtual ~SCEVSignExtendExpr(); + public: + /// get method - This just gets and returns a new SCEVSignExtend object + /// + static SCEVHandle get(const SCEVHandle &Op, const Type *Ty); + + const SCEVHandle &getOperand() const { return Op; } + virtual const Type *getType() const { return Ty; } + + virtual bool isLoopInvariant(const Loop *L) const { + return Op->isLoopInvariant(L); + } + + virtual bool hasComputableLoopEvolution(const Loop *L) const { + return Op->hasComputableLoopEvolution(L); + } + + /// getValueRange - Return the tightest constant bounds that this value is + /// known to have. This method is only valid on integer SCEV objects. + virtual ConstantRange getValueRange() const; + + SCEVHandle replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, + const SCEVHandle &Conc) const { + SCEVHandle H = Op->replaceSymbolicValuesWithConcrete(Sym, Conc); + if (H == Op) + return this; + return get(H, Ty); + } + + virtual void print(std::ostream &OS) const; + void print(std::ostream *OS) const { if (OS) print(*OS); } + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVSignExtendExpr *S) { return true; } + static inline bool classof(const SCEV *S) { + return S->getSCEVType() == scSignExtend; + } + }; + //===--------------------------------------------------------------------===// /// SCEVCommutativeExpr - This node is the base class for n'ary commutative @@ -503,6 +550,8 @@ namespace llvm { return ((SC*)this)->visitTruncateExpr((SCEVTruncateExpr*)S); case scZeroExtend: return ((SC*)this)->visitZeroExtendExpr((SCEVZeroExtendExpr*)S); + case scSignExtend: + return ((SC*)this)->visitSignExtendExpr((SCEVSignExtendExpr*)S); case scAddExpr: return ((SC*)this)->visitAddExpr((SCEVAddExpr*)S); case scMulExpr: diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index bf67fd3fffc..3ae65286fa7 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -245,6 +245,32 @@ void SCEVZeroExtendExpr::print(std::ostream &OS) const { OS << "(zeroextend " << *Op << " to " << *Ty << ")"; } +// SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any +// particular input. Don't use a SCEVHandle here, or else the object will never +// be deleted! +static ManagedStatic, + SCEVSignExtendExpr*> > SCEVSignExtends; + +SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty) + : SCEV(scSignExtend), Op(op), Ty(ty) { + assert(Op->getType()->isInteger() && Ty->isInteger() && + "Cannot sign extend non-integer value!"); + assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits() + && "This is not an extending conversion!"); +} + +SCEVSignExtendExpr::~SCEVSignExtendExpr() { + SCEVSignExtends->erase(std::make_pair(Op, Ty)); +} + +ConstantRange SCEVSignExtendExpr::getValueRange() const { + return getOperand()->getValueRange().signExtend(getBitWidth()); +} + +void SCEVSignExtendExpr::print(std::ostream &OS) const { + OS << "(signextend " << *Op << " to " << *Ty << ")"; +} + // SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any // particular input. Don't use a SCEVHandle here, or else the object will never // be deleted! @@ -588,6 +614,21 @@ SCEVHandle SCEVZeroExtendExpr::get(const SCEVHandle &Op, const Type *Ty) { return Result; } +SCEVHandle SCEVSignExtendExpr::get(const SCEVHandle &Op, const Type *Ty) { + if (SCEVConstant *SC = dyn_cast(Op)) + return SCEVUnknown::get( + ConstantExpr::getSExt(SC->getValue(), Ty)); + + // FIXME: If the input value is a chrec scev, and we can prove that the value + // did not overflow the old, smaller, value, we can sign extend all of the + // operands (often constants). This would allow analysis of something like + // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } + + SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)]; + if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty); + return Result; +} + // get - Get a canonical add expression, or something simpler if possible. SCEVHandle SCEVAddExpr::get(std::vector &Ops) { assert(!Ops.empty() && "Cannot get empty add!"); @@ -1370,6 +1411,9 @@ static APInt GetConstantFactor(SCEVHandle S) { if (SCEVZeroExtendExpr *E = dyn_cast(S)) return GetConstantFactor(E->getOperand()).zext( cast(E->getType())->getBitWidth()); + if (SCEVSignExtendExpr *E = dyn_cast(S)) + return GetConstantFactor(E->getOperand()).sext( + cast(E->getType())->getBitWidth()); if (SCEVAddExpr *A = dyn_cast(S)) { // The result is the min of all operands. @@ -1470,6 +1514,9 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { case Instruction::ZExt: return SCEVZeroExtendExpr::get(getSCEV(I->getOperand(0)), I->getType()); + case Instruction::SExt: + return SCEVSignExtendExpr::get(getSCEV(I->getOperand(0)), I->getType()); + case Instruction::BitCast: // BitCasts are no-op casts so we just eliminate the cast. if (I->getType()->isInteger() && diff --git a/lib/Analysis/ScalarEvolutionExpander.cpp b/lib/Analysis/ScalarEvolutionExpander.cpp index c88c7811954..c8c683cb3f6 100644 --- a/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/lib/Analysis/ScalarEvolutionExpander.cpp @@ -93,18 +93,17 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, Value *LHS, } Value *SCEVExpander::visitMulExpr(SCEVMulExpr *S) { - const Type *Ty = S->getType(); int FirstOp = 0; // Set if we should emit a subtract. if (SCEVConstant *SC = dyn_cast(S->getOperand(0))) if (SC->getValue()->isAllOnesValue()) FirstOp = 1; int i = S->getNumOperands()-2; - Value *V = expandInTy(S->getOperand(i+1), Ty); + Value *V = expand(S->getOperand(i+1)); // Emit a bunch of multiply instructions for (; i >= FirstOp; --i) - V = InsertBinop(Instruction::Mul, V, expandInTy(S->getOperand(i), Ty), + V = InsertBinop(Instruction::Mul, V, expand(S->getOperand(i)), InsertPt); // -1 * ... ---> 0 - ... if (FirstOp == 1) @@ -122,10 +121,10 @@ Value *SCEVExpander::visitAddRecExpr(SCEVAddRecExpr *S) { // {X,+,F} --> X + {0,+,F} if (!isa(S->getStart()) || !cast(S->getStart())->getValue()->isZero()) { - Value *Start = expandInTy(S->getStart(), Ty); + Value *Start = expand(S->getStart()); std::vector NewOps(S->op_begin(), S->op_end()); NewOps[0] = SCEVUnknown::getIntegerSCEV(0, Ty); - Value *Rest = expandInTy(SCEVAddRecExpr::get(NewOps, L), Ty); + Value *Rest = expand(SCEVAddRecExpr::get(NewOps, L)); // FIXME: look for an existing add to use. return InsertBinop(Instruction::Add, Rest, Start, InsertPt); @@ -164,7 +163,7 @@ Value *SCEVExpander::visitAddRecExpr(SCEVAddRecExpr *S) { // If this is a simple linear addrec, emit it now as a special case. if (S->getNumOperands() == 2) { // {0,+,F} --> i*F - Value *F = expandInTy(S->getOperand(1), Ty); + Value *F = expand(S->getOperand(1)); // IF the step is by one, just return the inserted IV. if (ConstantInt *CI = dyn_cast(F)) @@ -201,5 +200,5 @@ Value *SCEVExpander::visitAddRecExpr(SCEVAddRecExpr *S) { SCEVHandle V = S->evaluateAtIteration(IH); //cerr << "Evaluated: " << *this << "\n to: " << *V << "\n"; - return expandInTy(V, Ty); + return expand(V); } diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp index 8042d62581b..5965d1a8855 100644 --- a/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -277,8 +277,7 @@ Instruction *IndVarSimplify::LinearFunctionTestReplace(Loop *L, // Expand the code for the iteration count into the preheader of the loop. BasicBlock *Preheader = L->getLoopPreheader(); - Value *ExitCnt = RW.expandCodeFor(TripCount, Preheader->getTerminator(), - IndVar->getType()); + Value *ExitCnt = RW.expandCodeFor(TripCount, Preheader->getTerminator()); // Insert a new icmp_ne or icmp_eq instruction before the branch. ICmpInst::Predicate Opcode; @@ -383,7 +382,7 @@ void IndVarSimplify::RewriteLoopExitValues(Loop *L) { // just reuse it. Value *&ExitVal = ExitValues[Inst]; if (!ExitVal) - ExitVal = Rewriter.expandCodeFor(ExitValue, InsertPt,Inst->getType()); + ExitVal = Rewriter.expandCodeFor(ExitValue, InsertPt); DOUT << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal << " LoopVal = " << *Inst << "\n"; @@ -519,9 +518,12 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { Changed = true; DOUT << "INDVARS: New CanIV: " << *IndVar; - if (!isa(IterationCount)) + if (!isa(IterationCount)) { + if (IterationCount->getType() != LargestType) + IterationCount = SCEVZeroExtendExpr::get(IterationCount, LargestType); if (Instruction *DI = LinearFunctionTestReplace(L, IterationCount,Rewriter)) DeadInsts.insert(DI); + } // Now that we have a canonical induction variable, we can rewrite any // recurrences in terms of the induction variable. Start with the auxillary @@ -555,8 +557,7 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { std::map InsertedSizes; while (!IndVars.empty()) { PHINode *PN = IndVars.back().first; - Value *NewVal = Rewriter.expandCodeFor(IndVars.back().second, InsertPt, - PN->getType()); + Value *NewVal = Rewriter.expandCodeFor(IndVars.back().second, InsertPt); DOUT << "INDVARS: Rewrote IV '" << *IndVars.back().second << "' " << *PN << " into = " << *NewVal << "\n"; NewVal->takeName(PN); diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index bd7d1d92493..0c4807d31aa 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -555,8 +555,7 @@ Value *BasedUser::InsertCodeForBaseAtPosition(const SCEVHandle &NewBase, // If there is no immediate value, skip the next part. if (SCEVConstant *SC = dyn_cast(Imm)) if (SC->getValue()->isZero()) - return Rewriter.expandCodeFor(NewBase, BaseInsertPt, - OperandValToReplace->getType()); + return Rewriter.expandCodeFor(NewBase, BaseInsertPt); Value *Base = Rewriter.expandCodeFor(NewBase, BaseInsertPt); @@ -567,8 +566,7 @@ Value *BasedUser::InsertCodeForBaseAtPosition(const SCEVHandle &NewBase, // Always emit the immediate (if non-zero) into the same block as the user. SCEVHandle NewValSCEV = SCEVAddExpr::get(SCEVUnknown::get(Base), Imm); - return Rewriter.expandCodeFor(NewValSCEV, IP, - OperandValToReplace->getType()); + return Rewriter.expandCodeFor(NewValSCEV, IP); } @@ -598,6 +596,11 @@ void BasedUser::RewriteInstructionToUseNewBase(const SCEVHandle &NewBase, } } Value *NewVal = InsertCodeForBaseAtPosition(NewBase, Rewriter, InsertPt, L); + // Adjust the type back to match the Inst. + if (isa(OperandValToReplace->getType())) { + NewVal = new IntToPtrInst(NewVal, OperandValToReplace->getType(), "cast", + InsertPt); + } // Replace the use of the operand Value with the new Phi we just created. Inst->replaceUsesOfWith(OperandValToReplace, NewVal); DOUT << " CHANGED: IMM =" << *Imm; @@ -644,6 +647,11 @@ void BasedUser::RewriteInstructionToUseNewBase(const SCEVHandle &NewBase, // Insert the code into the end of the predecessor block. Instruction *InsertPt = PN->getIncomingBlock(i)->getTerminator(); Code = InsertCodeForBaseAtPosition(NewBase, Rewriter, InsertPt, L); + + // Adjust the type back to match the PHI. + if (isa(PN->getType())) { + Code = new IntToPtrInst(Code, PN->getType(), "cast", InsertPt); + } } // Replace the use of the operand Value with the new Phi we just created. @@ -1112,8 +1120,7 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride, // Emit the initial base value into the loop preheader. Value *CommonBaseV - = PreheaderRewriter.expandCodeFor(CommonExprs, PreInsertPt, - ReplacedTy); + = PreheaderRewriter.expandCodeFor(CommonExprs, PreInsertPt); if (RewriteFactor == 0) { // Create a new Phi for this base, and stick it in the loop header. @@ -1131,8 +1138,7 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride, IncAmount = SCEV::getNegativeSCEV(Stride); // Insert the stride into the preheader. - Value *StrideV = PreheaderRewriter.expandCodeFor(IncAmount, PreInsertPt, - ReplacedTy); + Value *StrideV = PreheaderRewriter.expandCodeFor(IncAmount, PreInsertPt); if (!isa(StrideV)) ++NumVariable; // Emit the increment of the base value before the terminator of the loop @@ -1142,8 +1148,7 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride, IncExp = SCEV::getNegativeSCEV(IncExp); IncExp = SCEVAddExpr::get(SCEVUnknown::get(NewPHI), IncExp); - IncV = Rewriter.expandCodeFor(IncExp, LatchBlock->getTerminator(), - ReplacedTy); + IncV = Rewriter.expandCodeFor(IncExp, LatchBlock->getTerminator()); IncV->setName(NewPHI->getName()+".inc"); NewPHI->addIncoming(IncV, LatchBlock); @@ -1199,8 +1204,7 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride, SCEVHandle Base = UsersToProcess.back().Base; // Emit the code for Base into the preheader. - Value *BaseV = PreheaderRewriter.expandCodeFor(Base, PreInsertPt, - ReplacedTy); + Value *BaseV = PreheaderRewriter.expandCodeFor(Base, PreInsertPt); DOUT << " INSERTING code for BASE = " << *Base << ":"; if (BaseV->hasName())