From 94bdb453a40e53e4318380b5a262ce3c324d10ce Mon Sep 17 00:00:00 2001 From: Matt Arsenault Date: Mon, 17 Mar 2014 18:58:01 +0000 Subject: [PATCH] Make DAGCombiner work on vector bitshifts with constant splat vectors. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@204071 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/CodeGen/SelectionDAGNodes.h | 5 + lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 298 +++++++++++---------- lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 13 + test/CodeGen/AArch64/neon-shl-ashr-lshr.ll | 6 +- test/CodeGen/X86/avx2-vector-shifts.ll | 27 +- test/CodeGen/X86/sse2-vector-shifts.ll | 152 +++++++++-- 6 files changed, 339 insertions(+), 162 deletions(-) diff --git a/include/llvm/CodeGen/SelectionDAGNodes.h b/include/llvm/CodeGen/SelectionDAGNodes.h index 0b18d1d358c..a6c72ca2d1c 100644 --- a/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1522,6 +1522,11 @@ public: unsigned MinSplatBits = 0, bool isBigEndian = false) const; + /// isConstantSplat - Simpler form of isConstantSplat. Get the constant splat + /// when you only care about the value. Returns nullptr if this isn't a + /// constant splat vector. + ConstantSDNode *isConstantSplat() const; + bool isConstant() const; static inline bool classof(const SDNode *N) { diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 842d5a3ee73..c45d6a1a790 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -280,7 +280,7 @@ namespace { SDValue XformToShuffleWithZero(SDNode *N); SDValue ReassociateOps(unsigned Opc, SDLoc DL, SDValue LHS, SDValue RHS); - SDValue visitShiftByConstant(SDNode *N, unsigned Amt); + SDValue visitShiftByConstant(SDNode *N, ConstantSDNode *Amt); bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS); SDValue SimplifyBinOpWithSameOpcodeHands(SDNode *N); @@ -634,7 +634,23 @@ static bool isOneUseSetCC(SDValue N) { return false; } -// \brief Returns the SDNode if it is a constant BuildVector or constant int. +/// isConstantSplatVector - Returns true if N is a BUILD_VECTOR node whose +/// elements are all the same constant or undefined. +static bool isConstantSplatVector(SDNode *N, APInt& SplatValue) { + BuildVectorSDNode *C = dyn_cast(N); + if (!C) + return false; + + APInt SplatUndef; + unsigned SplatBitSize; + bool HasAnyUndefs; + EVT EltVT = N->getValueType(0).getVectorElementType(); + return (C->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, + HasAnyUndefs) && + EltVT.getSizeInBits() >= SplatBitSize); +} + +// \brief Returns the SDNode if it is a constant BuildVector or constant. static SDNode *isConstantBuildVectorOrConstantInt(SDValue N) { if (isa(N)) return N.getNode(); @@ -644,6 +660,18 @@ static SDNode *isConstantBuildVectorOrConstantInt(SDValue N) { return NULL; } +// \brief Returns the SDNode if it is a constant splat BuildVector or constant +// int. +static ConstantSDNode *isConstOrConstSplat(SDValue N) { + if (ConstantSDNode *CN = dyn_cast(N)) + return CN; + + if (BuildVectorSDNode *BV = dyn_cast(N)) + return BV->isConstantSplat(); + + return nullptr; +} + SDValue DAGCombiner::ReassociateOps(unsigned Opc, SDLoc DL, SDValue N0, SDValue N1) { EVT VT = N0.getValueType(); @@ -1830,22 +1858,6 @@ SDValue DAGCombiner::visitSUBE(SDNode *N) { return SDValue(); } -/// isConstantSplatVector - Returns true if N is a BUILD_VECTOR node whose -/// elements are all the same constant or undefined. -static bool isConstantSplatVector(SDNode *N, APInt& SplatValue) { - BuildVectorSDNode *C = dyn_cast(N); - if (!C) - return false; - - APInt SplatUndef; - unsigned SplatBitSize; - bool HasAnyUndefs; - EVT EltVT = N->getValueType(0).getVectorElementType(); - return (C->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, - HasAnyUndefs) && - EltVT.getSizeInBits() >= SplatBitSize); -} - SDValue DAGCombiner::visitMUL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -3805,11 +3817,9 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { /// visitShiftByConstant - Handle transforms common to the three shifts, when /// the shift amount is a constant. -SDValue DAGCombiner::visitShiftByConstant(SDNode *N, unsigned Amt) { - assert(isa(N->getOperand(1)) && - "Expected an ConstantSDNode operand."); +SDValue DAGCombiner::visitShiftByConstant(SDNode *N, ConstantSDNode *Amt) { // We can't and shouldn't fold opaque constants. - if (cast(N->getOperand(1))->isOpaque()) + if (Amt->isOpaque()) return SDValue(); SDNode *LHS = N->getOperand(0).getNode(); @@ -3888,11 +3898,11 @@ SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) { if (N->hasOneUse() && N->getOperand(0).hasOneUse()) { SDValue N01 = N->getOperand(0).getOperand(1); - if (ConstantSDNode *N01C = dyn_cast(N01)) { + if (ConstantSDNode *N01C = isConstOrConstSplat(N01)) { EVT TruncVT = N->getValueType(0); SDValue N00 = N->getOperand(0).getOperand(0); APInt TruncC = N01C->getAPIntValue(); - TruncC = TruncC.trunc(TruncVT.getScalarType().getSizeInBits()); + TruncC = TruncC.trunc(TruncVT.getScalarSizeInBits()); return DAG.getNode(ISD::AND, SDLoc(N), TruncVT, DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N00), @@ -3921,7 +3931,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { ConstantSDNode *N0C = dyn_cast(N0); ConstantSDNode *N1C = dyn_cast(N1); EVT VT = N0.getValueType(); - unsigned OpSizeInBits = VT.getScalarType().getSizeInBits(); + unsigned OpSizeInBits = VT.getScalarSizeInBits(); // fold vector ops if (VT.isVector()) { @@ -3931,18 +3941,21 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { BuildVectorSDNode *N1CV = dyn_cast(N1); // If setcc produces all-one true value then: // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<isConstant() && - TLI.getBooleanContents(true) == - TargetLowering::ZeroOrNegativeOneBooleanContent && - N0.getOpcode() == ISD::AND) { - SDValue N00 = N0->getOperand(0); - SDValue N01 = N0->getOperand(1); - BuildVectorSDNode *N01CV = dyn_cast(N01); + if (N1CV && N1CV->isConstant()) { + if (N0.getOpcode() == ISD::AND && + TLI.getBooleanContents(true) == + TargetLowering::ZeroOrNegativeOneBooleanContent) { + SDValue N00 = N0->getOperand(0); + SDValue N01 = N0->getOperand(1); + BuildVectorSDNode *N01CV = dyn_cast(N01); - if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC) { - SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, VT, N01CV, N1CV); - if (C.getNode()) - return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C); + if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC) { + SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, VT, N01CV, N1CV); + if (C.getNode()) + return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C); + } + } else { + N1C = isConstOrConstSplat(N1); } } } @@ -3978,14 +3991,15 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { return SDValue(N, 0); // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2)) - if (N1C && N0.getOpcode() == ISD::SHL && - N0.getOperand(1).getOpcode() == ISD::Constant) { - uint64_t c1 = cast(N0.getOperand(1))->getZExtValue(); - uint64_t c2 = N1C->getZExtValue(); - if (c1 + c2 >= OpSizeInBits) - return DAG.getConstant(0, VT); - return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0.getOperand(0), - DAG.getConstant(c1 + c2, N1.getValueType())); + if (N1C && N0.getOpcode() == ISD::SHL) { + if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { + uint64_t c1 = N0C1->getZExtValue(); + uint64_t c2 = N1C->getZExtValue(); + if (c1 + c2 >= OpSizeInBits) + return DAG.getConstant(0, VT); + return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0.getOperand(0), + DAG.getConstant(c1 + c2, N1.getValueType())); + } } // fold (shl (ext (shl x, c1)), c2) -> (ext (shl x, (add c1, c2))) @@ -3996,20 +4010,21 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { if (N1C && (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND || N0.getOpcode() == ISD::SIGN_EXTEND) && - N0.getOperand(0).getOpcode() == ISD::SHL && - isa(N0.getOperand(0)->getOperand(1))) { - uint64_t c1 = - cast(N0.getOperand(0)->getOperand(1))->getZExtValue(); - uint64_t c2 = N1C->getZExtValue(); - EVT InnerShiftVT = N0.getOperand(0).getValueType(); - uint64_t InnerShiftSize = InnerShiftVT.getScalarType().getSizeInBits(); - if (c2 >= OpSizeInBits - InnerShiftSize) { - if (c1 + c2 >= OpSizeInBits) - return DAG.getConstant(0, VT); - return DAG.getNode(ISD::SHL, SDLoc(N0), VT, - DAG.getNode(N0.getOpcode(), SDLoc(N0), VT, - N0.getOperand(0)->getOperand(0)), - DAG.getConstant(c1 + c2, N1.getValueType())); + N0.getOperand(0).getOpcode() == ISD::SHL) { + SDValue N0Op0 = N0.getOperand(0); + if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) { + uint64_t c1 = N0Op0C1->getZExtValue(); + uint64_t c2 = N1C->getZExtValue(); + EVT InnerShiftVT = N0Op0.getValueType(); + uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits(); + if (c2 >= OpSizeInBits - InnerShiftSize) { + if (c1 + c2 >= OpSizeInBits) + return DAG.getConstant(0, VT); + return DAG.getNode(ISD::SHL, SDLoc(N0), VT, + DAG.getNode(N0.getOpcode(), SDLoc(N0), VT, + N0Op0->getOperand(0)), + DAG.getConstant(c1 + c2, N1.getValueType())); + } } } @@ -4017,19 +4032,20 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // Only fold this if the inner zext has no other uses to avoid increasing // the total number of instructions. if (N1C && N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() && - N0.getOperand(0).getOpcode() == ISD::SRL && - isa(N0.getOperand(0)->getOperand(1))) { - uint64_t c1 = - cast(N0.getOperand(0)->getOperand(1))->getZExtValue(); - if (c1 < VT.getSizeInBits()) { - uint64_t c2 = N1C->getZExtValue(); - if (c1 == c2) { - SDValue NewOp0 = N0.getOperand(0); - EVT CountVT = NewOp0.getOperand(1).getValueType(); - SDValue NewSHL = DAG.getNode(ISD::SHL, SDLoc(N), NewOp0.getValueType(), - NewOp0, DAG.getConstant(c2, CountVT)); - AddToWorkList(NewSHL.getNode()); - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL); + N0.getOperand(0).getOpcode() == ISD::SRL) { + SDValue N0Op0 = N0.getOperand(0); + if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) { + uint64_t c1 = N0Op0C1->getZExtValue(); + if (c1 < VT.getScalarSizeInBits()) { + uint64_t c2 = N1C->getZExtValue(); + if (c1 == c2) { + SDValue NewOp0 = N0.getOperand(0); + EVT CountVT = NewOp0.getOperand(1).getValueType(); + SDValue NewSHL = DAG.getNode(ISD::SHL, SDLoc(N), NewOp0.getValueType(), + NewOp0, DAG.getConstant(c2, CountVT)); + AddToWorkList(NewSHL.getNode()); + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL); + } } } } @@ -4038,40 +4054,39 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // (and (srl x, (sub c1, c2), MASK) // Only fold this if the inner shift has no other uses -- if it does, folding // this will increase the total number of instructions. - if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse() && - N0.getOperand(1).getOpcode() == ISD::Constant) { - uint64_t c1 = cast(N0.getOperand(1))->getZExtValue(); - if (c1 < VT.getSizeInBits()) { - uint64_t c2 = N1C->getZExtValue(); - APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(), - VT.getSizeInBits() - c1); - SDValue Shift; - if (c2 > c1) { - Mask = Mask.shl(c2-c1); - Shift = DAG.getNode(ISD::SHL, SDLoc(N), VT, N0.getOperand(0), - DAG.getConstant(c2-c1, N1.getValueType())); - } else { - Mask = Mask.lshr(c1-c2); - Shift = DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), - DAG.getConstant(c1-c2, N1.getValueType())); + if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse()) { + if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { + uint64_t c1 = N0C1->getZExtValue(); + if (c1 < OpSizeInBits) { + uint64_t c2 = N1C->getZExtValue(); + APInt Mask = APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - c1); + SDValue Shift; + if (c2 > c1) { + Mask = Mask.shl(c2 - c1); + Shift = DAG.getNode(ISD::SHL, SDLoc(N), VT, N0.getOperand(0), + DAG.getConstant(c2 - c1, N1.getValueType())); + } else { + Mask = Mask.lshr(c1 - c2); + Shift = DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), + DAG.getConstant(c1 - c2, N1.getValueType())); + } + return DAG.getNode(ISD::AND, SDLoc(N0), VT, Shift, + DAG.getConstant(Mask, VT)); } - return DAG.getNode(ISD::AND, SDLoc(N0), VT, Shift, - DAG.getConstant(Mask, VT)); } } // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1)) if (N1C && N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1)) { + unsigned BitSize = VT.getScalarSizeInBits(); SDValue HiBitsMask = - DAG.getConstant(APInt::getHighBitsSet(VT.getSizeInBits(), - VT.getSizeInBits() - - N1C->getZExtValue()), - VT); + DAG.getConstant(APInt::getHighBitsSet(BitSize, + BitSize - N1C->getZExtValue()), VT); return DAG.getNode(ISD::AND, SDLoc(N), VT, N0.getOperand(0), HiBitsMask); } if (N1C) { - SDValue NewSHL = visitShiftByConstant(N, N1C->getZExtValue()); + SDValue NewSHL = visitShiftByConstant(N, N1C); if (NewSHL.getNode()) return NewSHL; } @@ -4091,6 +4106,8 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { if (VT.isVector()) { SDValue FoldedVOp = SimplifyVBinOp(N); if (FoldedVOp.getNode()) return FoldedVOp; + + N1C = isConstOrConstSplat(N1); } // fold (sra c1, c2) -> (sra c1, c2) @@ -4124,11 +4141,12 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2)) if (N1C && N0.getOpcode() == ISD::SRA) { - if (ConstantSDNode *C1 = dyn_cast(N0.getOperand(1))) { + if (ConstantSDNode *C1 = isConstOrConstSplat(N0.getOperand(1))) { unsigned Sum = N1C->getZExtValue() + C1->getZExtValue(); - if (Sum >= OpSizeInBits) Sum = OpSizeInBits-1; + if (Sum >= OpSizeInBits) + Sum = OpSizeInBits - 1; return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0), - DAG.getConstant(Sum, N1C->getValueType(0))); + DAG.getConstant(Sum, N1.getValueType())); } } @@ -4137,14 +4155,17 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { // result_size - n != m. // If truncate is free for the target sext(shl) is likely to result in better // code. - if (N0.getOpcode() == ISD::SHL) { + if (N0.getOpcode() == ISD::SHL && N1C) { // Get the two constanst of the shifts, CN0 = m, CN = n. - const ConstantSDNode *N01C = dyn_cast(N0.getOperand(1)); - if (N01C && N1C) { + const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1)); + if (N01C) { + LLVMContext &Ctx = *DAG.getContext(); // Determine what the truncate's result bitsize and type would be. - EVT TruncVT = - EVT::getIntegerVT(*DAG.getContext(), - OpSizeInBits - N1C->getZExtValue()); + EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue()); + + if (VT.isVector()) + TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorNumElements()); + // Determine the residual right-shift amount. signed ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue(); @@ -4177,26 +4198,27 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1); } - // fold (sra (trunc (sr x, c1)), c2) -> (trunc (sra x, c1+c2)) + // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2)) // if c1 is equal to the number of bits the trunc removes if (N0.getOpcode() == ISD::TRUNCATE && (N0.getOperand(0).getOpcode() == ISD::SRL || N0.getOperand(0).getOpcode() == ISD::SRA) && N0.getOperand(0).hasOneUse() && N0.getOperand(0).getOperand(1).hasOneUse() && - N1C && isa(N0.getOperand(0).getOperand(1))) { - EVT LargeVT = N0.getOperand(0).getValueType(); - ConstantSDNode *LargeShiftAmt = - cast(N0.getOperand(0).getOperand(1)); + N1C) { + SDValue N0Op0 = N0.getOperand(0); + if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) { + unsigned LargeShiftVal = LargeShift->getZExtValue(); + EVT LargeVT = N0Op0.getValueType(); - if (LargeVT.getScalarType().getSizeInBits() - OpSizeInBits == - LargeShiftAmt->getZExtValue()) { - SDValue Amt = - DAG.getConstant(LargeShiftAmt->getZExtValue() + N1C->getZExtValue(), - getShiftAmountTy(N0.getOperand(0).getOperand(0).getValueType())); - SDValue SRA = DAG.getNode(ISD::SRA, SDLoc(N), LargeVT, - N0.getOperand(0).getOperand(0), Amt); - return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, SRA); + if (LargeVT.getScalarSizeInBits() - OpSizeInBits == LargeShiftVal) { + SDValue Amt = + DAG.getConstant(LargeShiftVal + N1C->getZExtValue(), + getShiftAmountTy(N0Op0.getOperand(0).getValueType())); + SDValue SRA = DAG.getNode(ISD::SRA, SDLoc(N), LargeVT, + N0Op0.getOperand(0), Amt); + return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, SRA); + } } } @@ -4210,7 +4232,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1); if (N1C) { - SDValue NewSRA = visitShiftByConstant(N, N1C->getZExtValue()); + SDValue NewSRA = visitShiftByConstant(N, N1C); if (NewSRA.getNode()) return NewSRA; } @@ -4230,6 +4252,8 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { if (VT.isVector()) { SDValue FoldedVOp = SimplifyVBinOp(N); if (FoldedVOp.getNode()) return FoldedVOp; + + N1C = isConstOrConstSplat(N1); } // fold (srl c1, c2) -> c1 >>u c2 @@ -4250,14 +4274,15 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { return DAG.getConstant(0, VT); // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2)) - if (N1C && N0.getOpcode() == ISD::SRL && - N0.getOperand(1).getOpcode() == ISD::Constant) { - uint64_t c1 = cast(N0.getOperand(1))->getZExtValue(); - uint64_t c2 = N1C->getZExtValue(); - if (c1 + c2 >= OpSizeInBits) - return DAG.getConstant(0, VT); - return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), - DAG.getConstant(c1 + c2, N1.getValueType())); + if (N1C && N0.getOpcode() == ISD::SRL) { + if (ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1))) { + uint64_t c1 = N01C->getZExtValue(); + uint64_t c2 = N1C->getZExtValue(); + if (c1 + c2 >= OpSizeInBits) + return DAG.getConstant(0, VT); + return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), + DAG.getConstant(c1 + c2, N1.getValueType())); + } } // fold (srl (trunc (srl x, c1)), c2) -> 0 or (trunc (srl x, (add c1, c2))) @@ -4282,18 +4307,21 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { } // fold (srl (shl x, c), c) -> (and x, cst2) - if (N1C && N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 && - N0.getValueSizeInBits() <= 64) { - uint64_t ShAmt = N1C->getZExtValue()+64-N0.getValueSizeInBits(); - return DAG.getNode(ISD::AND, SDLoc(N), VT, N0.getOperand(0), - DAG.getConstant(~0ULL >> ShAmt, VT)); + if (N1C && N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1) { + unsigned BitSize = N0.getScalarValueSizeInBits(); + if (BitSize <= 64) { + uint64_t ShAmt = N1C->getZExtValue() + 64 - BitSize; + return DAG.getNode(ISD::AND, SDLoc(N), VT, N0.getOperand(0), + DAG.getConstant(~0ULL >> ShAmt, VT)); + } } // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask) if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) { // Shifting in all undef bits? EVT SmallVT = N0.getOperand(0).getValueType(); - if (N1C->getZExtValue() >= SmallVT.getSizeInBits()) + unsigned BitSize = SmallVT.getScalarSizeInBits(); + if (N1C->getZExtValue() >= BitSize) return DAG.getUNDEF(VT); if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) { @@ -4302,7 +4330,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { N0.getOperand(0), DAG.getConstant(ShiftAmt, getShiftAmountTy(SmallVT))); AddToWorkList(SmallShift.getNode()); - APInt Mask = APInt::getAllOnesValue(VT.getSizeInBits()).lshr(ShiftAmt); + APInt Mask = APInt::getAllOnesValue(OpSizeInBits).lshr(ShiftAmt); return DAG.getNode(ISD::AND, SDLoc(N), VT, DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, SmallShift), DAG.getConstant(Mask, VT)); @@ -4311,14 +4339,14 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { // fold (srl (sra X, Y), 31) -> (srl X, 31). This srl only looks at the sign // bit, which is unmodified by sra. - if (N1C && N1C->getZExtValue() + 1 == VT.getSizeInBits()) { + if (N1C && N1C->getZExtValue() + 1 == OpSizeInBits) { if (N0.getOpcode() == ISD::SRA) return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1); } // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit). if (N1C && N0.getOpcode() == ISD::CTLZ && - N1C->getAPIntValue() == Log2_32(VT.getSizeInBits())) { + N1C->getAPIntValue() == Log2_32(OpSizeInBits)) { APInt KnownZero, KnownOne; DAG.ComputeMaskedBits(N0.getOperand(0), KnownZero, KnownOne); @@ -4365,7 +4393,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { return SDValue(N, 0); if (N1C) { - SDValue NewSRL = visitShiftByConstant(N, N1C->getZExtValue()); + SDValue NewSRL = visitShiftByConstant(N, N1C); if (NewSRL.getNode()) return NewSRL; } diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 43a02fe9c7e..df8d423ab22 100644 --- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -6573,6 +6573,19 @@ bool BuildVectorSDNode::isConstantSplat(APInt &SplatValue, return true; } +ConstantSDNode *BuildVectorSDNode::isConstantSplat() const { + SDValue Op0 = getOperand(0); + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + SDValue Opi = getOperand(i); + unsigned Opc = Opi.getOpcode(); + if ((Opc != ISD::UNDEF && Opc != ISD::Constant && Opc != ISD::ConstantFP) || + Opi != Op0) + return nullptr; + } + + return cast(Op0); +} + bool BuildVectorSDNode::isConstant() const { for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { unsigned Opc = getOperand(i).getOpcode(); diff --git a/test/CodeGen/AArch64/neon-shl-ashr-lshr.ll b/test/CodeGen/AArch64/neon-shl-ashr-lshr.ll index bd52fbde42b..0b520d7ac84 100644 --- a/test/CodeGen/AArch64/neon-shl-ashr-lshr.ll +++ b/test/CodeGen/AArch64/neon-shl-ashr-lshr.ll @@ -186,14 +186,14 @@ define <2 x i64> @ashr.v2i64(<2 x i64> %a, <2 x i64> %b) { define <1 x i64> @shl.v1i64.0(<1 x i64> %a) { ; CHECK-LABEL: shl.v1i64.0: -; CHECK: shl d{{[0-9]+}}, d{{[0-9]+}}, #0 +; CHECK-NOT: shl d{{[0-9]+}}, d{{[0-9]+}}, #0 %c = shl <1 x i64> %a, zeroinitializer ret <1 x i64> %c } define <2 x i32> @shl.v2i32.0(<2 x i32> %a) { ; CHECK-LABEL: shl.v2i32.0: -; CHECK: shl v{{[0-9]+}}.2s, v{{[0-9]+}}.2s, #0 +; CHECK-NOT: shl v{{[0-9]+}}.2s, v{{[0-9]+}}.2s, #0 %c = shl <2 x i32> %a, zeroinitializer ret <2 x i32> %c } @@ -285,7 +285,7 @@ define <1 x i16> @shl.v1i16.imm(<1 x i16> %a) { define <1 x i32> @shl.v1i32.imm(<1 x i32> %a) { ; CHECK-LABEL: shl.v1i32.imm: -; CHECK: shl v{{[0-9]+}}.2s, v{{[0-9]+}}.2s, #0 +; CHECK-NOT: shl v{{[0-9]+}}.2s, v{{[0-9]+}}.2s, #0 %c = shl <1 x i32> %a, zeroinitializer ret <1 x i32> %c } diff --git a/test/CodeGen/X86/avx2-vector-shifts.ll b/test/CodeGen/X86/avx2-vector-shifts.ll index 4868e4b4797..4ae2905ef22 100644 --- a/test/CodeGen/X86/avx2-vector-shifts.ll +++ b/test/CodeGen/X86/avx2-vector-shifts.ll @@ -9,7 +9,7 @@ entry: } ; CHECK-LABEL: test_sllw_1: -; CHECK: vpsllw $0, %ymm0, %ymm0 +; CHECK-NOT: vpsllw $0, %ymm0, %ymm0 ; CHECK: ret define <16 x i16> @test_sllw_2(<16 x i16> %InVec) { @@ -39,7 +39,7 @@ entry: } ; CHECK-LABEL: test_slld_1: -; CHECK: vpslld $0, %ymm0, %ymm0 +; CHECK-NOT: vpslld $0, %ymm0, %ymm0 ; CHECK: ret define <8 x i32> @test_slld_2(<8 x i32> %InVec) { @@ -69,7 +69,7 @@ entry: } ; CHECK-LABEL: test_sllq_1: -; CHECK: vpsllq $0, %ymm0, %ymm0 +; CHECK-NOT: vpsllq $0, %ymm0, %ymm0 ; CHECK: ret define <4 x i64> @test_sllq_2(<4 x i64> %InVec) { @@ -101,7 +101,7 @@ entry: } ; CHECK-LABEL: test_sraw_1: -; CHECK: vpsraw $0, %ymm0, %ymm0 +; CHECK-NOT: vpsraw $0, %ymm0, %ymm0 ; CHECK: ret define <16 x i16> @test_sraw_2(<16 x i16> %InVec) { @@ -131,7 +131,7 @@ entry: } ; CHECK-LABEL: test_srad_1: -; CHECK: vpsrad $0, %ymm0, %ymm0 +; CHECK-NOT: vpsrad $0, %ymm0, %ymm0 ; CHECK: ret define <8 x i32> @test_srad_2(<8 x i32> %InVec) { @@ -163,7 +163,7 @@ entry: } ; CHECK-LABEL: test_srlw_1: -; CHECK: vpsrlw $0, %ymm0, %ymm0 +; CHECK-NOT: vpsrlw $0, %ymm0, %ymm0 ; CHECK: ret define <16 x i16> @test_srlw_2(<16 x i16> %InVec) { @@ -193,7 +193,7 @@ entry: } ; CHECK-LABEL: test_srld_1: -; CHECK: vpsrld $0, %ymm0, %ymm0 +; CHECK-NOT: vpsrld $0, %ymm0, %ymm0 ; CHECK: ret define <8 x i32> @test_srld_2(<8 x i32> %InVec) { @@ -223,7 +223,7 @@ entry: } ; CHECK-LABEL: test_srlq_1: -; CHECK: vpsrlq $0, %ymm0, %ymm0 +; CHECK-NOT: vpsrlq $0, %ymm0, %ymm0 ; CHECK: ret define <4 x i64> @test_srlq_2(<4 x i64> %InVec) { @@ -245,3 +245,14 @@ entry: ; CHECK-LABEL: test_srlq_3: ; CHECK: vpsrlq $63, %ymm0, %ymm0 ; CHECK: ret + +; CHECK-LABEL: @srl_trunc_and_v4i64 +; CHECK: vpand +; CHECK-NEXT: vpsrlvd +; CHECK: ret +define <4 x i32> @srl_trunc_and_v4i64(<4 x i32> %x, <4 x i64> %y) nounwind { + %and = and <4 x i64> %y, + %trunc = trunc <4 x i64> %and to <4 x i32> + %sra = lshr <4 x i32> %x, %trunc + ret <4 x i32> %sra +} diff --git a/test/CodeGen/X86/sse2-vector-shifts.ll b/test/CodeGen/X86/sse2-vector-shifts.ll index 47a01ff2583..7c8d5e57889 100644 --- a/test/CodeGen/X86/sse2-vector-shifts.ll +++ b/test/CodeGen/X86/sse2-vector-shifts.ll @@ -9,8 +9,8 @@ entry: } ; CHECK-LABEL: test_sllw_1: -; CHECK: psllw $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psllw $0, %xmm0 +; CHECK: ret define <8 x i16> @test_sllw_2(<8 x i16> %InVec) { entry: @@ -39,8 +39,8 @@ entry: } ; CHECK-LABEL: test_slld_1: -; CHECK: pslld $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: pslld $0, %xmm0 +; CHECK: ret define <4 x i32> @test_slld_2(<4 x i32> %InVec) { entry: @@ -69,8 +69,8 @@ entry: } ; CHECK-LABEL: test_sllq_1: -; CHECK: psllq $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psllq $0, %xmm0 +; CHECK: ret define <2 x i64> @test_sllq_2(<2 x i64> %InVec) { entry: @@ -101,8 +101,8 @@ entry: } ; CHECK-LABEL: test_sraw_1: -; CHECK: psraw $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psraw $0, %xmm0 +; CHECK: ret define <8 x i16> @test_sraw_2(<8 x i16> %InVec) { entry: @@ -131,8 +131,8 @@ entry: } ; CHECK-LABEL: test_srad_1: -; CHECK: psrad $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psrad $0, %xmm0 +; CHECK: ret define <4 x i32> @test_srad_2(<4 x i32> %InVec) { entry: @@ -163,8 +163,8 @@ entry: } ; CHECK-LABEL: test_srlw_1: -; CHECK: psrlw $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psrlw $0, %xmm0 +; CHECK: ret define <8 x i16> @test_srlw_2(<8 x i16> %InVec) { entry: @@ -193,8 +193,8 @@ entry: } ; CHECK-LABEL: test_srld_1: -; CHECK: psrld $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psrld $0, %xmm0 +; CHECK: ret define <4 x i32> @test_srld_2(<4 x i32> %InVec) { entry: @@ -223,8 +223,8 @@ entry: } ; CHECK-LABEL: test_srlq_1: -; CHECK: psrlq $0, %xmm0 -; CHECK-NEXT: ret +; CHECK-NOT: psrlq $0, %xmm0 +; CHECK: ret define <2 x i64> @test_srlq_2(<2 x i64> %InVec) { entry: @@ -245,3 +245,123 @@ entry: ; CHECK-LABEL: test_srlq_3: ; CHECK: psrlq $63, %xmm0 ; CHECK-NEXT: ret + + +; CHECK-LABEL: sra_sra_v4i32: +; CHECK: psrad $6, %xmm0 +; CHECK-NEXT: retq +define <4 x i32> @sra_sra_v4i32(<4 x i32> %x) nounwind { + %sra0 = ashr <4 x i32> %x, + %sra1 = ashr <4 x i32> %sra0, + ret <4 x i32> %sra1 +} + +; CHECK-LABEL: @srl_srl_v4i32 +; CHECK: psrld $6, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @srl_srl_v4i32(<4 x i32> %x) nounwind { + %srl0 = lshr <4 x i32> %x, + %srl1 = lshr <4 x i32> %srl0, + ret <4 x i32> %srl1 +} + +; CHECK-LABEL: @srl_shl_v4i32 +; CHECK: andps +; CHECK-NEXT: retq +define <4 x i32> @srl_shl_v4i32(<4 x i32> %x) nounwind { + %srl0 = shl <4 x i32> %x, + %srl1 = lshr <4 x i32> %srl0, + ret <4 x i32> %srl1 +} + +; CHECK-LABEL: @srl_sra_31_v4i32 +; CHECK: psrld $31, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @srl_sra_31_v4i32(<4 x i32> %x, <4 x i32> %y) nounwind { + %sra = ashr <4 x i32> %x, %y + %srl1 = lshr <4 x i32> %sra, + ret <4 x i32> %srl1 +} + +; CHECK-LABEL: @shl_shl_v4i32 +; CHECK: pslld $6, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @shl_shl_v4i32(<4 x i32> %x) nounwind { + %shl0 = shl <4 x i32> %x, + %shl1 = shl <4 x i32> %shl0, + ret <4 x i32> %shl1 +} + +; CHECK-LABEL: @shl_sra_v4i32 +; CHECK: andps +; CHECK-NEXT: ret +define <4 x i32> @shl_sra_v4i32(<4 x i32> %x) nounwind { + %shl0 = ashr <4 x i32> %x, + %shl1 = shl <4 x i32> %shl0, + ret <4 x i32> %shl1 +} + +; CHECK-LABEL: @shl_srl_v4i32 +; CHECK: pslld $3, %xmm0 +; CHECK-NEXT: pand +; CHECK-NEXT: ret +define <4 x i32> @shl_srl_v4i32(<4 x i32> %x) nounwind { + %shl0 = lshr <4 x i32> %x, + %shl1 = shl <4 x i32> %shl0, + ret <4 x i32> %shl1 +} + +; CHECK-LABEL: @shl_zext_srl_v4i32 +; CHECK: andps +; CHECK-NEXT: ret +define <4 x i32> @shl_zext_srl_v4i32(<4 x i16> %x) nounwind { + %srl = lshr <4 x i16> %x, + %zext = zext <4 x i16> %srl to <4 x i32> + %shl = shl <4 x i32> %zext, + ret <4 x i32> %shl +} + +; CHECK: @sra_trunc_srl_v4i32 +; CHECK: psrad $19, %xmm0 +; CHECK-NEXT: retq +define <4 x i16> @sra_trunc_srl_v4i32(<4 x i32> %x) nounwind { + %srl = lshr <4 x i32> %x, + %trunc = trunc <4 x i32> %srl to <4 x i16> + %sra = ashr <4 x i16> %trunc, + ret <4 x i16> %sra +} + +; CHECK-LABEL: @shl_zext_shl_v4i32 +; CHECK: pand +; CHECK-NEXT: pslld $19, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @shl_zext_shl_v4i32(<4 x i16> %x) nounwind { + %shl0 = shl <4 x i16> %x, + %ext = zext <4 x i16> %shl0 to <4 x i32> + %shl1 = shl <4 x i32> %ext, + ret <4 x i32> %shl1 +} + +; CHECK-LABEL: @sra_v4i32 +; CHECK: psrad $3, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @sra_v4i32(<4 x i32> %x) nounwind { + %sra = ashr <4 x i32> %x, + ret <4 x i32> %sra +} + +; CHECK-LABEL: @srl_v4i32 +; CHECK: psrld $3, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @srl_v4i32(<4 x i32> %x) nounwind { + %sra = lshr <4 x i32> %x, + ret <4 x i32> %sra +} + +; CHECK-LABEL: @shl_v4i32 +; CHECK: pslld $3, %xmm0 +; CHECK-NEXT: ret +define <4 x i32> @shl_v4i32(<4 x i32> %x) nounwind { + %sra = shl <4 x i32> %x, + ret <4 x i32> %sra +}