diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 4c65a6ccd7e..6b4251df61a 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -16720,18 +16720,28 @@ static SDValue getTargetVShiftNode(unsigned Opc, SDLoc dl, MVT VT, case X86ISD::VSRAI: Opc = X86ISD::VSRA; break; } - // Need to build a vector containing shift amount. - // SSE/AVX packed shifts only use the lower 64-bit of the shift count. - SmallVector ShOps; - ShOps.push_back(ShAmt); - if (SVT == MVT::i32) { - ShOps.push_back(DAG.getConstant(0, SVT)); + const X86Subtarget &Subtarget = + DAG.getTarget().getSubtarget(); + if (Subtarget.hasSSE41() && ShAmt.getOpcode() == ISD::ZERO_EXTEND && + ShAmt.getOperand(0).getSimpleValueType() == MVT::i16) { + // Let the shuffle legalizer expand this shift amount node. + SDValue Op0 = ShAmt.getOperand(0); + Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(Op0), MVT::v8i16, Op0); + ShAmt = getShuffleVectorZeroOrUndef(Op0, 0, true, &Subtarget, DAG); + } else { + // Need to build a vector containing shift amount. + // SSE/AVX packed shifts only use the lower 64-bit of the shift count. + SmallVector ShOps; + ShOps.push_back(ShAmt); + if (SVT == MVT::i32) { + ShOps.push_back(DAG.getConstant(0, SVT)); + ShOps.push_back(DAG.getUNDEF(SVT)); + } ShOps.push_back(DAG.getUNDEF(SVT)); - } - ShOps.push_back(DAG.getUNDEF(SVT)); - MVT BVT = SVT == MVT::i32 ? MVT::v4i32 : MVT::v2i64; - ShAmt = DAG.getNode(ISD::BUILD_VECTOR, dl, BVT, ShOps); + MVT BVT = SVT == MVT::i32 ? MVT::v4i32 : MVT::v2i64; + ShAmt = DAG.getNode(ISD::BUILD_VECTOR, dl, BVT, ShOps); + } // The return type has to be a 128-bit type with the same element // type as the input type. diff --git a/test/CodeGen/X86/lower-vec-shift-2.ll b/test/CodeGen/X86/lower-vec-shift-2.ll index 90505b6dd8f..770775d3242 100644 --- a/test/CodeGen/X86/lower-vec-shift-2.ll +++ b/test/CodeGen/X86/lower-vec-shift-2.ll @@ -11,9 +11,8 @@ define <8 x i16> @test1(<8 x i16> %A, <8 x i16> %B) { ; SSE2-NEXT: retq ; AVX-LABEL: test1: ; AVX: # BB#0 -; AVX-NEXT: vmovd %xmm1, %eax -; AVX-NEXT: movzwl %ax, %eax -; AVX-NEXT: vmovd %eax, %xmm1 +; AVX-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0],xmm2[1,2,3,4,5,6,7] ; AVX-NEXT: vpsllw %xmm1, %xmm0, %xmm0 ; AVX-NEXT: retq entry: @@ -66,9 +65,8 @@ define <8 x i16> @test4(<8 x i16> %A, <8 x i16> %B) { ; SSE2-NEXT: retq ; AVX-LABEL: test4: ; AVX: # BB#0 -; AVX-NEXT: vmovd %xmm1, %eax -; AVX-NEXT: movzwl %ax, %eax -; AVX-NEXT: vmovd %eax, %xmm1 +; AVX-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0],xmm2[1,2,3,4,5,6,7] ; AVX-NEXT: vpsrlw %xmm1, %xmm0, %xmm0 ; AVX-NEXT: retq entry: @@ -121,9 +119,8 @@ define <8 x i16> @test7(<8 x i16> %A, <8 x i16> %B) { ; SSE2-NEXT: retq ; AVX-LABEL: test7: ; AVX: # BB#0 -; AVX-NEXT: vmovd %xmm1, %eax -; AVX-NEXT: movzwl %ax, %eax -; AVX-NEXT: vmovd %eax, %xmm1 +; AVX-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0],xmm2[1,2,3,4,5,6,7] ; AVX-NEXT: vpsraw %xmm1, %xmm0, %xmm0 ; AVX-NEXT: retq entry: