diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index ac137e23f8e..6210f0fcbc6 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2312,21 +2312,17 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), c)) // iff (trunc c) == c if (N1.getOpcode() == ISD::TRUNCATE && - N1.getOperand(0).getOpcode() == ISD::AND) { + N1.getOperand(0).getOpcode() == ISD::AND && + N1.hasOneUse() && N1.getOperand(0).hasOneUse()) { SDValue N101 = N1.getOperand(0).getOperand(1); - ConstantSDNode *N101C = dyn_cast(N101); - if (N101C) { + if (ConstantSDNode *N101C = dyn_cast(N101)) { MVT TruncVT = N1.getValueType(); - unsigned TruncBitSize = TruncVT.getSizeInBits(); - APInt ShAmt = N101C->getAPIntValue(); - if (ShAmt.trunc(TruncBitSize).getZExtValue() == N101C->getZExtValue()) { - SDValue N100 = N1.getOperand(0).getOperand(0); - return DAG.getNode(ISD::SHL, VT, N0, - DAG.getNode(ISD::AND, TruncVT, - DAG.getNode(ISD::TRUNCATE, TruncVT, N100), - DAG.getConstant(N101C->getZExtValue(), - TruncVT))); - } + SDValue N100 = N1.getOperand(0).getOperand(0); + return DAG.getNode(ISD::SHL, VT, N0, + DAG.getNode(ISD::AND, TruncVT, + DAG.getNode(ISD::TRUNCATE, TruncVT, N100), + DAG.getConstant(N101C->getZExtValue(), + TruncVT))); } } @@ -2444,21 +2440,17 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), c)) // iff (trunc c) == c if (N1.getOpcode() == ISD::TRUNCATE && - N1.getOperand(0).getOpcode() == ISD::AND) { + N1.getOperand(0).getOpcode() == ISD::AND && + N1.hasOneUse() && N1.getOperand(0).hasOneUse()) { SDValue N101 = N1.getOperand(0).getOperand(1); - ConstantSDNode *N101C = dyn_cast(N101); - if (N101C) { + if (ConstantSDNode *N101C = dyn_cast(N101)) { MVT TruncVT = N1.getValueType(); - unsigned TruncBitSize = TruncVT.getSizeInBits(); - APInt ShAmt = N101C->getAPIntValue(); - if (ShAmt.trunc(TruncBitSize).getZExtValue() == N101C->getZExtValue()) { - SDValue N100 = N1.getOperand(0).getOperand(0); - return DAG.getNode(ISD::SRA, VT, N0, - DAG.getNode(ISD::AND, TruncVT, - DAG.getNode(ISD::TRUNCATE, TruncVT, N100), - DAG.getConstant(N101C->getZExtValue(), - TruncVT))); - } + SDValue N100 = N1.getOperand(0).getOperand(0); + return DAG.getNode(ISD::SRA, VT, N0, + DAG.getNode(ISD::AND, TruncVT, + DAG.getNode(ISD::TRUNCATE, TruncVT, N100), + DAG.getConstant(N101C->getZExtValue(), + TruncVT))); } } @@ -2565,21 +2557,17 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), c)) // iff (trunc c) == c if (N1.getOpcode() == ISD::TRUNCATE && - N1.getOperand(0).getOpcode() == ISD::AND) { + N1.getOperand(0).getOpcode() == ISD::AND && + N1.hasOneUse() && N1.getOperand(0).hasOneUse()) { SDValue N101 = N1.getOperand(0).getOperand(1); - ConstantSDNode *N101C = dyn_cast(N101); - if (N101C) { + if (ConstantSDNode *N101C = dyn_cast(N101)) { MVT TruncVT = N1.getValueType(); - unsigned TruncBitSize = TruncVT.getSizeInBits(); - APInt ShAmt = N101C->getAPIntValue(); - if (ShAmt.trunc(TruncBitSize).getZExtValue() == N101C->getZExtValue()) { - SDValue N100 = N1.getOperand(0).getOperand(0); - return DAG.getNode(ISD::SRL, VT, N0, - DAG.getNode(ISD::AND, TruncVT, - DAG.getNode(ISD::TRUNCATE, TruncVT, N100), - DAG.getConstant(N101C->getZExtValue(), - TruncVT))); - } + SDValue N100 = N1.getOperand(0).getOperand(0); + return DAG.getNode(ISD::SRL, VT, N0, + DAG.getNode(ISD::AND, TruncVT, + DAG.getNode(ISD::TRUNCATE, TruncVT, N100), + DAG.getConstant(N101C->getZExtValue(), + TruncVT))); } }