Teach ScalarEvolution to exploit min and max expressions when proving

isKnownPredicate.

The motivation for this change is to optimize away checks in loops
like this:

    limit = min(t, len)
    for (i = 0 to limit)
      if (i >= len || i < 0) throw_array_of_of_bounds();
      a[i] = ...

Differential Revision: http://reviews.llvm.org/D6635



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@224285 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Sanjoy Das
2014-12-15 22:50:15 +00:00
parent 2f7e202f27
commit 574e01c32e
2 changed files with 546 additions and 8 deletions

View File

@ -6886,6 +6886,85 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
getNotSCEV(FoundLHS));
}
/// If Expr computes ~A, return A else return nullptr
static const SCEV *MatchNotExpr(const SCEV *Expr) {
const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
if (!Add || Add->getNumOperands() != 2) return nullptr;
const SCEVConstant *AddLHS = dyn_cast<SCEVConstant>(Add->getOperand(0));
if (!(AddLHS && AddLHS->getValue()->getValue().isAllOnesValue()))
return nullptr;
const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
if (!AddRHS || AddRHS->getNumOperands() != 2) return nullptr;
const SCEVConstant *MulLHS = dyn_cast<SCEVConstant>(AddRHS->getOperand(0));
if (!(MulLHS && MulLHS->getValue()->getValue().isAllOnesValue()))
return nullptr;
return AddRHS->getOperand(1);
}
/// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values?
template<typename MaxExprType>
static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr,
const SCEV *Candidate) {
const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr);
if (!MaxExpr) return false;
auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate);
return It != MaxExpr->op_end();
}
/// Is MaybeMinExpr an SMin or UMin of Candidate and some other values?
template<typename MaxExprType>
static bool IsMinConsistingOf(ScalarEvolution &SE,
const SCEV *MaybeMinExpr,
const SCEV *Candidate) {
const SCEV *MaybeMaxExpr = MatchNotExpr(MaybeMinExpr);
if (!MaybeMaxExpr)
return false;
return IsMaxConsistingOf<MaxExprType>(MaybeMaxExpr, SE.getNotSCEV(Candidate));
}
/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
/// expression?
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
switch (Pred) {
default:
return false;
case ICmpInst::ICMP_SGE:
std::swap(LHS, RHS);
// fall through
case ICmpInst::ICMP_SLE:
return
// min(A, ...) <= A
IsMinConsistingOf<SCEVSMaxExpr>(SE, LHS, RHS) ||
// A <= max(A, ...)
IsMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
case ICmpInst::ICMP_UGE:
std::swap(LHS, RHS);
// fall through
case ICmpInst::ICMP_ULE:
return
// min(A, ...) <= A
IsMinConsistingOf<SCEVUMaxExpr>(SE, LHS, RHS) ||
// A <= max(A, ...)
IsMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
}
llvm_unreachable("covered switch fell through?!");
}
/// isImpliedCondOperandsHelper - Test whether the condition described by
/// Pred, LHS, and RHS is true whenever the condition described by Pred,
/// FoundLHS, and FoundRHS is true.
@ -6894,6 +6973,12 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS,
const SCEV *FoundRHS) {
auto IsKnownPredicateFull =
[this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
return isKnownPredicateWithRanges(Pred, LHS, RHS) ||
IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS);
};
switch (Pred) {
default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
case ICmpInst::ICMP_EQ:
@ -6903,26 +6988,26 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
break;
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
if (isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, RHS, FoundRHS))
if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
if (isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, RHS, FoundRHS))
if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
if (isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, RHS, FoundRHS))
if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
if (isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, RHS, FoundRHS))
if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS))
return true;
break;
}