diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 7f1900d88c3..8650154c1a9 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -348,6 +348,14 @@ namespace { /// \return True if some memory operations were changed. bool MergeConsecutiveStores(StoreSDNode *N); + /// \brief Try to transform a truncation where C is a constant: + /// (trunc (and X, C)) -> (and (trunc X), (trunc C)) + /// + /// \p N needs to be a truncation and its first operand an AND. Other + /// requirements are checked by the function (e.g. that trunc is + /// single-use) and if missed an empty SDValue is returned. + SDValue distributeTruncateThroughAnd(SDNode *N); + public: DAGCombiner(SelectionDAG &D, AliasAnalysis &A, CodeGenOpt::Level OL) : DAG(D), TLI(D.getTargetLoweringInfo()), Level(BeforeLegalizeTypes), @@ -3806,6 +3814,28 @@ SDValue DAGCombiner::visitShiftByConstant(SDNode *N, unsigned Amt) { return DAG.getNode(LHS->getOpcode(), SDLoc(N), VT, NewShift, NewRHS); } +SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) { + assert(N->getOpcode() == ISD::TRUNCATE); + assert(N->getOperand(0).getOpcode() == ISD::AND); + + // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC) + if (N->hasOneUse() && N->getOperand(0).hasOneUse()) { + SDValue N01 = N->getOperand(0).getOperand(1); + + if (ConstantSDNode *N01C = dyn_cast(N01)) { + EVT TruncVT = N->getValueType(0); + SDValue N00 = N->getOperand(0).getOperand(0); + APInt TruncC = N01C->getAPIntValue(); + TruncC = TruncC.trunc(TruncVT.getScalarType().getSizeInBits()); + + return DAG.getNode(ISD::AND, SDLoc(N), TruncVT, + DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N00), + DAG.getConstant(TruncC, TruncVT)); + } + } + + return SDValue(); +} SDValue DAGCombiner::visitSHL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -3859,21 +3889,10 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { return DAG.getConstant(0, VT); // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))). if (N1.getOpcode() == ISD::TRUNCATE && - N1.getOperand(0).getOpcode() == ISD::AND && - N1.hasOneUse() && N1.getOperand(0).hasOneUse()) { - SDValue N101 = N1.getOperand(0).getOperand(1); - if (ConstantSDNode *N101C = dyn_cast(N101)) { - EVT TruncVT = N1.getValueType(); - SDValue N100 = N1.getOperand(0).getOperand(0); - APInt TruncC = N101C->getAPIntValue(); - TruncC = TruncC.trunc(TruncVT.getSizeInBits()); - return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, - DAG.getNode(ISD::AND, SDLoc(N), TruncVT, - DAG.getNode(ISD::TRUNCATE, - SDLoc(N), - TruncVT, N100), - DAG.getConstant(TruncC, TruncVT))); - } + N1.getOperand(0).getOpcode() == ISD::AND) { + SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()); + if (NewOp1.getNode()) + return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1); } if (N1C && SimplifyDemandedBits(SDValue(N, 0))) @@ -4073,22 +4092,10 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))). if (N1.getOpcode() == ISD::TRUNCATE && - N1.getOperand(0).getOpcode() == ISD::AND && - N1.hasOneUse() && N1.getOperand(0).hasOneUse()) { - SDValue N101 = N1.getOperand(0).getOperand(1); - if (ConstantSDNode *N101C = dyn_cast(N101)) { - EVT TruncVT = N1.getValueType(); - SDValue N100 = N1.getOperand(0).getOperand(0); - APInt TruncC = N101C->getAPIntValue(); - TruncC = TruncC.trunc(TruncVT.getScalarType().getSizeInBits()); - return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, - DAG.getNode(ISD::AND, SDLoc(N), - TruncVT, - DAG.getNode(ISD::TRUNCATE, - SDLoc(N), - TruncVT, N100), - DAG.getConstant(TruncC, TruncVT))); - } + N1.getOperand(0).getOpcode() == ISD::AND) { + SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()); + if (NewOp1.getNode()) + return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1); } // fold (sra (trunc (sr x, c1)), c2) -> (trunc (sra x, c1+c2)) @@ -4267,22 +4274,10 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))). if (N1.getOpcode() == ISD::TRUNCATE && - N1.getOperand(0).getOpcode() == ISD::AND && - N1.hasOneUse() && N1.getOperand(0).hasOneUse()) { - SDValue N101 = N1.getOperand(0).getOperand(1); - if (ConstantSDNode *N101C = dyn_cast(N101)) { - EVT TruncVT = N1.getValueType(); - SDValue N100 = N1.getOperand(0).getOperand(0); - APInt TruncC = N101C->getAPIntValue(); - TruncC = TruncC.trunc(TruncVT.getSizeInBits()); - return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, - DAG.getNode(ISD::AND, SDLoc(N), - TruncVT, - DAG.getNode(ISD::TRUNCATE, - SDLoc(N), - TruncVT, N100), - DAG.getConstant(TruncC, TruncVT))); - } + N1.getOperand(0).getOpcode() == ISD::AND) { + SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()); + if (NewOp1.getNode()) + return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, NewOp1); } // fold operands of srl based on knowledge that the low bits are not