diff --git a/include/llvm/Analysis/ScalarEvolution.h b/include/llvm/Analysis/ScalarEvolution.h index e53a8bffd4e..80809da81ae 100644 --- a/include/llvm/Analysis/ScalarEvolution.h +++ b/include/llvm/Analysis/ScalarEvolution.h @@ -624,6 +624,7 @@ namespace llvm { return getMulExpr(Ops, Flags); } const SCEV *getUDivExpr(const SCEV *LHS, const SCEV *RHS); + const SCEV *getUDivExactExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags); const SCEV *getAddRecExpr(SmallVectorImpl &Operands, diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index b65d99e4d67..e8ee46dff8c 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -321,7 +321,7 @@ const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { return S; } -const SCEV *ScalarEvolution::getConstant(const APInt& Val) { +const SCEV *ScalarEvolution::getConstant(const APInt &Val) { return getConstant(ConstantInt::get(getContext(), Val)); } @@ -2239,6 +2239,75 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, return S; } +static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { + APInt A = C1->getValue()->getValue().abs(); + APInt B = C2->getValue()->getValue().abs(); + uint32_t ABW = A.getBitWidth(); + uint32_t BBW = B.getBitWidth(); + + if (ABW > BBW) + B = B.zext(ABW); + else if (ABW < BBW) + A = A.zext(BBW); + + return APIntOps::GreatestCommonDivisor(A, B); +} + +/// getUDivExactExpr - Get a canonical unsigned division expression, or +/// something simpler if possible. There is no representation for an exact udiv +/// in SCEV IR, but we can attempt to remove factors from the LHS and RHS. +/// We can't do this when it's not exact because the udiv may be clearing bits. +const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, + const SCEV *RHS) { + // TODO: we could try to find factors in all sorts of things, but for now we + // just deal with u/exact (multiply, constant). See SCEVDivision towards the + // end of this file for inspiration. + + const SCEVMulExpr *Mul = dyn_cast(LHS); + if (!Mul) + return getUDivExpr(LHS, RHS); + + if (const SCEVConstant *RHSCst = dyn_cast(RHS)) { + // If the mulexpr multiplies by a constant, then that constant must be the + // first element of the mulexpr. + if (const SCEVConstant *LHSCst = + dyn_cast(Mul->getOperand(0))) { + if (LHSCst == RHSCst) { + SmallVector Operands; + Operands.append(Mul->op_begin() + 1, Mul->op_end()); + return getMulExpr(Operands); + } + + // We can't just assume that LHSCst divides RHSCst cleanly, it could be + // that there's a factor provided by one of the other terms. We need to + // check. + APInt Factor = gcd(LHSCst, RHSCst); + if (!Factor.isIntN(1)) { + LHSCst = cast( + getConstant(LHSCst->getValue()->getValue().udiv(Factor))); + RHSCst = cast( + getConstant(RHSCst->getValue()->getValue().udiv(Factor))); + SmallVector Operands; + Operands.push_back(LHSCst); + Operands.append(Mul->op_begin() + 1, Mul->op_end()); + LHS = getMulExpr(Operands); + RHS = RHSCst; + Mul = cast(LHS); + } + } + } + + for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) { + if (Mul->getOperand(i) == RHS) { + SmallVector Operands; + Operands.append(Mul->op_begin(), Mul->op_begin() + i); + Operands.append(Mul->op_begin() + i + 1, Mul->op_end()); + return getMulExpr(Operands); + } + } + + return getUDivExpr(LHS, RHS); +} /// getAddRecExpr - Get an add recurrence expression for the specified loop. /// Simplify the expression as much as possible. @@ -3689,17 +3758,24 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Use ComputeMaskedBits to compute what ShrinkDemandedConstant // knew about to reconstruct a low-bits mask value. unsigned LZ = A.countLeadingZeros(); + unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); ComputeMaskedBits(U->getOperand(0), KnownZero, KnownOne, TD); - APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ); - - if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask)) - return - getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)), - IntegerType::get(getContext(), BitWidth - LZ)), - U->getType()); + APInt EffectiveMask = + APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); + if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { + const SCEV *MulCount = getConstant( + ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, TZ))); + return getMulExpr( + getZeroExtendExpr( + getTruncateExpr( + getUDivExactExpr(getSCEV(U->getOperand(0)), MulCount), + IntegerType::get(getContext(), BitWidth - LZ - TZ)), + U->getType()), + MulCount); + } } break; @@ -6692,20 +6768,6 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, return SE.getCouldNotCompute(); } -static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue().abs(); - APInt B = C2->getValue()->getValue().abs(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.zext(ABW); - else if (ABW < BBW) - A = A.zext(BBW); - - return APIntOps::GreatestCommonDivisor(A, B); -} - static const APInt srem(const SCEVConstant *C1, const SCEVConstant *C2) { APInt A = C1->getValue()->getValue(); APInt B = C2->getValue()->getValue(); diff --git a/test/Analysis/ScalarEvolution/and-xor.ll b/test/Analysis/ScalarEvolution/and-xor.ll index 404ab91e269..ad636da4d4d 100644 --- a/test/Analysis/ScalarEvolution/and-xor.ll +++ b/test/Analysis/ScalarEvolution/and-xor.ll @@ -1,11 +1,27 @@ ; RUN: opt < %s -scalar-evolution -analyze | FileCheck %s +; CHECK-LABEL: @test1 ; CHECK: --> (zext ; CHECK: --> (zext ; CHECK-NOT: --> (zext -define i32 @foo(i32 %x) { +define i32 @test1(i32 %x) { %n = and i32 %x, 255 %y = xor i32 %n, 255 ret i32 %y } + +; ScalarEvolution shouldn't try to analyze %z into something like +; --> (zext i4 (-1 + (-1 * (trunc i64 (8 * %x) to i4))) to i64) +; or +; --> (8 * (zext i1 (trunc i64 ((8 * %x) /u 8) to i1) to i64)) + +; CHECK-LABEL: @test2 +; CHECK: --> (8 * (zext i1 (trunc i64 %x to i1) to i64)) + +define i64 @test2(i64 %x) { + %a = shl i64 %x, 3 + %t = and i64 %a, 8 + %z = xor i64 %t, 8 + ret i64 %z +} diff --git a/test/Analysis/ScalarEvolution/fold.ll b/test/Analysis/ScalarEvolution/fold.ll index 57006dd9bb4..84b657050c5 100644 --- a/test/Analysis/ScalarEvolution/fold.ll +++ b/test/Analysis/ScalarEvolution/fold.ll @@ -60,3 +60,20 @@ loop: exit: ret void } + +define void @test5(i32 %i) { +; CHECK-LABEL: @test5 + %A = and i32 %i, 1 +; CHECK: --> (zext i1 (trunc i32 %i to i1) to i32) + %B = and i32 %i, 2 +; CHECK: --> (2 * (zext i1 (trunc i32 (%i /u 2) to i1) to i32)) + %C = and i32 %i, 63 +; CHECK: --> (zext i6 (trunc i32 %i to i6) to i32) + %D = and i32 %i, 126 +; CHECK: --> (2 * (zext i6 (trunc i32 (%i /u 2) to i6) to i32)) + %E = and i32 %i, 64 +; CHECK: --> (64 * (zext i1 (trunc i32 (%i /u 64) to i1) to i32)) + %F = and i32 %i, -2147483648 +; CHECK: --> (-2147483648 * (%i /u -2147483648)) + ret void +} diff --git a/test/Analysis/ScalarEvolution/nsw-offset.ll b/test/Analysis/ScalarEvolution/nsw-offset.ll index 8969a5ad4ce..88cdcf23d9e 100644 --- a/test/Analysis/ScalarEvolution/nsw-offset.ll +++ b/test/Analysis/ScalarEvolution/nsw-offset.ll @@ -73,5 +73,5 @@ return: ; preds = %bb1.return_crit_edg ret void } -; CHECK: Loop %bb: backedge-taken count is ((-1 + %n) /u 2) +; CHECK: Loop %bb: backedge-taken count is ((-1 + (2 * (%no /u 2))) /u 2) ; CHECK: Loop %bb: max backedge-taken count is 1073741822 diff --git a/test/Analysis/ScalarEvolution/xor-and.ll b/test/Analysis/ScalarEvolution/xor-and.ll deleted file mode 100644 index 2616ea928a4..00000000000 --- a/test/Analysis/ScalarEvolution/xor-and.ll +++ /dev/null @@ -1,13 +0,0 @@ -; RUN: opt < %s -scalar-evolution -analyze | FileCheck %s - -; ScalarEvolution shouldn't try to analyze %z into something like -; --> (zext i4 (-1 + (-1 * (trunc i64 (8 * %x) to i4))) to i64) - -; CHECK: --> (zext i4 (-8 + (trunc i64 (8 * %x) to i4)) to i64) - -define i64 @foo(i64 %x) { - %a = shl i64 %x, 3 - %t = and i64 %a, 8 - %z = xor i64 %t, 8 - ret i64 %z -}