A bunch of cleanups and simplifications using the new PatternMatch predicates

and generally tidying things up.  Only very trivial functionality changes
like now doing (-1 - A) -> (~A) for vectors too.

 InstCombineAddSub.cpp |  296 +++++++++++++++++++++-----------------------------
 1 file changed, 126 insertions(+), 170 deletions(-)



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@125264 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Chris Lattner
2011-02-10 05:14:58 +00:00
parent a81556fb52
commit b9b9044600

View File

@@ -95,35 +95,26 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
if (Value *V = SimplifyUsingDistributiveLaws(I)) if (Value *V = SimplifyUsingDistributiveLaws(I))
return ReplaceInstUsesWith(I, V); return ReplaceInstUsesWith(I, V);
if (Constant *RHSC = dyn_cast<Constant>(RHS)) { if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
if (ConstantInt *CI = dyn_cast<ConstantInt>(RHSC)) { // X + (signbit) --> X ^ signbit
// X + (signbit) --> X ^ signbit const APInt &Val = CI->getValue();
const APInt& Val = CI->getValue(); if (Val.isSignBit())
uint32_t BitWidth = Val.getBitWidth(); return BinaryOperator::CreateXor(LHS, RHS);
if (Val == APInt::getSignBit(BitWidth))
return BinaryOperator::CreateXor(LHS, RHS);
// See if SimplifyDemandedBits can simplify this. This handles stuff like
// (X & 254)+1 -> (X&254)|1
if (SimplifyDemandedInstructionBits(I))
return &I;
// zext(bool) + C -> bool ? C + 1 : C
if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS))
if (ZI->getSrcTy() == Type::getInt1Ty(I.getContext()))
return SelectInst::Create(ZI->getOperand(0), AddOne(CI), CI);
}
if (isa<PHINode>(LHS))
if (Instruction *NV = FoldOpIntoPhi(I))
return NV;
ConstantInt *XorRHS = 0; // See if SimplifyDemandedBits can simplify this. This handles stuff like
Value *XorLHS = 0; // (X & 254)+1 -> (X&254)|1
if (isa<ConstantInt>(RHSC) && if (SimplifyDemandedInstructionBits(I))
match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) { return &I;
// zext(bool) + C -> bool ? C + 1 : C
if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS))
if (ZI->getSrcTy()->isIntegerTy(1))
return SelectInst::Create(ZI->getOperand(0), AddOne(CI), CI);
Value *XorLHS = 0; ConstantInt *XorRHS = 0;
if (match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) {
uint32_t TySizeBits = I.getType()->getScalarSizeInBits(); uint32_t TySizeBits = I.getType()->getScalarSizeInBits();
const APInt& RHSVal = cast<ConstantInt>(RHSC)->getValue(); const APInt &RHSVal = CI->getValue();
unsigned ExtendAmt = 0; unsigned ExtendAmt = 0;
// If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext. // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext.
// If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext. // If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext.
@@ -133,13 +124,13 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
else if (XorRHS->getValue().isPowerOf2()) else if (XorRHS->getValue().isPowerOf2())
ExtendAmt = TySizeBits - XorRHS->getValue().logBase2() - 1; ExtendAmt = TySizeBits - XorRHS->getValue().logBase2() - 1;
} }
if (ExtendAmt) { if (ExtendAmt) {
APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt); APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt);
if (!MaskedValueIsZero(XorLHS, Mask)) if (!MaskedValueIsZero(XorLHS, Mask))
ExtendAmt = 0; ExtendAmt = 0;
} }
if (ExtendAmt) { if (ExtendAmt) {
Constant *ShAmt = ConstantInt::get(I.getType(), ExtendAmt); Constant *ShAmt = ConstantInt::get(I.getType(), ExtendAmt);
Value *NewShl = Builder->CreateShl(XorLHS, ShAmt, "sext"); Value *NewShl = Builder->CreateShl(XorLHS, ShAmt, "sext");
@@ -148,23 +139,23 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
} }
} }
if (isa<Constant>(RHS) && isa<PHINode>(LHS))
if (Instruction *NV = FoldOpIntoPhi(I))
return NV;
if (I.getType()->isIntegerTy(1)) if (I.getType()->isIntegerTy(1))
return BinaryOperator::CreateXor(LHS, RHS); return BinaryOperator::CreateXor(LHS, RHS);
if (I.getType()->isIntegerTy()) { // X + X --> X << 1
// X + X --> X << 1 if (LHS == RHS && I.getType()->isIntegerTy())
if (LHS == RHS) return BinaryOperator::CreateShl(LHS, ConstantInt::get(I.getType(), 1));
return BinaryOperator::CreateShl(LHS, ConstantInt::get(I.getType(), 1));
}
// -A + B --> B - A // -A + B --> B - A
// -A + -B --> -(A + B) // -A + -B --> -(A + B)
if (Value *LHSV = dyn_castNegVal(LHS)) { if (Value *LHSV = dyn_castNegVal(LHS)) {
if (LHS->getType()->isIntOrIntVectorTy()) { if (Value *RHSV = dyn_castNegVal(RHS)) {
if (Value *RHSV = dyn_castNegVal(RHS)) { Value *NewAdd = Builder->CreateAdd(LHSV, RHSV, "sum");
Value *NewAdd = Builder->CreateAdd(LHSV, RHSV, "sum"); return BinaryOperator::CreateNeg(NewAdd);
return BinaryOperator::CreateNeg(NewAdd);
}
} }
return BinaryOperator::CreateSub(RHS, LHSV); return BinaryOperator::CreateSub(RHS, LHSV);
@@ -209,7 +200,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
} }
// W*X + Y*Z --> W * (X+Z) iff W == Y // W*X + Y*Z --> W * (X+Z) iff W == Y
if (I.getType()->isIntOrIntVectorTy()) { {
Value *W, *X, *Y, *Z; Value *W, *X, *Y, *Z;
if (match(LHS, m_Mul(m_Value(W), m_Value(X))) && if (match(LHS, m_Mul(m_Value(W), m_Value(X))) &&
match(RHS, m_Mul(m_Value(Y), m_Value(Z)))) { match(RHS, m_Mul(m_Value(Y), m_Value(Z)))) {
@@ -238,24 +229,22 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
// (X & FF00) + xx00 -> (X+xx00) & FF00 // (X & FF00) + xx00 -> (X+xx00) & FF00
if (LHS->hasOneUse() && if (LHS->hasOneUse() &&
match(LHS, m_And(m_Value(X), m_ConstantInt(C2)))) { match(LHS, m_And(m_Value(X), m_ConstantInt(C2))) &&
Constant *Anded = ConstantExpr::getAnd(CRHS, C2); CRHS->getValue() == (CRHS->getValue() & C2->getValue())) {
if (Anded == CRHS) { // See if all bits from the first bit set in the Add RHS up are included
// See if all bits from the first bit set in the Add RHS up are included // in the mask. First, get the rightmost bit.
// in the mask. First, get the rightmost bit. const APInt &AddRHSV = CRHS->getValue();
const APInt &AddRHSV = CRHS->getValue();
// Form a mask of all bits from the lowest bit added through the top.
APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1));
// Form a mask of all bits from the lowest bit added through the top. // See if the and mask includes all of these bits.
APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1)); APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue());
// See if the and mask includes all of these bits. if (AddRHSHighBits == AddRHSHighBitsAnd) {
APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue()); // Okay, the xform is safe. Insert the new add pronto.
Value *NewAdd = Builder->CreateAdd(X, CRHS, LHS->getName());
if (AddRHSHighBits == AddRHSHighBitsAnd) { return BinaryOperator::CreateAnd(NewAdd, C2);
// Okay, the xform is safe. Insert the new add pronto.
Value *NewAdd = Builder->CreateAdd(X, CRHS, LHS->getName());
return BinaryOperator::CreateAnd(NewAdd, C2);
}
} }
} }
@@ -280,12 +269,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
// Can we fold the add into the argument of the select? // Can we fold the add into the argument of the select?
// We check both true and false select arguments for a matching subtract. // We check both true and false select arguments for a matching subtract.
if (match(FV, m_Zero()) && if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A))))
match(TV, m_Sub(m_Value(N), m_Specific(A))))
// Fold the add into the true select value. // Fold the add into the true select value.
return SelectInst::Create(SI->getCondition(), N, A); return SelectInst::Create(SI->getCondition(), N, A);
if (match(TV, m_Zero()) &&
match(FV, m_Sub(m_Value(N), m_Specific(A)))) if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A))))
// Fold the add into the false select value. // Fold the add into the false select value.
return SelectInst::Create(SI->getCondition(), A, N); return SelectInst::Create(SI->getCondition(), A, N);
} }
@@ -550,12 +538,12 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
if (I.getType()->isIntegerTy(1)) if (I.getType()->isIntegerTy(1))
return BinaryOperator::CreateXor(Op0, Op1); return BinaryOperator::CreateXor(Op0, Op1);
// Replace (-1 - A) with (~A).
if (match(Op0, m_AllOnes()))
return BinaryOperator::CreateNot(Op1);
if (ConstantInt *C = dyn_cast<ConstantInt>(Op0)) { if (ConstantInt *C = dyn_cast<ConstantInt>(Op0)) {
// Replace (-1 - A) with (~A).
if (C->isAllOnesValue())
return BinaryOperator::CreateNot(Op1);
// C - ~X == X + (1+C) // C - ~X == X + (1+C)
Value *X = 0; Value *X = 0;
if (match(Op1, m_Not(m_Value(X)))) if (match(Op1, m_Not(m_Value(X))))
@@ -564,29 +552,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
// -(X >>u 31) -> (X >>s 31) // -(X >>u 31) -> (X >>s 31)
// -(X >>s 31) -> (X >>u 31) // -(X >>s 31) -> (X >>u 31)
if (C->isZero()) { if (C->isZero()) {
if (BinaryOperator *SI = dyn_cast<BinaryOperator>(Op1)) { Value *X; ConstantInt *CI;
if (SI->getOpcode() == Instruction::LShr) { if (match(Op1, m_LShr(m_Value(X), m_ConstantInt(CI))) &&
if (ConstantInt *CU = dyn_cast<ConstantInt>(SI->getOperand(1))) { // Verify we are shifting out everything but the sign bit.
// Check to see if we are shifting out everything but the sign bit. CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1)
if (CU->getLimitedValue(SI->getType()->getPrimitiveSizeInBits()) == return BinaryOperator::CreateAShr(X, CI);
SI->getType()->getPrimitiveSizeInBits()-1) {
// Ok, the transformation is safe. Insert AShr. if (match(Op1, m_AShr(m_Value(X), m_ConstantInt(CI))) &&
return BinaryOperator::Create(Instruction::AShr, // Verify we are shifting out everything but the sign bit.
SI->getOperand(0), CU, SI->getName()); CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1)
} return BinaryOperator::CreateLShr(X, CI);
}
} else if (SI->getOpcode() == Instruction::AShr) {
if (ConstantInt *CU = dyn_cast<ConstantInt>(SI->getOperand(1))) {
// Check to see if we are shifting out everything but the sign bit.
if (CU->getLimitedValue(SI->getType()->getPrimitiveSizeInBits()) ==
SI->getType()->getPrimitiveSizeInBits()-1) {
// Ok, the transformation is safe. Insert LShr.
return BinaryOperator::CreateLShr(
SI->getOperand(0), CU, SI->getName());
}
}
}
}
} }
// Try to fold constant sub into select arguments. // Try to fold constant sub into select arguments.
@@ -596,99 +571,80 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
// C - zext(bool) -> bool ? C - 1 : C // C - zext(bool) -> bool ? C - 1 : C
if (ZExtInst *ZI = dyn_cast<ZExtInst>(Op1)) if (ZExtInst *ZI = dyn_cast<ZExtInst>(Op1))
if (ZI->getSrcTy() == Type::getInt1Ty(I.getContext())) if (ZI->getSrcTy()->isIntegerTy(1))
return SelectInst::Create(ZI->getOperand(0), SubOne(C), C); return SelectInst::Create(ZI->getOperand(0), SubOne(C), C);
// C-(X+C2) --> (C-C2)-X
ConstantInt *C2;
if (match(Op1, m_Add(m_Value(X), m_ConstantInt(C2))))
return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
} }
if (BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1)) {
if (Op1I->getOpcode() == Instruction::Add) { { Value *Y;
if (Op1I->getOperand(0) == Op0) // X-(X+Y) == -Y // X-(X+Y) == -Y X-(Y+X) == -Y
return BinaryOperator::CreateNeg(Op1I->getOperand(1), if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) ||
I.getName()); match(Op1, m_Add(m_Value(Y), m_Specific(Op0))))
else if (Op1I->getOperand(1) == Op0) // X-(Y+X) == -Y return BinaryOperator::CreateNeg(Y);
return BinaryOperator::CreateNeg(Op1I->getOperand(0),
I.getName()); // (X-Y)-X == -Y
else if (ConstantInt *CI1 = dyn_cast<ConstantInt>(I.getOperand(0))) { if (match(Op0, m_Sub(m_Specific(Op1), m_Value(Y))))
if (ConstantInt *CI2 = dyn_cast<ConstantInt>(Op1I->getOperand(1))) return BinaryOperator::CreateNeg(Y);
// C1-(X+C2) --> (C1-C2)-X }
return BinaryOperator::CreateSub(
ConstantExpr::getSub(CI1, CI2), Op1I->getOperand(0)); if (Op1->hasOneUse()) {
} Value *X = 0, *Y = 0, *Z = 0;
Constant *C = 0;
ConstantInt *CI = 0;
// (X - (Y - Z)) --> (X + (Z - Y)).
if (match(Op1, m_Sub(m_Value(Y), m_Value(Z))))
return BinaryOperator::CreateAdd(Op0,
Builder->CreateSub(Z, Y, Op1->getName()));
// (X - (X & Y)) --> (X & ~Y)
//
if (match(Op1, m_And(m_Value(Y), m_Specific(Op0))) ||
match(Op1, m_And(m_Specific(Op0), m_Value(Y))))
return BinaryOperator::CreateAnd(Op0,
Builder->CreateNot(Y, Y->getName() + ".not"));
// 0 - (X sdiv C) -> (X sdiv -C)
if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) &&
match(Op0, m_Zero()))
return BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(C));
// 0 - (X << Y) -> (-X << Y) when X is freely negatable.
if (match(Op1, m_Shl(m_Value(X), m_Value(Y))) && match(Op0, m_Zero()))
if (Value *XNeg = dyn_castNegVal(X))
return BinaryOperator::CreateShl(XNeg, Y);
// X - X*C --> X * (1-C)
if (match(Op1, m_Mul(m_Specific(Op0), m_ConstantInt(CI)))) {
Constant *CP1 = ConstantExpr::getSub(ConstantInt::get(I.getType(),1), CI);
return BinaryOperator::CreateMul(Op0, CP1);
} }
if (Op1I->hasOneUse()) { // X - X<<C --> X * (1-(1<<C))
// Replace (x - (y - z)) with (x + (z - y)) if the (y - z) subexpression if (match(Op1, m_Shl(m_Specific(Op0), m_ConstantInt(CI)))) {
// is not used by anyone else... Constant *One = ConstantInt::get(I.getType(), 1);
// C = ConstantExpr::getSub(One, ConstantExpr::getShl(One, CI));
if (Op1I->getOpcode() == Instruction::Sub) { return BinaryOperator::CreateMul(Op0, C);
// Swap the two operands of the subexpr... }
Value *IIOp0 = Op1I->getOperand(0), *IIOp1 = Op1I->getOperand(1);
Op1I->setOperand(0, IIOp1); // X - A*-B -> X + A*B
Op1I->setOperand(1, IIOp0); // X - -A*B -> X + A*B
Value *A, *B;
// Create the new top level add instruction... if (match(Op1, m_Mul(m_Value(A), m_Neg(m_Value(B)))) ||
return BinaryOperator::CreateAdd(Op0, Op1); match(Op1, m_Mul(m_Neg(m_Value(A)), m_Value(B))))
} return BinaryOperator::CreateAdd(Op0, Builder->CreateMul(A, B));
// Replace (A - (A & B)) with (A & ~B) if this is the only use of (A&B)...
//
if (Op1I->getOpcode() == Instruction::And &&
(Op1I->getOperand(0) == Op0 || Op1I->getOperand(1) == Op0)) {
Value *OtherOp = Op1I->getOperand(Op1I->getOperand(0) == Op0);
Value *NewNot = Builder->CreateNot(OtherOp, "B.not");
return BinaryOperator::CreateAnd(Op0, NewNot);
}
// 0 - (X sdiv C) -> (X sdiv -C)
if (Op1I->getOpcode() == Instruction::SDiv)
if (ConstantInt *CSI = dyn_cast<ConstantInt>(Op0))
if (CSI->isZero())
if (Constant *DivRHS = dyn_cast<Constant>(Op1I->getOperand(1)))
return BinaryOperator::CreateSDiv(Op1I->getOperand(0),
ConstantExpr::getNeg(DivRHS));
// 0 - (C << X) -> (-C << X)
if (Op1I->getOpcode() == Instruction::Shl)
if (ConstantInt *CSI = dyn_cast<ConstantInt>(Op0))
if (CSI->isZero())
if (Value *ShlLHSNeg = dyn_castNegVal(Op1I->getOperand(0)))
return BinaryOperator::CreateShl(ShlLHSNeg, Op1I->getOperand(1));
// X - X*C --> X * (1-C)
ConstantInt *C2 = 0;
if (dyn_castFoldableMul(Op1I, C2) == Op0) {
Constant *CP1 =
ConstantExpr::getSub(ConstantInt::get(I.getType(), 1),
C2);
return BinaryOperator::CreateMul(Op0, CP1);
}
// X - A*-B -> X + A*B
// X - -A*B -> X + A*B
Value *A, *B;
if (match(Op1I, m_Mul(m_Value(A), m_Neg(m_Value(B)))) ||
match(Op1I, m_Mul(m_Neg(m_Value(A)), m_Value(B)))) {
Value *NewMul = Builder->CreateMul(A, B);
return BinaryOperator::CreateAdd(Op0, NewMul);
}
// X - A*Cst -> X + A*-Cst // X - A*CI -> X + A*-CI
// X - Cst*A -> X + A*-Cst // X - CI*A -> X + A*-CI
ConstantInt *BCst; if (match(Op1, m_Mul(m_Value(A), m_ConstantInt(CI))) ||
if (match(Op1I, m_Mul(m_Value(A), m_ConstantInt(BCst))) || match(Op1, m_Mul(m_ConstantInt(CI), m_Value(A)))) {
match(Op1I, m_Mul(m_ConstantInt(BCst), m_Value(A)))) { Value *NewMul = Builder->CreateMul(A, ConstantExpr::getNeg(CI));
Value *NewMul = Builder->CreateMul(A, ConstantExpr::getNeg(BCst)); return BinaryOperator::CreateAdd(Op0, NewMul);
return BinaryOperator::CreateAdd(Op0, NewMul);
}
}
}
if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) {
if (Op0I->getOpcode() == Instruction::Sub) {
if (Op0I->getOperand(0) == Op1) // (X-Y)-X == -Y
return BinaryOperator::CreateNeg(Op0I->getOperand(1),
I.getName());
} }
} }