diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index a2ae6095f50..8edd5a4df6c 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -806,6 +806,9 @@ X86TargetLowering::X86TargetLowering(X86TargetMachine &TM) setTargetDAGCombine(ISD::VECTOR_SHUFFLE); setTargetDAGCombine(ISD::BUILD_VECTOR); setTargetDAGCombine(ISD::SELECT); + setTargetDAGCombine(ISD::SHL); + setTargetDAGCombine(ISD::SRA); + setTargetDAGCombine(ISD::SRL); setTargetDAGCombine(ISD::STORE); computeRegisterProperties(); @@ -7654,6 +7657,93 @@ static SDValue PerformSELECTCombine(SDNode *N, SelectionDAG &DAG, return SDValue(); } +/// PerformShiftCombine - Transforms vector shift nodes to use vector shifts +/// when possible. +static SDValue PerformShiftCombine(SDNode* N, SelectionDAG &DAG, + const X86Subtarget *Subtarget) { + // On X86 with SSE2 support, we can transform this to a vector shift if + // all elements are shifted by the same amount. We can't do this in legalize + // because the a constant vector is typically transformed to a constant pool + // so we have no knowledge of the shift amount. + MVT VT = N->getValueType(0); + if (Subtarget->hasSSE2() && + (VT == MVT::v2i64 || VT == MVT::v4i32 || VT == MVT::v8i16)) { + SDValue ValOp = N->getOperand(0); + SDValue ShAmtOp = N->getOperand(1); + unsigned NumElts = VT.getVectorNumElements(); + + if (ShAmtOp.getOpcode() == ISD::BUILD_VECTOR) { + unsigned i = 0; + SDValue BaseShAmt; + for (; i != NumElts; ++i) { + SDValue Arg = ShAmtOp.getOperand(i); + if (Arg.getOpcode() == ISD::UNDEF) continue; + BaseShAmt = Arg; + break; + } + for (; i != NumElts; ++i) { + SDValue Arg = ShAmtOp.getOperand(i); + if (Arg.getOpcode() == ISD::UNDEF) continue; + if (Arg != BaseShAmt) { + return SDValue(); + } + } + + MVT EltVT = VT.getVectorElementType(); + if (EltVT.bitsGT(MVT::i32)) + BaseShAmt = DAG.getNode(ISD::TRUNCATE, MVT::i32, BaseShAmt); + else if (EltVT.bitsLT(MVT::i32)) + BaseShAmt = DAG.getNode(ISD::ANY_EXTEND, MVT::i32, BaseShAmt); + + // The shift amount is identical so we can do a vector shift. + switch (N->getOpcode()) { + default: + assert(0 && "Unknown shift opcode!"); + break; + case ISD::SHL: + if (VT == MVT::v2i64) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT, + DAG.getConstant(Intrinsic::x86_sse2_pslli_q, MVT::i32), + ValOp, BaseShAmt); + else if (VT == MVT::v4i32) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT, + DAG.getConstant(Intrinsic::x86_sse2_pslli_d, MVT::i32), + ValOp, BaseShAmt); + else if (VT == MVT::v8i16) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT, + DAG.getConstant(Intrinsic::x86_sse2_pslli_w, MVT::i32), + ValOp, BaseShAmt); + break; + case ISD::SRA: + if (VT == MVT::v4i32) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT, + DAG.getConstant(Intrinsic::x86_sse2_psrai_d, MVT::i32), + ValOp, BaseShAmt); + else if (VT == MVT::v8i16) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT, + DAG.getConstant(Intrinsic::x86_sse2_psrai_w, MVT::i32), + ValOp, BaseShAmt); + break; + case ISD::SRL: + if (VT == MVT::v2i64) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT, + DAG.getConstant(Intrinsic::x86_sse2_psrli_q, MVT::i32), + ValOp, BaseShAmt); + else if (VT == MVT::v4i32) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT, + DAG.getConstant(Intrinsic::x86_sse2_psrli_d, MVT::i32), + ValOp, BaseShAmt); + else if (VT == MVT::v8i16) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT, + DAG.getConstant(Intrinsic::x86_sse2_psrli_w, MVT::i32), + ValOp, BaseShAmt); + break; + } + } + } + return SDValue(); +} + /// PerformSTORECombine - Do target-specific dag combines on STORE nodes. static SDValue PerformSTORECombine(SDNode *N, SelectionDAG &DAG, const X86Subtarget *Subtarget) { @@ -7782,6 +7872,9 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::BUILD_VECTOR: return PerformBuildVectorCombine(N, DAG, Subtarget, *this); case ISD::SELECT: return PerformSELECTCombine(N, DAG, Subtarget); + case ISD::SHL: + case ISD::SRA: + case ISD::SRL: return PerformShiftCombine(N, DAG, Subtarget); case ISD::STORE: return PerformSTORECombine(N, DAG, Subtarget); case X86ISD::FXOR: case X86ISD::FOR: return PerformFORCombine(N, DAG);