diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index dea4a4616f3..e8918f4c34d 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -10952,6 +10952,26 @@ static SDValue LowerVACOPY(SDValue Op, const X86Subtarget *Subtarget, MachinePointerInfo(DstSV), MachinePointerInfo(SrcSV)); } +// getTargetVShiftByConstNode - Handle vector element shifts where the shift +// amount is a constant. Takes immediate version of shift as input. +static SDValue getTargetVShiftByConstNode(unsigned Opc, SDLoc dl, EVT VT, + SDValue SrcOp, uint64_t ShiftAmt, + SelectionDAG &DAG) { + + // Check for ShiftAmt >= element width + if (ShiftAmt >= VT.getVectorElementType().getSizeInBits()) { + if (Opc == X86ISD::VSRAI) + ShiftAmt = VT.getVectorElementType().getSizeInBits() - 1; + else + return DAG.getConstant(0, VT); + } + + assert((Opc == X86ISD::VSHLI || Opc == X86ISD::VSRLI || Opc == X86ISD::VSRAI) + && "Unknown target vector shift-by-constant node"); + + return DAG.getNode(Opc, dl, VT, SrcOp, DAG.getConstant(ShiftAmt, MVT::i8)); +} + // getTargetVShiftNode - Handle vector element shifts where the shift amount // may or may not be a constant. Takes immediate version of shift as input. static SDValue getTargetVShiftNode(unsigned Opc, SDLoc dl, EVT VT, @@ -10959,18 +10979,10 @@ static SDValue getTargetVShiftNode(unsigned Opc, SDLoc dl, EVT VT, SelectionDAG &DAG) { assert(ShAmt.getValueType() == MVT::i32 && "ShAmt is not i32"); - if (isa(ShAmt)) { - // Constant may be a TargetConstant. Use a regular constant. - uint32_t ShiftAmt = cast(ShAmt)->getZExtValue(); - switch (Opc) { - default: llvm_unreachable("Unknown target vector shift node"); - case X86ISD::VSHLI: - case X86ISD::VSRLI: - case X86ISD::VSRAI: - return DAG.getNode(Opc, dl, VT, SrcOp, - DAG.getConstant(ShiftAmt, MVT::i32)); - } - } + // Catch shift-by-constant. + if (ConstantSDNode *CShAmt = dyn_cast(ShAmt)) + return getTargetVShiftByConstNode(Opc, dl, VT, SrcOp, + CShAmt->getZExtValue(), DAG); // Change opcode to non-immediate version switch (Opc) { @@ -12416,10 +12428,8 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget *Subtarget, // AhiBlo = psllqi(AhiBlo, 32); // return AloBlo + AloBhi + AhiBlo; - SDValue ShAmt = DAG.getConstant(32, MVT::i32); - - SDValue Ahi = DAG.getNode(X86ISD::VSRLI, dl, VT, A, ShAmt); - SDValue Bhi = DAG.getNode(X86ISD::VSRLI, dl, VT, B, ShAmt); + SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG); + SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG); // Bit cast to 32-bit vectors for MULUDQ EVT MulVT = (VT == MVT::v2i64) ? MVT::v4i32 : @@ -12433,8 +12443,8 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget *Subtarget, SDValue AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, Bhi); SDValue AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, B); - AloBhi = DAG.getNode(X86ISD::VSHLI, dl, VT, AloBhi, ShAmt); - AhiBlo = DAG.getNode(X86ISD::VSHLI, dl, VT, AhiBlo, ShAmt); + AloBhi = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AloBhi, 32, DAG); + AhiBlo = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AhiBlo, 32, DAG); SDValue Res = DAG.getNode(ISD::ADD, dl, VT, AloBlo, AloBhi); return DAG.getNode(ISD::ADD, dl, VT, Res, AhiBlo); @@ -12462,7 +12472,7 @@ static SDValue LowerSDIV(SDValue Op, SelectionDAG &DAG) { if ((SplatValue != 0) && (SplatValue.isPowerOf2() || (-SplatValue).isPowerOf2())) { - unsigned lg2 = SplatValue.countTrailingZeros(); + unsigned Lg2 = SplatValue.countTrailingZeros(); // Splat the sign bit. SmallVector Sz(NumElts, DAG.getConstant(EltTy.getSizeInBits() - 1, @@ -12472,13 +12482,13 @@ static SDValue LowerSDIV(SDValue Op, SelectionDAG &DAG) { NumElts)); // Add (N0 < 0) ? abs2 - 1 : 0; SmallVector Amt(NumElts, - DAG.getConstant(EltTy.getSizeInBits() - lg2, + DAG.getConstant(EltTy.getSizeInBits() - Lg2, EltTy)); SDValue SRL = DAG.getNode(ISD::SRL, dl, VT, SGN, DAG.getNode(ISD::BUILD_VECTOR, dl, VT, &Amt[0], NumElts)); SDValue ADD = DAG.getNode(ISD::ADD, dl, VT, N0, SRL); - SmallVector Lg2Amt(NumElts, DAG.getConstant(lg2, EltTy)); + SmallVector Lg2Amt(NumElts, DAG.getConstant(Lg2, EltTy)); SDValue SRA = DAG.getNode(ISD::SRA, dl, VT, ADD, DAG.getNode(ISD::BUILD_VECTOR, dl, VT, &Lg2Amt[0], NumElts)); @@ -12514,21 +12524,22 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, (Subtarget->hasAVX512() && (VT == MVT::v8i64 || VT == MVT::v16i32))) { if (Op.getOpcode() == ISD::SHL) - return DAG.getNode(X86ISD::VSHLI, dl, VT, R, - DAG.getConstant(ShiftAmt, MVT::i32)); + return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt, + DAG); if (Op.getOpcode() == ISD::SRL) - return DAG.getNode(X86ISD::VSRLI, dl, VT, R, - DAG.getConstant(ShiftAmt, MVT::i32)); + return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt, + DAG); if (Op.getOpcode() == ISD::SRA && VT != MVT::v2i64 && VT != MVT::v4i64) - return DAG.getNode(X86ISD::VSRAI, dl, VT, R, - DAG.getConstant(ShiftAmt, MVT::i32)); + return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt, + DAG); } if (VT == MVT::v16i8) { if (Op.getOpcode() == ISD::SHL) { // Make a large shift. - SDValue SHL = DAG.getNode(X86ISD::VSHLI, dl, MVT::v8i16, R, - DAG.getConstant(ShiftAmt, MVT::i32)); + SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, + MVT::v8i16, R, ShiftAmt, + DAG); SHL = DAG.getNode(ISD::BITCAST, dl, VT, SHL); // Zero out the rightmost bits. SmallVector V(16, @@ -12539,8 +12550,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, } if (Op.getOpcode() == ISD::SRL) { // Make a large shift. - SDValue SRL = DAG.getNode(X86ISD::VSRLI, dl, MVT::v8i16, R, - DAG.getConstant(ShiftAmt, MVT::i32)); + SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, + MVT::v8i16, R, ShiftAmt, + DAG); SRL = DAG.getNode(ISD::BITCAST, dl, VT, SRL); // Zero out the leftmost bits. SmallVector V(16, @@ -12571,8 +12583,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, if (Subtarget->hasInt256() && VT == MVT::v32i8) { if (Op.getOpcode() == ISD::SHL) { // Make a large shift. - SDValue SHL = DAG.getNode(X86ISD::VSHLI, dl, MVT::v16i16, R, - DAG.getConstant(ShiftAmt, MVT::i32)); + SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, + MVT::v16i16, R, ShiftAmt, + DAG); SHL = DAG.getNode(ISD::BITCAST, dl, VT, SHL); // Zero out the rightmost bits. SmallVector V(32, @@ -12583,8 +12596,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, } if (Op.getOpcode() == ISD::SRL) { // Make a large shift. - SDValue SRL = DAG.getNode(X86ISD::VSRLI, dl, MVT::v16i16, R, - DAG.getConstant(ShiftAmt, MVT::i32)); + SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, + MVT::v16i16, R, ShiftAmt, + DAG); SRL = DAG.getNode(ISD::BITCAST, dl, VT, SRL); // Zero out the leftmost bits. SmallVector V(32, @@ -12649,14 +12663,14 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, default: llvm_unreachable("Unknown shift opcode!"); case ISD::SHL: - return DAG.getNode(X86ISD::VSHLI, dl, VT, R, - DAG.getConstant(ShiftAmt, MVT::i32)); + return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt, + DAG); case ISD::SRL: - return DAG.getNode(X86ISD::VSRLI, dl, VT, R, - DAG.getConstant(ShiftAmt, MVT::i32)); + return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt, + DAG); case ISD::SRA: - return DAG.getNode(X86ISD::VSRAI, dl, VT, R, - DAG.getConstant(ShiftAmt, MVT::i32)); + return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt, + DAG); } } @@ -12869,8 +12883,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget, // r = VSELECT(r, psllw(r & (char16)15, 4), a); SDValue M = DAG.getNode(ISD::AND, dl, VT, R, CM1); - M = getTargetVShiftNode(X86ISD::VSHLI, dl, MVT::v8i16, M, - DAG.getConstant(4, MVT::i32), DAG); + M = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, MVT::v8i16, M, 4, DAG); M = DAG.getNode(ISD::BITCAST, dl, VT, M); R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel, M, R); @@ -12881,8 +12894,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget, // r = VSELECT(r, psllw(r & (char16)63, 2), a); M = DAG.getNode(ISD::AND, dl, VT, R, CM2); - M = getTargetVShiftNode(X86ISD::VSHLI, dl, MVT::v8i16, M, - DAG.getConstant(2, MVT::i32), DAG); + M = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, MVT::v8i16, M, 2, DAG); M = DAG.getNode(ISD::BITCAST, dl, VT, M); R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel, M, R); @@ -13025,7 +13037,6 @@ SDValue X86TargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op, unsigned BitsDiff = VT.getScalarType().getSizeInBits() - ExtraVT.getScalarType().getSizeInBits(); - SDValue ShAmt = DAG.getConstant(BitsDiff, MVT::i32); switch (VT.getSimpleVT().SimpleTy) { default: return SDValue(); @@ -13075,8 +13086,10 @@ SDValue X86TargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op, } // If the above didn't work, then just use Shift-Left + Shift-Right. - Tmp1 = getTargetVShiftNode(X86ISD::VSHLI, dl, VT, Op0, ShAmt, DAG); - return getTargetVShiftNode(X86ISD::VSRAI, dl, VT, Tmp1, ShAmt, DAG); + Tmp1 = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Op0, BitsDiff, + DAG); + return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, Tmp1, BitsDiff, + DAG); } } } diff --git a/lib/Target/X86/X86InstrAVX512.td b/lib/Target/X86/X86InstrAVX512.td index 3fd725cd95f..05e346dec5a 100644 --- a/lib/Target/X86/X86InstrAVX512.td +++ b/lib/Target/X86/X86InstrAVX512.td @@ -1845,22 +1845,22 @@ multiclass avx512_shift_rmi opc, Format ImmFormR, Format ImmFormM, ValueType vt, X86MemOperand x86memop, PatFrag mem_frag, RegisterClass KRC> { def ri : AVX512BIi8, EVEX_4V; def rik : AVX512BIi8, EVEX_4V, EVEX_K; def mi: AVX512BIi8, EVEX_4V; + (i8 imm:$src2)))], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V; def mik: AVX512BIi8, EVEX_4V, EVEX_K; diff --git a/lib/Target/X86/X86InstrSSE.td b/lib/Target/X86/X86InstrSSE.td index c2f319704d2..f1bb9f84a72 100644 --- a/lib/Target/X86/X86InstrSSE.td +++ b/lib/Target/X86/X86InstrSSE.td @@ -3744,11 +3744,11 @@ multiclass PDI_binop_rmi opc, bits<8> opc2, Format ImmForm, (bc_frag (memopv2i64 addr:$src2)))))], itins.rm>, Sched<[WriteVecShiftLd, ReadAfterLd]>; def ri : PDIi8, + [(set RC:$dst, (DstVT (OpNode2 RC:$src1, (i8 imm:$src2))))], itins.ri>, Sched<[WriteVecShift]>; } @@ -5064,12 +5064,12 @@ multiclass SS3I_unop_rm_int_y opc, string OpcodeStr, // Helper fragments to match sext vXi1 to vXiY. def v16i1sextv16i8 : PatLeaf<(v16i8 (X86pcmpgt (bc_v16i8 (v4i32 immAllZerosV)), VR128:$src))>; -def v8i1sextv8i16 : PatLeaf<(v8i16 (X86vsrai VR128:$src, (i32 15)))>; -def v4i1sextv4i32 : PatLeaf<(v4i32 (X86vsrai VR128:$src, (i32 31)))>; +def v8i1sextv8i16 : PatLeaf<(v8i16 (X86vsrai VR128:$src, (i8 15)))>; +def v4i1sextv4i32 : PatLeaf<(v4i32 (X86vsrai VR128:$src, (i8 31)))>; def v32i1sextv32i8 : PatLeaf<(v32i8 (X86pcmpgt (bc_v32i8 (v8i32 immAllZerosV)), VR256:$src))>; -def v16i1sextv16i16: PatLeaf<(v16i16 (X86vsrai VR256:$src, (i32 15)))>; -def v8i1sextv8i32 : PatLeaf<(v8i32 (X86vsrai VR256:$src, (i32 31)))>; +def v16i1sextv16i16: PatLeaf<(v16i16 (X86vsrai VR256:$src, (i8 15)))>; +def v8i1sextv8i32 : PatLeaf<(v8i32 (X86vsrai VR256:$src, (i8 31)))>; let Predicates = [HasAVX] in { defm VPABSB : SS3I_unop_rm_int<0x1C, "vpabsb", diff --git a/test/CodeGen/X86/avx2-vector-shifts.ll b/test/CodeGen/X86/avx2-vector-shifts.ll index a978d93fc55..5592e6c8a5f 100644 --- a/test/CodeGen/X86/avx2-vector-shifts.ll +++ b/test/CodeGen/X86/avx2-vector-shifts.ll @@ -121,7 +121,7 @@ entry: } ; CHECK-LABEL: test_sraw_3: -; CHECK: vpsraw $16, %ymm0, %ymm0 +; CHECK: vpsraw $15, %ymm0, %ymm0 ; CHECK: ret define <8 x i32> @test_srad_1(<8 x i32> %InVec) { @@ -151,7 +151,7 @@ entry: } ; CHECK-LABEL: test_srad_3: -; CHECK: vpsrad $32, %ymm0, %ymm0 +; CHECK: vpsrad $31, %ymm0, %ymm0 ; CHECK: ret ; SSE Logical Shift Right diff --git a/test/CodeGen/X86/sse2-vector-shifts.ll b/test/CodeGen/X86/sse2-vector-shifts.ll index e2d612567a1..462def980a9 100644 --- a/test/CodeGen/X86/sse2-vector-shifts.ll +++ b/test/CodeGen/X86/sse2-vector-shifts.ll @@ -121,7 +121,7 @@ entry: } ; CHECK-LABEL: test_sraw_3: -; CHECK: psraw $16, %xmm0 +; CHECK: psraw $15, %xmm0 ; CHECK-NEXT: ret define <4 x i32> @test_srad_1(<4 x i32> %InVec) { @@ -151,7 +151,7 @@ entry: } ; CHECK-LABEL: test_srad_3: -; CHECK: psrad $32, %xmm0 +; CHECK: psrad $31, %xmm0 ; CHECK-NEXT: ret ; SSE Logical Shift Right