[X86][SSE] Vectorized i64 uniform constant SRA shifts

This patch adds vectorization support for uniform constant i64 arithmetic shift right operators.

Differential Revision: http://reviews.llvm.org/D9645

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@241514 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Simon Pilgrim 2015-07-06 22:35:19 +00:00
parent 610992baca
commit 6970be03d1
7 changed files with 106 additions and 81 deletions

View File

@ -1032,6 +1032,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::SHL, MVT::v2i64, Custom);
setOperationAction(ISD::SHL, MVT::v4i32, Custom);
setOperationAction(ISD::SRA, MVT::v2i64, Custom);
setOperationAction(ISD::SRA, MVT::v4i32, Custom);
}
@ -1211,6 +1212,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::SHL, MVT::v4i64, Custom);
setOperationAction(ISD::SHL, MVT::v8i32, Custom);
setOperationAction(ISD::SRA, MVT::v4i64, Custom);
setOperationAction(ISD::SRA, MVT::v8i32, Custom);
// Custom lower several nodes for 256-bit types.
@ -16948,6 +16950,38 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
unsigned X86Opc = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI :
(Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI;
auto ArithmeticShiftRight64 = [&](uint64_t ShiftAmt) {
assert((VT == MVT::v2i64 || VT == MVT::v4i64) && "Unexpected SRA type");
MVT ExVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() * 2);
SDValue Ex = DAG.getBitcast(ExVT, R);
if (ShiftAmt >= 32) {
// Splat sign to upper i32 dst, and SRA upper i32 src to lower i32.
SDValue Upper =
getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, Ex, 31, DAG);
SDValue Lower = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, Ex,
ShiftAmt - 32, DAG);
if (VT == MVT::v2i64)
Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower, {5, 1, 7, 3});
if (VT == MVT::v4i64)
Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower,
{9, 1, 11, 3, 13, 5, 15, 7});
} else {
// SRA upper i32, SHL whole i64 and select lower i32.
SDValue Upper = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, Ex,
ShiftAmt, DAG);
SDValue Lower =
getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt, DAG);
Lower = DAG.getBitcast(ExVT, Lower);
if (VT == MVT::v2i64)
Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower, {4, 1, 6, 3});
if (VT == MVT::v4i64)
Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower,
{8, 1, 10, 3, 12, 5, 14, 7});
}
return DAG.getBitcast(VT, Ex);
};
// Optimize shl/srl/sra with constant shift amount.
if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) {
if (auto *ShiftConst = BVAmt->getConstantSplatNode()) {
@ -16956,6 +16990,11 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode()))
return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
// i64 SRA needs to be performed as partial shifts.
if ((VT == MVT::v2i64 || (Subtarget->hasInt256() && VT == MVT::v4i64)) &&
Op.getOpcode() == ISD::SRA)
return ArithmeticShiftRight64(ShiftAmt);
if (VT == MVT::v16i8 || (Subtarget->hasInt256() && VT == MVT::v32i8)) {
unsigned NumElts = VT.getVectorNumElements();
MVT ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
@ -17039,7 +17078,12 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
if (ShAmt != ShiftAmt)
return SDValue();
}
return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode()))
return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
if (Op.getOpcode() == ISD::SRA)
return ArithmeticShiftRight64(ShiftAmt);
}
return SDValue();
@ -17121,7 +17165,9 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
if (Vals[j] != Amt.getOperand(i + j))
return SDValue();
}
return DAG.getNode(X86OpcV, dl, VT, R, Op.getOperand(1));
if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode()))
return DAG.getNode(X86OpcV, dl, VT, R, Op.getOperand(1));
}
return SDValue();
}

View File

@ -117,6 +117,8 @@ unsigned X86TTIImpl::getArithmeticInstrCost(
static const CostTblEntry<MVT::SimpleValueType>
AVX2UniformConstCostTable[] = {
{ ISD::SRA, MVT::v4i64, 4 }, // 2 x psrad + shuffle.
{ ISD::SDIV, MVT::v16i16, 6 }, // vpmulhw sequence
{ ISD::UDIV, MVT::v16i16, 6 }, // vpmulhuw sequence
{ ISD::SDIV, MVT::v8i32, 15 }, // vpmuldq sequence
@ -211,6 +213,7 @@ unsigned X86TTIImpl::getArithmeticInstrCost(
{ ISD::SRA, MVT::v16i8, 4 }, // psrlw, pand, pxor, psubb.
{ ISD::SRA, MVT::v8i16, 1 }, // psraw.
{ ISD::SRA, MVT::v4i32, 1 }, // psrad.
{ ISD::SRA, MVT::v2i64, 4 }, // 2 x psrad + shuffle.
{ ISD::SDIV, MVT::v8i16, 6 }, // pmulhw sequence
{ ISD::UDIV, MVT::v8i16, 6 }, // pmulhuw sequence

View File

@ -247,9 +247,9 @@ entry:
define %shifttypec @shift2i16const(%shifttypec %a, %shifttypec %b) {
entry:
; SSE2: shift2i16const
; SSE2: cost of 20 {{.*}} ashr
; SSE2: cost of 4 {{.*}} ashr
; SSE2-CODEGEN: shift2i16const
; SSE2-CODEGEN: sarq $
; SSE2-CODEGEN: psrad $3
%0 = ashr %shifttypec %a , <i16 3, i16 3>
ret %shifttypec %0
@ -320,9 +320,9 @@ entry:
define %shifttypec2i32 @shift2i32c(%shifttypec2i32 %a, %shifttypec2i32 %b) {
entry:
; SSE2: shift2i32c
; SSE2: cost of 20 {{.*}} ashr
; SSE2: cost of 4 {{.*}} ashr
; SSE2-CODEGEN: shift2i32c
; SSE2-CODEGEN: sarq $3
; SSE2-CODEGEN: psrad $3
%0 = ashr %shifttypec2i32 %a , <i32 3, i32 3>
ret %shifttypec2i32 %0
@ -391,9 +391,9 @@ entry:
define %shifttypec2i64 @shift2i64c(%shifttypec2i64 %a, %shifttypec2i64 %b) {
entry:
; SSE2: shift2i64c
; SSE2: cost of 20 {{.*}} ashr
; SSE2: cost of 4 {{.*}} ashr
; SSE2-CODEGEN: shift2i64c
; SSE2-CODEGEN: sarq $3
; SSE2-CODEGEN: psrad $3
%0 = ashr %shifttypec2i64 %a , <i64 3, i64 3>
ret %shifttypec2i64 %0
@ -403,9 +403,9 @@ entry:
define %shifttypec4i64 @shift4i64c(%shifttypec4i64 %a, %shifttypec4i64 %b) {
entry:
; SSE2: shift4i64c
; SSE2: cost of 40 {{.*}} ashr
; SSE2: cost of 8 {{.*}} ashr
; SSE2-CODEGEN: shift4i64c
; SSE2-CODEGEN: sarq $3
; SSE2-CODEGEN: psrad $3
%0 = ashr %shifttypec4i64 %a , <i64 3, i64 3, i64 3, i64 3>
ret %shifttypec4i64 %0
@ -415,9 +415,9 @@ entry:
define %shifttypec8i64 @shift8i64c(%shifttypec8i64 %a, %shifttypec8i64 %b) {
entry:
; SSE2: shift8i64c
; SSE2: cost of 80 {{.*}} ashr
; SSE2: cost of 16 {{.*}} ashr
; SSE2-CODEGEN: shift8i64c
; SSE2-CODEGEN: sarq $3
; SSE2-CODEGEN: psrad $3
%0 = ashr %shifttypec8i64 %a , <i64 3, i64 3, i64 3, i64 3,
i64 3, i64 3, i64 3, i64 3>
@ -428,9 +428,9 @@ entry:
define %shifttypec16i64 @shift16i64c(%shifttypec16i64 %a, %shifttypec16i64 %b) {
entry:
; SSE2: shift16i64c
; SSE2: cost of 160 {{.*}} ashr
; SSE2: cost of 32 {{.*}} ashr
; SSE2-CODEGEN: shift16i64c
; SSE2-CODEGEN: sarq $3
; SSE2-CODEGEN: psrad $3
%0 = ashr %shifttypec16i64 %a , <i64 3, i64 3, i64 3, i64 3,
i64 3, i64 3, i64 3, i64 3,
@ -443,9 +443,9 @@ entry:
define %shifttypec32i64 @shift32i64c(%shifttypec32i64 %a, %shifttypec32i64 %b) {
entry:
; SSE2: shift32i64c
; SSE2: cost of 320 {{.*}} ashr
; SSE2: cost of 64 {{.*}} ashr
; SSE2-CODEGEN: shift32i64c
; SSE2-CODEGEN: sarq $3
; SSE2-CODEGEN: psrad $3
%0 = ashr %shifttypec32i64 %a ,<i64 3, i64 3, i64 3, i64 3,
i64 3, i64 3, i64 3, i64 3,
@ -462,9 +462,9 @@ entry:
define %shifttypec2i8 @shift2i8c(%shifttypec2i8 %a, %shifttypec2i8 %b) {
entry:
; SSE2: shift2i8c
; SSE2: cost of 20 {{.*}} ashr
; SSE2: cost of 4 {{.*}} ashr
; SSE2-CODEGEN: shift2i8c
; SSE2-CODEGEN: sarq $3
; SSE2-CODEGEN: psrad $3
%0 = ashr %shifttypec2i8 %a , <i8 3, i8 3>
ret %shifttypec2i8 %0

View File

@ -954,38 +954,35 @@ define <16 x i8> @constant_shift_v16i8(<16 x i8> %a) {
define <2 x i64> @splatconstant_shift_v2i64(<2 x i64> %a) {
; SSE2-LABEL: splatconstant_shift_v2i64:
; SSE2: # BB#0:
; SSE2-NEXT: movd %xmm0, %rax
; SSE2-NEXT: sarq $7, %rax
; SSE2-NEXT: movd %rax, %xmm1
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[2,3,0,1]
; SSE2-NEXT: movd %xmm0, %rax
; SSE2-NEXT: sarq $7, %rax
; SSE2-NEXT: movd %rax, %xmm0
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
; SSE2-NEXT: movdqa %xmm1, %xmm0
; SSE2-NEXT: movdqa %xmm0, %xmm1
; SSE2-NEXT: psrad $7, %xmm1
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[1,3,2,3]
; SSE2-NEXT: psrlq $7, %xmm0
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
; SSE2-NEXT: retq
;
; SSE41-LABEL: splatconstant_shift_v2i64:
; SSE41: # BB#0:
; SSE41-NEXT: pextrq $1, %xmm0, %rax
; SSE41-NEXT: sarq $7, %rax
; SSE41-NEXT: movd %rax, %xmm1
; SSE41-NEXT: movd %xmm0, %rax
; SSE41-NEXT: sarq $7, %rax
; SSE41-NEXT: movd %rax, %xmm0
; SSE41-NEXT: punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
; SSE41-NEXT: movdqa %xmm0, %xmm1
; SSE41-NEXT: psrad $7, %xmm1
; SSE41-NEXT: psrlq $7, %xmm0
; SSE41-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
; SSE41-NEXT: retq
;
; AVX-LABEL: splatconstant_shift_v2i64:
; AVX: # BB#0:
; AVX-NEXT: vpextrq $1, %xmm0, %rax
; AVX-NEXT: sarq $7, %rax
; AVX-NEXT: vmovq %rax, %xmm1
; AVX-NEXT: vmovq %xmm0, %rax
; AVX-NEXT: sarq $7, %rax
; AVX-NEXT: vmovq %rax, %xmm0
; AVX-NEXT: vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
; AVX-NEXT: retq
; AVX1-LABEL: splatconstant_shift_v2i64:
; AVX1: # BB#0:
; AVX1-NEXT: vpsrad $7, %xmm0, %xmm1
; AVX1-NEXT: vpsrlq $7, %xmm0, %xmm0
; AVX1-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
; AVX1-NEXT: retq
;
; AVX2-LABEL: splatconstant_shift_v2i64:
; AVX2: # BB#0:
; AVX2-NEXT: vpsrad $7, %xmm0, %xmm1
; AVX2-NEXT: vpsrlq $7, %xmm0, %xmm0
; AVX2-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
; AVX2-NEXT: retq
%shift = ashr <2 x i64> %a, <i64 7, i64 7>
ret <2 x i64> %shift
}

View File

@ -663,41 +663,20 @@ define <4 x i64> @splatconstant_shift_v4i64(<4 x i64> %a) {
; AVX1-LABEL: splatconstant_shift_v4i64:
; AVX1: # BB#0:
; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm1
; AVX1-NEXT: vpextrq $1, %xmm1, %rax
; AVX1-NEXT: sarq $7, %rax
; AVX1-NEXT: vmovq %rax, %xmm2
; AVX1-NEXT: vmovq %xmm1, %rax
; AVX1-NEXT: sarq $7, %rax
; AVX1-NEXT: vmovq %rax, %xmm1
; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm2[0]
; AVX1-NEXT: vpextrq $1, %xmm0, %rax
; AVX1-NEXT: sarq $7, %rax
; AVX1-NEXT: vmovq %rax, %xmm2
; AVX1-NEXT: vmovq %xmm0, %rax
; AVX1-NEXT: sarq $7, %rax
; AVX1-NEXT: vmovq %rax, %xmm0
; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
; AVX1-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0
; AVX1-NEXT: vpsrad $7, %xmm1, %xmm2
; AVX1-NEXT: vpsrlq $7, %xmm1, %xmm1
; AVX1-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0,1],xmm2[2,3],xmm1[4,5],xmm2[6,7]
; AVX1-NEXT: vpsrad $7, %xmm0, %xmm2
; AVX1-NEXT: vpsrlq $7, %xmm0, %xmm0
; AVX1-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3],xmm0[4,5],xmm2[6,7]
; AVX1-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0
; AVX1-NEXT: retq
;
; AVX2-LABEL: splatconstant_shift_v4i64:
; AVX2: # BB#0:
; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm1
; AVX2-NEXT: vpextrq $1, %xmm1, %rax
; AVX2-NEXT: sarq $7, %rax
; AVX2-NEXT: vmovq %rax, %xmm2
; AVX2-NEXT: vmovq %xmm1, %rax
; AVX2-NEXT: sarq $7, %rax
; AVX2-NEXT: vmovq %rax, %xmm1
; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm2[0]
; AVX2-NEXT: vpextrq $1, %xmm0, %rax
; AVX2-NEXT: sarq $7, %rax
; AVX2-NEXT: vmovq %rax, %xmm2
; AVX2-NEXT: vmovq %xmm0, %rax
; AVX2-NEXT: sarq $7, %rax
; AVX2-NEXT: vmovq %rax, %xmm0
; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
; AVX2-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0
; AVX2-NEXT: vpsrad $7, %ymm0, %ymm1
; AVX2-NEXT: vpsrlq $7, %ymm0, %ymm0
; AVX2-NEXT: vpblendd {{.*#+}} ymm0 = ymm0[0],ymm1[1],ymm0[2],ymm1[3],ymm0[4],ymm1[5],ymm0[6],ymm1[7]
; AVX2-NEXT: retq
%shift = ashr <4 x i64> %a, <i64 7, i64 7, i64 7, i64 7>
ret <4 x i64> %shift

View File

@ -3,13 +3,12 @@
; test vector shifts converted to proper SSE2 vector shifts when the shift
; amounts are the same.
; Note that x86 does have ashr
; Note that x86 does have ashr
; shift1a can't use a packed shift
define void @shift1a(<2 x i64> %val, <2 x i64>* %dst) nounwind {
entry:
; CHECK-LABEL: shift1a:
; CHECK: sarl
; CHECK: psrad $31
%ashr = ashr <2 x i64> %val, < i64 32, i64 32 >
store <2 x i64> %ashr, <2 x i64>* %dst
ret void

View File

@ -1,8 +1,9 @@
; RUN: llc < %s -march=x86 -mattr=+sse4.2 | FileCheck %s
; CHECK: {{cwtl|movswl}}
; CHECK: {{cwtl|movswl}}
; CHECK: psllq $48, %xmm0
; CHECK: psrad $16, %xmm0
; CHECK: pshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
; sign extension v2i32 to v2i16
; sign extension v2i16 to v2i32
define void @convert(<2 x i32>* %dst.addr, <2 x i16> %src) nounwind {
entry: