Make use of @llvm.assume in ValueTracking (computeKnownBits, etc.)

This change, which allows @llvm.assume to be used from within computeKnownBits
(and other associated functions in ValueTracking), adds some (optional)
parameters to computeKnownBits and friends. These functions now (optionally)
take a "context" instruction pointer, an AssumptionTracker pointer, and also a
DomTree pointer, and most of the changes are just to pass this new information
when it is easily available from InstSimplify, InstCombine, etc.

As explained below, the significant conceptual change is that known properties
of a value might depend on the control-flow location of the use (because we
care that the @llvm.assume dominates the use because assumptions have
control-flow dependencies). This means that, when we ask if bits are known in a
value, we might get different answers for different uses.

The significant changes are all in ValueTracking. Two main changes: First, as
with the rest of the code, new parameters need to be passed around. To make
this easier, I grouped them into a structure, and I made internal static
versions of the relevant functions that take this structure as a parameter. The
new code does as you might expect, it looks for @llvm.assume calls that make
use of the value we're trying to learn something about (often indirectly),
attempts to pattern match that expression, and uses the result if successful.
By making use of the AssumptionTracker, the process of finding @llvm.assume
calls is not expensive.

Part of the structure being passed around inside ValueTracking is a set of
already-considered @llvm.assume calls. This is to prevent a query using, for
example, the assume(a == b), to recurse on itself. The context and DT params
are used to find applicable assumptions. An assumption needs to dominate the
context instruction, or come after it deterministically. In this latter case we
only handle the specific case where both the assumption and the context
instruction are in the same block, and we need to exclude assumptions from
being used to simplify their own ephemeral values (those which contribute only
to the assumption) because otherwise the assumption would prove its feeding
comparison trivial and would be removed.

This commit adds the plumbing and the logic for a simple masked-bit propagation
(just enough to write a regression test). Future commits add more patterns
(and, correspondingly, more regression tests).

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@217342 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Hal Finkel
2014-09-07 18:57:58 +00:00
parent 22f8dcb2b5
commit 851b04c920
48 changed files with 1457 additions and 545 deletions

View File

@@ -58,8 +58,8 @@ static Type *reduceToSingleValueType(Type *T) {
}
Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL);
unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL);
unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, AT, MI, DT);
unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, AT, MI, DT);
unsigned MinAlign = std::min(DstAlign, SrcAlign);
unsigned CopyAlign = MI->getAlignment();
@@ -154,7 +154,7 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
}
Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) {
unsigned Alignment = getKnownAlignment(MI->getDest(), DL);
unsigned Alignment = getKnownAlignment(MI->getDest(), DL, AT, MI, DT);
if (MI->getAlignment() < Alignment) {
MI->setAlignment(ConstantInt::get(MI->getAlignmentType(),
Alignment, false));
@@ -322,7 +322,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
uint32_t BitWidth = IT->getBitWidth();
APInt KnownZero(BitWidth, 0);
APInt KnownOne(BitWidth, 0);
computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne);
computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne, 0, II);
unsigned TrailingZeros = KnownOne.countTrailingZeros();
APInt Mask(APInt::getLowBitsSet(BitWidth, TrailingZeros));
if ((Mask & KnownZero) == Mask)
@@ -340,7 +340,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
uint32_t BitWidth = IT->getBitWidth();
APInt KnownZero(BitWidth, 0);
APInt KnownOne(BitWidth, 0);
computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne);
computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne, 0, II);
unsigned LeadingZeros = KnownOne.countLeadingZeros();
APInt Mask(APInt::getHighBitsSet(BitWidth, LeadingZeros));
if ((Mask & KnownZero) == Mask)
@@ -355,14 +355,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
uint32_t BitWidth = IT->getBitWidth();
APInt LHSKnownZero(BitWidth, 0);
APInt LHSKnownOne(BitWidth, 0);
computeKnownBits(LHS, LHSKnownZero, LHSKnownOne);
computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, II);
bool LHSKnownNegative = LHSKnownOne[BitWidth - 1];
bool LHSKnownPositive = LHSKnownZero[BitWidth - 1];
if (LHSKnownNegative || LHSKnownPositive) {
APInt RHSKnownZero(BitWidth, 0);
APInt RHSKnownOne(BitWidth, 0);
computeKnownBits(RHS, RHSKnownZero, RHSKnownOne);
computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, II);
bool RHSKnownNegative = RHSKnownOne[BitWidth - 1];
bool RHSKnownPositive = RHSKnownZero[BitWidth - 1];
if (LHSKnownNegative && RHSKnownNegative) {
@@ -426,7 +426,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// can prove that it will never overflow.
if (II->getIntrinsicID() == Intrinsic::sadd_with_overflow) {
Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1);
if (WillNotOverflowSignedAdd(LHS, RHS)) {
if (WillNotOverflowSignedAdd(LHS, RHS, II)) {
Value *Add = Builder->CreateNSWAdd(LHS, RHS);
Add->takeName(&CI);
Constant *V[] = {UndefValue::get(Add->getType()), Builder->getFalse()};
@@ -464,10 +464,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
APInt LHSKnownZero(BitWidth, 0);
APInt LHSKnownOne(BitWidth, 0);
computeKnownBits(LHS, LHSKnownZero, LHSKnownOne);
computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, II);
APInt RHSKnownZero(BitWidth, 0);
APInt RHSKnownOne(BitWidth, 0);
computeKnownBits(RHS, RHSKnownZero, RHSKnownOne);
computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, II);
// Get the largest possible values for each operand.
APInt LHSMax = ~LHSKnownZero;
@@ -521,7 +521,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::ppc_altivec_lvx:
case Intrinsic::ppc_altivec_lvxl:
// Turn PPC lvx -> load if the pointer is known aligned.
if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL) >= 16) {
if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16,
DL, AT, II, DT) >= 16) {
Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0),
PointerType::getUnqual(II->getType()));
return new LoadInst(Ptr);
@@ -530,7 +531,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::ppc_altivec_stvx:
case Intrinsic::ppc_altivec_stvxl:
// Turn stvx -> store if the pointer is known aligned.
if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL) >= 16) {
if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16,
DL, AT, II, DT) >= 16) {
Type *OpPtrTy =
PointerType::getUnqual(II->getArgOperand(0)->getType());
Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy);
@@ -541,7 +543,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::x86_sse2_storeu_pd:
case Intrinsic::x86_sse2_storeu_dq:
// Turn X86 storeu -> store if the pointer is known aligned.
if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL) >= 16) {
if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16,
DL, AT, II, DT) >= 16) {
Type *OpPtrTy =
PointerType::getUnqual(II->getArgOperand(1)->getType());
Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), OpPtrTy);
@@ -886,7 +889,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::arm_neon_vst2lane:
case Intrinsic::arm_neon_vst3lane:
case Intrinsic::arm_neon_vst4lane: {
unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), DL);
unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), DL, AT, II, DT);
unsigned AlignArg = II->getNumArgOperands() - 1;
ConstantInt *IntrAlign = dyn_cast<ConstantInt>(II->getArgOperand(AlignArg));
if (IntrAlign && IntrAlign->getZExtValue() < MemAlign) {