From b9b9044600a472d5f8750f43fd884e32e2afe4cc Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Thu, 10 Feb 2011 05:14:58 +0000 Subject: [PATCH] A bunch of cleanups and simplifications using the new PatternMatch predicates and generally tidying things up. Only very trivial functionality changes like now doing (-1 - A) -> (~A) for vectors too. InstCombineAddSub.cpp | 296 +++++++++++++++++++++----------------------------- 1 file changed, 126 insertions(+), 170 deletions(-) git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@125264 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../InstCombine/InstCombineAddSub.cpp | 304 ++++++++---------- 1 file changed, 130 insertions(+), 174 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index e5d9a8b9618..91116ca27ad 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -95,35 +95,26 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Value *V = SimplifyUsingDistributiveLaws(I)) return ReplaceInstUsesWith(I, V); - if (Constant *RHSC = dyn_cast(RHS)) { - if (ConstantInt *CI = dyn_cast(RHSC)) { - // X + (signbit) --> X ^ signbit - const APInt& Val = CI->getValue(); - uint32_t BitWidth = Val.getBitWidth(); - if (Val == APInt::getSignBit(BitWidth)) - return BinaryOperator::CreateXor(LHS, RHS); - - // See if SimplifyDemandedBits can simplify this. This handles stuff like - // (X & 254)+1 -> (X&254)|1 - if (SimplifyDemandedInstructionBits(I)) - return &I; - - // zext(bool) + C -> bool ? C + 1 : C - if (ZExtInst *ZI = dyn_cast(LHS)) - if (ZI->getSrcTy() == Type::getInt1Ty(I.getContext())) - return SelectInst::Create(ZI->getOperand(0), AddOne(CI), CI); - } - - if (isa(LHS)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + if (ConstantInt *CI = dyn_cast(RHS)) { + // X + (signbit) --> X ^ signbit + const APInt &Val = CI->getValue(); + if (Val.isSignBit()) + return BinaryOperator::CreateXor(LHS, RHS); - ConstantInt *XorRHS = 0; - Value *XorLHS = 0; - if (isa(RHSC) && - match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) { + // See if SimplifyDemandedBits can simplify this. This handles stuff like + // (X & 254)+1 -> (X&254)|1 + if (SimplifyDemandedInstructionBits(I)) + return &I; + + // zext(bool) + C -> bool ? C + 1 : C + if (ZExtInst *ZI = dyn_cast(LHS)) + if (ZI->getSrcTy()->isIntegerTy(1)) + return SelectInst::Create(ZI->getOperand(0), AddOne(CI), CI); + + Value *XorLHS = 0; ConstantInt *XorRHS = 0; + if (match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) { uint32_t TySizeBits = I.getType()->getScalarSizeInBits(); - const APInt& RHSVal = cast(RHSC)->getValue(); + const APInt &RHSVal = CI->getValue(); unsigned ExtendAmt = 0; // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext. // If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext. @@ -133,13 +124,13 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { else if (XorRHS->getValue().isPowerOf2()) ExtendAmt = TySizeBits - XorRHS->getValue().logBase2() - 1; } - + if (ExtendAmt) { APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt); if (!MaskedValueIsZero(XorLHS, Mask)) ExtendAmt = 0; } - + if (ExtendAmt) { Constant *ShAmt = ConstantInt::get(I.getType(), ExtendAmt); Value *NewShl = Builder->CreateShl(XorLHS, ShAmt, "sext"); @@ -148,23 +139,23 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } + if (isa(RHS) && isa(LHS)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + if (I.getType()->isIntegerTy(1)) return BinaryOperator::CreateXor(LHS, RHS); - if (I.getType()->isIntegerTy()) { - // X + X --> X << 1 - if (LHS == RHS) - return BinaryOperator::CreateShl(LHS, ConstantInt::get(I.getType(), 1)); - } + // X + X --> X << 1 + if (LHS == RHS && I.getType()->isIntegerTy()) + return BinaryOperator::CreateShl(LHS, ConstantInt::get(I.getType(), 1)); // -A + B --> B - A // -A + -B --> -(A + B) if (Value *LHSV = dyn_castNegVal(LHS)) { - if (LHS->getType()->isIntOrIntVectorTy()) { - if (Value *RHSV = dyn_castNegVal(RHS)) { - Value *NewAdd = Builder->CreateAdd(LHSV, RHSV, "sum"); - return BinaryOperator::CreateNeg(NewAdd); - } + if (Value *RHSV = dyn_castNegVal(RHS)) { + Value *NewAdd = Builder->CreateAdd(LHSV, RHSV, "sum"); + return BinaryOperator::CreateNeg(NewAdd); } return BinaryOperator::CreateSub(RHS, LHSV); @@ -209,7 +200,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } // W*X + Y*Z --> W * (X+Z) iff W == Y - if (I.getType()->isIntOrIntVectorTy()) { + { Value *W, *X, *Y, *Z; if (match(LHS, m_Mul(m_Value(W), m_Value(X))) && match(RHS, m_Mul(m_Value(Y), m_Value(Z)))) { @@ -238,24 +229,22 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // (X & FF00) + xx00 -> (X+xx00) & FF00 if (LHS->hasOneUse() && - match(LHS, m_And(m_Value(X), m_ConstantInt(C2)))) { - Constant *Anded = ConstantExpr::getAnd(CRHS, C2); - if (Anded == CRHS) { - // See if all bits from the first bit set in the Add RHS up are included - // in the mask. First, get the rightmost bit. - const APInt &AddRHSV = CRHS->getValue(); + match(LHS, m_And(m_Value(X), m_ConstantInt(C2))) && + CRHS->getValue() == (CRHS->getValue() & C2->getValue())) { + // See if all bits from the first bit set in the Add RHS up are included + // in the mask. First, get the rightmost bit. + const APInt &AddRHSV = CRHS->getValue(); + + // Form a mask of all bits from the lowest bit added through the top. + APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1)); - // Form a mask of all bits from the lowest bit added through the top. - APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1)); + // See if the and mask includes all of these bits. + APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue()); - // See if the and mask includes all of these bits. - APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue()); - - if (AddRHSHighBits == AddRHSHighBitsAnd) { - // Okay, the xform is safe. Insert the new add pronto. - Value *NewAdd = Builder->CreateAdd(X, CRHS, LHS->getName()); - return BinaryOperator::CreateAnd(NewAdd, C2); - } + if (AddRHSHighBits == AddRHSHighBitsAnd) { + // Okay, the xform is safe. Insert the new add pronto. + Value *NewAdd = Builder->CreateAdd(X, CRHS, LHS->getName()); + return BinaryOperator::CreateAnd(NewAdd, C2); } } @@ -280,12 +269,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // Can we fold the add into the argument of the select? // We check both true and false select arguments for a matching subtract. - if (match(FV, m_Zero()) && - match(TV, m_Sub(m_Value(N), m_Specific(A)))) + if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A)))) // Fold the add into the true select value. return SelectInst::Create(SI->getCondition(), N, A); - if (match(TV, m_Zero()) && - match(FV, m_Sub(m_Value(N), m_Specific(A)))) + + if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A)))) // Fold the add into the false select value. return SelectInst::Create(SI->getCondition(), A, N); } @@ -550,12 +538,12 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (I.getType()->isIntegerTy(1)) return BinaryOperator::CreateXor(Op0, Op1); + + // Replace (-1 - A) with (~A). + if (match(Op0, m_AllOnes())) + return BinaryOperator::CreateNot(Op1); if (ConstantInt *C = dyn_cast(Op0)) { - // Replace (-1 - A) with (~A). - if (C->isAllOnesValue()) - return BinaryOperator::CreateNot(Op1); - // C - ~X == X + (1+C) Value *X = 0; if (match(Op1, m_Not(m_Value(X)))) @@ -564,29 +552,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // -(X >>u 31) -> (X >>s 31) // -(X >>s 31) -> (X >>u 31) if (C->isZero()) { - if (BinaryOperator *SI = dyn_cast(Op1)) { - if (SI->getOpcode() == Instruction::LShr) { - if (ConstantInt *CU = dyn_cast(SI->getOperand(1))) { - // Check to see if we are shifting out everything but the sign bit. - if (CU->getLimitedValue(SI->getType()->getPrimitiveSizeInBits()) == - SI->getType()->getPrimitiveSizeInBits()-1) { - // Ok, the transformation is safe. Insert AShr. - return BinaryOperator::Create(Instruction::AShr, - SI->getOperand(0), CU, SI->getName()); - } - } - } else if (SI->getOpcode() == Instruction::AShr) { - if (ConstantInt *CU = dyn_cast(SI->getOperand(1))) { - // Check to see if we are shifting out everything but the sign bit. - if (CU->getLimitedValue(SI->getType()->getPrimitiveSizeInBits()) == - SI->getType()->getPrimitiveSizeInBits()-1) { - // Ok, the transformation is safe. Insert LShr. - return BinaryOperator::CreateLShr( - SI->getOperand(0), CU, SI->getName()); - } - } - } - } + Value *X; ConstantInt *CI; + if (match(Op1, m_LShr(m_Value(X), m_ConstantInt(CI))) && + // Verify we are shifting out everything but the sign bit. + CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1) + return BinaryOperator::CreateAShr(X, CI); + + if (match(Op1, m_AShr(m_Value(X), m_ConstantInt(CI))) && + // Verify we are shifting out everything but the sign bit. + CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1) + return BinaryOperator::CreateLShr(X, CI); } // Try to fold constant sub into select arguments. @@ -596,99 +571,80 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // C - zext(bool) -> bool ? C - 1 : C if (ZExtInst *ZI = dyn_cast(Op1)) - if (ZI->getSrcTy() == Type::getInt1Ty(I.getContext())) + if (ZI->getSrcTy()->isIntegerTy(1)) return SelectInst::Create(ZI->getOperand(0), SubOne(C), C); + + // C-(X+C2) --> (C-C2)-X + ConstantInt *C2; + if (match(Op1, m_Add(m_Value(X), m_ConstantInt(C2)))) + return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); } - if (BinaryOperator *Op1I = dyn_cast(Op1)) { - if (Op1I->getOpcode() == Instruction::Add) { - if (Op1I->getOperand(0) == Op0) // X-(X+Y) == -Y - return BinaryOperator::CreateNeg(Op1I->getOperand(1), - I.getName()); - else if (Op1I->getOperand(1) == Op0) // X-(Y+X) == -Y - return BinaryOperator::CreateNeg(Op1I->getOperand(0), - I.getName()); - else if (ConstantInt *CI1 = dyn_cast(I.getOperand(0))) { - if (ConstantInt *CI2 = dyn_cast(Op1I->getOperand(1))) - // C1-(X+C2) --> (C1-C2)-X - return BinaryOperator::CreateSub( - ConstantExpr::getSub(CI1, CI2), Op1I->getOperand(0)); - } + + { Value *Y; + // X-(X+Y) == -Y X-(Y+X) == -Y + if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) || + match(Op1, m_Add(m_Value(Y), m_Specific(Op0)))) + return BinaryOperator::CreateNeg(Y); + + // (X-Y)-X == -Y + if (match(Op0, m_Sub(m_Specific(Op1), m_Value(Y)))) + return BinaryOperator::CreateNeg(Y); + } + + if (Op1->hasOneUse()) { + Value *X = 0, *Y = 0, *Z = 0; + Constant *C = 0; + ConstantInt *CI = 0; + + // (X - (Y - Z)) --> (X + (Z - Y)). + if (match(Op1, m_Sub(m_Value(Y), m_Value(Z)))) + return BinaryOperator::CreateAdd(Op0, + Builder->CreateSub(Z, Y, Op1->getName())); + + // (X - (X & Y)) --> (X & ~Y) + // + if (match(Op1, m_And(m_Value(Y), m_Specific(Op0))) || + match(Op1, m_And(m_Specific(Op0), m_Value(Y)))) + return BinaryOperator::CreateAnd(Op0, + Builder->CreateNot(Y, Y->getName() + ".not")); + + // 0 - (X sdiv C) -> (X sdiv -C) + if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) && + match(Op0, m_Zero())) + return BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(C)); + + // 0 - (X << Y) -> (-X << Y) when X is freely negatable. + if (match(Op1, m_Shl(m_Value(X), m_Value(Y))) && match(Op0, m_Zero())) + if (Value *XNeg = dyn_castNegVal(X)) + return BinaryOperator::CreateShl(XNeg, Y); + + // X - X*C --> X * (1-C) + if (match(Op1, m_Mul(m_Specific(Op0), m_ConstantInt(CI)))) { + Constant *CP1 = ConstantExpr::getSub(ConstantInt::get(I.getType(),1), CI); + return BinaryOperator::CreateMul(Op0, CP1); } - if (Op1I->hasOneUse()) { - // Replace (x - (y - z)) with (x + (z - y)) if the (y - z) subexpression - // is not used by anyone else... - // - if (Op1I->getOpcode() == Instruction::Sub) { - // Swap the two operands of the subexpr... - Value *IIOp0 = Op1I->getOperand(0), *IIOp1 = Op1I->getOperand(1); - Op1I->setOperand(0, IIOp1); - Op1I->setOperand(1, IIOp0); - - // Create the new top level add instruction... - return BinaryOperator::CreateAdd(Op0, Op1); - } - - // Replace (A - (A & B)) with (A & ~B) if this is the only use of (A&B)... - // - if (Op1I->getOpcode() == Instruction::And && - (Op1I->getOperand(0) == Op0 || Op1I->getOperand(1) == Op0)) { - Value *OtherOp = Op1I->getOperand(Op1I->getOperand(0) == Op0); - - Value *NewNot = Builder->CreateNot(OtherOp, "B.not"); - return BinaryOperator::CreateAnd(Op0, NewNot); - } - - // 0 - (X sdiv C) -> (X sdiv -C) - if (Op1I->getOpcode() == Instruction::SDiv) - if (ConstantInt *CSI = dyn_cast(Op0)) - if (CSI->isZero()) - if (Constant *DivRHS = dyn_cast(Op1I->getOperand(1))) - return BinaryOperator::CreateSDiv(Op1I->getOperand(0), - ConstantExpr::getNeg(DivRHS)); - - // 0 - (C << X) -> (-C << X) - if (Op1I->getOpcode() == Instruction::Shl) - if (ConstantInt *CSI = dyn_cast(Op0)) - if (CSI->isZero()) - if (Value *ShlLHSNeg = dyn_castNegVal(Op1I->getOperand(0))) - return BinaryOperator::CreateShl(ShlLHSNeg, Op1I->getOperand(1)); - - // X - X*C --> X * (1-C) - ConstantInt *C2 = 0; - if (dyn_castFoldableMul(Op1I, C2) == Op0) { - Constant *CP1 = - ConstantExpr::getSub(ConstantInt::get(I.getType(), 1), - C2); - return BinaryOperator::CreateMul(Op0, CP1); - } - - // X - A*-B -> X + A*B - // X - -A*B -> X + A*B - Value *A, *B; - if (match(Op1I, m_Mul(m_Value(A), m_Neg(m_Value(B)))) || - match(Op1I, m_Mul(m_Neg(m_Value(A)), m_Value(B)))) { - Value *NewMul = Builder->CreateMul(A, B); - return BinaryOperator::CreateAdd(Op0, NewMul); - } + // X - X< X * (1-(1< X + A*B + // X - -A*B -> X + A*B + Value *A, *B; + if (match(Op1, m_Mul(m_Value(A), m_Neg(m_Value(B)))) || + match(Op1, m_Mul(m_Neg(m_Value(A)), m_Value(B)))) + return BinaryOperator::CreateAdd(Op0, Builder->CreateMul(A, B)); - // X - A*Cst -> X + A*-Cst - // X - Cst*A -> X + A*-Cst - ConstantInt *BCst; - if (match(Op1I, m_Mul(m_Value(A), m_ConstantInt(BCst))) || - match(Op1I, m_Mul(m_ConstantInt(BCst), m_Value(A)))) { - Value *NewMul = Builder->CreateMul(A, ConstantExpr::getNeg(BCst)); - return BinaryOperator::CreateAdd(Op0, NewMul); - } - } - } - - if (BinaryOperator *Op0I = dyn_cast(Op0)) { - if (Op0I->getOpcode() == Instruction::Sub) { - if (Op0I->getOperand(0) == Op1) // (X-Y)-X == -Y - return BinaryOperator::CreateNeg(Op0I->getOperand(1), - I.getName()); + // X - A*CI -> X + A*-CI + // X - CI*A -> X + A*-CI + if (match(Op1, m_Mul(m_Value(A), m_ConstantInt(CI))) || + match(Op1, m_Mul(m_ConstantInt(CI), m_Value(A)))) { + Value *NewMul = Builder->CreateMul(A, ConstantExpr::getNeg(CI)); + return BinaryOperator::CreateAdd(Op0, NewMul); } }