diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 75a0b43bb69..2a451ce5f12 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -278,6 +278,49 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, return 0; } +/// SimplifyWithOpReplaced - See if V simplifies when its operand Op is +/// replaced with RepOp. +static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, + const TargetData *TD) { + // Trivial replacement. + if (V == Op) + return RepOp; + + Instruction *I = dyn_cast(V); + if (!I) + return 0; + + // If this is a binary operator, try to simplify it with the replaced op. + if (BinaryOperator *B = dyn_cast(I)) { + if (B->getOperand(0) == Op) + return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), TD); + if (B->getOperand(1) == Op) + return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, TD); + } + + // If all operands are constant after substituting Op for RepOp then we can + // constant fold the instruction. + if (Constant *CRepOp = dyn_cast(RepOp)) { + // Build a list of all constant operands. + SmallVector ConstOps; + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + if (I->getOperand(i) == Op) + ConstOps.push_back(CRepOp); + else if (Constant *COp = dyn_cast(I->getOperand(i))) + ConstOps.push_back(COp); + else + break; + } + + // All operands were constants, fold it. + if (ConstOps.size() == I->getNumOperands()) + return ConstantFoldInstOperands(I->getOpcode(), I->getType(), + ConstOps.data(), ConstOps.size(), TD); + } + + return 0; +} + /// visitSelectInstWithICmp - Visit a SelectInst that has an /// ICmpInst as its first operand. /// @@ -416,25 +459,21 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, } } - if (CmpLHS == TrueVal && CmpRHS == FalseVal) { - // Transform (X == Y) ? X : Y -> Y - if (Pred == ICmpInst::ICMP_EQ) + // If we have an equality comparison then we know the value in one of the + // arms of the select. See if substituting this value into the arm and + // simplifying the result yields the same value as the other arm. + if (Pred == ICmpInst::ICMP_EQ) { + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TD) == TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TD) == TrueVal) return ReplaceInstUsesWith(SI, FalseVal); - // Transform (X != Y) ? X : Y -> X - if (Pred == ICmpInst::ICMP_NE) + } else if (Pred == ICmpInst::ICMP_NE) { + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TD) == FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TD) == FalseVal) return ReplaceInstUsesWith(SI, TrueVal); - /// NOTE: if we wanted to, this is where to detect integer MIN/MAX - - } else if (CmpLHS == FalseVal && CmpRHS == TrueVal) { - // Transform (X == Y) ? Y : X -> X - if (Pred == ICmpInst::ICMP_EQ) - return ReplaceInstUsesWith(SI, FalseVal); - // Transform (X != Y) ? Y : X -> Y - if (Pred == ICmpInst::ICMP_NE) - return ReplaceInstUsesWith(SI, TrueVal); - /// NOTE: if we wanted to, this is where to detect integer MIN/MAX } + // NOTE: if we wanted to, this is where to detect integer MIN/MAX + if (isa(CmpRHS)) { if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) { // Transform (X == C) ? X : Y -> (X == C) ? C : Y diff --git a/test/Transforms/InstCombine/select.ll b/test/Transforms/InstCombine/select.ll index 39259078b87..379228512cd 100644 --- a/test/Transforms/InstCombine/select.ll +++ b/test/Transforms/InstCombine/select.ll @@ -749,3 +749,43 @@ define i1 @test55(i1 %X, i32 %Y, i32 %Z) { ; CHECK: icmp eq ; CHECK: ret i1 } + +define i32 @test56(i16 %x) nounwind { + %tobool = icmp eq i16 %x, 0 + %conv = zext i16 %x to i32 + %cond = select i1 %tobool, i32 0, i32 %conv + ret i32 %cond +; CHECK: @test56 +; CHECK-NEXT: zext +; CHECK-NEXT: ret +} + +define i32 @test57(i32 %x, i32 %y) nounwind { + %and = and i32 %x, %y + %tobool = icmp eq i32 %x, 0 + %.and = select i1 %tobool, i32 0, i32 %and + ret i32 %.and +; CHECK: @test57 +; CHECK-NEXT: and i32 %x, %y +; CHECK-NEXT: ret +} + +define i32 @test58(i16 %x) nounwind { + %tobool = icmp ne i16 %x, 1 + %conv = zext i16 %x to i32 + %cond = select i1 %tobool, i32 %conv, i32 1 + ret i32 %cond +; CHECK: @test58 +; CHECK-NEXT: zext +; CHECK-NEXT: ret +} + +define i32 @test59(i32 %x, i32 %y) nounwind { + %and = and i32 %x, %y + %tobool = icmp ne i32 %x, %y + %.and = select i1 %tobool, i32 %and, i32 %y + ret i32 %.and +; CHECK: @test59 +; CHECK-NEXT: and i32 %x, %y +; CHECK-NEXT: ret +}