diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index 1e2a4647cc5..2a03bc1461d 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -222,6 +222,9 @@ namespace { Instruction *OptAndOp(Instruction *Op, ConstantIntegral *OpRHS, ConstantIntegral *AndRHS, BinaryOperator &TheAnd); + + Value *FoldLogicalPlusAnd(Value *LHS, Value *RHS, ConstantIntegral *Mask, + bool isSub, Instruction &I); Instruction *InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, bool Inside, Instruction &IB); @@ -1570,6 +1573,46 @@ Instruction *InstCombiner::InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, return new SetCondInst(Instruction::SetGT, OffsetVal, AddCST); } +/// FoldLogicalPlusAnd - We know that Mask is of the form 0+1+, and that this is +/// part of an expression (LHS +/- RHS) & Mask, where isSub determines whether +/// the operator is a sub. If we can fold one of the following xforms: +/// +/// ((A & N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == Mask +/// ((A | N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 +/// ((A ^ N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 +/// +/// return (A +/- B). +/// +Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS, + ConstantIntegral *Mask, bool isSub, + Instruction &I) { + Instruction *LHSI = dyn_cast(LHS); + if (!LHSI || LHSI->getNumOperands() != 2 || + !isa(LHSI->getOperand(1))) return 0; + + ConstantInt *N = cast(LHSI->getOperand(1)); + + switch (LHSI->getOpcode()) { + default: return 0; + case Instruction::And: + if (ConstantExpr::getAnd(N, Mask) == Mask) + break; + return 0; + case Instruction::Or: + case Instruction::Xor: + if (ConstantExpr::getAnd(N, Mask)->isNullValue()) + break; + return 0; + } + + Instruction *New; + if (isSub) + New = BinaryOperator::createSub(LHSI->getOperand(0), RHS, "fold"); + else + New = BinaryOperator::createAdd(LHSI->getOperand(0), RHS, "fold"); + return InsertNewInstBefore(New, I); +} + Instruction *InstCombiner::visitAnd(BinaryOperator &I) { bool Changed = SimplifyCommutative(I); @@ -1640,6 +1683,29 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { MaskedValueIsZero(Op0RHS, AndRHS)) return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); break; + case Instruction::Add: + // If the AndRHS is a power of two minus one (0+1+). + if ((AndRHS->getRawValue() & AndRHS->getRawValue()+1) == 0) { + // ((A & N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == AndRHS. + // ((A | N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 + // ((A ^ N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 + if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, false, I)) + return BinaryOperator::createAnd(V, AndRHS); + if (Value *V = FoldLogicalPlusAnd(Op0RHS, Op0LHS, AndRHS, false, I)) + return BinaryOperator::createAnd(V, AndRHS); // Add commutes + } + break; + + case Instruction::Sub: + // If the AndRHS is a power of two minus one (0+1+). + if ((AndRHS->getRawValue() & AndRHS->getRawValue()+1) == 0) { + // ((A & N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == AndRHS. + // ((A | N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 + // ((A ^ N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 + if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, true, I)) + return BinaryOperator::createAnd(V, AndRHS); + } + break; } if (ConstantInt *Op0CI = dyn_cast(Op0I->getOperand(1)))