diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index d041083f040..be754d13c7a 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -2091,6 +2091,54 @@ Instruction *InstCombiner::visitCastInst(CastInst &CI) { return 0; } +/// GetSelectFoldableOperands - We want to turn code that looks like this: +/// %C = or %A, %B +/// %D = select %cond, %C, %A +/// into: +/// %C = select %cond, %B, 0 +/// %D = or %A, %C +/// +/// Assuming that the specified instruction is an operand to the select, return +/// a bitmask indicating which operands of this instruction are foldable if they +/// equal the other incoming value of the select. +/// +static unsigned GetSelectFoldableOperands(Instruction *I) { + switch (I->getOpcode()) { + case Instruction::Add: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + return 3; // Can fold through either operand. + case Instruction::Sub: // Can only fold on the amount subtracted. + case Instruction::Shl: // Can only fold on the shift amount. + case Instruction::Shr: + return 1; + default: + return 0; // Cannot fold + } +} + +/// GetSelectFoldableConstant - For the same transformation as the previous +/// function, return the identity constant that goes into the select. +static Constant *GetSelectFoldableConstant(Instruction *I) { + switch (I->getOpcode()) { + default: assert(0 && "This cannot happen!"); abort(); + case Instruction::Add: + case Instruction::Sub: + case Instruction::Or: + case Instruction::Xor: + return Constant::getNullValue(I->getType()); + case Instruction::Shl: + case Instruction::Shr: + return Constant::getNullValue(Type::UByteTy); + case Instruction::And: + return ConstantInt::getAllOnesValue(I->getType()); + case Instruction::Mul: + return ConstantInt::get(I->getType(), 1); + } +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -2150,6 +2198,66 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } + // See if we can fold the select into one of our operands. + if (SI.getType()->isInteger()) { + // See the comment above GetSelectFoldableOperands for a description of the + // transformation we are doing here. + if (Instruction *TVI = dyn_cast(TrueVal)) + if (TVI->hasOneUse() && TVI->getNumOperands() == 2 && + !isa(FalseVal)) + if (unsigned SFO = GetSelectFoldableOperands(TVI)) { + unsigned OpToFold = 0; + if ((SFO & 1) && FalseVal == TVI->getOperand(0)) { + OpToFold = 1; + } else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) { + OpToFold = 2; + } + + if (OpToFold) { + Constant *C = GetSelectFoldableConstant(TVI); + std::string Name = TVI->getName(); TVI->setName(""); + Instruction *NewSel = + new SelectInst(SI.getCondition(), TVI->getOperand(2-OpToFold), C, + Name); + InsertNewInstBefore(NewSel, SI); + if (BinaryOperator *BO = dyn_cast(TVI)) + return BinaryOperator::create(BO->getOpcode(), FalseVal, NewSel); + else if (ShiftInst *SI = dyn_cast(TVI)) + return new ShiftInst(SI->getOpcode(), FalseVal, NewSel); + else { + assert(0 && "Unknown instruction!!"); + } + } + } + + if (Instruction *FVI = dyn_cast(FalseVal)) + if (FVI->hasOneUse() && FVI->getNumOperands() == 2 && + !isa(TrueVal)) + if (unsigned SFO = GetSelectFoldableOperands(FVI)) { + unsigned OpToFold = 0; + if ((SFO & 1) && TrueVal == FVI->getOperand(0)) { + OpToFold = 1; + } else if ((SFO & 2) && TrueVal == FVI->getOperand(1)) { + OpToFold = 2; + } + + if (OpToFold) { + Constant *C = GetSelectFoldableConstant(FVI); + std::string Name = FVI->getName(); FVI->setName(""); + Instruction *NewSel = + new SelectInst(SI.getCondition(), C, FVI->getOperand(2-OpToFold), + Name); + InsertNewInstBefore(NewSel, SI); + if (BinaryOperator *BO = dyn_cast(FVI)) + return BinaryOperator::create(BO->getOpcode(), TrueVal, NewSel); + else if (ShiftInst *SI = dyn_cast(FVI)) + return new ShiftInst(SI->getOpcode(), TrueVal, NewSel); + else { + assert(0 && "Unknown instruction!!"); + } + } + } + } return 0; }