diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 90c4f5758fe..3ad542cd97c 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -748,13 +748,8 @@ SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op, if (const SCEVAddRecExpr *AddRec = dyn_cast(Op)) { std::vector Operands; for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) - // FIXME: This should allow truncation of other expression types! - if (isa(AddRec->getOperand(i))) - Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty)); - else - break; - if (Operands.size() == AddRec->getNumOperands()) - return getAddRecExpr(Operands, AddRec->getLoop()); + Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty)); + return getAddRecExpr(Operands, AddRec->getLoop()); } SCEVTruncateExpr *&Result = (*SCEVTruncates)[std::make_pair(Op, Ty)]; @@ -966,7 +961,66 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector &Ops) { return getAddExpr(Ops); } - // Now we know the first non-constant operand. Skip past any cast SCEVs. + // Check for truncates. If all the operands are truncated from the same + // type, see if factoring out the truncate would permit the result to be + // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n) + // if the contents of the resulting outer trunc fold to something simple. + for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { + const SCEVTruncateExpr *Trunc = cast(Ops[Idx]); + const Type *DstType = Trunc->getType(); + const Type *SrcType = Trunc->getOperand()->getType(); + std::vector LargeOps; + bool Ok = true; + // Check all the operands to see if they can be represented in the + // source type of the truncate. + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + if (const SCEVTruncateExpr *T = dyn_cast(Ops[i])) { + if (T->getOperand()->getType() != SrcType) { + Ok = false; + break; + } + LargeOps.push_back(T->getOperand()); + } else if (const SCEVConstant *C = dyn_cast(Ops[i])) { + // This could be either sign or zero extension, but sign extension + // is much more likely to be foldable here. + LargeOps.push_back(getSignExtendExpr(C, SrcType)); + } else if (const SCEVMulExpr *M = dyn_cast(Ops[i])) { + std::vector LargeMulOps; + for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { + if (const SCEVTruncateExpr *T = + dyn_cast(M->getOperand(j))) { + if (T->getOperand()->getType() != SrcType) { + Ok = false; + break; + } + LargeMulOps.push_back(T->getOperand()); + } else if (const SCEVConstant *C = + dyn_cast(M->getOperand(j))) { + // This could be either sign or zero extension, but sign extension + // is much more likely to be foldable here. + LargeMulOps.push_back(getSignExtendExpr(C, SrcType)); + } else { + Ok = false; + break; + } + } + if (Ok) + LargeOps.push_back(getMulExpr(LargeMulOps)); + } else { + Ok = false; + break; + } + } + if (Ok) { + // Evaluate the expression in the larger type. + SCEVHandle Fold = getAddExpr(LargeOps); + // If it folds to something simple, use it. Otherwise, don't. + if (isa(Fold) || isa(Fold)) + return getTruncateExpr(Fold, DstType); + } + } + + // Skip past any other cast SCEVs. while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr) ++Idx; diff --git a/test/Analysis/ScalarEvolution/trip-count4.ll b/test/Analysis/ScalarEvolution/trip-count4.ll new file mode 100644 index 00000000000..a61d5da57ea --- /dev/null +++ b/test/Analysis/ScalarEvolution/trip-count4.ll @@ -0,0 +1,24 @@ +; RUN: llvm-as < %s | opt -analyze -scalar-evolution -disable-output \ +; RUN: | grep {sext.*trunc.*Exits: 11} + +; ScalarEvolution should be able to compute a loop exit value for %indvar.i8. + +define void @another_count_down_signed(double* %d, i64 %n) nounwind { +entry: + br label %loop + +loop: ; preds = %loop, %entry + %indvar = phi i64 [ %n, %entry ], [ %indvar.next, %loop ] ; [#uses=4] + %s0 = shl i64 %indvar, 8 ; [#uses=1] + %indvar.i8 = ashr i64 %s0, 8 ; [#uses=1] + %t0 = getelementptr double* %d, i64 %indvar.i8 ; [#uses=2] + %t1 = load double* %t0 ; [#uses=1] + %t2 = mul double %t1, 1.000000e-01 ; [#uses=1] + store double %t2, double* %t0 + %indvar.next = sub i64 %indvar, 1 ; [#uses=2] + %exitcond = icmp eq i64 %indvar.next, 10 ; [#uses=1] + br i1 %exitcond, label %return, label %loop + +return: ; preds = %loop + ret void +}