From f8c4fd600578266b679856bb530a3b2774b3a2ce Mon Sep 17 00:00:00 2001 From: Johannes Doerfert Date: Mon, 9 Feb 2015 12:34:23 +0000 Subject: [PATCH] Allow ScalarEvolution to catch more min/max cases For the attached test case different types are used in the ICmpInst and SelectInst that represent the min/max expressions. However, if the ICmpInst type is smaller a comparison with the sign/zero extended operands would have yielded the same result. This situation might arise after the instruction combination pass was applied. Differential Revision: http://reviews.llvm.org/D7338 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@228572 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Analysis/ScalarEvolution.cpp | 48 +++++++++-------- .../Analysis/ScalarEvolution/min-max-exprs.ll | 53 +++++++++++++++++++ 2 files changed, 78 insertions(+), 23 deletions(-) create mode 100644 test/Analysis/ScalarEvolution/min-max-exprs.ll diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 8f2e4f2ba28..649c6e0d575 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -4297,9 +4297,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case ICmpInst::ICMP_SGE: // a >s b ? a+x : b+x -> smax(a, b)+x // a >s b ? b+x : a+x -> smin(a, b)+x - if (LHS->getType() == U->getType()) { - const SCEV *LS = getSCEV(LHS); - const SCEV *RS = getSCEV(RHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType())) { + const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType()); + const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4320,9 +4321,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case ICmpInst::ICMP_UGE: // a >u b ? a+x : b+x -> umax(a, b)+x // a >u b ? b+x : a+x -> umin(a, b)+x - if (LHS->getType() == U->getType()) { - const SCEV *LS = getSCEV(LHS); - const SCEV *RS = getSCEV(RHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType())) { + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); + const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4337,11 +4339,11 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case ICmpInst::ICMP_NE: // n != 0 ? n+x : 1+x -> umax(n, 1)+x - if (LHS->getType() == U->getType() && - isa(RHS) && - cast(RHS)->isZero()) { - const SCEV *One = getConstant(LHS->getType(), 1); - const SCEV *LS = getSCEV(LHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType()) && + isa(RHS) && cast(RHS)->isZero()) { + const SCEV *One = getConstant(U->getType(), 1); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4352,11 +4354,11 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case ICmpInst::ICMP_EQ: // n == 0 ? 1+x : n+x -> umax(n, 1)+x - if (LHS->getType() == U->getType() && - isa(RHS) && - cast(RHS)->isZero()) { - const SCEV *One = getConstant(LHS->getType(), 1); - const SCEV *LS = getSCEV(LHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType()) && + isa(RHS) && cast(RHS)->isZero()) { + const SCEV *One = getConstant(U->getType(), 1); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, One); @@ -7028,8 +7030,8 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, return false; } -// Verify if an linear IV with positive stride can overflow when in a -// less-than comparison, knowing the invariant term of the comparison, the +// Verify if an linear IV with positive stride can overflow when in a +// less-than comparison, knowing the invariant term of the comparison, the // stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { @@ -7057,7 +7059,7 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); } -// Verify if an linear IV with negative stride can overflow when in a +// Verify if an linear IV with negative stride can overflow when in a // greater-than comparison, knowing the invariant term of the comparison, // the stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, @@ -7088,7 +7090,7 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, // Compute the backedge taken count knowing the interval difference, the // stride and presence of the equality in the comparison. -const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, +const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, bool Equality) { const SCEV *One = getConstant(Step->getType(), 1); Delta = Equality ? getAddExpr(Delta, Step) @@ -7128,7 +7130,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // Avoid proven overflow cases: this will ensure that the backedge taken count // will not generate any unsigned overflow. Relaxed no-overflow conditions - // exploit NoWrapFlags, allowing to optimize in presence of undefined + // exploit NoWrapFlags, allowing to optimize in presence of undefined // behaviors like the case of C language. if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) return getCouldNotCompute(); @@ -7208,7 +7210,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, // Avoid proven overflow cases: this will ensure that the backedge taken count // will not generate any unsigned overflow. Relaxed no-overflow conditions - // exploit NoWrapFlags, allowing to optimize in presence of undefined + // exploit NoWrapFlags, allowing to optimize in presence of undefined // behaviors like the case of C language. if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap)) return getCouldNotCompute(); @@ -7256,7 +7258,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (isa(BECount)) MaxBECount = BECount; else - MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), + MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), getConstant(MinStride), false); if (isa(MaxBECount)) diff --git a/test/Analysis/ScalarEvolution/min-max-exprs.ll b/test/Analysis/ScalarEvolution/min-max-exprs.ll new file mode 100644 index 00000000000..3e0a35dd829 --- /dev/null +++ b/test/Analysis/ScalarEvolution/min-max-exprs.ll @@ -0,0 +1,53 @@ +; RUN: opt -scalar-evolution -analyze < %s | FileCheck %s +; +; This checks if the min and max expressions are properly recognized by +; ScalarEvolution even though they the ICmpInst and SelectInst have different +; types. +; +; #define max(a, b) (a > b ? a : b) +; #define min(a, b) (a < b ? a : b) +; +; void f(int *A, int N) { +; for (int i = 0; i < N; i++) { +; A[max(0, i - 3)] = A[min(N, i + 3)] * 2; +; } +; } +; +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" + +define void @f(i32* %A, i32 %N) { +bb: + br label %bb1 + +bb1: ; preds = %bb2, %bb + %i.0 = phi i32 [ 0, %bb ], [ %tmp23, %bb2 ] + %i.0.1 = sext i32 %i.0 to i64 + %tmp = icmp slt i32 %i.0, %N + br i1 %tmp, label %bb2, label %bb24 + +bb2: ; preds = %bb1 + %tmp3 = add nuw nsw i32 %i.0, 3 + %tmp4 = icmp slt i32 %tmp3, %N + %tmp5 = sext i32 %tmp3 to i64 + %tmp6 = sext i32 %N to i64 + %tmp9 = select i1 %tmp4, i64 %tmp5, i64 %tmp6 +; min(N, i+3) +; CHECK: select i1 %tmp4, i64 %tmp5, i64 %tmp6 +; CHECK-NEXT: --> (-1 + (-1 * ((-1 + (-1 * (sext i32 {3,+,1}<%bb1> to i64))) smax (-1 + (-1 * (sext i32 %N to i64)))))) + %tmp11 = getelementptr inbounds i32* %A, i64 %tmp9 + %tmp12 = load i32* %tmp11, align 4 + %tmp13 = shl nsw i32 %tmp12, 1 + %tmp14 = icmp sge i32 3, %i.0 + %tmp17 = add nsw i64 %i.0.1, -3 + %tmp19 = select i1 %tmp14, i64 0, i64 %tmp17 +; max(0, i - 3) +; CHECK: select i1 %tmp14, i64 0, i64 %tmp17 +; CHECK-NEXT: --> (-3 + (3 smax {0,+,1}<%bb1>)) + %tmp21 = getelementptr inbounds i32* %A, i64 %tmp19 + store i32 %tmp13, i32* %tmp21, align 4 + %tmp23 = add nuw nsw i32 %i.0, 1 + br label %bb1 + +bb24: ; preds = %bb1 + ret void +}