diff --git a/lib/Target/ARM/ARMISelLowering.cpp b/lib/Target/ARM/ARMISelLowering.cpp index 190ca076dae..fd9abb74cfe 100644 --- a/lib/Target/ARM/ARMISelLowering.cpp +++ b/lib/Target/ARM/ARMISelLowering.cpp @@ -6983,20 +6983,25 @@ static inline bool isZeroOrAllOnes(SDValue N, bool AllOnes) { // Combine a constant select operand into its use: // -// (add (select cc, 0, c), x) -> (select cc, x, (add, x, c)) -// (sub x, (select cc, 0, c)) -> (select cc, x, (sub, x, c)) +// (add (select cc, 0, c), x) -> (select cc, x, (add, x, c)) +// (sub x, (select cc, 0, c)) -> (select cc, x, (sub, x, c)) +// (and (select cc, -1, c), x) -> (select cc, x, (and, x, c)) [AllOnes=1] +// (or (select cc, 0, c), x) -> (select cc, x, (or, x, c)) +// (xor (select cc, 0, c), x) -> (select cc, x, (xor, x, c)) // // The transform is rejected if the select doesn't have a constant operand that -// is null. +// is null, or all ones when AllOnes is set. // // @param N The node to transform. // @param Slct The N operand that is a select. // @param OtherOp The other N operand (x above). // @param DCI Context. +// @param AllOnes Require the select constant to be all ones instead of null. // @returns The new node, or SDValue() on failure. static SDValue combineSelectAndUse(SDNode *N, SDValue Slct, SDValue OtherOp, - TargetLowering::DAGCombinerInfo &DCI) { + TargetLowering::DAGCombinerInfo &DCI, + bool AllOnes = false) { SelectionDAG &DAG = DCI.DAG; const TargetLowering &TLI = DAG.getTargetLoweringInfo(); EVT VT = N->getValueType(0); @@ -7016,12 +7021,9 @@ SDValue combineSelectAndUse(SDNode *N, SDValue Slct, SDValue OtherOp, bool DoXform = false; bool InvCC = false; - assert ((Opc == ISD::ADD || (Opc == ISD::SUB && Slct == N->getOperand(1))) && - "Bad input!"); - - if (isZeroOrAllOnes(LHS, false)) { + if (isZeroOrAllOnes(LHS, AllOnes)) { DoXform = true; - } else if (CC != ISD::SETCC_INVALID && isZeroOrAllOnes(RHS, false)) { + } else if (CC != ISD::SETCC_INVALID && isZeroOrAllOnes(RHS, AllOnes)) { std::swap(LHS, RHS); SDValue Op0 = Slct.getOperand(0); EVT OpVT = isSlctCC ? Op0.getValueType() : Op0.getOperand(0).getValueType(); @@ -7050,6 +7052,25 @@ SDValue combineSelectAndUse(SDNode *N, SDValue Slct, SDValue OtherOp, CCOp, OtherOp, Result); } +// Attempt combineSelectAndUse on each operand of a commutative operator N. +static +SDValue combineSelectAndUseCommutative(SDNode *N, bool AllOnes, + TargetLowering::DAGCombinerInfo &DCI) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + if (N0.getOpcode() == ISD::SELECT && N0.getNode()->hasOneUse()) { + SDValue Result = combineSelectAndUse(N, N0, N1, DCI, AllOnes); + if (Result.getNode()) + return Result; + } + if (N1.getOpcode() == ISD::SELECT && N1.getNode()->hasOneUse()) { + SDValue Result = combineSelectAndUse(N, N1, N0, DCI, AllOnes); + if (Result.getNode()) + return Result; + } + return SDValue(); +} + // AddCombineToVPADDL- For pair-wise add on neon, use the vpaddl instruction // (only after legalization). static SDValue AddCombineToVPADDL(SDNode *N, SDValue N0, SDValue N1, @@ -7382,6 +7403,10 @@ static SDValue PerformANDCombine(SDNode *N, } if (!Subtarget->isThumb1Only()) { + // fold (and (select cc, -1, c), x) -> (select cc, x, (and, x, c)) + SDValue Result = combineSelectAndUseCommutative(N, true, DCI); + if (Result.getNode()) + return Result; // (and x, (cmov -1, y, cond)) => (and.cond x, y) SDValue CAND = formConditionalOp(N, DAG, true); if (CAND.getNode()) @@ -7425,6 +7450,10 @@ static SDValue PerformORCombine(SDNode *N, } if (!Subtarget->isThumb1Only()) { + // fold (or (select cc, 0, c), x) -> (select cc, x, (or, x, c)) + SDValue Result = combineSelectAndUseCommutative(N, false, DCI); + if (Result.getNode()) + return Result; // (or x, (cmov 0, y, cond)) => (or.cond x, y) SDValue COR = formConditionalOp(N, DAG, true); if (COR.getNode()) @@ -7593,6 +7622,10 @@ static SDValue PerformXORCombine(SDNode *N, return SDValue(); if (!Subtarget->isThumb1Only()) { + // fold (xor (select cc, 0, c), x) -> (select cc, x, (xor, x, c)) + SDValue Result = combineSelectAndUseCommutative(N, false, DCI); + if (Result.getNode()) + return Result; // (xor x, (cmov 0, y, cond)) => (xor.cond x, y) SDValue CXOR = formConditionalOp(N, DAG, true); if (CXOR.getNode()) diff --git a/test/CodeGen/ARM/select_xform.ll b/test/CodeGen/ARM/select_xform.ll index 7f653d55733..c4b07326dd6 100644 --- a/test/CodeGen/ARM/select_xform.ll +++ b/test/CodeGen/ARM/select_xform.ll @@ -33,12 +33,12 @@ define i32 @t2(i32 %a, i32 %b, i32 %c, i32 %d) nounwind { define i32 @t3(i32 %a, i32 %b, i32 %x, i32 %y) nounwind { ; ARM: t3: -; ARM: mvnlt r2, #0 -; ARM: and r0, r2, r3 +; ARM: andge r3, r3, r2 +; ARM: mov r0, r3 ; T2: t3: -; T2: movlt.w r2, #-1 -; T2: and.w r0, r2, r3 +; T2: andge.w r3, r3, r2 +; T2: mov r0, r3 %cond = icmp slt i32 %a, %b %z = select i1 %cond, i32 -1, i32 %x %s = and i32 %z, %y @@ -47,12 +47,12 @@ define i32 @t3(i32 %a, i32 %b, i32 %x, i32 %y) nounwind { define i32 @t4(i32 %a, i32 %b, i32 %x, i32 %y) nounwind { ; ARM: t4: -; ARM: movlt r2, #0 -; ARM: orr r0, r2, r3 +; ARM: orrge r3, r3, r2 +; ARM: mov r0, r3 ; T2: t4: -; T2: movlt r2, #0 -; T2: orr.w r0, r2, r3 +; T2: orrge.w r3, r3, r2 +; T2: mov r0, r3 %cond = icmp slt i32 %a, %b %z = select i1 %cond, i32 0, i32 %x %s = or i32 %z, %y